figpack-spike-sorting 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- figpack_spike_sorting/__init__.py +3 -0
- figpack_spike_sorting/figpack_spike_sorting.js +815 -0
- figpack_spike_sorting/spike_sorting_extension.py +24 -0
- figpack_spike_sorting/views/AutocorrelogramItem.py +32 -0
- figpack_spike_sorting/views/Autocorrelograms.py +120 -0
- figpack_spike_sorting/views/AverageWaveforms.py +147 -0
- figpack_spike_sorting/views/CrossCorrelogramItem.py +35 -0
- figpack_spike_sorting/views/CrossCorrelograms.py +134 -0
- figpack_spike_sorting/views/RasterPlot.py +286 -0
- figpack_spike_sorting/views/RasterPlotItem.py +28 -0
- figpack_spike_sorting/views/SpikeAmplitudes.py +366 -0
- figpack_spike_sorting/views/SpikeAmplitudesItem.py +38 -0
- figpack_spike_sorting/views/UnitLocations.py +78 -0
- figpack_spike_sorting/views/UnitMetricsGraph.py +129 -0
- figpack_spike_sorting/views/UnitSimilarityScore.py +40 -0
- figpack_spike_sorting/views/UnitsTable.py +83 -0
- figpack_spike_sorting/views/UnitsTableColumn.py +40 -0
- figpack_spike_sorting/views/UnitsTableRow.py +36 -0
- figpack_spike_sorting/views/__init__.py +43 -0
- figpack_spike_sorting-0.1.0.dist-info/METADATA +42 -0
- figpack_spike_sorting-0.1.0.dist-info/RECORD +23 -0
- figpack_spike_sorting-0.1.0.dist-info/WHEEL +5 -0
- figpack_spike_sorting-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import figpack
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def _load_javascript_code():
|
|
5
|
+
"""Load the JavaScript code from the built figpack_spike_sorting.js file"""
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
js_path = os.path.join(os.path.dirname(__file__), "figpack_spike_sorting.js")
|
|
9
|
+
try:
|
|
10
|
+
with open(js_path, "r", encoding="utf-8") as f:
|
|
11
|
+
return f.read()
|
|
12
|
+
except FileNotFoundError:
|
|
13
|
+
raise FileNotFoundError(
|
|
14
|
+
f"Could not find figpack_spike_sorting.js at {js_path}. "
|
|
15
|
+
"Make sure to run 'npm run build' to generate the JavaScript bundle."
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Create and register the figpack_spike_sorting extension
|
|
20
|
+
spike_sorting_extension = figpack.FigpackExtension(
|
|
21
|
+
name="figpack-spike-sorting",
|
|
22
|
+
javascript_code=_load_javascript_code(),
|
|
23
|
+
version="1.0.0",
|
|
24
|
+
)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AutocorrelogramItem for spike sorting views
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AutocorrelogramItem:
|
|
11
|
+
"""
|
|
12
|
+
Represents a single autocorrelogram for a unit
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
*,
|
|
18
|
+
unit_id: Union[str, int],
|
|
19
|
+
bin_edges_sec: np.ndarray,
|
|
20
|
+
bin_counts: np.ndarray,
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Initialize an AutocorrelogramItem
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
unit_id: Identifier for the unit
|
|
27
|
+
bin_edges_sec: Array of bin edges in seconds
|
|
28
|
+
bin_counts: Array of bin counts
|
|
29
|
+
"""
|
|
30
|
+
self.unit_id = unit_id
|
|
31
|
+
self.bin_edges_sec = np.array(bin_edges_sec, dtype=np.float32)
|
|
32
|
+
self.bin_counts = np.array(bin_counts, dtype=np.int32)
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Autocorrelograms view for figpack - displays multiple autocorrelograms
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
import figpack
|
|
10
|
+
from .AutocorrelogramItem import AutocorrelogramItem
|
|
11
|
+
from ..spike_sorting_extension import spike_sorting_extension
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Autocorrelograms(figpack.ExtensionView):
|
|
15
|
+
"""
|
|
16
|
+
A view that displays multiple autocorrelograms for spike sorting analysis
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
*,
|
|
22
|
+
autocorrelograms: List[AutocorrelogramItem],
|
|
23
|
+
):
|
|
24
|
+
"""
|
|
25
|
+
Initialize an Autocorrelograms view
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
autocorrelograms: List of AutocorrelogramItem objects
|
|
29
|
+
"""
|
|
30
|
+
super().__init__(
|
|
31
|
+
extension=spike_sorting_extension,
|
|
32
|
+
view_type="spike_sorting.Autocorrelograms",
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
self.autocorrelograms = autocorrelograms
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def from_sorting(sorting):
|
|
39
|
+
import spikeinterface as si
|
|
40
|
+
import spikeinterface.widgets as sw
|
|
41
|
+
|
|
42
|
+
assert isinstance(sorting, si.BaseSorting), "Input must be a BaseSorting object"
|
|
43
|
+
W = sw.plot_autocorrelograms(sorting)
|
|
44
|
+
return Autocorrelograms.from_spikeinterface_widget(W)
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def from_spikeinterface_widget(W):
|
|
48
|
+
from spikeinterface.widgets.base import to_attr
|
|
49
|
+
from spikeinterface.widgets.utils_sortingview import make_serializable
|
|
50
|
+
|
|
51
|
+
from .AutocorrelogramItem import AutocorrelogramItem
|
|
52
|
+
|
|
53
|
+
data_plot = W.data_plot
|
|
54
|
+
|
|
55
|
+
dp = to_attr(data_plot)
|
|
56
|
+
|
|
57
|
+
unit_ids = make_serializable(dp.unit_ids)
|
|
58
|
+
|
|
59
|
+
ac_items = []
|
|
60
|
+
for i in range(len(unit_ids)):
|
|
61
|
+
for j in range(i, len(unit_ids)):
|
|
62
|
+
if i == j:
|
|
63
|
+
ac_items.append(
|
|
64
|
+
AutocorrelogramItem(
|
|
65
|
+
unit_id=unit_ids[i],
|
|
66
|
+
bin_edges_sec=(dp.bins / 1000.0).astype("float32"),
|
|
67
|
+
bin_counts=dp.correlograms[i, j].astype("int32"),
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
view = Autocorrelograms(autocorrelograms=ac_items)
|
|
72
|
+
return view
|
|
73
|
+
|
|
74
|
+
def _write_to_zarr_group(self, group: figpack.Group) -> None:
|
|
75
|
+
"""
|
|
76
|
+
Write the Autocorrelograms data to a Zarr group
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
group: Zarr group to write data into
|
|
80
|
+
"""
|
|
81
|
+
super()._write_to_zarr_group(group)
|
|
82
|
+
|
|
83
|
+
# Store the number of autocorrelograms
|
|
84
|
+
num_autocorrelograms = len(self.autocorrelograms)
|
|
85
|
+
group.attrs["num_autocorrelograms"] = num_autocorrelograms
|
|
86
|
+
|
|
87
|
+
if num_autocorrelograms == 0:
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
# Get dimensions from first autocorrelogram
|
|
91
|
+
num_bins = len(self.autocorrelograms[0].bin_counts)
|
|
92
|
+
|
|
93
|
+
# Store bin edges (same for all autocorrelograms)
|
|
94
|
+
group.create_dataset(
|
|
95
|
+
"bin_edges_sec",
|
|
96
|
+
data=self.autocorrelograms[0].bin_edges_sec,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Create 2D array for all bin counts
|
|
100
|
+
bin_counts = np.zeros((num_autocorrelograms, num_bins), dtype=np.int32)
|
|
101
|
+
|
|
102
|
+
# Store metadata for each autocorrelogram and populate bin counts array
|
|
103
|
+
autocorrelogram_metadata = []
|
|
104
|
+
for i, autocorr in enumerate(self.autocorrelograms):
|
|
105
|
+
metadata = {
|
|
106
|
+
"unit_id": str(autocorr.unit_id),
|
|
107
|
+
"index": i, # Store index to map to bin_counts array
|
|
108
|
+
"num_bins": num_bins,
|
|
109
|
+
}
|
|
110
|
+
autocorrelogram_metadata.append(metadata)
|
|
111
|
+
bin_counts[i] = autocorr.bin_counts
|
|
112
|
+
|
|
113
|
+
# Store the bin counts as a single 2D dataset
|
|
114
|
+
group.create_dataset(
|
|
115
|
+
"bin_counts",
|
|
116
|
+
data=bin_counts,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Store the autocorrelogram metadata
|
|
120
|
+
group.attrs["autocorrelograms"] = autocorrelogram_metadata
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AverageWaveforms view for figpack - displays multiple average waveforms
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import List, Optional, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import figpack
|
|
9
|
+
from ..spike_sorting_extension import spike_sorting_extension
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AverageWaveformItem:
|
|
13
|
+
"""
|
|
14
|
+
Represents a single average waveform for a unit
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
*,
|
|
20
|
+
unit_id: Union[str, int],
|
|
21
|
+
channel_ids: List[Union[str, int]],
|
|
22
|
+
waveform: np.ndarray,
|
|
23
|
+
waveform_std_dev: Optional[np.ndarray] = None,
|
|
24
|
+
waveform_percentiles: Optional[List[np.ndarray]] = None,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Initialize an AverageWaveformItem
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
unit_id: Identifier for the unit
|
|
31
|
+
channel_ids: List of channel identifiers
|
|
32
|
+
waveform: 2D numpy array representing the average waveform (num_samples x num_channels)
|
|
33
|
+
waveform_std_dev: Optional 2D numpy array representing the standard deviation of the waveform
|
|
34
|
+
waveform_percentiles: Optional list of 2D numpy arrays representing percentiles of the waveform
|
|
35
|
+
"""
|
|
36
|
+
self.unit_id = unit_id
|
|
37
|
+
self.channel_ids = channel_ids
|
|
38
|
+
self.waveform = np.array(waveform, dtype=np.float32)
|
|
39
|
+
self.waveform_std_dev = (
|
|
40
|
+
np.array(waveform_std_dev, dtype=np.float32)
|
|
41
|
+
if waveform_std_dev is not None
|
|
42
|
+
else None
|
|
43
|
+
)
|
|
44
|
+
if waveform_percentiles is not None:
|
|
45
|
+
self.waveform_percentiles = [
|
|
46
|
+
np.array(p, dtype=np.float32) for p in waveform_percentiles
|
|
47
|
+
]
|
|
48
|
+
else:
|
|
49
|
+
self.waveform_percentiles = None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class AverageWaveforms(figpack.ExtensionView):
|
|
53
|
+
"""
|
|
54
|
+
A view that displays multiple average waveforms for spike sorting analysis
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, *, average_waveforms: List[AverageWaveformItem]):
|
|
58
|
+
"""
|
|
59
|
+
Initialize an AverageWaveforms view
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
average_waveforms: List of AverageWaveformItem objects
|
|
63
|
+
"""
|
|
64
|
+
super().__init__(
|
|
65
|
+
extension=spike_sorting_extension,
|
|
66
|
+
view_type="spike_sorting.AverageWaveforms",
|
|
67
|
+
)
|
|
68
|
+
self.average_waveforms = average_waveforms
|
|
69
|
+
|
|
70
|
+
@staticmethod
|
|
71
|
+
def from_sorting_analyzer(sorting_analyzer):
|
|
72
|
+
sorting_analyzer.compute(
|
|
73
|
+
["random_spikes", "waveforms", "templates", "noise_levels"]
|
|
74
|
+
)
|
|
75
|
+
ext_templates = sorting_analyzer.get_extension("templates")
|
|
76
|
+
# shape is num_units, num_samples, num_channels
|
|
77
|
+
av_templates = ext_templates.get_data(operator="average")
|
|
78
|
+
|
|
79
|
+
ext_noise_levels = sorting_analyzer.get_extension("noise_levels")
|
|
80
|
+
noise_levels = ext_noise_levels.get_data()
|
|
81
|
+
|
|
82
|
+
waveform_std_dev = np.zeros(
|
|
83
|
+
(av_templates.shape[1], av_templates.shape[2]), dtype=np.float32
|
|
84
|
+
)
|
|
85
|
+
for i in range(av_templates.shape[2]):
|
|
86
|
+
waveform_std_dev[:, i] = noise_levels[i]
|
|
87
|
+
|
|
88
|
+
average_waveform_items = []
|
|
89
|
+
for i, unit_id in enumerate(sorting_analyzer.unit_ids):
|
|
90
|
+
waveform = av_templates[i]
|
|
91
|
+
channel_ids = list(sorting_analyzer.recording.get_channel_ids())
|
|
92
|
+
average_waveform_items.append(
|
|
93
|
+
AverageWaveformItem(
|
|
94
|
+
unit_id=unit_id,
|
|
95
|
+
waveform=waveform,
|
|
96
|
+
channel_ids=channel_ids,
|
|
97
|
+
waveform_std_dev=waveform_std_dev,
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
view = AverageWaveforms(average_waveforms=average_waveform_items)
|
|
101
|
+
return view
|
|
102
|
+
|
|
103
|
+
def _write_to_zarr_group(self, group: figpack.Group) -> None:
|
|
104
|
+
"""
|
|
105
|
+
Write the AverageWaveforms data to a Zarr group
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
group: Zarr group to write data into
|
|
109
|
+
"""
|
|
110
|
+
super()._write_to_zarr_group(group)
|
|
111
|
+
|
|
112
|
+
# Store the number of average waveforms
|
|
113
|
+
group.attrs["num_average_waveforms"] = len(self.average_waveforms)
|
|
114
|
+
|
|
115
|
+
# Store metadata for each average waveform
|
|
116
|
+
average_waveform_metadata = []
|
|
117
|
+
for i, waveform in enumerate(self.average_waveforms):
|
|
118
|
+
waveform_name = f"waveform_{i}"
|
|
119
|
+
|
|
120
|
+
# Store metadata
|
|
121
|
+
metadata = {
|
|
122
|
+
"name": waveform_name,
|
|
123
|
+
"unit_id": str(waveform.unit_id),
|
|
124
|
+
"channel_ids": [str(ch) for ch in waveform.channel_ids],
|
|
125
|
+
}
|
|
126
|
+
average_waveform_metadata.append(metadata)
|
|
127
|
+
|
|
128
|
+
# Create arrays for this average waveform
|
|
129
|
+
group.create_dataset(
|
|
130
|
+
f"{waveform_name}/waveform",
|
|
131
|
+
data=waveform.waveform,
|
|
132
|
+
)
|
|
133
|
+
if waveform.waveform_std_dev is not None:
|
|
134
|
+
group.create_dataset(
|
|
135
|
+
f"{waveform_name}/waveform_std_dev",
|
|
136
|
+
data=waveform.waveform_std_dev,
|
|
137
|
+
)
|
|
138
|
+
if waveform.waveform_percentiles is not None:
|
|
139
|
+
for j, p in enumerate(waveform.waveform_percentiles):
|
|
140
|
+
group.create_dataset(
|
|
141
|
+
f"{waveform_name}/waveform_percentile_{j}",
|
|
142
|
+
data=p,
|
|
143
|
+
dtype=p.dtype,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Store the average waveform metadata
|
|
147
|
+
group.attrs["average_waveforms"] = average_waveform_metadata
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CrossCorrelogramItem for spike sorting views
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CrossCorrelogramItem:
|
|
11
|
+
"""
|
|
12
|
+
Represents a single cross-correlogram between two units
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
*,
|
|
18
|
+
unit_id1: Union[str, int],
|
|
19
|
+
unit_id2: Union[str, int],
|
|
20
|
+
bin_edges_sec: np.ndarray,
|
|
21
|
+
bin_counts: np.ndarray,
|
|
22
|
+
):
|
|
23
|
+
"""
|
|
24
|
+
Initialize a CrossCorrelogramItem
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
unit_id1: Identifier for the first unit
|
|
28
|
+
unit_id2: Identifier for the second unit
|
|
29
|
+
bin_edges_sec: Array of bin edges in seconds
|
|
30
|
+
bin_counts: Array of bin counts
|
|
31
|
+
"""
|
|
32
|
+
self.unit_id1 = unit_id1
|
|
33
|
+
self.unit_id2 = unit_id2
|
|
34
|
+
self.bin_edges_sec = np.array(bin_edges_sec, dtype=np.float32)
|
|
35
|
+
self.bin_counts = np.array(bin_counts, dtype=np.int32)
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CrossCorrelograms view for figpack - displays multiple cross-correlograms
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import List, Optional
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
import figpack
|
|
10
|
+
from ..spike_sorting_extension import spike_sorting_extension
|
|
11
|
+
|
|
12
|
+
from .CrossCorrelogramItem import CrossCorrelogramItem
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CrossCorrelograms(figpack.ExtensionView):
|
|
16
|
+
"""
|
|
17
|
+
A view that displays multiple cross-correlograms for spike sorting analysis
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
*,
|
|
23
|
+
cross_correlograms: List[CrossCorrelogramItem],
|
|
24
|
+
hide_unit_selector: Optional[bool] = False,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Initialize a CrossCorrelograms view
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
cross_correlograms: List of CrossCorrelogramItem objects
|
|
31
|
+
hide_unit_selector: Whether to hide the unit selector widget
|
|
32
|
+
"""
|
|
33
|
+
super().__init__(
|
|
34
|
+
extension=spike_sorting_extension,
|
|
35
|
+
view_type="spike_sorting.CrossCorrelograms",
|
|
36
|
+
)
|
|
37
|
+
self.cross_correlograms = cross_correlograms
|
|
38
|
+
self.hide_unit_selector = hide_unit_selector
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def from_sorting(sorting):
|
|
42
|
+
import spikeinterface as si
|
|
43
|
+
import spikeinterface.widgets as sw
|
|
44
|
+
|
|
45
|
+
assert isinstance(sorting, si.BaseSorting), "Input must be a BaseSorting object"
|
|
46
|
+
W = sw.CrossCorrelogramsWidget(sorting)
|
|
47
|
+
return CrossCorrelograms.from_spikeinterface_widget(W)
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def from_spikeinterface_widget(W):
|
|
51
|
+
from spikeinterface.widgets.base import to_attr
|
|
52
|
+
from spikeinterface.widgets.utils_sortingview import make_serializable
|
|
53
|
+
|
|
54
|
+
from .CrossCorrelogramItem import CrossCorrelogramItem
|
|
55
|
+
|
|
56
|
+
data_plot = W.data_plot
|
|
57
|
+
|
|
58
|
+
dp = to_attr(data_plot)
|
|
59
|
+
|
|
60
|
+
unit_ids = make_serializable(dp.unit_ids)
|
|
61
|
+
|
|
62
|
+
if dp.similarity is not None:
|
|
63
|
+
similarity = dp.similarity
|
|
64
|
+
else:
|
|
65
|
+
similarity = np.ones((len(unit_ids), len(unit_ids)))
|
|
66
|
+
|
|
67
|
+
cc_items = []
|
|
68
|
+
for i in range(len(unit_ids)):
|
|
69
|
+
for j in range(i, len(unit_ids)):
|
|
70
|
+
if similarity[i, j] >= dp.min_similarity_for_correlograms:
|
|
71
|
+
cc_items.append(
|
|
72
|
+
CrossCorrelogramItem(
|
|
73
|
+
unit_id1=unit_ids[i],
|
|
74
|
+
unit_id2=unit_ids[j],
|
|
75
|
+
bin_edges_sec=(dp.bins / 1000.0).astype("float32"),
|
|
76
|
+
bin_counts=dp.correlograms[i, j].astype("int32"),
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
view = CrossCorrelograms(cross_correlograms=cc_items, hide_unit_selector=False)
|
|
81
|
+
return view
|
|
82
|
+
|
|
83
|
+
def _write_to_zarr_group(self, group: figpack.Group) -> None:
|
|
84
|
+
"""
|
|
85
|
+
Write the CrossCorrelograms data to a Zarr group
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
group: Zarr group to write data into
|
|
89
|
+
"""
|
|
90
|
+
super()._write_to_zarr_group(group)
|
|
91
|
+
|
|
92
|
+
# Set view properties
|
|
93
|
+
if self.hide_unit_selector is not None:
|
|
94
|
+
group.attrs["hide_unit_selector"] = self.hide_unit_selector
|
|
95
|
+
|
|
96
|
+
# Store the number of cross-correlograms
|
|
97
|
+
num_cross_correlograms = len(self.cross_correlograms)
|
|
98
|
+
group.attrs["num_cross_correlograms"] = num_cross_correlograms
|
|
99
|
+
|
|
100
|
+
if num_cross_correlograms == 0:
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
# Get dimensions from first cross-correlogram
|
|
104
|
+
num_bins = len(self.cross_correlograms[0].bin_counts)
|
|
105
|
+
|
|
106
|
+
# Store bin edges (same for all cross-correlograms)
|
|
107
|
+
group.create_dataset(
|
|
108
|
+
"bin_edges_sec",
|
|
109
|
+
data=self.cross_correlograms[0].bin_edges_sec,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Create 2D array for all bin counts
|
|
113
|
+
bin_counts = np.zeros((num_cross_correlograms, num_bins), dtype=np.int32)
|
|
114
|
+
|
|
115
|
+
# Store metadata for each cross-correlogram and populate bin counts array
|
|
116
|
+
cross_correlogram_metadata = []
|
|
117
|
+
for i, cross_corr in enumerate(self.cross_correlograms):
|
|
118
|
+
metadata = {
|
|
119
|
+
"unit_id1": str(cross_corr.unit_id1),
|
|
120
|
+
"unit_id2": str(cross_corr.unit_id2),
|
|
121
|
+
"index": i, # Store index to map to bin_counts array
|
|
122
|
+
"num_bins": num_bins,
|
|
123
|
+
}
|
|
124
|
+
cross_correlogram_metadata.append(metadata)
|
|
125
|
+
bin_counts[i] = cross_corr.bin_counts
|
|
126
|
+
|
|
127
|
+
# Store the bin counts as a single 2D dataset
|
|
128
|
+
group.create_dataset(
|
|
129
|
+
"bin_counts",
|
|
130
|
+
data=bin_counts,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Store the cross-correlogram metadata
|
|
134
|
+
group.attrs["cross_correlograms"] = cross_correlogram_metadata
|