figpack 0.2.16__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- figpack/__init__.py +2 -3
- figpack/core/__init__.py +2 -2
- figpack/core/_bundle_utils.py +56 -18
- figpack/core/extension_view.py +7 -25
- figpack/core/figpack_extension.py +0 -71
- figpack/figpack-figure-dist/assets/index-DBwmtEpB.js +91 -0
- figpack/figpack-figure-dist/assets/{index-D9a3K6eW.css → index-DHWczh-Q.css} +1 -1
- figpack/figpack-figure-dist/index.html +2 -2
- figpack/views/PlotlyExtension/PlotlyExtension.py +4 -50
- figpack/views/PlotlyExtension/_plotly_extension.py +46 -0
- figpack/views/PlotlyExtension/plotly_view.js +84 -80
- figpack/views/__init__.py +1 -0
- {figpack-0.2.16.dist-info → figpack-0.2.17.dist-info}/METADATA +1 -1
- figpack-0.2.17.dist-info/RECORD +43 -0
- figpack/figpack-figure-dist/assets/index-DtOnN02w.js +0 -846
- figpack/franklab/__init__.py +0 -5
- figpack/franklab/views/TrackAnimation.py +0 -154
- figpack/franklab/views/__init__.py +0 -9
- figpack/spike_sorting/__init__.py +0 -5
- figpack/spike_sorting/views/AutocorrelogramItem.py +0 -32
- figpack/spike_sorting/views/Autocorrelograms.py +0 -116
- figpack/spike_sorting/views/AverageWaveforms.py +0 -146
- figpack/spike_sorting/views/CrossCorrelogramItem.py +0 -35
- figpack/spike_sorting/views/CrossCorrelograms.py +0 -131
- figpack/spike_sorting/views/RasterPlot.py +0 -284
- figpack/spike_sorting/views/RasterPlotItem.py +0 -28
- figpack/spike_sorting/views/SpikeAmplitudes.py +0 -364
- figpack/spike_sorting/views/SpikeAmplitudesItem.py +0 -38
- figpack/spike_sorting/views/UnitMetricsGraph.py +0 -127
- figpack/spike_sorting/views/UnitSimilarityScore.py +0 -40
- figpack/spike_sorting/views/UnitsTable.py +0 -82
- figpack/spike_sorting/views/UnitsTableColumn.py +0 -40
- figpack/spike_sorting/views/UnitsTableRow.py +0 -36
- figpack/spike_sorting/views/__init__.py +0 -41
- figpack-0.2.16.dist-info/RECORD +0 -61
- {figpack-0.2.16.dist-info → figpack-0.2.17.dist-info}/WHEEL +0 -0
- {figpack-0.2.16.dist-info → figpack-0.2.17.dist-info}/entry_points.txt +0 -0
- {figpack-0.2.16.dist-info → figpack-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {figpack-0.2.16.dist-info → figpack-0.2.17.dist-info}/top_level.txt +0 -0
|
@@ -1,284 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
RasterPlot view for figpack - displays multiple raster plots
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import List
|
|
6
|
-
import numpy as np
|
|
7
|
-
import zarr
|
|
8
|
-
|
|
9
|
-
from ...core.figpack_view import FigpackView
|
|
10
|
-
from ...core.zarr import Group
|
|
11
|
-
from .RasterPlotItem import RasterPlotItem
|
|
12
|
-
from .UnitsTable import UnitsTable, UnitsTableColumn, UnitsTableRow
|
|
13
|
-
from ...views.Box import Box, LayoutItem
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class RasterPlot(FigpackView):
|
|
17
|
-
"""
|
|
18
|
-
A view that displays multiple raster plots for spike sorting analysis
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
def __init__(
|
|
22
|
-
self,
|
|
23
|
-
*,
|
|
24
|
-
start_time_sec: float,
|
|
25
|
-
end_time_sec: float,
|
|
26
|
-
plots: List[RasterPlotItem],
|
|
27
|
-
):
|
|
28
|
-
"""
|
|
29
|
-
Initialize a RasterPlot view
|
|
30
|
-
|
|
31
|
-
Args:
|
|
32
|
-
start_time_sec: Start time in seconds for the plot range
|
|
33
|
-
end_time_sec: End time in seconds for the plot range
|
|
34
|
-
plots: List of RasterPlotItem objects
|
|
35
|
-
height: Height of the plot in pixels (default: 500)
|
|
36
|
-
"""
|
|
37
|
-
self.start_time_sec = float(start_time_sec)
|
|
38
|
-
self.end_time_sec = float(end_time_sec)
|
|
39
|
-
self.plots = plots
|
|
40
|
-
|
|
41
|
-
@staticmethod
|
|
42
|
-
def from_nwb_units_table(
|
|
43
|
-
nwb_url_or_path_or_h5py,
|
|
44
|
-
*,
|
|
45
|
-
units_path: str,
|
|
46
|
-
include_units_selector: bool = False,
|
|
47
|
-
):
|
|
48
|
-
if isinstance(nwb_url_or_path_or_h5py, str):
|
|
49
|
-
import lindi
|
|
50
|
-
|
|
51
|
-
f = lindi.LindiH5pyFile.from_hdf5_file(nwb_url_or_path_or_h5py)
|
|
52
|
-
else:
|
|
53
|
-
f = nwb_url_or_path_or_h5py
|
|
54
|
-
X = f[units_path]
|
|
55
|
-
spike_times = X["spike_times"]
|
|
56
|
-
spike_times_index = X["spike_times_index"]
|
|
57
|
-
id = X["id"]
|
|
58
|
-
plots = []
|
|
59
|
-
num_units = len(spike_times_index)
|
|
60
|
-
start_times = []
|
|
61
|
-
end_times = []
|
|
62
|
-
for unit_index in range(num_units):
|
|
63
|
-
unit_id = id[unit_index]
|
|
64
|
-
if unit_index > 0:
|
|
65
|
-
start_index = spike_times_index[unit_index - 1]
|
|
66
|
-
else:
|
|
67
|
-
start_index = 0
|
|
68
|
-
end_index = spike_times_index[unit_index]
|
|
69
|
-
unit_spike_times = spike_times[start_index:end_index]
|
|
70
|
-
if len(unit_spike_times) == 0:
|
|
71
|
-
continue
|
|
72
|
-
start_times.append(unit_spike_times[0])
|
|
73
|
-
end_times.append(unit_spike_times[-1])
|
|
74
|
-
plots.append(
|
|
75
|
-
RasterPlotItem(unit_id=str(unit_id), spike_times_sec=unit_spike_times)
|
|
76
|
-
)
|
|
77
|
-
view = RasterPlot(
|
|
78
|
-
start_time_sec=min(start_times),
|
|
79
|
-
end_time_sec=max(end_times),
|
|
80
|
-
plots=plots,
|
|
81
|
-
)
|
|
82
|
-
if include_units_selector:
|
|
83
|
-
columns: List[UnitsTableColumn] = [
|
|
84
|
-
UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
|
|
85
|
-
]
|
|
86
|
-
rows: List[UnitsTableRow] = []
|
|
87
|
-
for unit_id in id:
|
|
88
|
-
rows.append(
|
|
89
|
-
UnitsTableRow(
|
|
90
|
-
unit_id=str(unit_id),
|
|
91
|
-
values={},
|
|
92
|
-
)
|
|
93
|
-
)
|
|
94
|
-
units_table = UnitsTable(
|
|
95
|
-
columns=columns,
|
|
96
|
-
rows=rows,
|
|
97
|
-
)
|
|
98
|
-
layout = Box(
|
|
99
|
-
direction="horizontal",
|
|
100
|
-
items=[
|
|
101
|
-
LayoutItem(view=units_table, max_size=150, title="Units"),
|
|
102
|
-
LayoutItem(view=view, title="Spike Amplitudes"),
|
|
103
|
-
],
|
|
104
|
-
)
|
|
105
|
-
return layout
|
|
106
|
-
else:
|
|
107
|
-
return view
|
|
108
|
-
|
|
109
|
-
def _write_to_zarr_group(self, group: Group) -> None:
|
|
110
|
-
"""
|
|
111
|
-
Args:
|
|
112
|
-
group: Zarr group to write data into
|
|
113
|
-
"""
|
|
114
|
-
# Set the view type
|
|
115
|
-
group.attrs["view_type"] = "RasterPlot"
|
|
116
|
-
|
|
117
|
-
# Store view parameters
|
|
118
|
-
group.attrs["start_time_sec"] = self.start_time_sec
|
|
119
|
-
group.attrs["end_time_sec"] = self.end_time_sec
|
|
120
|
-
|
|
121
|
-
# Prepare unified data arrays
|
|
122
|
-
unified_data = self._prepare_unified_data()
|
|
123
|
-
|
|
124
|
-
if unified_data["total_spikes"] == 0:
|
|
125
|
-
# Handle empty data case
|
|
126
|
-
group.create_dataset("timestamps", data=np.array([], dtype=np.float32))
|
|
127
|
-
group.create_dataset("unit_indices", data=np.array([], dtype=np.uint16))
|
|
128
|
-
group.create_dataset("reference_times", data=np.array([], dtype=np.float32))
|
|
129
|
-
group.create_dataset(
|
|
130
|
-
"reference_indices", data=np.array([], dtype=np.uint32)
|
|
131
|
-
)
|
|
132
|
-
group.attrs["unit_ids"] = []
|
|
133
|
-
group.attrs["total_spikes"] = 0
|
|
134
|
-
return
|
|
135
|
-
|
|
136
|
-
chunks = (
|
|
137
|
-
(2_000_000,)
|
|
138
|
-
if unified_data["total_spikes"] > 2_000_000
|
|
139
|
-
else (len(unified_data["timestamps"]),)
|
|
140
|
-
)
|
|
141
|
-
# Store main data arrays
|
|
142
|
-
group.create_dataset(
|
|
143
|
-
"timestamps",
|
|
144
|
-
data=unified_data["timestamps"],
|
|
145
|
-
chunks=chunks,
|
|
146
|
-
)
|
|
147
|
-
group.create_dataset(
|
|
148
|
-
"unit_indices",
|
|
149
|
-
data=unified_data["unit_indices"],
|
|
150
|
-
chunks=chunks,
|
|
151
|
-
)
|
|
152
|
-
group.create_dataset(
|
|
153
|
-
"reference_times",
|
|
154
|
-
data=unified_data["reference_times"],
|
|
155
|
-
chunks=(len(unified_data["reference_times"]),),
|
|
156
|
-
)
|
|
157
|
-
group.create_dataset(
|
|
158
|
-
"reference_indices",
|
|
159
|
-
data=unified_data["reference_indices"],
|
|
160
|
-
chunks=(len(unified_data["reference_indices"]),),
|
|
161
|
-
)
|
|
162
|
-
|
|
163
|
-
# Create spike counts array with 1-second bins
|
|
164
|
-
duration = self.end_time_sec - self.start_time_sec
|
|
165
|
-
num_bins = int(np.ceil(duration))
|
|
166
|
-
num_units = len(self.plots)
|
|
167
|
-
spike_counts = np.zeros((num_bins, num_units), dtype=np.uint16)
|
|
168
|
-
|
|
169
|
-
# Efficiently compute spike counts for each unit
|
|
170
|
-
for unit_idx, plot in enumerate(self.plots):
|
|
171
|
-
# Convert spike times to bin indices
|
|
172
|
-
bin_indices = (
|
|
173
|
-
(np.array(plot.spike_times_sec) - self.start_time_sec)
|
|
174
|
-
).astype(int)
|
|
175
|
-
# Count spikes in valid bins
|
|
176
|
-
valid_indices = (bin_indices >= 0) & (bin_indices < num_bins)
|
|
177
|
-
unique_bins, counts = np.unique(
|
|
178
|
-
bin_indices[valid_indices], return_counts=True
|
|
179
|
-
)
|
|
180
|
-
spike_counts[unique_bins, unit_idx] = counts.clip(
|
|
181
|
-
max=65535
|
|
182
|
-
) # Clip to uint16 max
|
|
183
|
-
|
|
184
|
-
# Store spike counts array
|
|
185
|
-
group.create_dataset(
|
|
186
|
-
"spike_counts_1sec",
|
|
187
|
-
data=spike_counts,
|
|
188
|
-
chunks=(min(num_bins, 10000), min(num_units, 500)),
|
|
189
|
-
)
|
|
190
|
-
|
|
191
|
-
# Store unit ID mapping
|
|
192
|
-
group.attrs["unit_ids"] = unified_data["unit_ids"]
|
|
193
|
-
group.attrs["total_spikes"] = unified_data["total_spikes"]
|
|
194
|
-
|
|
195
|
-
def _prepare_unified_data(self) -> dict:
|
|
196
|
-
"""
|
|
197
|
-
Prepare unified data arrays from all plots
|
|
198
|
-
|
|
199
|
-
Returns:
|
|
200
|
-
Dictionary containing unified arrays and metadata
|
|
201
|
-
"""
|
|
202
|
-
if not self.plots:
|
|
203
|
-
return {
|
|
204
|
-
"timestamps": np.array([], dtype=np.float32),
|
|
205
|
-
"unit_indices": np.array([], dtype=np.uint16),
|
|
206
|
-
"reference_times": np.array([], dtype=np.float32),
|
|
207
|
-
"reference_indices": np.array([], dtype=np.uint32),
|
|
208
|
-
"unit_ids": [],
|
|
209
|
-
"total_spikes": 0,
|
|
210
|
-
}
|
|
211
|
-
|
|
212
|
-
# Create unit ID mapping
|
|
213
|
-
unit_ids = [str(plot.unit_id) for plot in self.plots]
|
|
214
|
-
unit_id_to_index = {unit_id: i for i, unit_id in enumerate(unit_ids)}
|
|
215
|
-
|
|
216
|
-
# Collect all spikes with their unit indices
|
|
217
|
-
all_spikes = []
|
|
218
|
-
for plot in self.plots:
|
|
219
|
-
unit_index = unit_id_to_index[str(plot.unit_id)]
|
|
220
|
-
for time in plot.spike_times_sec:
|
|
221
|
-
all_spikes.append((float(time), unit_index))
|
|
222
|
-
|
|
223
|
-
if not all_spikes:
|
|
224
|
-
return {
|
|
225
|
-
"timestamps": np.array([], dtype=np.float32),
|
|
226
|
-
"unit_indices": np.array([], dtype=np.uint16),
|
|
227
|
-
"reference_times": np.array([], dtype=np.float32),
|
|
228
|
-
"reference_indices": np.array([], dtype=np.uint32),
|
|
229
|
-
"unit_ids": unit_ids,
|
|
230
|
-
"total_spikes": 0,
|
|
231
|
-
}
|
|
232
|
-
|
|
233
|
-
# Sort by timestamp
|
|
234
|
-
all_spikes.sort(key=lambda x: x[0])
|
|
235
|
-
|
|
236
|
-
# Extract sorted arrays
|
|
237
|
-
timestamps = np.array([spike[0] for spike in all_spikes], dtype=np.float32)
|
|
238
|
-
unit_indices = np.array([spike[1] for spike in all_spikes], dtype=np.uint16)
|
|
239
|
-
|
|
240
|
-
# Generate reference arrays
|
|
241
|
-
reference_times, reference_indices = self._generate_reference_arrays(timestamps)
|
|
242
|
-
|
|
243
|
-
return {
|
|
244
|
-
"timestamps": timestamps,
|
|
245
|
-
"unit_indices": unit_indices,
|
|
246
|
-
"reference_times": reference_times,
|
|
247
|
-
"reference_indices": reference_indices,
|
|
248
|
-
"unit_ids": unit_ids,
|
|
249
|
-
"total_spikes": len(all_spikes),
|
|
250
|
-
}
|
|
251
|
-
|
|
252
|
-
def _generate_reference_arrays(
|
|
253
|
-
self, timestamps: np.ndarray, interval_sec: float = 1.0
|
|
254
|
-
) -> tuple:
|
|
255
|
-
"""
|
|
256
|
-
Generate reference arrays using actual timestamps from the data
|
|
257
|
-
|
|
258
|
-
Args:
|
|
259
|
-
timestamps: Sorted array of timestamps
|
|
260
|
-
interval_sec: Minimum interval between reference points
|
|
261
|
-
|
|
262
|
-
Returns:
|
|
263
|
-
Tuple of (reference_times, reference_indices)
|
|
264
|
-
"""
|
|
265
|
-
if len(timestamps) == 0:
|
|
266
|
-
return np.array([], dtype=np.float32), np.array([], dtype=np.uint32)
|
|
267
|
-
|
|
268
|
-
reference_times = []
|
|
269
|
-
reference_indices = []
|
|
270
|
-
|
|
271
|
-
current_ref_time = timestamps[0]
|
|
272
|
-
reference_times.append(current_ref_time)
|
|
273
|
-
reference_indices.append(0)
|
|
274
|
-
|
|
275
|
-
# Find the next reference point at least interval_sec later
|
|
276
|
-
for i, timestamp in enumerate(timestamps):
|
|
277
|
-
if timestamp >= current_ref_time + interval_sec:
|
|
278
|
-
reference_times.append(timestamp)
|
|
279
|
-
reference_indices.append(i)
|
|
280
|
-
current_ref_time = timestamp
|
|
281
|
-
|
|
282
|
-
return np.array(reference_times, dtype=np.float32), np.array(
|
|
283
|
-
reference_indices, dtype=np.uint32
|
|
284
|
-
)
|
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
RasterPlotItem for figpack - represents a single unit's raster plot
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import Union
|
|
6
|
-
import numpy as np
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class RasterPlotItem:
|
|
10
|
-
"""
|
|
11
|
-
Represents spike times for a single unit in a raster plot
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
def __init__(
|
|
15
|
-
self,
|
|
16
|
-
*,
|
|
17
|
-
unit_id: Union[str, int],
|
|
18
|
-
spike_times_sec: np.ndarray,
|
|
19
|
-
):
|
|
20
|
-
"""
|
|
21
|
-
Initialize a RasterPlotItem
|
|
22
|
-
|
|
23
|
-
Args:
|
|
24
|
-
unit_id: Identifier for the unit
|
|
25
|
-
spike_times_sec: Numpy array of spike times in seconds
|
|
26
|
-
"""
|
|
27
|
-
self.unit_id = unit_id
|
|
28
|
-
self.spike_times_sec = np.array(spike_times_sec, dtype=np.float32)
|
|
@@ -1,364 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
SpikeAmplitudes view for figpack - displays spike amplitudes over time
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import List
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
|
|
9
|
-
from ...core.figpack_view import FigpackView
|
|
10
|
-
from ...core.zarr import Group
|
|
11
|
-
from .SpikeAmplitudesItem import SpikeAmplitudesItem
|
|
12
|
-
from .UnitsTable import UnitsTable, UnitsTableColumn, UnitsTableRow
|
|
13
|
-
from ...views.Box import Box, LayoutItem
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class SpikeAmplitudes(FigpackView):
|
|
17
|
-
"""
|
|
18
|
-
A view that displays spike amplitudes over time for multiple units
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
def __init__(
|
|
22
|
-
self,
|
|
23
|
-
*,
|
|
24
|
-
start_time_sec: float,
|
|
25
|
-
end_time_sec: float,
|
|
26
|
-
plots: List[SpikeAmplitudesItem],
|
|
27
|
-
):
|
|
28
|
-
"""
|
|
29
|
-
Initialize a SpikeAmplitudes view
|
|
30
|
-
|
|
31
|
-
Args:
|
|
32
|
-
start_time_sec: Start time of the view in seconds
|
|
33
|
-
end_time_sec: End time of the view in seconds
|
|
34
|
-
plots: List of SpikeAmplitudesItem objects
|
|
35
|
-
"""
|
|
36
|
-
self.start_time_sec = start_time_sec
|
|
37
|
-
self.end_time_sec = end_time_sec
|
|
38
|
-
self.plots = plots
|
|
39
|
-
|
|
40
|
-
@staticmethod
|
|
41
|
-
def from_nwb_units_table(
|
|
42
|
-
nwb_url_or_path_or_h5py,
|
|
43
|
-
*,
|
|
44
|
-
units_path: str,
|
|
45
|
-
include_units_selector: bool = False,
|
|
46
|
-
):
|
|
47
|
-
if isinstance(nwb_url_or_path_or_h5py, str):
|
|
48
|
-
import lindi
|
|
49
|
-
|
|
50
|
-
f = lindi.LindiH5pyFile.from_hdf5_file(nwb_url_or_path_or_h5py)
|
|
51
|
-
else:
|
|
52
|
-
f = nwb_url_or_path_or_h5py
|
|
53
|
-
X = f[units_path]
|
|
54
|
-
spike_amplitudes = X["spike_amplitudes"]
|
|
55
|
-
# spike_amplitudes_index = X["spike_amplitudes_index"] # presumably the same as spike_times_index
|
|
56
|
-
spike_times = X["spike_times"]
|
|
57
|
-
spike_times_index = X["spike_times_index"]
|
|
58
|
-
id = X["id"]
|
|
59
|
-
plots = []
|
|
60
|
-
num_units = len(spike_times_index)
|
|
61
|
-
start_times = []
|
|
62
|
-
end_times = []
|
|
63
|
-
for unit_index in range(num_units):
|
|
64
|
-
unit_id = id[unit_index]
|
|
65
|
-
if unit_index > 0:
|
|
66
|
-
start_index = spike_times_index[unit_index - 1]
|
|
67
|
-
else:
|
|
68
|
-
start_index = 0
|
|
69
|
-
end_index = spike_times_index[unit_index]
|
|
70
|
-
unit_spike_amplitudes = spike_amplitudes[start_index:end_index]
|
|
71
|
-
unit_spike_times = spike_times[start_index:end_index]
|
|
72
|
-
if len(unit_spike_times) == 0:
|
|
73
|
-
continue
|
|
74
|
-
start_times.append(unit_spike_times[0])
|
|
75
|
-
end_times.append(unit_spike_times[-1])
|
|
76
|
-
plots.append(
|
|
77
|
-
SpikeAmplitudesItem(
|
|
78
|
-
unit_id=str(unit_id),
|
|
79
|
-
spike_times_sec=unit_spike_times,
|
|
80
|
-
spike_amplitudes=unit_spike_amplitudes,
|
|
81
|
-
)
|
|
82
|
-
)
|
|
83
|
-
view = SpikeAmplitudes(
|
|
84
|
-
start_time_sec=min(start_times),
|
|
85
|
-
end_time_sec=max(end_times),
|
|
86
|
-
plots=plots,
|
|
87
|
-
)
|
|
88
|
-
if include_units_selector:
|
|
89
|
-
columns: List[UnitsTableColumn] = [
|
|
90
|
-
UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
|
|
91
|
-
]
|
|
92
|
-
rows: List[UnitsTableRow] = []
|
|
93
|
-
for unit_id in id:
|
|
94
|
-
rows.append(
|
|
95
|
-
UnitsTableRow(
|
|
96
|
-
unit_id=str(unit_id),
|
|
97
|
-
values={},
|
|
98
|
-
)
|
|
99
|
-
)
|
|
100
|
-
units_table = UnitsTable(
|
|
101
|
-
columns=columns,
|
|
102
|
-
rows=rows,
|
|
103
|
-
)
|
|
104
|
-
layout = Box(
|
|
105
|
-
direction="horizontal",
|
|
106
|
-
items=[
|
|
107
|
-
LayoutItem(view=units_table, max_size=150, title="Units"),
|
|
108
|
-
LayoutItem(view=view, title="Spike Amplitudes"),
|
|
109
|
-
],
|
|
110
|
-
)
|
|
111
|
-
return layout
|
|
112
|
-
else:
|
|
113
|
-
return view
|
|
114
|
-
|
|
115
|
-
def _write_to_zarr_group(self, group: Group) -> None:
|
|
116
|
-
"""
|
|
117
|
-
Write the SpikeAmplitudes data to a Zarr group using unified storage format
|
|
118
|
-
|
|
119
|
-
Args:
|
|
120
|
-
group: Zarr group to write data into
|
|
121
|
-
"""
|
|
122
|
-
# Set the view type
|
|
123
|
-
group.attrs["view_type"] = "SpikeAmplitudes"
|
|
124
|
-
|
|
125
|
-
# Store view parameters
|
|
126
|
-
group.attrs["start_time_sec"] = self.start_time_sec
|
|
127
|
-
group.attrs["end_time_sec"] = self.end_time_sec
|
|
128
|
-
|
|
129
|
-
# Prepare unified data arrays
|
|
130
|
-
unified_data = self._prepare_unified_data()
|
|
131
|
-
|
|
132
|
-
if unified_data["total_spikes"] == 0:
|
|
133
|
-
# Handle empty data case
|
|
134
|
-
group.create_dataset("timestamps", data=np.array([], dtype=np.float32))
|
|
135
|
-
group.create_dataset("unit_indices", data=np.array([], dtype=np.uint16))
|
|
136
|
-
group.create_dataset("amplitudes", data=np.array([], dtype=np.float32))
|
|
137
|
-
group.create_dataset("reference_times", data=np.array([], dtype=np.float32))
|
|
138
|
-
group.create_dataset(
|
|
139
|
-
"reference_indices", data=np.array([], dtype=np.uint32)
|
|
140
|
-
)
|
|
141
|
-
group.attrs["unit_ids"] = []
|
|
142
|
-
group.attrs["total_spikes"] = 0
|
|
143
|
-
return
|
|
144
|
-
|
|
145
|
-
chunks = (
|
|
146
|
-
(2_000_000,)
|
|
147
|
-
if unified_data["total_spikes"] > 2_000_000
|
|
148
|
-
else (len(unified_data["timestamps"]),)
|
|
149
|
-
)
|
|
150
|
-
# Store main data arrays
|
|
151
|
-
group.create_dataset(
|
|
152
|
-
"timestamps",
|
|
153
|
-
data=unified_data["timestamps"],
|
|
154
|
-
chunks=chunks,
|
|
155
|
-
)
|
|
156
|
-
group.create_dataset(
|
|
157
|
-
"unit_indices",
|
|
158
|
-
data=unified_data["unit_indices"],
|
|
159
|
-
chunks=chunks,
|
|
160
|
-
)
|
|
161
|
-
group.create_dataset(
|
|
162
|
-
"amplitudes",
|
|
163
|
-
data=unified_data["amplitudes"],
|
|
164
|
-
chunks=chunks,
|
|
165
|
-
)
|
|
166
|
-
group.create_dataset(
|
|
167
|
-
"reference_times",
|
|
168
|
-
data=unified_data["reference_times"],
|
|
169
|
-
chunks=(len(unified_data["reference_times"]),),
|
|
170
|
-
)
|
|
171
|
-
group.create_dataset(
|
|
172
|
-
"reference_indices",
|
|
173
|
-
data=unified_data["reference_indices"],
|
|
174
|
-
chunks=(len(unified_data["reference_indices"]),),
|
|
175
|
-
)
|
|
176
|
-
|
|
177
|
-
# Store unit ID mapping
|
|
178
|
-
group.attrs["unit_ids"] = unified_data["unit_ids"]
|
|
179
|
-
group.attrs["total_spikes"] = unified_data["total_spikes"]
|
|
180
|
-
|
|
181
|
-
# Create subsampled data
|
|
182
|
-
subsampled_data = self._create_subsampled_data(
|
|
183
|
-
unified_data["timestamps"],
|
|
184
|
-
unified_data["unit_indices"],
|
|
185
|
-
unified_data["amplitudes"],
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
if subsampled_data:
|
|
189
|
-
subsampled_group = group.create_group("subsampled_data")
|
|
190
|
-
for factor_name, data in subsampled_data.items():
|
|
191
|
-
chunks = (
|
|
192
|
-
(2_000_000,)
|
|
193
|
-
if len(data["timestamps"]) > 2_000_000
|
|
194
|
-
else (len(data["timestamps"]),)
|
|
195
|
-
)
|
|
196
|
-
factor_group = subsampled_group.create_group(factor_name)
|
|
197
|
-
factor_group.create_dataset(
|
|
198
|
-
"timestamps",
|
|
199
|
-
data=data["timestamps"],
|
|
200
|
-
chunks=chunks,
|
|
201
|
-
)
|
|
202
|
-
factor_group.create_dataset(
|
|
203
|
-
"unit_indices",
|
|
204
|
-
data=data["unit_indices"],
|
|
205
|
-
chunks=chunks,
|
|
206
|
-
)
|
|
207
|
-
factor_group.create_dataset(
|
|
208
|
-
"amplitudes",
|
|
209
|
-
data=data["amplitudes"],
|
|
210
|
-
chunks=chunks,
|
|
211
|
-
)
|
|
212
|
-
factor_group.create_dataset(
|
|
213
|
-
"reference_times",
|
|
214
|
-
data=data["reference_times"],
|
|
215
|
-
chunks=(len(data["reference_times"]),),
|
|
216
|
-
)
|
|
217
|
-
factor_group.create_dataset(
|
|
218
|
-
"reference_indices",
|
|
219
|
-
data=data["reference_indices"],
|
|
220
|
-
chunks=(len(data["reference_indices"]),),
|
|
221
|
-
)
|
|
222
|
-
|
|
223
|
-
def _prepare_unified_data(self) -> dict:
|
|
224
|
-
"""
|
|
225
|
-
Prepare unified data arrays from all plots
|
|
226
|
-
|
|
227
|
-
Returns:
|
|
228
|
-
Dictionary containing unified arrays and metadata
|
|
229
|
-
"""
|
|
230
|
-
if not self.plots:
|
|
231
|
-
return {
|
|
232
|
-
"timestamps": np.array([], dtype=np.float32),
|
|
233
|
-
"unit_indices": np.array([], dtype=np.uint16),
|
|
234
|
-
"amplitudes": np.array([], dtype=np.float32),
|
|
235
|
-
"reference_times": np.array([], dtype=np.float32),
|
|
236
|
-
"reference_indices": np.array([], dtype=np.uint32),
|
|
237
|
-
"unit_ids": [],
|
|
238
|
-
"total_spikes": 0,
|
|
239
|
-
}
|
|
240
|
-
|
|
241
|
-
# Create unit ID mapping
|
|
242
|
-
unit_ids = [str(plot.unit_id) for plot in self.plots]
|
|
243
|
-
unit_id_to_index = {unit_id: i for i, unit_id in enumerate(unit_ids)}
|
|
244
|
-
|
|
245
|
-
# Collect all spikes with their unit indices
|
|
246
|
-
all_spikes = []
|
|
247
|
-
for plot in self.plots:
|
|
248
|
-
unit_index = unit_id_to_index[str(plot.unit_id)]
|
|
249
|
-
for time, amplitude in zip(plot.spike_times_sec, plot.spike_amplitudes):
|
|
250
|
-
all_spikes.append((float(time), unit_index, float(amplitude)))
|
|
251
|
-
|
|
252
|
-
if not all_spikes:
|
|
253
|
-
return {
|
|
254
|
-
"timestamps": np.array([], dtype=np.float32),
|
|
255
|
-
"unit_indices": np.array([], dtype=np.uint16),
|
|
256
|
-
"amplitudes": np.array([], dtype=np.float32),
|
|
257
|
-
"reference_times": np.array([], dtype=np.float32),
|
|
258
|
-
"reference_indices": np.array([], dtype=np.uint32),
|
|
259
|
-
"unit_ids": unit_ids,
|
|
260
|
-
"total_spikes": 0,
|
|
261
|
-
}
|
|
262
|
-
|
|
263
|
-
# Sort by timestamp
|
|
264
|
-
all_spikes.sort(key=lambda x: x[0])
|
|
265
|
-
|
|
266
|
-
# Extract sorted arrays
|
|
267
|
-
timestamps = np.array([spike[0] for spike in all_spikes], dtype=np.float32)
|
|
268
|
-
unit_indices = np.array([spike[1] for spike in all_spikes], dtype=np.uint16)
|
|
269
|
-
amplitudes = np.array([spike[2] for spike in all_spikes], dtype=np.float32)
|
|
270
|
-
|
|
271
|
-
# Generate reference arrays
|
|
272
|
-
reference_times, reference_indices = self._generate_reference_arrays(timestamps)
|
|
273
|
-
|
|
274
|
-
return {
|
|
275
|
-
"timestamps": timestamps,
|
|
276
|
-
"unit_indices": unit_indices,
|
|
277
|
-
"amplitudes": amplitudes,
|
|
278
|
-
"reference_times": reference_times,
|
|
279
|
-
"reference_indices": reference_indices,
|
|
280
|
-
"unit_ids": unit_ids,
|
|
281
|
-
"total_spikes": len(all_spikes),
|
|
282
|
-
}
|
|
283
|
-
|
|
284
|
-
def _generate_reference_arrays(
|
|
285
|
-
self, timestamps: np.ndarray, interval_sec: float = 1.0
|
|
286
|
-
) -> tuple:
|
|
287
|
-
"""
|
|
288
|
-
Generate reference arrays using actual timestamps from the data
|
|
289
|
-
|
|
290
|
-
Args:
|
|
291
|
-
timestamps: Sorted array of timestamps
|
|
292
|
-
interval_sec: Minimum interval between reference points
|
|
293
|
-
|
|
294
|
-
Returns:
|
|
295
|
-
Tuple of (reference_times, reference_indices)
|
|
296
|
-
"""
|
|
297
|
-
if len(timestamps) == 0:
|
|
298
|
-
return np.array([], dtype=np.float32), np.array([], dtype=np.uint32)
|
|
299
|
-
|
|
300
|
-
reference_times = []
|
|
301
|
-
reference_indices = []
|
|
302
|
-
|
|
303
|
-
current_ref_time = timestamps[0]
|
|
304
|
-
reference_times.append(current_ref_time)
|
|
305
|
-
reference_indices.append(0)
|
|
306
|
-
|
|
307
|
-
# Find the next reference point at least interval_sec later
|
|
308
|
-
for i, timestamp in enumerate(timestamps):
|
|
309
|
-
if timestamp >= current_ref_time + interval_sec:
|
|
310
|
-
reference_times.append(timestamp)
|
|
311
|
-
reference_indices.append(i)
|
|
312
|
-
current_ref_time = timestamp
|
|
313
|
-
|
|
314
|
-
return np.array(reference_times, dtype=np.float32), np.array(
|
|
315
|
-
reference_indices, dtype=np.uint32
|
|
316
|
-
)
|
|
317
|
-
|
|
318
|
-
def _create_subsampled_data(
|
|
319
|
-
self, timestamps: np.ndarray, unit_indices: np.ndarray, amplitudes: np.ndarray
|
|
320
|
-
) -> dict:
|
|
321
|
-
"""
|
|
322
|
-
Create subsampled data with geometric progression factors
|
|
323
|
-
|
|
324
|
-
Args:
|
|
325
|
-
timestamps: Original timestamps array
|
|
326
|
-
unit_indices: Original unit indices array
|
|
327
|
-
amplitudes: Original amplitudes array
|
|
328
|
-
|
|
329
|
-
Returns:
|
|
330
|
-
Dictionary of subsampled data by factor
|
|
331
|
-
"""
|
|
332
|
-
subsampled_data = {}
|
|
333
|
-
factor = 4
|
|
334
|
-
current_timestamps = timestamps
|
|
335
|
-
current_unit_indices = unit_indices
|
|
336
|
-
current_amplitudes = amplitudes
|
|
337
|
-
|
|
338
|
-
while len(current_timestamps) >= 500000:
|
|
339
|
-
# Create subsampled version by taking every Nth spike
|
|
340
|
-
subsampled_indices = np.arange(0, len(current_timestamps), factor)
|
|
341
|
-
subsampled_timestamps = current_timestamps[subsampled_indices]
|
|
342
|
-
subsampled_unit_indices = current_unit_indices[subsampled_indices]
|
|
343
|
-
subsampled_amplitudes = current_amplitudes[subsampled_indices]
|
|
344
|
-
|
|
345
|
-
# Generate reference arrays for this subsampled level
|
|
346
|
-
ref_times, ref_indices = self._generate_reference_arrays(
|
|
347
|
-
subsampled_timestamps
|
|
348
|
-
)
|
|
349
|
-
|
|
350
|
-
subsampled_data[f"factor_{factor}"] = {
|
|
351
|
-
"timestamps": subsampled_timestamps,
|
|
352
|
-
"unit_indices": subsampled_unit_indices,
|
|
353
|
-
"amplitudes": subsampled_amplitudes,
|
|
354
|
-
"reference_times": ref_times,
|
|
355
|
-
"reference_indices": ref_indices,
|
|
356
|
-
}
|
|
357
|
-
|
|
358
|
-
# Prepare for next iteration
|
|
359
|
-
current_timestamps = subsampled_timestamps
|
|
360
|
-
current_unit_indices = subsampled_unit_indices
|
|
361
|
-
current_amplitudes = subsampled_amplitudes
|
|
362
|
-
factor *= 4 # Geometric progression: 4, 16, 64, 256, ...
|
|
363
|
-
|
|
364
|
-
return subsampled_data
|