figpack 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of figpack might be problematic. Click here for more details.
- figpack/__init__.py +1 -1
- figpack/figpack-figure-dist/assets/{index-CTBd5_Gw.js → index-CjiTpC6i.js} +87 -87
- figpack/figpack-figure-dist/index.html +1 -1
- figpack/spike_sorting/views/RasterPlot.py +237 -26
- figpack/spike_sorting/views/SpikeAmplitudes.py +321 -36
- figpack/views/Spectrogram.py +223 -0
- figpack/views/__init__.py +1 -0
- {figpack-0.2.7.dist-info → figpack-0.2.8.dist-info}/METADATA +1 -1
- {figpack-0.2.7.dist-info → figpack-0.2.8.dist-info}/RECORD +13 -12
- {figpack-0.2.7.dist-info → figpack-0.2.8.dist-info}/WHEEL +0 -0
- {figpack-0.2.7.dist-info → figpack-0.2.8.dist-info}/entry_points.txt +0 -0
- {figpack-0.2.7.dist-info → figpack-0.2.8.dist-info}/licenses/LICENSE +0 -0
- {figpack-0.2.7.dist-info → figpack-0.2.8.dist-info}/top_level.txt +0 -0
|
@@ -2,13 +2,15 @@
|
|
|
2
2
|
SpikeAmplitudes view for figpack - displays spike amplitudes over time
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import List
|
|
5
|
+
from typing import List
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import zarr
|
|
9
9
|
|
|
10
10
|
from ...core.figpack_view import FigpackView
|
|
11
11
|
from .SpikeAmplitudesItem import SpikeAmplitudesItem
|
|
12
|
+
from .UnitsTable import UnitsTable, UnitsTableColumn, UnitsTableRow
|
|
13
|
+
from ...views.Box import Box, LayoutItem
|
|
12
14
|
|
|
13
15
|
|
|
14
16
|
class SpikeAmplitudes(FigpackView):
|
|
@@ -22,8 +24,6 @@ class SpikeAmplitudes(FigpackView):
|
|
|
22
24
|
start_time_sec: float,
|
|
23
25
|
end_time_sec: float,
|
|
24
26
|
plots: List[SpikeAmplitudesItem],
|
|
25
|
-
hide_unit_selector: bool = False,
|
|
26
|
-
height: int = 500,
|
|
27
27
|
):
|
|
28
28
|
"""
|
|
29
29
|
Initialize a SpikeAmplitudes view
|
|
@@ -32,18 +32,89 @@ class SpikeAmplitudes(FigpackView):
|
|
|
32
32
|
start_time_sec: Start time of the view in seconds
|
|
33
33
|
end_time_sec: End time of the view in seconds
|
|
34
34
|
plots: List of SpikeAmplitudesItem objects
|
|
35
|
-
hide_unit_selector: Whether to hide the unit selector
|
|
36
|
-
height: Height of the view in pixels
|
|
37
35
|
"""
|
|
38
36
|
self.start_time_sec = start_time_sec
|
|
39
37
|
self.end_time_sec = end_time_sec
|
|
40
38
|
self.plots = plots
|
|
41
|
-
|
|
42
|
-
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def from_nwb_units_table(
|
|
42
|
+
nwb_url_or_path_or_h5py,
|
|
43
|
+
*,
|
|
44
|
+
units_path: str,
|
|
45
|
+
include_units_selector: bool = False,
|
|
46
|
+
):
|
|
47
|
+
if isinstance(nwb_url_or_path_or_h5py, str):
|
|
48
|
+
import lindi
|
|
49
|
+
|
|
50
|
+
f = lindi.LindiH5pyFile.from_hdf5_file(nwb_url_or_path_or_h5py)
|
|
51
|
+
else:
|
|
52
|
+
f = nwb_url_or_path_or_h5py
|
|
53
|
+
X = f[units_path]
|
|
54
|
+
spike_amplitudes = X["spike_amplitudes"]
|
|
55
|
+
# spike_amplitudes_index = X["spike_amplitudes_index"] # presumably the same as spike_times_index
|
|
56
|
+
spike_times = X["spike_times"]
|
|
57
|
+
spike_times_index = X["spike_times_index"]
|
|
58
|
+
id = X["id"]
|
|
59
|
+
plots = []
|
|
60
|
+
num_units = len(spike_times_index)
|
|
61
|
+
start_times = []
|
|
62
|
+
end_times = []
|
|
63
|
+
for unit_index in range(num_units):
|
|
64
|
+
unit_id = id[unit_index]
|
|
65
|
+
if unit_index > 0:
|
|
66
|
+
start_index = spike_times_index[unit_index - 1]
|
|
67
|
+
else:
|
|
68
|
+
start_index = 0
|
|
69
|
+
end_index = spike_times_index[unit_index]
|
|
70
|
+
unit_spike_amplitudes = spike_amplitudes[start_index:end_index]
|
|
71
|
+
unit_spike_times = spike_times[start_index:end_index]
|
|
72
|
+
if len(unit_spike_times) == 0:
|
|
73
|
+
continue
|
|
74
|
+
start_times.append(unit_spike_times[0])
|
|
75
|
+
end_times.append(unit_spike_times[-1])
|
|
76
|
+
plots.append(
|
|
77
|
+
SpikeAmplitudesItem(
|
|
78
|
+
unit_id=str(unit_id),
|
|
79
|
+
spike_times_sec=unit_spike_times,
|
|
80
|
+
spike_amplitudes=unit_spike_amplitudes,
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
view = SpikeAmplitudes(
|
|
84
|
+
start_time_sec=min(start_times),
|
|
85
|
+
end_time_sec=max(end_times),
|
|
86
|
+
plots=plots,
|
|
87
|
+
)
|
|
88
|
+
if include_units_selector:
|
|
89
|
+
columns: List[UnitsTableColumn] = [
|
|
90
|
+
UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
|
|
91
|
+
]
|
|
92
|
+
rows: List[UnitsTableRow] = []
|
|
93
|
+
for unit_id in id:
|
|
94
|
+
rows.append(
|
|
95
|
+
UnitsTableRow(
|
|
96
|
+
unit_id=str(unit_id),
|
|
97
|
+
values={},
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
units_table = UnitsTable(
|
|
101
|
+
columns=columns,
|
|
102
|
+
rows=rows,
|
|
103
|
+
)
|
|
104
|
+
layout = Box(
|
|
105
|
+
direction="horizontal",
|
|
106
|
+
items=[
|
|
107
|
+
LayoutItem(view=units_table, max_size=150, title="Units"),
|
|
108
|
+
LayoutItem(view=view, title="Spike Amplitudes"),
|
|
109
|
+
],
|
|
110
|
+
)
|
|
111
|
+
return layout
|
|
112
|
+
else:
|
|
113
|
+
return view
|
|
43
114
|
|
|
44
115
|
def _write_to_zarr_group(self, group: zarr.Group) -> None:
|
|
45
116
|
"""
|
|
46
|
-
Write the SpikeAmplitudes data to a Zarr group
|
|
117
|
+
Write the SpikeAmplitudes data to a Zarr group using unified storage format
|
|
47
118
|
|
|
48
119
|
Args:
|
|
49
120
|
group: Zarr group to write data into
|
|
@@ -54,36 +125,250 @@ class SpikeAmplitudes(FigpackView):
|
|
|
54
125
|
# Store view parameters
|
|
55
126
|
group.attrs["start_time_sec"] = self.start_time_sec
|
|
56
127
|
group.attrs["end_time_sec"] = self.end_time_sec
|
|
57
|
-
group.attrs["hide_unit_selector"] = self.hide_unit_selector
|
|
58
|
-
group.attrs["height"] = self.height
|
|
59
|
-
|
|
60
|
-
# Store the number of plots
|
|
61
|
-
group.attrs["num_plots"] = len(self.plots)
|
|
62
|
-
|
|
63
|
-
# Store metadata for each plot
|
|
64
|
-
plot_metadata = []
|
|
65
|
-
for i, plot in enumerate(self.plots):
|
|
66
|
-
plot_name = f"plot_{i}"
|
|
67
|
-
|
|
68
|
-
# Store metadata
|
|
69
|
-
metadata = {
|
|
70
|
-
"name": plot_name,
|
|
71
|
-
"unit_id": str(plot.unit_id),
|
|
72
|
-
"num_spikes": len(plot.spike_times_sec),
|
|
73
|
-
}
|
|
74
|
-
plot_metadata.append(metadata)
|
|
75
128
|
|
|
76
|
-
|
|
129
|
+
# Prepare unified data arrays
|
|
130
|
+
unified_data = self._prepare_unified_data()
|
|
131
|
+
|
|
132
|
+
if unified_data["total_spikes"] == 0:
|
|
133
|
+
# Handle empty data case
|
|
134
|
+
group.create_dataset("timestamps", data=np.array([], dtype=np.float32))
|
|
135
|
+
group.create_dataset("unit_indices", data=np.array([], dtype=np.uint16))
|
|
136
|
+
group.create_dataset("amplitudes", data=np.array([], dtype=np.float32))
|
|
137
|
+
group.create_dataset("reference_times", data=np.array([], dtype=np.float32))
|
|
77
138
|
group.create_dataset(
|
|
78
|
-
|
|
79
|
-
data=plot.spike_times_sec,
|
|
80
|
-
dtype=np.float32,
|
|
139
|
+
"reference_indices", data=np.array([], dtype=np.uint32)
|
|
81
140
|
)
|
|
82
|
-
group.
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
141
|
+
group.attrs["unit_ids"] = []
|
|
142
|
+
group.attrs["total_spikes"] = 0
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
chunks = (
|
|
146
|
+
(2_000_000,)
|
|
147
|
+
if unified_data["total_spikes"] > 2_000_000
|
|
148
|
+
else (len(unified_data["timestamps"]),)
|
|
149
|
+
)
|
|
150
|
+
# Store main data arrays
|
|
151
|
+
group.create_dataset(
|
|
152
|
+
"timestamps",
|
|
153
|
+
data=unified_data["timestamps"],
|
|
154
|
+
dtype=np.float32,
|
|
155
|
+
chunks=chunks,
|
|
156
|
+
)
|
|
157
|
+
group.create_dataset(
|
|
158
|
+
"unit_indices",
|
|
159
|
+
data=unified_data["unit_indices"],
|
|
160
|
+
dtype=np.uint16,
|
|
161
|
+
chunks=chunks,
|
|
162
|
+
)
|
|
163
|
+
group.create_dataset(
|
|
164
|
+
"amplitudes",
|
|
165
|
+
data=unified_data["amplitudes"],
|
|
166
|
+
dtype=np.float32,
|
|
167
|
+
chunks=chunks,
|
|
168
|
+
)
|
|
169
|
+
group.create_dataset(
|
|
170
|
+
"reference_times",
|
|
171
|
+
data=unified_data["reference_times"],
|
|
172
|
+
dtype=np.float32,
|
|
173
|
+
chunks=(len(unified_data["reference_times"]),),
|
|
174
|
+
)
|
|
175
|
+
group.create_dataset(
|
|
176
|
+
"reference_indices",
|
|
177
|
+
data=unified_data["reference_indices"],
|
|
178
|
+
dtype=np.uint32,
|
|
179
|
+
chunks=(len(unified_data["reference_indices"]),),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Store unit ID mapping
|
|
183
|
+
group.attrs["unit_ids"] = unified_data["unit_ids"]
|
|
184
|
+
group.attrs["total_spikes"] = unified_data["total_spikes"]
|
|
185
|
+
|
|
186
|
+
# Create subsampled data
|
|
187
|
+
subsampled_data = self._create_subsampled_data(
|
|
188
|
+
unified_data["timestamps"],
|
|
189
|
+
unified_data["unit_indices"],
|
|
190
|
+
unified_data["amplitudes"],
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
if subsampled_data:
|
|
194
|
+
subsampled_group = group.create_group("subsampled_data")
|
|
195
|
+
for factor_name, data in subsampled_data.items():
|
|
196
|
+
chunks = (
|
|
197
|
+
(2_000_000,)
|
|
198
|
+
if len(data["timestamps"]) > 2_000_000
|
|
199
|
+
else (len(data["timestamps"]),)
|
|
200
|
+
)
|
|
201
|
+
factor_group = subsampled_group.create_group(factor_name)
|
|
202
|
+
factor_group.create_dataset(
|
|
203
|
+
"timestamps",
|
|
204
|
+
data=data["timestamps"],
|
|
205
|
+
dtype=np.float32,
|
|
206
|
+
chunks=chunks,
|
|
207
|
+
)
|
|
208
|
+
factor_group.create_dataset(
|
|
209
|
+
"unit_indices",
|
|
210
|
+
data=data["unit_indices"],
|
|
211
|
+
dtype=np.uint16,
|
|
212
|
+
chunks=chunks,
|
|
213
|
+
)
|
|
214
|
+
factor_group.create_dataset(
|
|
215
|
+
"amplitudes",
|
|
216
|
+
data=data["amplitudes"],
|
|
217
|
+
dtype=np.float32,
|
|
218
|
+
chunks=chunks,
|
|
219
|
+
)
|
|
220
|
+
factor_group.create_dataset(
|
|
221
|
+
"reference_times",
|
|
222
|
+
data=data["reference_times"],
|
|
223
|
+
dtype=np.float32,
|
|
224
|
+
chunks=(len(data["reference_times"]),),
|
|
225
|
+
)
|
|
226
|
+
factor_group.create_dataset(
|
|
227
|
+
"reference_indices",
|
|
228
|
+
data=data["reference_indices"],
|
|
229
|
+
dtype=np.uint32,
|
|
230
|
+
chunks=(len(data["reference_indices"]),),
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def _prepare_unified_data(self) -> dict:
|
|
234
|
+
"""
|
|
235
|
+
Prepare unified data arrays from all plots
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
Dictionary containing unified arrays and metadata
|
|
239
|
+
"""
|
|
240
|
+
if not self.plots:
|
|
241
|
+
return {
|
|
242
|
+
"timestamps": np.array([], dtype=np.float32),
|
|
243
|
+
"unit_indices": np.array([], dtype=np.uint16),
|
|
244
|
+
"amplitudes": np.array([], dtype=np.float32),
|
|
245
|
+
"reference_times": np.array([], dtype=np.float32),
|
|
246
|
+
"reference_indices": np.array([], dtype=np.uint32),
|
|
247
|
+
"unit_ids": [],
|
|
248
|
+
"total_spikes": 0,
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
# Create unit ID mapping
|
|
252
|
+
unit_ids = [str(plot.unit_id) for plot in self.plots]
|
|
253
|
+
unit_id_to_index = {unit_id: i for i, unit_id in enumerate(unit_ids)}
|
|
254
|
+
|
|
255
|
+
# Collect all spikes with their unit indices
|
|
256
|
+
all_spikes = []
|
|
257
|
+
for plot in self.plots:
|
|
258
|
+
unit_index = unit_id_to_index[str(plot.unit_id)]
|
|
259
|
+
for time, amplitude in zip(plot.spike_times_sec, plot.spike_amplitudes):
|
|
260
|
+
all_spikes.append((float(time), unit_index, float(amplitude)))
|
|
261
|
+
|
|
262
|
+
if not all_spikes:
|
|
263
|
+
return {
|
|
264
|
+
"timestamps": np.array([], dtype=np.float32),
|
|
265
|
+
"unit_indices": np.array([], dtype=np.uint16),
|
|
266
|
+
"amplitudes": np.array([], dtype=np.float32),
|
|
267
|
+
"reference_times": np.array([], dtype=np.float32),
|
|
268
|
+
"reference_indices": np.array([], dtype=np.uint32),
|
|
269
|
+
"unit_ids": unit_ids,
|
|
270
|
+
"total_spikes": 0,
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
# Sort by timestamp
|
|
274
|
+
all_spikes.sort(key=lambda x: x[0])
|
|
275
|
+
|
|
276
|
+
# Extract sorted arrays
|
|
277
|
+
timestamps = np.array([spike[0] for spike in all_spikes], dtype=np.float32)
|
|
278
|
+
unit_indices = np.array([spike[1] for spike in all_spikes], dtype=np.uint16)
|
|
279
|
+
amplitudes = np.array([spike[2] for spike in all_spikes], dtype=np.float32)
|
|
280
|
+
|
|
281
|
+
# Generate reference arrays
|
|
282
|
+
reference_times, reference_indices = self._generate_reference_arrays(timestamps)
|
|
283
|
+
|
|
284
|
+
return {
|
|
285
|
+
"timestamps": timestamps,
|
|
286
|
+
"unit_indices": unit_indices,
|
|
287
|
+
"amplitudes": amplitudes,
|
|
288
|
+
"reference_times": reference_times,
|
|
289
|
+
"reference_indices": reference_indices,
|
|
290
|
+
"unit_ids": unit_ids,
|
|
291
|
+
"total_spikes": len(all_spikes),
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
def _generate_reference_arrays(
|
|
295
|
+
self, timestamps: np.ndarray, interval_sec: float = 1.0
|
|
296
|
+
) -> tuple:
|
|
297
|
+
"""
|
|
298
|
+
Generate reference arrays using actual timestamps from the data
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
timestamps: Sorted array of timestamps
|
|
302
|
+
interval_sec: Minimum interval between reference points
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
Tuple of (reference_times, reference_indices)
|
|
306
|
+
"""
|
|
307
|
+
if len(timestamps) == 0:
|
|
308
|
+
return np.array([], dtype=np.float32), np.array([], dtype=np.uint32)
|
|
309
|
+
|
|
310
|
+
reference_times = []
|
|
311
|
+
reference_indices = []
|
|
312
|
+
|
|
313
|
+
current_ref_time = timestamps[0]
|
|
314
|
+
reference_times.append(current_ref_time)
|
|
315
|
+
reference_indices.append(0)
|
|
316
|
+
|
|
317
|
+
# Find the next reference point at least interval_sec later
|
|
318
|
+
for i, timestamp in enumerate(timestamps):
|
|
319
|
+
if timestamp >= current_ref_time + interval_sec:
|
|
320
|
+
reference_times.append(timestamp)
|
|
321
|
+
reference_indices.append(i)
|
|
322
|
+
current_ref_time = timestamp
|
|
323
|
+
|
|
324
|
+
return np.array(reference_times, dtype=np.float32), np.array(
|
|
325
|
+
reference_indices, dtype=np.uint32
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
def _create_subsampled_data(
|
|
329
|
+
self, timestamps: np.ndarray, unit_indices: np.ndarray, amplitudes: np.ndarray
|
|
330
|
+
) -> dict:
|
|
331
|
+
"""
|
|
332
|
+
Create subsampled data with geometric progression factors
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
timestamps: Original timestamps array
|
|
336
|
+
unit_indices: Original unit indices array
|
|
337
|
+
amplitudes: Original amplitudes array
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
Dictionary of subsampled data by factor
|
|
341
|
+
"""
|
|
342
|
+
subsampled_data = {}
|
|
343
|
+
factor = 4
|
|
344
|
+
current_timestamps = timestamps
|
|
345
|
+
current_unit_indices = unit_indices
|
|
346
|
+
current_amplitudes = amplitudes
|
|
347
|
+
|
|
348
|
+
while len(current_timestamps) >= 500000:
|
|
349
|
+
# Create subsampled version by taking every Nth spike
|
|
350
|
+
subsampled_indices = np.arange(0, len(current_timestamps), factor)
|
|
351
|
+
subsampled_timestamps = current_timestamps[subsampled_indices]
|
|
352
|
+
subsampled_unit_indices = current_unit_indices[subsampled_indices]
|
|
353
|
+
subsampled_amplitudes = current_amplitudes[subsampled_indices]
|
|
354
|
+
|
|
355
|
+
# Generate reference arrays for this subsampled level
|
|
356
|
+
ref_times, ref_indices = self._generate_reference_arrays(
|
|
357
|
+
subsampled_timestamps
|
|
86
358
|
)
|
|
87
359
|
|
|
88
|
-
|
|
89
|
-
|
|
360
|
+
subsampled_data[f"factor_{factor}"] = {
|
|
361
|
+
"timestamps": subsampled_timestamps,
|
|
362
|
+
"unit_indices": subsampled_unit_indices,
|
|
363
|
+
"amplitudes": subsampled_amplitudes,
|
|
364
|
+
"reference_times": ref_times,
|
|
365
|
+
"reference_indices": ref_indices,
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
# Prepare for next iteration
|
|
369
|
+
current_timestamps = subsampled_timestamps
|
|
370
|
+
current_unit_indices = subsampled_unit_indices
|
|
371
|
+
current_amplitudes = subsampled_amplitudes
|
|
372
|
+
factor *= 4 # Geometric progression: 4, 16, 64, 256, ...
|
|
373
|
+
|
|
374
|
+
return subsampled_data
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Spectrogram visualization component
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import zarr
|
|
10
|
+
|
|
11
|
+
from ..core.figpack_view import FigpackView
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Spectrogram(FigpackView):
|
|
15
|
+
"""
|
|
16
|
+
A spectrogram visualization component for time-frequency data
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
*,
|
|
22
|
+
start_time_sec: float,
|
|
23
|
+
sampling_frequency_hz: float,
|
|
24
|
+
frequency_min_hz: float,
|
|
25
|
+
frequency_delta_hz: float,
|
|
26
|
+
data: np.ndarray,
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
Initialize a Spectrogram view
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
start_time_sec: Starting time in seconds
|
|
33
|
+
sampling_frequency_hz: Sampling rate in Hz
|
|
34
|
+
frequency_min_hz: Minimum frequency in Hz
|
|
35
|
+
frequency_delta_hz: Frequency bin spacing in Hz
|
|
36
|
+
data: N×M numpy array where N is timepoints and M is frequency bins
|
|
37
|
+
"""
|
|
38
|
+
assert data.ndim == 2, "Data must be a 2D array (timepoints × frequencies)"
|
|
39
|
+
assert sampling_frequency_hz > 0, "Sampling frequency must be positive"
|
|
40
|
+
assert frequency_delta_hz > 0, "Frequency delta must be positive"
|
|
41
|
+
|
|
42
|
+
self.start_time_sec = start_time_sec
|
|
43
|
+
self.sampling_frequency_hz = sampling_frequency_hz
|
|
44
|
+
self.frequency_min_hz = frequency_min_hz
|
|
45
|
+
self.frequency_delta_hz = frequency_delta_hz
|
|
46
|
+
self.data = data.astype(np.float32) # Ensure float32 for efficiency
|
|
47
|
+
|
|
48
|
+
n_timepoints, n_frequencies = data.shape
|
|
49
|
+
self.n_timepoints = n_timepoints
|
|
50
|
+
self.n_frequencies = n_frequencies
|
|
51
|
+
|
|
52
|
+
# Calculate frequency bins
|
|
53
|
+
self.frequency_bins = (
|
|
54
|
+
frequency_min_hz + np.arange(n_frequencies) * frequency_delta_hz
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Calculate data range for color scaling
|
|
58
|
+
self.data_min = float(np.nanmin(data))
|
|
59
|
+
self.data_max = float(np.nanmax(data))
|
|
60
|
+
|
|
61
|
+
# Prepare downsampled arrays for efficient rendering
|
|
62
|
+
self.downsampled_data = self._compute_downsampled_data()
|
|
63
|
+
|
|
64
|
+
def _compute_downsampled_data(self) -> dict:
|
|
65
|
+
"""
|
|
66
|
+
Compute downsampled arrays at power-of-4 factors using max values only.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
dict: {factor: (ceil(N/factor), M) float32 array}, where each bin
|
|
70
|
+
contains the maximum value across the time dimension.
|
|
71
|
+
"""
|
|
72
|
+
data = self.data # (N, M), float32
|
|
73
|
+
n_timepoints, n_frequencies = data.shape
|
|
74
|
+
downsampled = {}
|
|
75
|
+
|
|
76
|
+
if n_timepoints < 4:
|
|
77
|
+
# No level with factor >= 4 fits the stop condition (factor < N)
|
|
78
|
+
return downsampled
|
|
79
|
+
|
|
80
|
+
def _first_level_from_raw(x: np.ndarray) -> np.ndarray:
|
|
81
|
+
"""Build the factor=4 level directly from the raw data."""
|
|
82
|
+
N, M = x.shape
|
|
83
|
+
n_bins = math.ceil(N / 4)
|
|
84
|
+
pad = n_bins * 4 - N
|
|
85
|
+
# Pad time axis with NaNs so max ignores the padded tail
|
|
86
|
+
x_pad = np.pad(
|
|
87
|
+
x, ((0, pad), (0, 0)), mode="constant", constant_values=np.nan
|
|
88
|
+
)
|
|
89
|
+
blk = x_pad.reshape(n_bins, 4, M) # (B, 4, M)
|
|
90
|
+
maxs = np.nanmax(blk, axis=1) # (B, M)
|
|
91
|
+
return maxs.astype(np.float32)
|
|
92
|
+
|
|
93
|
+
def _downsample4_bins(level_max: np.ndarray) -> np.ndarray:
|
|
94
|
+
"""
|
|
95
|
+
Build the next pyramid level from the previous one by grouping every 4
|
|
96
|
+
bins. Input is (B, M) -> Output is (ceil(B/4), M).
|
|
97
|
+
"""
|
|
98
|
+
B, M = level_max.shape
|
|
99
|
+
n_bins_next = math.ceil(B / 4)
|
|
100
|
+
pad = n_bins_next * 4 - B
|
|
101
|
+
lvl_pad = np.pad(
|
|
102
|
+
level_max,
|
|
103
|
+
((0, pad), (0, 0)),
|
|
104
|
+
mode="constant",
|
|
105
|
+
constant_values=np.nan,
|
|
106
|
+
)
|
|
107
|
+
blk = lvl_pad.reshape(n_bins_next, 4, M) # (B', 4, M)
|
|
108
|
+
|
|
109
|
+
# Next maxs from maxs
|
|
110
|
+
maxs = np.nanmax(blk, axis=1) # (B', M)
|
|
111
|
+
return maxs.astype(np.float32)
|
|
112
|
+
|
|
113
|
+
# Level 1: factor = 4 from raw data
|
|
114
|
+
factor = 4
|
|
115
|
+
level = _first_level_from_raw(data)
|
|
116
|
+
downsampled[factor] = level
|
|
117
|
+
|
|
118
|
+
# Higher levels: factor *= 4 each time, built from previous level
|
|
119
|
+
factor *= 4 # -> 16
|
|
120
|
+
while factor < n_timepoints / 1000:
|
|
121
|
+
level = _downsample4_bins(level)
|
|
122
|
+
downsampled[factor] = level
|
|
123
|
+
factor *= 4
|
|
124
|
+
|
|
125
|
+
return downsampled
|
|
126
|
+
|
|
127
|
+
def _calculate_optimal_chunk_size(
|
|
128
|
+
self, shape: tuple, target_size_mb: float = 5.0
|
|
129
|
+
) -> tuple:
|
|
130
|
+
"""
|
|
131
|
+
Calculate optimal chunk size for Zarr storage targeting ~5MB per chunk
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
shape: Array shape (n_timepoints, n_frequencies)
|
|
135
|
+
target_size_mb: Target chunk size in MB
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Tuple of chunk dimensions
|
|
139
|
+
"""
|
|
140
|
+
# Calculate bytes per element (float32 = 4 bytes)
|
|
141
|
+
bytes_per_element = 4
|
|
142
|
+
target_size_bytes = target_size_mb * 1024 * 1024
|
|
143
|
+
|
|
144
|
+
n_timepoints, n_frequencies = shape
|
|
145
|
+
elements_per_timepoint = n_frequencies
|
|
146
|
+
|
|
147
|
+
# Calculate chunk size in timepoints
|
|
148
|
+
max_timepoints_per_chunk = target_size_bytes // (
|
|
149
|
+
elements_per_timepoint * bytes_per_element
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Find next lower power of 2
|
|
153
|
+
chunk_timepoints = 2 ** math.floor(math.log2(max_timepoints_per_chunk))
|
|
154
|
+
chunk_timepoints = max(chunk_timepoints, 1) # At least 1
|
|
155
|
+
chunk_timepoints = min(chunk_timepoints, n_timepoints) # At most n_timepoints
|
|
156
|
+
|
|
157
|
+
# If n_timepoints is less than our calculated size, round down to next power of 2
|
|
158
|
+
if chunk_timepoints > n_timepoints:
|
|
159
|
+
chunk_timepoints = 2 ** math.floor(math.log2(n_timepoints))
|
|
160
|
+
|
|
161
|
+
return (chunk_timepoints, n_frequencies)
|
|
162
|
+
|
|
163
|
+
def _write_to_zarr_group(self, group: zarr.Group) -> None:
|
|
164
|
+
"""
|
|
165
|
+
Write the spectrogram data to a Zarr group
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
group: Zarr group to write data into
|
|
169
|
+
"""
|
|
170
|
+
group.attrs["view_type"] = "Spectrogram"
|
|
171
|
+
|
|
172
|
+
# Store metadata
|
|
173
|
+
group.attrs["start_time_sec"] = self.start_time_sec
|
|
174
|
+
group.attrs["sampling_frequency_hz"] = self.sampling_frequency_hz
|
|
175
|
+
group.attrs["frequency_min_hz"] = self.frequency_min_hz
|
|
176
|
+
group.attrs["frequency_delta_hz"] = self.frequency_delta_hz
|
|
177
|
+
group.attrs["n_timepoints"] = self.n_timepoints
|
|
178
|
+
group.attrs["n_frequencies"] = self.n_frequencies
|
|
179
|
+
group.attrs["data_min"] = self.data_min
|
|
180
|
+
group.attrs["data_max"] = self.data_max
|
|
181
|
+
|
|
182
|
+
# Store frequency bins
|
|
183
|
+
group.create_dataset(
|
|
184
|
+
"frequency_bins",
|
|
185
|
+
data=self.frequency_bins.astype(np.float32),
|
|
186
|
+
compression="blosc",
|
|
187
|
+
compression_opts={"cname": "lz4", "clevel": 5, "shuffle": 1},
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Store original data with optimal chunking
|
|
191
|
+
original_chunks = self._calculate_optimal_chunk_size(self.data.shape)
|
|
192
|
+
group.create_dataset(
|
|
193
|
+
"data",
|
|
194
|
+
data=self.data,
|
|
195
|
+
chunks=original_chunks,
|
|
196
|
+
compression="blosc",
|
|
197
|
+
compression_opts={"cname": "lz4", "clevel": 5, "shuffle": 1},
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Store downsampled data arrays
|
|
201
|
+
downsample_factors = list(self.downsampled_data.keys())
|
|
202
|
+
group.attrs["downsample_factors"] = downsample_factors
|
|
203
|
+
|
|
204
|
+
for factor, downsampled_array in self.downsampled_data.items():
|
|
205
|
+
dataset_name = f"data_ds_{factor}"
|
|
206
|
+
|
|
207
|
+
# Calculate optimal chunks for this downsampled array
|
|
208
|
+
ds_chunks = self._calculate_optimal_chunk_size(downsampled_array.shape)
|
|
209
|
+
|
|
210
|
+
group.create_dataset(
|
|
211
|
+
dataset_name,
|
|
212
|
+
data=downsampled_array,
|
|
213
|
+
chunks=ds_chunks,
|
|
214
|
+
compression="blosc",
|
|
215
|
+
compression_opts={"cname": "lz4", "clevel": 5, "shuffle": 1},
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
print(f"Stored Spectrogram with {len(downsample_factors)} downsampled levels:")
|
|
219
|
+
print(f" Original: {self.data.shape} (chunks: {original_chunks})")
|
|
220
|
+
for factor in downsample_factors:
|
|
221
|
+
ds_shape = self.downsampled_data[factor].shape
|
|
222
|
+
ds_chunks = self._calculate_optimal_chunk_size(ds_shape)
|
|
223
|
+
print(f" Factor {factor}: {ds_shape} (chunks: {ds_chunks})")
|
figpack/views/__init__.py
CHANGED
|
@@ -8,6 +8,7 @@ from .Markdown import Markdown
|
|
|
8
8
|
from .MatplotlibFigure import MatplotlibFigure
|
|
9
9
|
from .MultiChannelTimeseries import MultiChannelTimeseries
|
|
10
10
|
from .PlotlyFigure import PlotlyFigure
|
|
11
|
+
from .Spectrogram import Spectrogram
|
|
11
12
|
from .Splitter import Splitter
|
|
12
13
|
from .TabLayout import TabLayout
|
|
13
14
|
from .TabLayoutItem import TabLayoutItem
|