Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ cov.xml
.DS_Store

uv.lock
.codex

# libraries
**/neuropixels_library_generated
Expand Down
75 changes: 75 additions & 0 deletions src/probeinterface/probegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,81 @@ def get_contact_count(self) -> int:
n = sum(probe.get_contact_count() for probe in self.probes)
return n

def _build_contact_vector(self) -> np.ndarray:
"""
Return the channel-ordered dense view of the probegroup, computed fresh.

Private by convention: this method is intended for integration with SpikeInterface,
which needs a channel-ordered view for recording-facing queries. Fields and dtype
may evolve with consumer requirements, so user code should not depend on it directly.
For stable probegroup state, use the public `get_global_*` methods.

Invariants
----------
- Ordering: rows are sorted ascending by `device_channel_indices` using a stable
sort. Ties preserve per-probe insertion order.
- Row count: one row per *connected* contact (`device_channel_indices >= 0`).
The returned size is generally smaller than `self.get_contact_count()` when the
probegroup has unwired contacts. This matches SpikeInterface's pre-migration
`contact_vector` convention.
- Dtype: includes `probe_index`, `x`, `y`, and `z` if `ndim == 3`. Optional fields
`shank_ids` and `contact_sides` appear only when at least one probe in the group
defines them. Consumers must guard field access accordingly.
- Raises `ValueError` on empty probegroups and on probegroups with no wired
contacts.

This method builds a fresh array on every call. It is not cached. Consumers that
need to call it repeatedly in a hot loop should cache the result at the call site,
where the lifetime and invalidation story are local.
"""
if len(self.probes) == 0:
raise ValueError("Cannot build a contact_vector for an empty ProbeGroup")

has_shank_ids = any(probe.shank_ids is not None for probe in self.probes)
has_contact_sides = any(probe.contact_sides is not None for probe in self.probes)

dtype = [("probe_index", "int64"), ("x", "float64"), ("y", "float64")]
if self.ndim == 3:
dtype.append(("z", "float64"))
if has_shank_ids:
dtype.append(("shank_ids", "U64"))
if has_contact_sides:
dtype.append(("contact_sides", "U8"))

channel_index_parts = []
contact_vector_parts = []
for probe_index, probe in enumerate(self.probes):
device_channel_indices = probe.device_channel_indices
if device_channel_indices is None:
continue

device_channel_indices = np.asarray(device_channel_indices)
connected = device_channel_indices >= 0
if not np.any(connected):
continue

probe_vector = np.zeros(np.sum(connected), dtype=dtype)
probe_vector["probe_index"] = probe_index
probe_vector["x"] = probe.contact_positions[connected, 0]
probe_vector["y"] = probe.contact_positions[connected, 1]
if self.ndim == 3:
probe_vector["z"] = probe.contact_positions[connected, 2]
if has_shank_ids and probe.shank_ids is not None:
probe_vector["shank_ids"] = probe.shank_ids[connected]
if has_contact_sides and probe.contact_sides is not None:
probe_vector["contact_sides"] = probe.contact_sides[connected]

channel_index_parts.append(device_channel_indices[connected])
contact_vector_parts.append(probe_vector)

if len(contact_vector_parts) == 0:
raise ValueError("contact_vector requires at least one wired contact")

channel_indices = np.concatenate(channel_index_parts, axis=0)
contact_vector = np.concatenate(contact_vector_parts, axis=0)
order = np.argsort(channel_indices, kind="stable")
return contact_vector[order]

def to_numpy(self, complete: bool = False) -> np.ndarray:
"""
Export all probes into a numpy array.
Expand Down
79 changes: 79 additions & 0 deletions tests/test_probegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,85 @@ def test_set_contact_ids_rejects_wrong_size():
probe.set_contact_ids(["a", "b", "c"])


def test_contact_vector_orders_connected_contacts():
from probeinterface import Probe

probe0 = Probe(ndim=2, si_units="um")
probe0.set_contacts(
positions=np.array([[10.0, 0.0], [30.0, 0.0]]),
shapes="circle",
shape_params={"radius": 5},
shank_ids=["s0", "s1"],
contact_sides=["front", "back"],
)
probe0.set_device_channel_indices([2, -1])

probe1 = Probe(ndim=2, si_units="um")
probe1.set_contacts(
positions=np.array([[0.0, 0.0], [20.0, 0.0]]),
shapes="circle",
shape_params={"radius": 5},
shank_ids=["s0", "s0"],
contact_sides=["front", "front"],
)
probe1.set_device_channel_indices([0, 1])

probegroup = ProbeGroup()
probegroup.add_probe(probe0)
probegroup.add_probe(probe1)

arr = probegroup._build_contact_vector()

assert arr.dtype.names == ("probe_index", "x", "y", "shank_ids", "contact_sides")
assert arr.size == 3
assert np.array_equal(arr["probe_index"], np.array([1, 1, 0]))
assert np.array_equal(arr["x"], np.array([0.0, 20.0, 10.0]))
assert np.array_equal(np.column_stack((arr["x"], arr["y"])), np.array([[0.0, 0.0], [20.0, 0.0], [10.0, 0.0]]))


def test_contact_vector_reflects_current_probe_state():
probegroup = ProbeGroup()
probe = generate_dummy_probe()
probe.set_device_channel_indices(np.arange(probe.get_contact_count()))
probegroup.add_probe(probe)

dense_before = probegroup._build_contact_vector()
original_positions = np.column_stack((dense_before["x"], dense_before["y"])).copy()

probe.move([5.0, 0.0])

dense_after_move = probegroup._build_contact_vector()
assert dense_after_move is not dense_before
assert np.array_equal(
np.column_stack((dense_after_move["x"], dense_after_move["y"])),
original_positions + np.array([5.0, 0.0]),
)

probe.set_shank_ids(np.array(["a"] * probe.get_contact_count()))
dense_with_shanks = probegroup._build_contact_vector()
assert "shank_ids" in dense_with_shanks.dtype.names


def test_contact_vector_requires_wired_contacts():
probegroup = ProbeGroup()
probe = generate_dummy_probe()
probegroup.add_probe(probe)

with pytest.raises(ValueError, match="requires at least one wired contact"):
probegroup._build_contact_vector()


def test_contact_vector_supports_3d_positions():
probegroup = ProbeGroup()
probe = generate_dummy_probe().to_3d()
probe.set_device_channel_indices(np.arange(probe.get_contact_count()))
probegroup.add_probe(probe)

dense = probegroup._build_contact_vector()
assert dense.dtype.names[:4] == ("probe_index", "x", "y", "z")
assert np.column_stack((dense["x"], dense["y"], dense["z"])).shape[1] == 3


# ── get_global_contact_positions() tests ────────────────────────────────────


Expand Down
Loading