figpack 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. figpack/__init__.py +4 -3
  2. figpack/core/__init__.py +2 -2
  3. figpack/core/_bundle_utils.py +80 -33
  4. figpack/core/_view_figure.py +1 -1
  5. figpack/core/extension_view.py +9 -27
  6. figpack/core/figpack_extension.py +0 -71
  7. figpack/core/figpack_view.py +2 -2
  8. figpack/core/zarr.py +61 -0
  9. figpack/figpack-figure-dist/assets/index-DBwmtEpB.js +91 -0
  10. figpack/figpack-figure-dist/assets/{index-D9a3K6eW.css → index-DHWczh-Q.css} +1 -1
  11. figpack/figpack-figure-dist/index.html +2 -2
  12. figpack/views/Box.py +2 -3
  13. figpack/views/DataFrame.py +6 -12
  14. figpack/views/Gallery.py +2 -3
  15. figpack/views/Image.py +3 -9
  16. figpack/views/Markdown.py +3 -4
  17. figpack/views/MatplotlibFigure.py +2 -11
  18. figpack/views/MultiChannelTimeseries.py +2 -6
  19. figpack/views/PlotlyExtension/PlotlyExtension.py +8 -60
  20. figpack/views/PlotlyExtension/_plotly_extension.py +46 -0
  21. figpack/views/PlotlyExtension/plotly_view.js +84 -80
  22. figpack/views/Spectrogram.py +2 -7
  23. figpack/views/Splitter.py +2 -3
  24. figpack/views/TabLayout.py +2 -3
  25. figpack/views/TimeseriesGraph.py +6 -10
  26. figpack/views/__init__.py +1 -0
  27. {figpack-0.2.15.dist-info → figpack-0.2.17.dist-info}/METADATA +21 -2
  28. figpack-0.2.17.dist-info/RECORD +43 -0
  29. figpack/figpack-figure-dist/assets/index-DtOnN02w.js +0 -846
  30. figpack/franklab/__init__.py +0 -5
  31. figpack/franklab/views/TrackAnimation.py +0 -153
  32. figpack/franklab/views/__init__.py +0 -9
  33. figpack/spike_sorting/__init__.py +0 -5
  34. figpack/spike_sorting/views/AutocorrelogramItem.py +0 -32
  35. figpack/spike_sorting/views/Autocorrelograms.py +0 -118
  36. figpack/spike_sorting/views/AverageWaveforms.py +0 -147
  37. figpack/spike_sorting/views/CrossCorrelogramItem.py +0 -35
  38. figpack/spike_sorting/views/CrossCorrelograms.py +0 -132
  39. figpack/spike_sorting/views/RasterPlot.py +0 -288
  40. figpack/spike_sorting/views/RasterPlotItem.py +0 -28
  41. figpack/spike_sorting/views/SpikeAmplitudes.py +0 -374
  42. figpack/spike_sorting/views/SpikeAmplitudesItem.py +0 -38
  43. figpack/spike_sorting/views/UnitMetricsGraph.py +0 -129
  44. figpack/spike_sorting/views/UnitSimilarityScore.py +0 -40
  45. figpack/spike_sorting/views/UnitsTable.py +0 -89
  46. figpack/spike_sorting/views/UnitsTableColumn.py +0 -40
  47. figpack/spike_sorting/views/UnitsTableRow.py +0 -36
  48. figpack/spike_sorting/views/__init__.py +0 -41
  49. figpack-0.2.15.dist-info/RECORD +0 -60
  50. {figpack-0.2.15.dist-info → figpack-0.2.17.dist-info}/WHEEL +0 -0
  51. {figpack-0.2.15.dist-info → figpack-0.2.17.dist-info}/entry_points.txt +0 -0
  52. {figpack-0.2.15.dist-info → figpack-0.2.17.dist-info}/licenses/LICENSE +0 -0
  53. {figpack-0.2.15.dist-info → figpack-0.2.17.dist-info}/top_level.txt +0 -0
@@ -1,288 +0,0 @@
1
- """
2
- RasterPlot view for figpack - displays multiple raster plots
3
- """
4
-
5
- from typing import List
6
- import numpy as np
7
- import zarr
8
-
9
- from ...core.figpack_view import FigpackView
10
- from .RasterPlotItem import RasterPlotItem
11
- from .UnitsTable import UnitsTable, UnitsTableColumn, UnitsTableRow
12
- from ...views.Box import Box, LayoutItem
13
-
14
-
15
- class RasterPlot(FigpackView):
16
- """
17
- A view that displays multiple raster plots for spike sorting analysis
18
- """
19
-
20
- def __init__(
21
- self,
22
- *,
23
- start_time_sec: float,
24
- end_time_sec: float,
25
- plots: List[RasterPlotItem],
26
- ):
27
- """
28
- Initialize a RasterPlot view
29
-
30
- Args:
31
- start_time_sec: Start time in seconds for the plot range
32
- end_time_sec: End time in seconds for the plot range
33
- plots: List of RasterPlotItem objects
34
- height: Height of the plot in pixels (default: 500)
35
- """
36
- self.start_time_sec = float(start_time_sec)
37
- self.end_time_sec = float(end_time_sec)
38
- self.plots = plots
39
-
40
- @staticmethod
41
- def from_nwb_units_table(
42
- nwb_url_or_path_or_h5py,
43
- *,
44
- units_path: str,
45
- include_units_selector: bool = False,
46
- ):
47
- if isinstance(nwb_url_or_path_or_h5py, str):
48
- import lindi
49
-
50
- f = lindi.LindiH5pyFile.from_hdf5_file(nwb_url_or_path_or_h5py)
51
- else:
52
- f = nwb_url_or_path_or_h5py
53
- X = f[units_path]
54
- spike_times = X["spike_times"]
55
- spike_times_index = X["spike_times_index"]
56
- id = X["id"]
57
- plots = []
58
- num_units = len(spike_times_index)
59
- start_times = []
60
- end_times = []
61
- for unit_index in range(num_units):
62
- unit_id = id[unit_index]
63
- if unit_index > 0:
64
- start_index = spike_times_index[unit_index - 1]
65
- else:
66
- start_index = 0
67
- end_index = spike_times_index[unit_index]
68
- unit_spike_times = spike_times[start_index:end_index]
69
- if len(unit_spike_times) == 0:
70
- continue
71
- start_times.append(unit_spike_times[0])
72
- end_times.append(unit_spike_times[-1])
73
- plots.append(
74
- RasterPlotItem(unit_id=str(unit_id), spike_times_sec=unit_spike_times)
75
- )
76
- view = RasterPlot(
77
- start_time_sec=min(start_times),
78
- end_time_sec=max(end_times),
79
- plots=plots,
80
- )
81
- if include_units_selector:
82
- columns: List[UnitsTableColumn] = [
83
- UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
84
- ]
85
- rows: List[UnitsTableRow] = []
86
- for unit_id in id:
87
- rows.append(
88
- UnitsTableRow(
89
- unit_id=str(unit_id),
90
- values={},
91
- )
92
- )
93
- units_table = UnitsTable(
94
- columns=columns,
95
- rows=rows,
96
- )
97
- layout = Box(
98
- direction="horizontal",
99
- items=[
100
- LayoutItem(view=units_table, max_size=150, title="Units"),
101
- LayoutItem(view=view, title="Spike Amplitudes"),
102
- ],
103
- )
104
- return layout
105
- else:
106
- return view
107
-
108
- def _write_to_zarr_group(self, group: zarr.Group) -> None:
109
- """
110
- Args:
111
- group: Zarr group to write data into
112
- """
113
- # Set the view type
114
- group.attrs["view_type"] = "RasterPlot"
115
-
116
- # Store view parameters
117
- group.attrs["start_time_sec"] = self.start_time_sec
118
- group.attrs["end_time_sec"] = self.end_time_sec
119
-
120
- # Prepare unified data arrays
121
- unified_data = self._prepare_unified_data()
122
-
123
- if unified_data["total_spikes"] == 0:
124
- # Handle empty data case
125
- group.create_dataset("timestamps", data=np.array([], dtype=np.float32))
126
- group.create_dataset("unit_indices", data=np.array([], dtype=np.uint16))
127
- group.create_dataset("reference_times", data=np.array([], dtype=np.float32))
128
- group.create_dataset(
129
- "reference_indices", data=np.array([], dtype=np.uint32)
130
- )
131
- group.attrs["unit_ids"] = []
132
- group.attrs["total_spikes"] = 0
133
- return
134
-
135
- chunks = (
136
- (2_000_000,)
137
- if unified_data["total_spikes"] > 2_000_000
138
- else (len(unified_data["timestamps"]),)
139
- )
140
- # Store main data arrays
141
- group.create_dataset(
142
- "timestamps",
143
- data=unified_data["timestamps"],
144
- dtype=np.float32,
145
- chunks=chunks,
146
- )
147
- group.create_dataset(
148
- "unit_indices",
149
- data=unified_data["unit_indices"],
150
- dtype=np.uint16,
151
- chunks=chunks,
152
- )
153
- group.create_dataset(
154
- "reference_times",
155
- data=unified_data["reference_times"],
156
- dtype=np.float32,
157
- chunks=(len(unified_data["reference_times"]),),
158
- )
159
- group.create_dataset(
160
- "reference_indices",
161
- data=unified_data["reference_indices"],
162
- dtype=np.uint32,
163
- chunks=(len(unified_data["reference_indices"]),),
164
- )
165
-
166
- # Create spike counts array with 1-second bins
167
- duration = self.end_time_sec - self.start_time_sec
168
- num_bins = int(np.ceil(duration))
169
- num_units = len(self.plots)
170
- spike_counts = np.zeros((num_bins, num_units), dtype=np.uint16)
171
-
172
- # Efficiently compute spike counts for each unit
173
- for unit_idx, plot in enumerate(self.plots):
174
- # Convert spike times to bin indices
175
- bin_indices = (
176
- (np.array(plot.spike_times_sec) - self.start_time_sec)
177
- ).astype(int)
178
- # Count spikes in valid bins
179
- valid_indices = (bin_indices >= 0) & (bin_indices < num_bins)
180
- unique_bins, counts = np.unique(
181
- bin_indices[valid_indices], return_counts=True
182
- )
183
- spike_counts[unique_bins, unit_idx] = counts.clip(
184
- max=65535
185
- ) # Clip to uint16 max
186
-
187
- # Store spike counts array
188
- group.create_dataset(
189
- "spike_counts_1sec",
190
- data=spike_counts,
191
- dtype=np.uint16,
192
- chunks=(min(num_bins, 10000), min(num_units, 500)),
193
- )
194
-
195
- # Store unit ID mapping
196
- group.attrs["unit_ids"] = unified_data["unit_ids"]
197
- group.attrs["total_spikes"] = unified_data["total_spikes"]
198
-
199
- def _prepare_unified_data(self) -> dict:
200
- """
201
- Prepare unified data arrays from all plots
202
-
203
- Returns:
204
- Dictionary containing unified arrays and metadata
205
- """
206
- if not self.plots:
207
- return {
208
- "timestamps": np.array([], dtype=np.float32),
209
- "unit_indices": np.array([], dtype=np.uint16),
210
- "reference_times": np.array([], dtype=np.float32),
211
- "reference_indices": np.array([], dtype=np.uint32),
212
- "unit_ids": [],
213
- "total_spikes": 0,
214
- }
215
-
216
- # Create unit ID mapping
217
- unit_ids = [str(plot.unit_id) for plot in self.plots]
218
- unit_id_to_index = {unit_id: i for i, unit_id in enumerate(unit_ids)}
219
-
220
- # Collect all spikes with their unit indices
221
- all_spikes = []
222
- for plot in self.plots:
223
- unit_index = unit_id_to_index[str(plot.unit_id)]
224
- for time in plot.spike_times_sec:
225
- all_spikes.append((float(time), unit_index))
226
-
227
- if not all_spikes:
228
- return {
229
- "timestamps": np.array([], dtype=np.float32),
230
- "unit_indices": np.array([], dtype=np.uint16),
231
- "reference_times": np.array([], dtype=np.float32),
232
- "reference_indices": np.array([], dtype=np.uint32),
233
- "unit_ids": unit_ids,
234
- "total_spikes": 0,
235
- }
236
-
237
- # Sort by timestamp
238
- all_spikes.sort(key=lambda x: x[0])
239
-
240
- # Extract sorted arrays
241
- timestamps = np.array([spike[0] for spike in all_spikes], dtype=np.float32)
242
- unit_indices = np.array([spike[1] for spike in all_spikes], dtype=np.uint16)
243
-
244
- # Generate reference arrays
245
- reference_times, reference_indices = self._generate_reference_arrays(timestamps)
246
-
247
- return {
248
- "timestamps": timestamps,
249
- "unit_indices": unit_indices,
250
- "reference_times": reference_times,
251
- "reference_indices": reference_indices,
252
- "unit_ids": unit_ids,
253
- "total_spikes": len(all_spikes),
254
- }
255
-
256
- def _generate_reference_arrays(
257
- self, timestamps: np.ndarray, interval_sec: float = 1.0
258
- ) -> tuple:
259
- """
260
- Generate reference arrays using actual timestamps from the data
261
-
262
- Args:
263
- timestamps: Sorted array of timestamps
264
- interval_sec: Minimum interval between reference points
265
-
266
- Returns:
267
- Tuple of (reference_times, reference_indices)
268
- """
269
- if len(timestamps) == 0:
270
- return np.array([], dtype=np.float32), np.array([], dtype=np.uint32)
271
-
272
- reference_times = []
273
- reference_indices = []
274
-
275
- current_ref_time = timestamps[0]
276
- reference_times.append(current_ref_time)
277
- reference_indices.append(0)
278
-
279
- # Find the next reference point at least interval_sec later
280
- for i, timestamp in enumerate(timestamps):
281
- if timestamp >= current_ref_time + interval_sec:
282
- reference_times.append(timestamp)
283
- reference_indices.append(i)
284
- current_ref_time = timestamp
285
-
286
- return np.array(reference_times, dtype=np.float32), np.array(
287
- reference_indices, dtype=np.uint32
288
- )
@@ -1,28 +0,0 @@
1
- """
2
- RasterPlotItem for figpack - represents a single unit's raster plot
3
- """
4
-
5
- from typing import Union
6
- import numpy as np
7
-
8
-
9
- class RasterPlotItem:
10
- """
11
- Represents spike times for a single unit in a raster plot
12
- """
13
-
14
- def __init__(
15
- self,
16
- *,
17
- unit_id: Union[str, int],
18
- spike_times_sec: np.ndarray,
19
- ):
20
- """
21
- Initialize a RasterPlotItem
22
-
23
- Args:
24
- unit_id: Identifier for the unit
25
- spike_times_sec: Numpy array of spike times in seconds
26
- """
27
- self.unit_id = unit_id
28
- self.spike_times_sec = np.array(spike_times_sec, dtype=np.float32)
@@ -1,374 +0,0 @@
1
- """
2
- SpikeAmplitudes view for figpack - displays spike amplitudes over time
3
- """
4
-
5
- from typing import List
6
-
7
- import numpy as np
8
- import zarr
9
-
10
- from ...core.figpack_view import FigpackView
11
- from .SpikeAmplitudesItem import SpikeAmplitudesItem
12
- from .UnitsTable import UnitsTable, UnitsTableColumn, UnitsTableRow
13
- from ...views.Box import Box, LayoutItem
14
-
15
-
16
- class SpikeAmplitudes(FigpackView):
17
- """
18
- A view that displays spike amplitudes over time for multiple units
19
- """
20
-
21
- def __init__(
22
- self,
23
- *,
24
- start_time_sec: float,
25
- end_time_sec: float,
26
- plots: List[SpikeAmplitudesItem],
27
- ):
28
- """
29
- Initialize a SpikeAmplitudes view
30
-
31
- Args:
32
- start_time_sec: Start time of the view in seconds
33
- end_time_sec: End time of the view in seconds
34
- plots: List of SpikeAmplitudesItem objects
35
- """
36
- self.start_time_sec = start_time_sec
37
- self.end_time_sec = end_time_sec
38
- self.plots = plots
39
-
40
- @staticmethod
41
- def from_nwb_units_table(
42
- nwb_url_or_path_or_h5py,
43
- *,
44
- units_path: str,
45
- include_units_selector: bool = False,
46
- ):
47
- if isinstance(nwb_url_or_path_or_h5py, str):
48
- import lindi
49
-
50
- f = lindi.LindiH5pyFile.from_hdf5_file(nwb_url_or_path_or_h5py)
51
- else:
52
- f = nwb_url_or_path_or_h5py
53
- X = f[units_path]
54
- spike_amplitudes = X["spike_amplitudes"]
55
- # spike_amplitudes_index = X["spike_amplitudes_index"] # presumably the same as spike_times_index
56
- spike_times = X["spike_times"]
57
- spike_times_index = X["spike_times_index"]
58
- id = X["id"]
59
- plots = []
60
- num_units = len(spike_times_index)
61
- start_times = []
62
- end_times = []
63
- for unit_index in range(num_units):
64
- unit_id = id[unit_index]
65
- if unit_index > 0:
66
- start_index = spike_times_index[unit_index - 1]
67
- else:
68
- start_index = 0
69
- end_index = spike_times_index[unit_index]
70
- unit_spike_amplitudes = spike_amplitudes[start_index:end_index]
71
- unit_spike_times = spike_times[start_index:end_index]
72
- if len(unit_spike_times) == 0:
73
- continue
74
- start_times.append(unit_spike_times[0])
75
- end_times.append(unit_spike_times[-1])
76
- plots.append(
77
- SpikeAmplitudesItem(
78
- unit_id=str(unit_id),
79
- spike_times_sec=unit_spike_times,
80
- spike_amplitudes=unit_spike_amplitudes,
81
- )
82
- )
83
- view = SpikeAmplitudes(
84
- start_time_sec=min(start_times),
85
- end_time_sec=max(end_times),
86
- plots=plots,
87
- )
88
- if include_units_selector:
89
- columns: List[UnitsTableColumn] = [
90
- UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
91
- ]
92
- rows: List[UnitsTableRow] = []
93
- for unit_id in id:
94
- rows.append(
95
- UnitsTableRow(
96
- unit_id=str(unit_id),
97
- values={},
98
- )
99
- )
100
- units_table = UnitsTable(
101
- columns=columns,
102
- rows=rows,
103
- )
104
- layout = Box(
105
- direction="horizontal",
106
- items=[
107
- LayoutItem(view=units_table, max_size=150, title="Units"),
108
- LayoutItem(view=view, title="Spike Amplitudes"),
109
- ],
110
- )
111
- return layout
112
- else:
113
- return view
114
-
115
- def _write_to_zarr_group(self, group: zarr.Group) -> None:
116
- """
117
- Write the SpikeAmplitudes data to a Zarr group using unified storage format
118
-
119
- Args:
120
- group: Zarr group to write data into
121
- """
122
- # Set the view type
123
- group.attrs["view_type"] = "SpikeAmplitudes"
124
-
125
- # Store view parameters
126
- group.attrs["start_time_sec"] = self.start_time_sec
127
- group.attrs["end_time_sec"] = self.end_time_sec
128
-
129
- # Prepare unified data arrays
130
- unified_data = self._prepare_unified_data()
131
-
132
- if unified_data["total_spikes"] == 0:
133
- # Handle empty data case
134
- group.create_dataset("timestamps", data=np.array([], dtype=np.float32))
135
- group.create_dataset("unit_indices", data=np.array([], dtype=np.uint16))
136
- group.create_dataset("amplitudes", data=np.array([], dtype=np.float32))
137
- group.create_dataset("reference_times", data=np.array([], dtype=np.float32))
138
- group.create_dataset(
139
- "reference_indices", data=np.array([], dtype=np.uint32)
140
- )
141
- group.attrs["unit_ids"] = []
142
- group.attrs["total_spikes"] = 0
143
- return
144
-
145
- chunks = (
146
- (2_000_000,)
147
- if unified_data["total_spikes"] > 2_000_000
148
- else (len(unified_data["timestamps"]),)
149
- )
150
- # Store main data arrays
151
- group.create_dataset(
152
- "timestamps",
153
- data=unified_data["timestamps"],
154
- dtype=np.float32,
155
- chunks=chunks,
156
- )
157
- group.create_dataset(
158
- "unit_indices",
159
- data=unified_data["unit_indices"],
160
- dtype=np.uint16,
161
- chunks=chunks,
162
- )
163
- group.create_dataset(
164
- "amplitudes",
165
- data=unified_data["amplitudes"],
166
- dtype=np.float32,
167
- chunks=chunks,
168
- )
169
- group.create_dataset(
170
- "reference_times",
171
- data=unified_data["reference_times"],
172
- dtype=np.float32,
173
- chunks=(len(unified_data["reference_times"]),),
174
- )
175
- group.create_dataset(
176
- "reference_indices",
177
- data=unified_data["reference_indices"],
178
- dtype=np.uint32,
179
- chunks=(len(unified_data["reference_indices"]),),
180
- )
181
-
182
- # Store unit ID mapping
183
- group.attrs["unit_ids"] = unified_data["unit_ids"]
184
- group.attrs["total_spikes"] = unified_data["total_spikes"]
185
-
186
- # Create subsampled data
187
- subsampled_data = self._create_subsampled_data(
188
- unified_data["timestamps"],
189
- unified_data["unit_indices"],
190
- unified_data["amplitudes"],
191
- )
192
-
193
- if subsampled_data:
194
- subsampled_group = group.create_group("subsampled_data")
195
- for factor_name, data in subsampled_data.items():
196
- chunks = (
197
- (2_000_000,)
198
- if len(data["timestamps"]) > 2_000_000
199
- else (len(data["timestamps"]),)
200
- )
201
- factor_group = subsampled_group.create_group(factor_name)
202
- factor_group.create_dataset(
203
- "timestamps",
204
- data=data["timestamps"],
205
- dtype=np.float32,
206
- chunks=chunks,
207
- )
208
- factor_group.create_dataset(
209
- "unit_indices",
210
- data=data["unit_indices"],
211
- dtype=np.uint16,
212
- chunks=chunks,
213
- )
214
- factor_group.create_dataset(
215
- "amplitudes",
216
- data=data["amplitudes"],
217
- dtype=np.float32,
218
- chunks=chunks,
219
- )
220
- factor_group.create_dataset(
221
- "reference_times",
222
- data=data["reference_times"],
223
- dtype=np.float32,
224
- chunks=(len(data["reference_times"]),),
225
- )
226
- factor_group.create_dataset(
227
- "reference_indices",
228
- data=data["reference_indices"],
229
- dtype=np.uint32,
230
- chunks=(len(data["reference_indices"]),),
231
- )
232
-
233
- def _prepare_unified_data(self) -> dict:
234
- """
235
- Prepare unified data arrays from all plots
236
-
237
- Returns:
238
- Dictionary containing unified arrays and metadata
239
- """
240
- if not self.plots:
241
- return {
242
- "timestamps": np.array([], dtype=np.float32),
243
- "unit_indices": np.array([], dtype=np.uint16),
244
- "amplitudes": np.array([], dtype=np.float32),
245
- "reference_times": np.array([], dtype=np.float32),
246
- "reference_indices": np.array([], dtype=np.uint32),
247
- "unit_ids": [],
248
- "total_spikes": 0,
249
- }
250
-
251
- # Create unit ID mapping
252
- unit_ids = [str(plot.unit_id) for plot in self.plots]
253
- unit_id_to_index = {unit_id: i for i, unit_id in enumerate(unit_ids)}
254
-
255
- # Collect all spikes with their unit indices
256
- all_spikes = []
257
- for plot in self.plots:
258
- unit_index = unit_id_to_index[str(plot.unit_id)]
259
- for time, amplitude in zip(plot.spike_times_sec, plot.spike_amplitudes):
260
- all_spikes.append((float(time), unit_index, float(amplitude)))
261
-
262
- if not all_spikes:
263
- return {
264
- "timestamps": np.array([], dtype=np.float32),
265
- "unit_indices": np.array([], dtype=np.uint16),
266
- "amplitudes": np.array([], dtype=np.float32),
267
- "reference_times": np.array([], dtype=np.float32),
268
- "reference_indices": np.array([], dtype=np.uint32),
269
- "unit_ids": unit_ids,
270
- "total_spikes": 0,
271
- }
272
-
273
- # Sort by timestamp
274
- all_spikes.sort(key=lambda x: x[0])
275
-
276
- # Extract sorted arrays
277
- timestamps = np.array([spike[0] for spike in all_spikes], dtype=np.float32)
278
- unit_indices = np.array([spike[1] for spike in all_spikes], dtype=np.uint16)
279
- amplitudes = np.array([spike[2] for spike in all_spikes], dtype=np.float32)
280
-
281
- # Generate reference arrays
282
- reference_times, reference_indices = self._generate_reference_arrays(timestamps)
283
-
284
- return {
285
- "timestamps": timestamps,
286
- "unit_indices": unit_indices,
287
- "amplitudes": amplitudes,
288
- "reference_times": reference_times,
289
- "reference_indices": reference_indices,
290
- "unit_ids": unit_ids,
291
- "total_spikes": len(all_spikes),
292
- }
293
-
294
- def _generate_reference_arrays(
295
- self, timestamps: np.ndarray, interval_sec: float = 1.0
296
- ) -> tuple:
297
- """
298
- Generate reference arrays using actual timestamps from the data
299
-
300
- Args:
301
- timestamps: Sorted array of timestamps
302
- interval_sec: Minimum interval between reference points
303
-
304
- Returns:
305
- Tuple of (reference_times, reference_indices)
306
- """
307
- if len(timestamps) == 0:
308
- return np.array([], dtype=np.float32), np.array([], dtype=np.uint32)
309
-
310
- reference_times = []
311
- reference_indices = []
312
-
313
- current_ref_time = timestamps[0]
314
- reference_times.append(current_ref_time)
315
- reference_indices.append(0)
316
-
317
- # Find the next reference point at least interval_sec later
318
- for i, timestamp in enumerate(timestamps):
319
- if timestamp >= current_ref_time + interval_sec:
320
- reference_times.append(timestamp)
321
- reference_indices.append(i)
322
- current_ref_time = timestamp
323
-
324
- return np.array(reference_times, dtype=np.float32), np.array(
325
- reference_indices, dtype=np.uint32
326
- )
327
-
328
- def _create_subsampled_data(
329
- self, timestamps: np.ndarray, unit_indices: np.ndarray, amplitudes: np.ndarray
330
- ) -> dict:
331
- """
332
- Create subsampled data with geometric progression factors
333
-
334
- Args:
335
- timestamps: Original timestamps array
336
- unit_indices: Original unit indices array
337
- amplitudes: Original amplitudes array
338
-
339
- Returns:
340
- Dictionary of subsampled data by factor
341
- """
342
- subsampled_data = {}
343
- factor = 4
344
- current_timestamps = timestamps
345
- current_unit_indices = unit_indices
346
- current_amplitudes = amplitudes
347
-
348
- while len(current_timestamps) >= 500000:
349
- # Create subsampled version by taking every Nth spike
350
- subsampled_indices = np.arange(0, len(current_timestamps), factor)
351
- subsampled_timestamps = current_timestamps[subsampled_indices]
352
- subsampled_unit_indices = current_unit_indices[subsampled_indices]
353
- subsampled_amplitudes = current_amplitudes[subsampled_indices]
354
-
355
- # Generate reference arrays for this subsampled level
356
- ref_times, ref_indices = self._generate_reference_arrays(
357
- subsampled_timestamps
358
- )
359
-
360
- subsampled_data[f"factor_{factor}"] = {
361
- "timestamps": subsampled_timestamps,
362
- "unit_indices": subsampled_unit_indices,
363
- "amplitudes": subsampled_amplitudes,
364
- "reference_times": ref_times,
365
- "reference_indices": ref_indices,
366
- }
367
-
368
- # Prepare for next iteration
369
- current_timestamps = subsampled_timestamps
370
- current_unit_indices = subsampled_unit_indices
371
- current_amplitudes = subsampled_amplitudes
372
- factor *= 4 # Geometric progression: 4, 16, 64, 256, ...
373
-
374
- return subsampled_data