figpack 0.2.16__py3-none-any.whl → 0.2.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of figpack might be problematic. Click here for more details.

Files changed (45) hide show
  1. figpack/__init__.py +2 -3
  2. figpack/cli.py +74 -0
  3. figpack/core/__init__.py +2 -2
  4. figpack/core/_bundle_utils.py +85 -18
  5. figpack/core/_file_handler.py +192 -0
  6. figpack/core/_server_manager.py +42 -6
  7. figpack/core/_show_view.py +1 -1
  8. figpack/core/_view_figure.py +43 -12
  9. figpack/core/extension_view.py +7 -25
  10. figpack/core/figpack_extension.py +0 -71
  11. figpack/extensions.py +356 -0
  12. figpack/figpack-figure-dist/assets/{index-D9a3K6eW.css → index-BJUFDPIM.css} +1 -1
  13. figpack/figpack-figure-dist/assets/index-nBpxgXXT.js +91 -0
  14. figpack/figpack-figure-dist/index.html +2 -2
  15. figpack/views/PlotlyExtension/PlotlyExtension.py +4 -50
  16. figpack/views/PlotlyExtension/_plotly_extension.py +46 -0
  17. figpack/views/PlotlyExtension/plotly_view.js +84 -80
  18. figpack/views/__init__.py +1 -0
  19. {figpack-0.2.16.dist-info → figpack-0.2.18.dist-info}/METADATA +1 -1
  20. figpack-0.2.18.dist-info/RECORD +45 -0
  21. figpack/figpack-figure-dist/assets/index-DtOnN02w.js +0 -846
  22. figpack/franklab/__init__.py +0 -5
  23. figpack/franklab/views/TrackAnimation.py +0 -154
  24. figpack/franklab/views/__init__.py +0 -9
  25. figpack/spike_sorting/__init__.py +0 -5
  26. figpack/spike_sorting/views/AutocorrelogramItem.py +0 -32
  27. figpack/spike_sorting/views/Autocorrelograms.py +0 -116
  28. figpack/spike_sorting/views/AverageWaveforms.py +0 -146
  29. figpack/spike_sorting/views/CrossCorrelogramItem.py +0 -35
  30. figpack/spike_sorting/views/CrossCorrelograms.py +0 -131
  31. figpack/spike_sorting/views/RasterPlot.py +0 -284
  32. figpack/spike_sorting/views/RasterPlotItem.py +0 -28
  33. figpack/spike_sorting/views/SpikeAmplitudes.py +0 -364
  34. figpack/spike_sorting/views/SpikeAmplitudesItem.py +0 -38
  35. figpack/spike_sorting/views/UnitMetricsGraph.py +0 -127
  36. figpack/spike_sorting/views/UnitSimilarityScore.py +0 -40
  37. figpack/spike_sorting/views/UnitsTable.py +0 -82
  38. figpack/spike_sorting/views/UnitsTableColumn.py +0 -40
  39. figpack/spike_sorting/views/UnitsTableRow.py +0 -36
  40. figpack/spike_sorting/views/__init__.py +0 -41
  41. figpack-0.2.16.dist-info/RECORD +0 -61
  42. {figpack-0.2.16.dist-info → figpack-0.2.18.dist-info}/WHEEL +0 -0
  43. {figpack-0.2.16.dist-info → figpack-0.2.18.dist-info}/entry_points.txt +0 -0
  44. {figpack-0.2.16.dist-info → figpack-0.2.18.dist-info}/licenses/LICENSE +0 -0
  45. {figpack-0.2.16.dist-info → figpack-0.2.18.dist-info}/top_level.txt +0 -0
@@ -1,284 +0,0 @@
1
- """
2
- RasterPlot view for figpack - displays multiple raster plots
3
- """
4
-
5
- from typing import List
6
- import numpy as np
7
- import zarr
8
-
9
- from ...core.figpack_view import FigpackView
10
- from ...core.zarr import Group
11
- from .RasterPlotItem import RasterPlotItem
12
- from .UnitsTable import UnitsTable, UnitsTableColumn, UnitsTableRow
13
- from ...views.Box import Box, LayoutItem
14
-
15
-
16
- class RasterPlot(FigpackView):
17
- """
18
- A view that displays multiple raster plots for spike sorting analysis
19
- """
20
-
21
- def __init__(
22
- self,
23
- *,
24
- start_time_sec: float,
25
- end_time_sec: float,
26
- plots: List[RasterPlotItem],
27
- ):
28
- """
29
- Initialize a RasterPlot view
30
-
31
- Args:
32
- start_time_sec: Start time in seconds for the plot range
33
- end_time_sec: End time in seconds for the plot range
34
- plots: List of RasterPlotItem objects
35
- height: Height of the plot in pixels (default: 500)
36
- """
37
- self.start_time_sec = float(start_time_sec)
38
- self.end_time_sec = float(end_time_sec)
39
- self.plots = plots
40
-
41
- @staticmethod
42
- def from_nwb_units_table(
43
- nwb_url_or_path_or_h5py,
44
- *,
45
- units_path: str,
46
- include_units_selector: bool = False,
47
- ):
48
- if isinstance(nwb_url_or_path_or_h5py, str):
49
- import lindi
50
-
51
- f = lindi.LindiH5pyFile.from_hdf5_file(nwb_url_or_path_or_h5py)
52
- else:
53
- f = nwb_url_or_path_or_h5py
54
- X = f[units_path]
55
- spike_times = X["spike_times"]
56
- spike_times_index = X["spike_times_index"]
57
- id = X["id"]
58
- plots = []
59
- num_units = len(spike_times_index)
60
- start_times = []
61
- end_times = []
62
- for unit_index in range(num_units):
63
- unit_id = id[unit_index]
64
- if unit_index > 0:
65
- start_index = spike_times_index[unit_index - 1]
66
- else:
67
- start_index = 0
68
- end_index = spike_times_index[unit_index]
69
- unit_spike_times = spike_times[start_index:end_index]
70
- if len(unit_spike_times) == 0:
71
- continue
72
- start_times.append(unit_spike_times[0])
73
- end_times.append(unit_spike_times[-1])
74
- plots.append(
75
- RasterPlotItem(unit_id=str(unit_id), spike_times_sec=unit_spike_times)
76
- )
77
- view = RasterPlot(
78
- start_time_sec=min(start_times),
79
- end_time_sec=max(end_times),
80
- plots=plots,
81
- )
82
- if include_units_selector:
83
- columns: List[UnitsTableColumn] = [
84
- UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
85
- ]
86
- rows: List[UnitsTableRow] = []
87
- for unit_id in id:
88
- rows.append(
89
- UnitsTableRow(
90
- unit_id=str(unit_id),
91
- values={},
92
- )
93
- )
94
- units_table = UnitsTable(
95
- columns=columns,
96
- rows=rows,
97
- )
98
- layout = Box(
99
- direction="horizontal",
100
- items=[
101
- LayoutItem(view=units_table, max_size=150, title="Units"),
102
- LayoutItem(view=view, title="Spike Amplitudes"),
103
- ],
104
- )
105
- return layout
106
- else:
107
- return view
108
-
109
- def _write_to_zarr_group(self, group: Group) -> None:
110
- """
111
- Args:
112
- group: Zarr group to write data into
113
- """
114
- # Set the view type
115
- group.attrs["view_type"] = "RasterPlot"
116
-
117
- # Store view parameters
118
- group.attrs["start_time_sec"] = self.start_time_sec
119
- group.attrs["end_time_sec"] = self.end_time_sec
120
-
121
- # Prepare unified data arrays
122
- unified_data = self._prepare_unified_data()
123
-
124
- if unified_data["total_spikes"] == 0:
125
- # Handle empty data case
126
- group.create_dataset("timestamps", data=np.array([], dtype=np.float32))
127
- group.create_dataset("unit_indices", data=np.array([], dtype=np.uint16))
128
- group.create_dataset("reference_times", data=np.array([], dtype=np.float32))
129
- group.create_dataset(
130
- "reference_indices", data=np.array([], dtype=np.uint32)
131
- )
132
- group.attrs["unit_ids"] = []
133
- group.attrs["total_spikes"] = 0
134
- return
135
-
136
- chunks = (
137
- (2_000_000,)
138
- if unified_data["total_spikes"] > 2_000_000
139
- else (len(unified_data["timestamps"]),)
140
- )
141
- # Store main data arrays
142
- group.create_dataset(
143
- "timestamps",
144
- data=unified_data["timestamps"],
145
- chunks=chunks,
146
- )
147
- group.create_dataset(
148
- "unit_indices",
149
- data=unified_data["unit_indices"],
150
- chunks=chunks,
151
- )
152
- group.create_dataset(
153
- "reference_times",
154
- data=unified_data["reference_times"],
155
- chunks=(len(unified_data["reference_times"]),),
156
- )
157
- group.create_dataset(
158
- "reference_indices",
159
- data=unified_data["reference_indices"],
160
- chunks=(len(unified_data["reference_indices"]),),
161
- )
162
-
163
- # Create spike counts array with 1-second bins
164
- duration = self.end_time_sec - self.start_time_sec
165
- num_bins = int(np.ceil(duration))
166
- num_units = len(self.plots)
167
- spike_counts = np.zeros((num_bins, num_units), dtype=np.uint16)
168
-
169
- # Efficiently compute spike counts for each unit
170
- for unit_idx, plot in enumerate(self.plots):
171
- # Convert spike times to bin indices
172
- bin_indices = (
173
- (np.array(plot.spike_times_sec) - self.start_time_sec)
174
- ).astype(int)
175
- # Count spikes in valid bins
176
- valid_indices = (bin_indices >= 0) & (bin_indices < num_bins)
177
- unique_bins, counts = np.unique(
178
- bin_indices[valid_indices], return_counts=True
179
- )
180
- spike_counts[unique_bins, unit_idx] = counts.clip(
181
- max=65535
182
- ) # Clip to uint16 max
183
-
184
- # Store spike counts array
185
- group.create_dataset(
186
- "spike_counts_1sec",
187
- data=spike_counts,
188
- chunks=(min(num_bins, 10000), min(num_units, 500)),
189
- )
190
-
191
- # Store unit ID mapping
192
- group.attrs["unit_ids"] = unified_data["unit_ids"]
193
- group.attrs["total_spikes"] = unified_data["total_spikes"]
194
-
195
- def _prepare_unified_data(self) -> dict:
196
- """
197
- Prepare unified data arrays from all plots
198
-
199
- Returns:
200
- Dictionary containing unified arrays and metadata
201
- """
202
- if not self.plots:
203
- return {
204
- "timestamps": np.array([], dtype=np.float32),
205
- "unit_indices": np.array([], dtype=np.uint16),
206
- "reference_times": np.array([], dtype=np.float32),
207
- "reference_indices": np.array([], dtype=np.uint32),
208
- "unit_ids": [],
209
- "total_spikes": 0,
210
- }
211
-
212
- # Create unit ID mapping
213
- unit_ids = [str(plot.unit_id) for plot in self.plots]
214
- unit_id_to_index = {unit_id: i for i, unit_id in enumerate(unit_ids)}
215
-
216
- # Collect all spikes with their unit indices
217
- all_spikes = []
218
- for plot in self.plots:
219
- unit_index = unit_id_to_index[str(plot.unit_id)]
220
- for time in plot.spike_times_sec:
221
- all_spikes.append((float(time), unit_index))
222
-
223
- if not all_spikes:
224
- return {
225
- "timestamps": np.array([], dtype=np.float32),
226
- "unit_indices": np.array([], dtype=np.uint16),
227
- "reference_times": np.array([], dtype=np.float32),
228
- "reference_indices": np.array([], dtype=np.uint32),
229
- "unit_ids": unit_ids,
230
- "total_spikes": 0,
231
- }
232
-
233
- # Sort by timestamp
234
- all_spikes.sort(key=lambda x: x[0])
235
-
236
- # Extract sorted arrays
237
- timestamps = np.array([spike[0] for spike in all_spikes], dtype=np.float32)
238
- unit_indices = np.array([spike[1] for spike in all_spikes], dtype=np.uint16)
239
-
240
- # Generate reference arrays
241
- reference_times, reference_indices = self._generate_reference_arrays(timestamps)
242
-
243
- return {
244
- "timestamps": timestamps,
245
- "unit_indices": unit_indices,
246
- "reference_times": reference_times,
247
- "reference_indices": reference_indices,
248
- "unit_ids": unit_ids,
249
- "total_spikes": len(all_spikes),
250
- }
251
-
252
- def _generate_reference_arrays(
253
- self, timestamps: np.ndarray, interval_sec: float = 1.0
254
- ) -> tuple:
255
- """
256
- Generate reference arrays using actual timestamps from the data
257
-
258
- Args:
259
- timestamps: Sorted array of timestamps
260
- interval_sec: Minimum interval between reference points
261
-
262
- Returns:
263
- Tuple of (reference_times, reference_indices)
264
- """
265
- if len(timestamps) == 0:
266
- return np.array([], dtype=np.float32), np.array([], dtype=np.uint32)
267
-
268
- reference_times = []
269
- reference_indices = []
270
-
271
- current_ref_time = timestamps[0]
272
- reference_times.append(current_ref_time)
273
- reference_indices.append(0)
274
-
275
- # Find the next reference point at least interval_sec later
276
- for i, timestamp in enumerate(timestamps):
277
- if timestamp >= current_ref_time + interval_sec:
278
- reference_times.append(timestamp)
279
- reference_indices.append(i)
280
- current_ref_time = timestamp
281
-
282
- return np.array(reference_times, dtype=np.float32), np.array(
283
- reference_indices, dtype=np.uint32
284
- )
@@ -1,28 +0,0 @@
1
- """
2
- RasterPlotItem for figpack - represents a single unit's raster plot
3
- """
4
-
5
- from typing import Union
6
- import numpy as np
7
-
8
-
9
- class RasterPlotItem:
10
- """
11
- Represents spike times for a single unit in a raster plot
12
- """
13
-
14
- def __init__(
15
- self,
16
- *,
17
- unit_id: Union[str, int],
18
- spike_times_sec: np.ndarray,
19
- ):
20
- """
21
- Initialize a RasterPlotItem
22
-
23
- Args:
24
- unit_id: Identifier for the unit
25
- spike_times_sec: Numpy array of spike times in seconds
26
- """
27
- self.unit_id = unit_id
28
- self.spike_times_sec = np.array(spike_times_sec, dtype=np.float32)
@@ -1,364 +0,0 @@
1
- """
2
- SpikeAmplitudes view for figpack - displays spike amplitudes over time
3
- """
4
-
5
- from typing import List
6
-
7
- import numpy as np
8
-
9
- from ...core.figpack_view import FigpackView
10
- from ...core.zarr import Group
11
- from .SpikeAmplitudesItem import SpikeAmplitudesItem
12
- from .UnitsTable import UnitsTable, UnitsTableColumn, UnitsTableRow
13
- from ...views.Box import Box, LayoutItem
14
-
15
-
16
- class SpikeAmplitudes(FigpackView):
17
- """
18
- A view that displays spike amplitudes over time for multiple units
19
- """
20
-
21
- def __init__(
22
- self,
23
- *,
24
- start_time_sec: float,
25
- end_time_sec: float,
26
- plots: List[SpikeAmplitudesItem],
27
- ):
28
- """
29
- Initialize a SpikeAmplitudes view
30
-
31
- Args:
32
- start_time_sec: Start time of the view in seconds
33
- end_time_sec: End time of the view in seconds
34
- plots: List of SpikeAmplitudesItem objects
35
- """
36
- self.start_time_sec = start_time_sec
37
- self.end_time_sec = end_time_sec
38
- self.plots = plots
39
-
40
- @staticmethod
41
- def from_nwb_units_table(
42
- nwb_url_or_path_or_h5py,
43
- *,
44
- units_path: str,
45
- include_units_selector: bool = False,
46
- ):
47
- if isinstance(nwb_url_or_path_or_h5py, str):
48
- import lindi
49
-
50
- f = lindi.LindiH5pyFile.from_hdf5_file(nwb_url_or_path_or_h5py)
51
- else:
52
- f = nwb_url_or_path_or_h5py
53
- X = f[units_path]
54
- spike_amplitudes = X["spike_amplitudes"]
55
- # spike_amplitudes_index = X["spike_amplitudes_index"] # presumably the same as spike_times_index
56
- spike_times = X["spike_times"]
57
- spike_times_index = X["spike_times_index"]
58
- id = X["id"]
59
- plots = []
60
- num_units = len(spike_times_index)
61
- start_times = []
62
- end_times = []
63
- for unit_index in range(num_units):
64
- unit_id = id[unit_index]
65
- if unit_index > 0:
66
- start_index = spike_times_index[unit_index - 1]
67
- else:
68
- start_index = 0
69
- end_index = spike_times_index[unit_index]
70
- unit_spike_amplitudes = spike_amplitudes[start_index:end_index]
71
- unit_spike_times = spike_times[start_index:end_index]
72
- if len(unit_spike_times) == 0:
73
- continue
74
- start_times.append(unit_spike_times[0])
75
- end_times.append(unit_spike_times[-1])
76
- plots.append(
77
- SpikeAmplitudesItem(
78
- unit_id=str(unit_id),
79
- spike_times_sec=unit_spike_times,
80
- spike_amplitudes=unit_spike_amplitudes,
81
- )
82
- )
83
- view = SpikeAmplitudes(
84
- start_time_sec=min(start_times),
85
- end_time_sec=max(end_times),
86
- plots=plots,
87
- )
88
- if include_units_selector:
89
- columns: List[UnitsTableColumn] = [
90
- UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
91
- ]
92
- rows: List[UnitsTableRow] = []
93
- for unit_id in id:
94
- rows.append(
95
- UnitsTableRow(
96
- unit_id=str(unit_id),
97
- values={},
98
- )
99
- )
100
- units_table = UnitsTable(
101
- columns=columns,
102
- rows=rows,
103
- )
104
- layout = Box(
105
- direction="horizontal",
106
- items=[
107
- LayoutItem(view=units_table, max_size=150, title="Units"),
108
- LayoutItem(view=view, title="Spike Amplitudes"),
109
- ],
110
- )
111
- return layout
112
- else:
113
- return view
114
-
115
- def _write_to_zarr_group(self, group: Group) -> None:
116
- """
117
- Write the SpikeAmplitudes data to a Zarr group using unified storage format
118
-
119
- Args:
120
- group: Zarr group to write data into
121
- """
122
- # Set the view type
123
- group.attrs["view_type"] = "SpikeAmplitudes"
124
-
125
- # Store view parameters
126
- group.attrs["start_time_sec"] = self.start_time_sec
127
- group.attrs["end_time_sec"] = self.end_time_sec
128
-
129
- # Prepare unified data arrays
130
- unified_data = self._prepare_unified_data()
131
-
132
- if unified_data["total_spikes"] == 0:
133
- # Handle empty data case
134
- group.create_dataset("timestamps", data=np.array([], dtype=np.float32))
135
- group.create_dataset("unit_indices", data=np.array([], dtype=np.uint16))
136
- group.create_dataset("amplitudes", data=np.array([], dtype=np.float32))
137
- group.create_dataset("reference_times", data=np.array([], dtype=np.float32))
138
- group.create_dataset(
139
- "reference_indices", data=np.array([], dtype=np.uint32)
140
- )
141
- group.attrs["unit_ids"] = []
142
- group.attrs["total_spikes"] = 0
143
- return
144
-
145
- chunks = (
146
- (2_000_000,)
147
- if unified_data["total_spikes"] > 2_000_000
148
- else (len(unified_data["timestamps"]),)
149
- )
150
- # Store main data arrays
151
- group.create_dataset(
152
- "timestamps",
153
- data=unified_data["timestamps"],
154
- chunks=chunks,
155
- )
156
- group.create_dataset(
157
- "unit_indices",
158
- data=unified_data["unit_indices"],
159
- chunks=chunks,
160
- )
161
- group.create_dataset(
162
- "amplitudes",
163
- data=unified_data["amplitudes"],
164
- chunks=chunks,
165
- )
166
- group.create_dataset(
167
- "reference_times",
168
- data=unified_data["reference_times"],
169
- chunks=(len(unified_data["reference_times"]),),
170
- )
171
- group.create_dataset(
172
- "reference_indices",
173
- data=unified_data["reference_indices"],
174
- chunks=(len(unified_data["reference_indices"]),),
175
- )
176
-
177
- # Store unit ID mapping
178
- group.attrs["unit_ids"] = unified_data["unit_ids"]
179
- group.attrs["total_spikes"] = unified_data["total_spikes"]
180
-
181
- # Create subsampled data
182
- subsampled_data = self._create_subsampled_data(
183
- unified_data["timestamps"],
184
- unified_data["unit_indices"],
185
- unified_data["amplitudes"],
186
- )
187
-
188
- if subsampled_data:
189
- subsampled_group = group.create_group("subsampled_data")
190
- for factor_name, data in subsampled_data.items():
191
- chunks = (
192
- (2_000_000,)
193
- if len(data["timestamps"]) > 2_000_000
194
- else (len(data["timestamps"]),)
195
- )
196
- factor_group = subsampled_group.create_group(factor_name)
197
- factor_group.create_dataset(
198
- "timestamps",
199
- data=data["timestamps"],
200
- chunks=chunks,
201
- )
202
- factor_group.create_dataset(
203
- "unit_indices",
204
- data=data["unit_indices"],
205
- chunks=chunks,
206
- )
207
- factor_group.create_dataset(
208
- "amplitudes",
209
- data=data["amplitudes"],
210
- chunks=chunks,
211
- )
212
- factor_group.create_dataset(
213
- "reference_times",
214
- data=data["reference_times"],
215
- chunks=(len(data["reference_times"]),),
216
- )
217
- factor_group.create_dataset(
218
- "reference_indices",
219
- data=data["reference_indices"],
220
- chunks=(len(data["reference_indices"]),),
221
- )
222
-
223
- def _prepare_unified_data(self) -> dict:
224
- """
225
- Prepare unified data arrays from all plots
226
-
227
- Returns:
228
- Dictionary containing unified arrays and metadata
229
- """
230
- if not self.plots:
231
- return {
232
- "timestamps": np.array([], dtype=np.float32),
233
- "unit_indices": np.array([], dtype=np.uint16),
234
- "amplitudes": np.array([], dtype=np.float32),
235
- "reference_times": np.array([], dtype=np.float32),
236
- "reference_indices": np.array([], dtype=np.uint32),
237
- "unit_ids": [],
238
- "total_spikes": 0,
239
- }
240
-
241
- # Create unit ID mapping
242
- unit_ids = [str(plot.unit_id) for plot in self.plots]
243
- unit_id_to_index = {unit_id: i for i, unit_id in enumerate(unit_ids)}
244
-
245
- # Collect all spikes with their unit indices
246
- all_spikes = []
247
- for plot in self.plots:
248
- unit_index = unit_id_to_index[str(plot.unit_id)]
249
- for time, amplitude in zip(plot.spike_times_sec, plot.spike_amplitudes):
250
- all_spikes.append((float(time), unit_index, float(amplitude)))
251
-
252
- if not all_spikes:
253
- return {
254
- "timestamps": np.array([], dtype=np.float32),
255
- "unit_indices": np.array([], dtype=np.uint16),
256
- "amplitudes": np.array([], dtype=np.float32),
257
- "reference_times": np.array([], dtype=np.float32),
258
- "reference_indices": np.array([], dtype=np.uint32),
259
- "unit_ids": unit_ids,
260
- "total_spikes": 0,
261
- }
262
-
263
- # Sort by timestamp
264
- all_spikes.sort(key=lambda x: x[0])
265
-
266
- # Extract sorted arrays
267
- timestamps = np.array([spike[0] for spike in all_spikes], dtype=np.float32)
268
- unit_indices = np.array([spike[1] for spike in all_spikes], dtype=np.uint16)
269
- amplitudes = np.array([spike[2] for spike in all_spikes], dtype=np.float32)
270
-
271
- # Generate reference arrays
272
- reference_times, reference_indices = self._generate_reference_arrays(timestamps)
273
-
274
- return {
275
- "timestamps": timestamps,
276
- "unit_indices": unit_indices,
277
- "amplitudes": amplitudes,
278
- "reference_times": reference_times,
279
- "reference_indices": reference_indices,
280
- "unit_ids": unit_ids,
281
- "total_spikes": len(all_spikes),
282
- }
283
-
284
- def _generate_reference_arrays(
285
- self, timestamps: np.ndarray, interval_sec: float = 1.0
286
- ) -> tuple:
287
- """
288
- Generate reference arrays using actual timestamps from the data
289
-
290
- Args:
291
- timestamps: Sorted array of timestamps
292
- interval_sec: Minimum interval between reference points
293
-
294
- Returns:
295
- Tuple of (reference_times, reference_indices)
296
- """
297
- if len(timestamps) == 0:
298
- return np.array([], dtype=np.float32), np.array([], dtype=np.uint32)
299
-
300
- reference_times = []
301
- reference_indices = []
302
-
303
- current_ref_time = timestamps[0]
304
- reference_times.append(current_ref_time)
305
- reference_indices.append(0)
306
-
307
- # Find the next reference point at least interval_sec later
308
- for i, timestamp in enumerate(timestamps):
309
- if timestamp >= current_ref_time + interval_sec:
310
- reference_times.append(timestamp)
311
- reference_indices.append(i)
312
- current_ref_time = timestamp
313
-
314
- return np.array(reference_times, dtype=np.float32), np.array(
315
- reference_indices, dtype=np.uint32
316
- )
317
-
318
- def _create_subsampled_data(
319
- self, timestamps: np.ndarray, unit_indices: np.ndarray, amplitudes: np.ndarray
320
- ) -> dict:
321
- """
322
- Create subsampled data with geometric progression factors
323
-
324
- Args:
325
- timestamps: Original timestamps array
326
- unit_indices: Original unit indices array
327
- amplitudes: Original amplitudes array
328
-
329
- Returns:
330
- Dictionary of subsampled data by factor
331
- """
332
- subsampled_data = {}
333
- factor = 4
334
- current_timestamps = timestamps
335
- current_unit_indices = unit_indices
336
- current_amplitudes = amplitudes
337
-
338
- while len(current_timestamps) >= 500000:
339
- # Create subsampled version by taking every Nth spike
340
- subsampled_indices = np.arange(0, len(current_timestamps), factor)
341
- subsampled_timestamps = current_timestamps[subsampled_indices]
342
- subsampled_unit_indices = current_unit_indices[subsampled_indices]
343
- subsampled_amplitudes = current_amplitudes[subsampled_indices]
344
-
345
- # Generate reference arrays for this subsampled level
346
- ref_times, ref_indices = self._generate_reference_arrays(
347
- subsampled_timestamps
348
- )
349
-
350
- subsampled_data[f"factor_{factor}"] = {
351
- "timestamps": subsampled_timestamps,
352
- "unit_indices": subsampled_unit_indices,
353
- "amplitudes": subsampled_amplitudes,
354
- "reference_times": ref_times,
355
- "reference_indices": ref_indices,
356
- }
357
-
358
- # Prepare for next iteration
359
- current_timestamps = subsampled_timestamps
360
- current_unit_indices = subsampled_unit_indices
361
- current_amplitudes = subsampled_amplitudes
362
- factor *= 4 # Geometric progression: 4, 16, 64, 256, ...
363
-
364
- return subsampled_data