figpack-spike-sorting 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,286 @@
1
+ """
2
+ RasterPlot view for figpack - displays multiple raster plots
3
+ """
4
+
5
+ from typing import List
6
+ import numpy as np
7
+
8
+ from .RasterPlotItem import RasterPlotItem
9
+ from .UnitsTable import UnitsTable, UnitsTableColumn, UnitsTableRow
10
+
11
+ import figpack
12
+ import figpack.views as fv
13
+ from ..spike_sorting_extension import spike_sorting_extension
14
+
15
+
16
+ class RasterPlot(figpack.ExtensionView):
17
+ """
18
+ A view that displays multiple raster plots for spike sorting analysis
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ *,
24
+ start_time_sec: float,
25
+ end_time_sec: float,
26
+ plots: List[RasterPlotItem],
27
+ ):
28
+ """
29
+ Initialize a RasterPlot view
30
+
31
+ Args:
32
+ start_time_sec: Start time in seconds for the plot range
33
+ end_time_sec: End time in seconds for the plot range
34
+ plots: List of RasterPlotItem objects
35
+ height: Height of the plot in pixels (default: 500)
36
+ """
37
+ super().__init__(
38
+ extension=spike_sorting_extension, view_type="spike_sorting.RasterPlot"
39
+ )
40
+ self.start_time_sec = float(start_time_sec)
41
+ self.end_time_sec = float(end_time_sec)
42
+ self.plots = plots
43
+
44
+ @staticmethod
45
+ def from_nwb_units_table(
46
+ nwb_url_or_path_or_h5py,
47
+ *,
48
+ units_path: str,
49
+ include_units_selector: bool = False,
50
+ ):
51
+ if isinstance(nwb_url_or_path_or_h5py, str):
52
+ import lindi
53
+
54
+ f = lindi.LindiH5pyFile.from_hdf5_file(nwb_url_or_path_or_h5py)
55
+ else:
56
+ f = nwb_url_or_path_or_h5py
57
+ X = f[units_path]
58
+ spike_times = X["spike_times"]
59
+ spike_times_index = X["spike_times_index"]
60
+ id = X["id"]
61
+ plots = []
62
+ num_units = len(spike_times_index)
63
+ start_times = []
64
+ end_times = []
65
+ for unit_index in range(num_units):
66
+ unit_id = id[unit_index]
67
+ if unit_index > 0:
68
+ start_index = spike_times_index[unit_index - 1]
69
+ else:
70
+ start_index = 0
71
+ end_index = spike_times_index[unit_index]
72
+ unit_spike_times = spike_times[start_index:end_index]
73
+ if len(unit_spike_times) == 0:
74
+ continue
75
+ start_times.append(unit_spike_times[0])
76
+ end_times.append(unit_spike_times[-1])
77
+ plots.append(
78
+ RasterPlotItem(unit_id=str(unit_id), spike_times_sec=unit_spike_times)
79
+ )
80
+ view = RasterPlot(
81
+ start_time_sec=min(start_times),
82
+ end_time_sec=max(end_times),
83
+ plots=plots,
84
+ )
85
+ if include_units_selector:
86
+ columns: List[UnitsTableColumn] = [
87
+ UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
88
+ ]
89
+ rows: List[UnitsTableRow] = []
90
+ for unit_id in id:
91
+ rows.append(
92
+ UnitsTableRow(
93
+ unit_id=str(unit_id),
94
+ values={},
95
+ )
96
+ )
97
+ units_table = UnitsTable(
98
+ columns=columns,
99
+ rows=rows,
100
+ )
101
+ layout = fv.Box(
102
+ direction="horizontal",
103
+ items=[
104
+ fv.LayoutItem(view=units_table, max_size=150, title="Units"),
105
+ fv.LayoutItem(view=view, title="Spike Amplitudes"),
106
+ ],
107
+ )
108
+ return layout
109
+ else:
110
+ return view
111
+
112
+ def _write_to_zarr_group(self, group: figpack.Group) -> None:
113
+ """
114
+ Args:
115
+ group: Zarr group to write data into
116
+ """
117
+ super()._write_to_zarr_group(group)
118
+
119
+ # Store view parameters
120
+ group.attrs["start_time_sec"] = self.start_time_sec
121
+ group.attrs["end_time_sec"] = self.end_time_sec
122
+
123
+ # Prepare unified data arrays
124
+ unified_data = self._prepare_unified_data()
125
+
126
+ if unified_data["total_spikes"] == 0:
127
+ # Handle empty data case
128
+ group.create_dataset("timestamps", data=np.array([], dtype=np.float32))
129
+ group.create_dataset("unit_indices", data=np.array([], dtype=np.uint16))
130
+ group.create_dataset("reference_times", data=np.array([], dtype=np.float32))
131
+ group.create_dataset(
132
+ "reference_indices", data=np.array([], dtype=np.uint32)
133
+ )
134
+ group.attrs["unit_ids"] = []
135
+ group.attrs["total_spikes"] = 0
136
+ return
137
+
138
+ chunks = (
139
+ (2_000_000,)
140
+ if unified_data["total_spikes"] > 2_000_000
141
+ else (len(unified_data["timestamps"]),)
142
+ )
143
+ # Store main data arrays
144
+ group.create_dataset(
145
+ "timestamps",
146
+ data=unified_data["timestamps"],
147
+ chunks=chunks,
148
+ )
149
+ group.create_dataset(
150
+ "unit_indices",
151
+ data=unified_data["unit_indices"],
152
+ chunks=chunks,
153
+ )
154
+ group.create_dataset(
155
+ "reference_times",
156
+ data=unified_data["reference_times"],
157
+ chunks=(len(unified_data["reference_times"]),),
158
+ )
159
+ group.create_dataset(
160
+ "reference_indices",
161
+ data=unified_data["reference_indices"],
162
+ chunks=(len(unified_data["reference_indices"]),),
163
+ )
164
+
165
+ # Create spike counts array with 1-second bins
166
+ duration = self.end_time_sec - self.start_time_sec
167
+ num_bins = int(np.ceil(duration))
168
+ num_units = len(self.plots)
169
+ spike_counts = np.zeros((num_bins, num_units), dtype=np.uint16)
170
+
171
+ # Efficiently compute spike counts for each unit
172
+ for unit_idx, plot in enumerate(self.plots):
173
+ # Convert spike times to bin indices
174
+ bin_indices = (
175
+ (np.array(plot.spike_times_sec) - self.start_time_sec)
176
+ ).astype(int)
177
+ # Count spikes in valid bins
178
+ valid_indices = (bin_indices >= 0) & (bin_indices < num_bins)
179
+ unique_bins, counts = np.unique(
180
+ bin_indices[valid_indices], return_counts=True
181
+ )
182
+ spike_counts[unique_bins, unit_idx] = counts.clip(
183
+ max=65535
184
+ ) # Clip to uint16 max
185
+
186
+ # Store spike counts array
187
+ group.create_dataset(
188
+ "spike_counts_1sec",
189
+ data=spike_counts,
190
+ chunks=(min(num_bins, 10000), min(num_units, 500)),
191
+ )
192
+
193
+ # Store unit ID mapping
194
+ group.attrs["unit_ids"] = unified_data["unit_ids"]
195
+ group.attrs["total_spikes"] = unified_data["total_spikes"]
196
+
197
+ def _prepare_unified_data(self) -> dict:
198
+ """
199
+ Prepare unified data arrays from all plots
200
+
201
+ Returns:
202
+ Dictionary containing unified arrays and metadata
203
+ """
204
+ if not self.plots:
205
+ return {
206
+ "timestamps": np.array([], dtype=np.float32),
207
+ "unit_indices": np.array([], dtype=np.uint16),
208
+ "reference_times": np.array([], dtype=np.float32),
209
+ "reference_indices": np.array([], dtype=np.uint32),
210
+ "unit_ids": [],
211
+ "total_spikes": 0,
212
+ }
213
+
214
+ # Create unit ID mapping
215
+ unit_ids = [str(plot.unit_id) for plot in self.plots]
216
+ unit_id_to_index = {unit_id: i for i, unit_id in enumerate(unit_ids)}
217
+
218
+ # Collect all spikes with their unit indices
219
+ all_spikes = []
220
+ for plot in self.plots:
221
+ unit_index = unit_id_to_index[str(plot.unit_id)]
222
+ for time in plot.spike_times_sec:
223
+ all_spikes.append((float(time), unit_index))
224
+
225
+ if not all_spikes:
226
+ return {
227
+ "timestamps": np.array([], dtype=np.float32),
228
+ "unit_indices": np.array([], dtype=np.uint16),
229
+ "reference_times": np.array([], dtype=np.float32),
230
+ "reference_indices": np.array([], dtype=np.uint32),
231
+ "unit_ids": unit_ids,
232
+ "total_spikes": 0,
233
+ }
234
+
235
+ # Sort by timestamp
236
+ all_spikes.sort(key=lambda x: x[0])
237
+
238
+ # Extract sorted arrays
239
+ timestamps = np.array([spike[0] for spike in all_spikes], dtype=np.float32)
240
+ unit_indices = np.array([spike[1] for spike in all_spikes], dtype=np.uint16)
241
+
242
+ # Generate reference arrays
243
+ reference_times, reference_indices = self._generate_reference_arrays(timestamps)
244
+
245
+ return {
246
+ "timestamps": timestamps,
247
+ "unit_indices": unit_indices,
248
+ "reference_times": reference_times,
249
+ "reference_indices": reference_indices,
250
+ "unit_ids": unit_ids,
251
+ "total_spikes": len(all_spikes),
252
+ }
253
+
254
+ def _generate_reference_arrays(
255
+ self, timestamps: np.ndarray, interval_sec: float = 1.0
256
+ ) -> tuple:
257
+ """
258
+ Generate reference arrays using actual timestamps from the data
259
+
260
+ Args:
261
+ timestamps: Sorted array of timestamps
262
+ interval_sec: Minimum interval between reference points
263
+
264
+ Returns:
265
+ Tuple of (reference_times, reference_indices)
266
+ """
267
+ if len(timestamps) == 0:
268
+ return np.array([], dtype=np.float32), np.array([], dtype=np.uint32)
269
+
270
+ reference_times = []
271
+ reference_indices = []
272
+
273
+ current_ref_time = timestamps[0]
274
+ reference_times.append(current_ref_time)
275
+ reference_indices.append(0)
276
+
277
+ # Find the next reference point at least interval_sec later
278
+ for i, timestamp in enumerate(timestamps):
279
+ if timestamp >= current_ref_time + interval_sec:
280
+ reference_times.append(timestamp)
281
+ reference_indices.append(i)
282
+ current_ref_time = timestamp
283
+
284
+ return np.array(reference_times, dtype=np.float32), np.array(
285
+ reference_indices, dtype=np.uint32
286
+ )
@@ -0,0 +1,28 @@
1
+ """
2
+ RasterPlotItem for figpack - represents a single unit's raster plot
3
+ """
4
+
5
+ from typing import Union
6
+ import numpy as np
7
+
8
+
9
+ class RasterPlotItem:
10
+ """
11
+ Represents spike times for a single unit in a raster plot
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ *,
17
+ unit_id: Union[str, int],
18
+ spike_times_sec: np.ndarray,
19
+ ):
20
+ """
21
+ Initialize a RasterPlotItem
22
+
23
+ Args:
24
+ unit_id: Identifier for the unit
25
+ spike_times_sec: Numpy array of spike times in seconds
26
+ """
27
+ self.unit_id = unit_id
28
+ self.spike_times_sec = np.array(spike_times_sec, dtype=np.float32)
@@ -0,0 +1,366 @@
1
+ """
2
+ SpikeAmplitudes view for figpack - displays spike amplitudes over time
3
+ """
4
+
5
+ from typing import List
6
+
7
+ import numpy as np
8
+
9
+ from .SpikeAmplitudesItem import SpikeAmplitudesItem
10
+ from .UnitsTable import UnitsTable, UnitsTableColumn, UnitsTableRow
11
+
12
+ import figpack
13
+ from ..spike_sorting_extension import spike_sorting_extension
14
+
15
+
16
+ class SpikeAmplitudes(figpack.ExtensionView):
17
+ """
18
+ A view that displays spike amplitudes over time for multiple units
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ *,
24
+ start_time_sec: float,
25
+ end_time_sec: float,
26
+ plots: List[SpikeAmplitudesItem],
27
+ ):
28
+ """
29
+ Initialize a SpikeAmplitudes view
30
+
31
+ Args:
32
+ start_time_sec: Start time of the view in seconds
33
+ end_time_sec: End time of the view in seconds
34
+ plots: List of SpikeAmplitudesItem objects
35
+ """
36
+ super().__init__(
37
+ extension=spike_sorting_extension, view_type="spike_sorting.SpikeAmplitudes"
38
+ )
39
+ self.start_time_sec = start_time_sec
40
+ self.end_time_sec = end_time_sec
41
+ self.plots = plots
42
+
43
+ @staticmethod
44
+ def from_nwb_units_table(
45
+ nwb_url_or_path_or_h5py,
46
+ *,
47
+ units_path: str,
48
+ include_units_selector: bool = False,
49
+ ):
50
+ if isinstance(nwb_url_or_path_or_h5py, str):
51
+ import lindi
52
+
53
+ f = lindi.LindiH5pyFile.from_hdf5_file(nwb_url_or_path_or_h5py)
54
+ else:
55
+ f = nwb_url_or_path_or_h5py
56
+ X = f[units_path]
57
+ spike_amplitudes = X["spike_amplitudes"]
58
+ # spike_amplitudes_index = X["spike_amplitudes_index"] # presumably the same as spike_times_index
59
+ spike_times = X["spike_times"]
60
+ spike_times_index = X["spike_times_index"]
61
+ id = X["id"]
62
+ plots = []
63
+ num_units = len(spike_times_index)
64
+ start_times = []
65
+ end_times = []
66
+ for unit_index in range(num_units):
67
+ unit_id = id[unit_index]
68
+ if unit_index > 0:
69
+ start_index = spike_times_index[unit_index - 1]
70
+ else:
71
+ start_index = 0
72
+ end_index = spike_times_index[unit_index]
73
+ unit_spike_amplitudes = spike_amplitudes[start_index:end_index]
74
+ unit_spike_times = spike_times[start_index:end_index]
75
+ if len(unit_spike_times) == 0:
76
+ continue
77
+ start_times.append(unit_spike_times[0])
78
+ end_times.append(unit_spike_times[-1])
79
+ plots.append(
80
+ SpikeAmplitudesItem(
81
+ unit_id=str(unit_id),
82
+ spike_times_sec=unit_spike_times,
83
+ spike_amplitudes=unit_spike_amplitudes,
84
+ )
85
+ )
86
+ view = SpikeAmplitudes(
87
+ start_time_sec=min(start_times),
88
+ end_time_sec=max(end_times),
89
+ plots=plots,
90
+ )
91
+ if include_units_selector:
92
+ columns: List[UnitsTableColumn] = [
93
+ UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
94
+ ]
95
+ rows: List[UnitsTableRow] = []
96
+ for unit_id in id:
97
+ rows.append(
98
+ UnitsTableRow(
99
+ unit_id=str(unit_id),
100
+ values={},
101
+ )
102
+ )
103
+ units_table = UnitsTable(
104
+ columns=columns,
105
+ rows=rows,
106
+ )
107
+ layout = Box(
108
+ direction="horizontal",
109
+ items=[
110
+ LayoutItem(view=units_table, max_size=150, title="Units"),
111
+ LayoutItem(view=view, title="Spike Amplitudes"),
112
+ ],
113
+ )
114
+ return layout
115
+ else:
116
+ return view
117
+
118
+ def _write_to_zarr_group(self, group: figpack.Group) -> None:
119
+ """
120
+ Write the SpikeAmplitudes data to a Zarr group using unified storage format
121
+
122
+ Args:
123
+ group: Zarr group to write data into
124
+ """
125
+ super()._write_to_zarr_group(group)
126
+
127
+ # Store view parameters
128
+ group.attrs["start_time_sec"] = self.start_time_sec
129
+ group.attrs["end_time_sec"] = self.end_time_sec
130
+
131
+ # Prepare unified data arrays
132
+ unified_data = self._prepare_unified_data()
133
+
134
+ if unified_data["total_spikes"] == 0:
135
+ # Handle empty data case
136
+ group.create_dataset("timestamps", data=np.array([], dtype=np.float32))
137
+ group.create_dataset("unit_indices", data=np.array([], dtype=np.uint16))
138
+ group.create_dataset("amplitudes", data=np.array([], dtype=np.float32))
139
+ group.create_dataset("reference_times", data=np.array([], dtype=np.float32))
140
+ group.create_dataset(
141
+ "reference_indices", data=np.array([], dtype=np.uint32)
142
+ )
143
+ group.attrs["unit_ids"] = []
144
+ group.attrs["total_spikes"] = 0
145
+ return
146
+
147
+ chunks = (
148
+ (2_000_000,)
149
+ if unified_data["total_spikes"] > 2_000_000
150
+ else (len(unified_data["timestamps"]),)
151
+ )
152
+ # Store main data arrays
153
+ group.create_dataset(
154
+ "timestamps",
155
+ data=unified_data["timestamps"],
156
+ chunks=chunks,
157
+ )
158
+ group.create_dataset(
159
+ "unit_indices",
160
+ data=unified_data["unit_indices"],
161
+ chunks=chunks,
162
+ )
163
+ group.create_dataset(
164
+ "amplitudes",
165
+ data=unified_data["amplitudes"],
166
+ chunks=chunks,
167
+ )
168
+ group.create_dataset(
169
+ "reference_times",
170
+ data=unified_data["reference_times"],
171
+ chunks=(len(unified_data["reference_times"]),),
172
+ )
173
+ group.create_dataset(
174
+ "reference_indices",
175
+ data=unified_data["reference_indices"],
176
+ chunks=(len(unified_data["reference_indices"]),),
177
+ )
178
+
179
+ # Store unit ID mapping
180
+ group.attrs["unit_ids"] = unified_data["unit_ids"]
181
+ group.attrs["total_spikes"] = unified_data["total_spikes"]
182
+
183
+ # Create subsampled data
184
+ subsampled_data = self._create_subsampled_data(
185
+ unified_data["timestamps"],
186
+ unified_data["unit_indices"],
187
+ unified_data["amplitudes"],
188
+ )
189
+
190
+ if subsampled_data:
191
+ subsampled_group = group.create_group("subsampled_data")
192
+ for factor_name, data in subsampled_data.items():
193
+ chunks = (
194
+ (2_000_000,)
195
+ if len(data["timestamps"]) > 2_000_000
196
+ else (len(data["timestamps"]),)
197
+ )
198
+ factor_group = subsampled_group.create_group(factor_name)
199
+ factor_group.create_dataset(
200
+ "timestamps",
201
+ data=data["timestamps"],
202
+ chunks=chunks,
203
+ )
204
+ factor_group.create_dataset(
205
+ "unit_indices",
206
+ data=data["unit_indices"],
207
+ chunks=chunks,
208
+ )
209
+ factor_group.create_dataset(
210
+ "amplitudes",
211
+ data=data["amplitudes"],
212
+ chunks=chunks,
213
+ )
214
+ factor_group.create_dataset(
215
+ "reference_times",
216
+ data=data["reference_times"],
217
+ chunks=(len(data["reference_times"]),),
218
+ )
219
+ factor_group.create_dataset(
220
+ "reference_indices",
221
+ data=data["reference_indices"],
222
+ chunks=(len(data["reference_indices"]),),
223
+ )
224
+
225
+ def _prepare_unified_data(self) -> dict:
226
+ """
227
+ Prepare unified data arrays from all plots
228
+
229
+ Returns:
230
+ Dictionary containing unified arrays and metadata
231
+ """
232
+ if not self.plots:
233
+ return {
234
+ "timestamps": np.array([], dtype=np.float32),
235
+ "unit_indices": np.array([], dtype=np.uint16),
236
+ "amplitudes": np.array([], dtype=np.float32),
237
+ "reference_times": np.array([], dtype=np.float32),
238
+ "reference_indices": np.array([], dtype=np.uint32),
239
+ "unit_ids": [],
240
+ "total_spikes": 0,
241
+ }
242
+
243
+ # Create unit ID mapping
244
+ unit_ids = [str(plot.unit_id) for plot in self.plots]
245
+ unit_id_to_index = {unit_id: i for i, unit_id in enumerate(unit_ids)}
246
+
247
+ # Collect all spikes with their unit indices
248
+ all_spikes = []
249
+ for plot in self.plots:
250
+ unit_index = unit_id_to_index[str(plot.unit_id)]
251
+ for time, amplitude in zip(plot.spike_times_sec, plot.spike_amplitudes):
252
+ all_spikes.append((float(time), unit_index, float(amplitude)))
253
+
254
+ if not all_spikes:
255
+ return {
256
+ "timestamps": np.array([], dtype=np.float32),
257
+ "unit_indices": np.array([], dtype=np.uint16),
258
+ "amplitudes": np.array([], dtype=np.float32),
259
+ "reference_times": np.array([], dtype=np.float32),
260
+ "reference_indices": np.array([], dtype=np.uint32),
261
+ "unit_ids": unit_ids,
262
+ "total_spikes": 0,
263
+ }
264
+
265
+ # Sort by timestamp
266
+ all_spikes.sort(key=lambda x: x[0])
267
+
268
+ # Extract sorted arrays
269
+ timestamps = np.array([spike[0] for spike in all_spikes], dtype=np.float32)
270
+ unit_indices = np.array([spike[1] for spike in all_spikes], dtype=np.uint16)
271
+ amplitudes = np.array([spike[2] for spike in all_spikes], dtype=np.float32)
272
+
273
+ # Generate reference arrays
274
+ reference_times, reference_indices = self._generate_reference_arrays(timestamps)
275
+
276
+ return {
277
+ "timestamps": timestamps,
278
+ "unit_indices": unit_indices,
279
+ "amplitudes": amplitudes,
280
+ "reference_times": reference_times,
281
+ "reference_indices": reference_indices,
282
+ "unit_ids": unit_ids,
283
+ "total_spikes": len(all_spikes),
284
+ }
285
+
286
+ def _generate_reference_arrays(
287
+ self, timestamps: np.ndarray, interval_sec: float = 1.0
288
+ ) -> tuple:
289
+ """
290
+ Generate reference arrays using actual timestamps from the data
291
+
292
+ Args:
293
+ timestamps: Sorted array of timestamps
294
+ interval_sec: Minimum interval between reference points
295
+
296
+ Returns:
297
+ Tuple of (reference_times, reference_indices)
298
+ """
299
+ if len(timestamps) == 0:
300
+ return np.array([], dtype=np.float32), np.array([], dtype=np.uint32)
301
+
302
+ reference_times = []
303
+ reference_indices = []
304
+
305
+ current_ref_time = timestamps[0]
306
+ reference_times.append(current_ref_time)
307
+ reference_indices.append(0)
308
+
309
+ # Find the next reference point at least interval_sec later
310
+ for i, timestamp in enumerate(timestamps):
311
+ if timestamp >= current_ref_time + interval_sec:
312
+ reference_times.append(timestamp)
313
+ reference_indices.append(i)
314
+ current_ref_time = timestamp
315
+
316
+ return np.array(reference_times, dtype=np.float32), np.array(
317
+ reference_indices, dtype=np.uint32
318
+ )
319
+
320
+ def _create_subsampled_data(
321
+ self, timestamps: np.ndarray, unit_indices: np.ndarray, amplitudes: np.ndarray
322
+ ) -> dict:
323
+ """
324
+ Create subsampled data with geometric progression factors
325
+
326
+ Args:
327
+ timestamps: Original timestamps array
328
+ unit_indices: Original unit indices array
329
+ amplitudes: Original amplitudes array
330
+
331
+ Returns:
332
+ Dictionary of subsampled data by factor
333
+ """
334
+ subsampled_data = {}
335
+ factor = 4
336
+ current_timestamps = timestamps
337
+ current_unit_indices = unit_indices
338
+ current_amplitudes = amplitudes
339
+
340
+ while len(current_timestamps) >= 500000:
341
+ # Create subsampled version by taking every Nth spike
342
+ subsampled_indices = np.arange(0, len(current_timestamps), factor)
343
+ subsampled_timestamps = current_timestamps[subsampled_indices]
344
+ subsampled_unit_indices = current_unit_indices[subsampled_indices]
345
+ subsampled_amplitudes = current_amplitudes[subsampled_indices]
346
+
347
+ # Generate reference arrays for this subsampled level
348
+ ref_times, ref_indices = self._generate_reference_arrays(
349
+ subsampled_timestamps
350
+ )
351
+
352
+ subsampled_data[f"factor_{factor}"] = {
353
+ "timestamps": subsampled_timestamps,
354
+ "unit_indices": subsampled_unit_indices,
355
+ "amplitudes": subsampled_amplitudes,
356
+ "reference_times": ref_times,
357
+ "reference_indices": ref_indices,
358
+ }
359
+
360
+ # Prepare for next iteration
361
+ current_timestamps = subsampled_timestamps
362
+ current_unit_indices = subsampled_unit_indices
363
+ current_amplitudes = subsampled_amplitudes
364
+ factor *= 4 # Geometric progression: 4, 16, 64, 256, ...
365
+
366
+ return subsampled_data