figpack 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of figpack might be problematic. Click here for more details.
- figpack/__init__.py +1 -1
- figpack/figpack-figure-dist/assets/{index-CTBd5_Gw.js → index-CjiTpC6i.js} +87 -87
- figpack/figpack-figure-dist/index.html +1 -1
- figpack/spike_sorting/views/RasterPlot.py +237 -26
- figpack/spike_sorting/views/SpikeAmplitudes.py +321 -36
- figpack/views/Spectrogram.py +223 -0
- figpack/views/__init__.py +1 -0
- {figpack-0.2.7.dist-info → figpack-0.2.8.dist-info}/METADATA +1 -1
- {figpack-0.2.7.dist-info → figpack-0.2.8.dist-info}/RECORD +13 -12
- {figpack-0.2.7.dist-info → figpack-0.2.8.dist-info}/WHEEL +0 -0
- {figpack-0.2.7.dist-info → figpack-0.2.8.dist-info}/entry_points.txt +0 -0
- {figpack-0.2.7.dist-info → figpack-0.2.8.dist-info}/licenses/LICENSE +0 -0
- {figpack-0.2.7.dist-info → figpack-0.2.8.dist-info}/top_level.txt +0 -0
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
<link rel="icon" type="image/png" href="./assets/neurosift-logo-CLsuwLMO.png" />
|
|
6
6
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
|
7
7
|
<title>figpack figure</title>
|
|
8
|
-
<script type="module" crossorigin src="./assets/index-
|
|
8
|
+
<script type="module" crossorigin src="./assets/index-CjiTpC6i.js"></script>
|
|
9
9
|
<link rel="stylesheet" crossorigin href="./assets/index-Cmae55E4.css">
|
|
10
10
|
</head>
|
|
11
11
|
<body>
|
|
@@ -8,6 +8,8 @@ import zarr
|
|
|
8
8
|
|
|
9
9
|
from ...core.figpack_view import FigpackView
|
|
10
10
|
from .RasterPlotItem import RasterPlotItem
|
|
11
|
+
from .UnitsTable import UnitsTable, UnitsTableColumn, UnitsTableRow
|
|
12
|
+
from ...views.Box import Box, LayoutItem
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
class RasterPlot(FigpackView):
|
|
@@ -21,7 +23,6 @@ class RasterPlot(FigpackView):
|
|
|
21
23
|
start_time_sec: float,
|
|
22
24
|
end_time_sec: float,
|
|
23
25
|
plots: List[RasterPlotItem],
|
|
24
|
-
height: int = 500,
|
|
25
26
|
):
|
|
26
27
|
"""
|
|
27
28
|
Initialize a RasterPlot view
|
|
@@ -35,43 +36,253 @@ class RasterPlot(FigpackView):
|
|
|
35
36
|
self.start_time_sec = float(start_time_sec)
|
|
36
37
|
self.end_time_sec = float(end_time_sec)
|
|
37
38
|
self.plots = plots
|
|
38
|
-
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def from_nwb_units_table(
|
|
42
|
+
nwb_url_or_path_or_h5py,
|
|
43
|
+
*,
|
|
44
|
+
units_path: str,
|
|
45
|
+
include_units_selector: bool = False,
|
|
46
|
+
):
|
|
47
|
+
if isinstance(nwb_url_or_path_or_h5py, str):
|
|
48
|
+
import lindi
|
|
49
|
+
|
|
50
|
+
f = lindi.LindiH5pyFile.from_hdf5_file(nwb_url_or_path_or_h5py)
|
|
51
|
+
else:
|
|
52
|
+
f = nwb_url_or_path_or_h5py
|
|
53
|
+
X = f[units_path]
|
|
54
|
+
spike_times = X["spike_times"]
|
|
55
|
+
spike_times_index = X["spike_times_index"]
|
|
56
|
+
id = X["id"]
|
|
57
|
+
plots = []
|
|
58
|
+
num_units = len(spike_times_index)
|
|
59
|
+
start_times = []
|
|
60
|
+
end_times = []
|
|
61
|
+
for unit_index in range(num_units):
|
|
62
|
+
unit_id = id[unit_index]
|
|
63
|
+
if unit_index > 0:
|
|
64
|
+
start_index = spike_times_index[unit_index - 1]
|
|
65
|
+
else:
|
|
66
|
+
start_index = 0
|
|
67
|
+
end_index = spike_times_index[unit_index]
|
|
68
|
+
unit_spike_times = spike_times[start_index:end_index]
|
|
69
|
+
if len(unit_spike_times) == 0:
|
|
70
|
+
continue
|
|
71
|
+
start_times.append(unit_spike_times[0])
|
|
72
|
+
end_times.append(unit_spike_times[-1])
|
|
73
|
+
plots.append(
|
|
74
|
+
RasterPlotItem(unit_id=str(unit_id), spike_times_sec=unit_spike_times)
|
|
75
|
+
)
|
|
76
|
+
view = RasterPlot(
|
|
77
|
+
start_time_sec=min(start_times),
|
|
78
|
+
end_time_sec=max(end_times),
|
|
79
|
+
plots=plots,
|
|
80
|
+
)
|
|
81
|
+
if include_units_selector:
|
|
82
|
+
columns: List[UnitsTableColumn] = [
|
|
83
|
+
UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
|
|
84
|
+
]
|
|
85
|
+
rows: List[UnitsTableRow] = []
|
|
86
|
+
for unit_id in id:
|
|
87
|
+
rows.append(
|
|
88
|
+
UnitsTableRow(
|
|
89
|
+
unit_id=str(unit_id),
|
|
90
|
+
values={},
|
|
91
|
+
)
|
|
92
|
+
)
|
|
93
|
+
units_table = UnitsTable(
|
|
94
|
+
columns=columns,
|
|
95
|
+
rows=rows,
|
|
96
|
+
)
|
|
97
|
+
layout = Box(
|
|
98
|
+
direction="horizontal",
|
|
99
|
+
items=[
|
|
100
|
+
LayoutItem(view=units_table, max_size=150, title="Units"),
|
|
101
|
+
LayoutItem(view=view, title="Spike Amplitudes"),
|
|
102
|
+
],
|
|
103
|
+
)
|
|
104
|
+
return layout
|
|
105
|
+
else:
|
|
106
|
+
return view
|
|
39
107
|
|
|
40
108
|
def _write_to_zarr_group(self, group: zarr.Group) -> None:
|
|
41
109
|
"""
|
|
42
|
-
Write the RasterPlot data to a Zarr group
|
|
43
|
-
|
|
44
110
|
Args:
|
|
45
111
|
group: Zarr group to write data into
|
|
46
112
|
"""
|
|
47
113
|
# Set the view type
|
|
48
114
|
group.attrs["view_type"] = "RasterPlot"
|
|
49
115
|
|
|
50
|
-
# Store view
|
|
116
|
+
# Store view parameters
|
|
51
117
|
group.attrs["start_time_sec"] = self.start_time_sec
|
|
52
118
|
group.attrs["end_time_sec"] = self.end_time_sec
|
|
53
|
-
group.attrs["height"] = self.height
|
|
54
|
-
group.attrs["num_plots"] = len(self.plots)
|
|
55
|
-
|
|
56
|
-
# Store metadata for each plot
|
|
57
|
-
plot_metadata = []
|
|
58
|
-
for i, plot in enumerate(self.plots):
|
|
59
|
-
plot_name = f"plot_{i}"
|
|
60
|
-
|
|
61
|
-
# Store metadata
|
|
62
|
-
metadata = {
|
|
63
|
-
"name": plot_name,
|
|
64
|
-
"unit_id": str(plot.unit_id),
|
|
65
|
-
"num_spikes": len(plot.spike_times_sec),
|
|
66
|
-
}
|
|
67
|
-
plot_metadata.append(metadata)
|
|
68
119
|
|
|
69
|
-
|
|
120
|
+
# Prepare unified data arrays
|
|
121
|
+
unified_data = self._prepare_unified_data()
|
|
122
|
+
|
|
123
|
+
if unified_data["total_spikes"] == 0:
|
|
124
|
+
# Handle empty data case
|
|
125
|
+
group.create_dataset("timestamps", data=np.array([], dtype=np.float32))
|
|
126
|
+
group.create_dataset("unit_indices", data=np.array([], dtype=np.uint16))
|
|
127
|
+
group.create_dataset("reference_times", data=np.array([], dtype=np.float32))
|
|
70
128
|
group.create_dataset(
|
|
71
|
-
|
|
72
|
-
data=plot.spike_times_sec,
|
|
73
|
-
dtype=np.float32,
|
|
129
|
+
"reference_indices", data=np.array([], dtype=np.uint32)
|
|
74
130
|
)
|
|
131
|
+
group.attrs["unit_ids"] = []
|
|
132
|
+
group.attrs["total_spikes"] = 0
|
|
133
|
+
return
|
|
134
|
+
|
|
135
|
+
chunks = (
|
|
136
|
+
(2_000_000,)
|
|
137
|
+
if unified_data["total_spikes"] > 2_000_000
|
|
138
|
+
else (len(unified_data["timestamps"]),)
|
|
139
|
+
)
|
|
140
|
+
# Store main data arrays
|
|
141
|
+
group.create_dataset(
|
|
142
|
+
"timestamps",
|
|
143
|
+
data=unified_data["timestamps"],
|
|
144
|
+
dtype=np.float32,
|
|
145
|
+
chunks=chunks,
|
|
146
|
+
)
|
|
147
|
+
group.create_dataset(
|
|
148
|
+
"unit_indices",
|
|
149
|
+
data=unified_data["unit_indices"],
|
|
150
|
+
dtype=np.uint16,
|
|
151
|
+
chunks=chunks,
|
|
152
|
+
)
|
|
153
|
+
group.create_dataset(
|
|
154
|
+
"reference_times",
|
|
155
|
+
data=unified_data["reference_times"],
|
|
156
|
+
dtype=np.float32,
|
|
157
|
+
chunks=(len(unified_data["reference_times"]),),
|
|
158
|
+
)
|
|
159
|
+
group.create_dataset(
|
|
160
|
+
"reference_indices",
|
|
161
|
+
data=unified_data["reference_indices"],
|
|
162
|
+
dtype=np.uint32,
|
|
163
|
+
chunks=(len(unified_data["reference_indices"]),),
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Create spike counts array with 1-second bins
|
|
167
|
+
duration = self.end_time_sec - self.start_time_sec
|
|
168
|
+
num_bins = int(np.ceil(duration))
|
|
169
|
+
num_units = len(self.plots)
|
|
170
|
+
spike_counts = np.zeros((num_bins, num_units), dtype=np.uint16)
|
|
171
|
+
|
|
172
|
+
# Efficiently compute spike counts for each unit
|
|
173
|
+
for unit_idx, plot in enumerate(self.plots):
|
|
174
|
+
# Convert spike times to bin indices
|
|
175
|
+
bin_indices = (
|
|
176
|
+
(np.array(plot.spike_times_sec) - self.start_time_sec)
|
|
177
|
+
).astype(int)
|
|
178
|
+
# Count spikes in valid bins
|
|
179
|
+
valid_indices = (bin_indices >= 0) & (bin_indices < num_bins)
|
|
180
|
+
unique_bins, counts = np.unique(
|
|
181
|
+
bin_indices[valid_indices], return_counts=True
|
|
182
|
+
)
|
|
183
|
+
spike_counts[unique_bins, unit_idx] = counts.clip(
|
|
184
|
+
max=65535
|
|
185
|
+
) # Clip to uint16 max
|
|
186
|
+
|
|
187
|
+
# Store spike counts array
|
|
188
|
+
group.create_dataset(
|
|
189
|
+
"spike_counts_1sec",
|
|
190
|
+
data=spike_counts,
|
|
191
|
+
dtype=np.uint16,
|
|
192
|
+
chunks=(min(num_bins, 10000), min(num_units, 500)),
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Store unit ID mapping
|
|
196
|
+
group.attrs["unit_ids"] = unified_data["unit_ids"]
|
|
197
|
+
group.attrs["total_spikes"] = unified_data["total_spikes"]
|
|
198
|
+
|
|
199
|
+
def _prepare_unified_data(self) -> dict:
|
|
200
|
+
"""
|
|
201
|
+
Prepare unified data arrays from all plots
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
Dictionary containing unified arrays and metadata
|
|
205
|
+
"""
|
|
206
|
+
if not self.plots:
|
|
207
|
+
return {
|
|
208
|
+
"timestamps": np.array([], dtype=np.float32),
|
|
209
|
+
"unit_indices": np.array([], dtype=np.uint16),
|
|
210
|
+
"reference_times": np.array([], dtype=np.float32),
|
|
211
|
+
"reference_indices": np.array([], dtype=np.uint32),
|
|
212
|
+
"unit_ids": [],
|
|
213
|
+
"total_spikes": 0,
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
# Create unit ID mapping
|
|
217
|
+
unit_ids = [str(plot.unit_id) for plot in self.plots]
|
|
218
|
+
unit_id_to_index = {unit_id: i for i, unit_id in enumerate(unit_ids)}
|
|
219
|
+
|
|
220
|
+
# Collect all spikes with their unit indices
|
|
221
|
+
all_spikes = []
|
|
222
|
+
for plot in self.plots:
|
|
223
|
+
unit_index = unit_id_to_index[str(plot.unit_id)]
|
|
224
|
+
for time in plot.spike_times_sec:
|
|
225
|
+
all_spikes.append((float(time), unit_index))
|
|
226
|
+
|
|
227
|
+
if not all_spikes:
|
|
228
|
+
return {
|
|
229
|
+
"timestamps": np.array([], dtype=np.float32),
|
|
230
|
+
"unit_indices": np.array([], dtype=np.uint16),
|
|
231
|
+
"reference_times": np.array([], dtype=np.float32),
|
|
232
|
+
"reference_indices": np.array([], dtype=np.uint32),
|
|
233
|
+
"unit_ids": unit_ids,
|
|
234
|
+
"total_spikes": 0,
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
# Sort by timestamp
|
|
238
|
+
all_spikes.sort(key=lambda x: x[0])
|
|
239
|
+
|
|
240
|
+
# Extract sorted arrays
|
|
241
|
+
timestamps = np.array([spike[0] for spike in all_spikes], dtype=np.float32)
|
|
242
|
+
unit_indices = np.array([spike[1] for spike in all_spikes], dtype=np.uint16)
|
|
243
|
+
|
|
244
|
+
# Generate reference arrays
|
|
245
|
+
reference_times, reference_indices = self._generate_reference_arrays(timestamps)
|
|
246
|
+
|
|
247
|
+
return {
|
|
248
|
+
"timestamps": timestamps,
|
|
249
|
+
"unit_indices": unit_indices,
|
|
250
|
+
"reference_times": reference_times,
|
|
251
|
+
"reference_indices": reference_indices,
|
|
252
|
+
"unit_ids": unit_ids,
|
|
253
|
+
"total_spikes": len(all_spikes),
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
def _generate_reference_arrays(
|
|
257
|
+
self, timestamps: np.ndarray, interval_sec: float = 1.0
|
|
258
|
+
) -> tuple:
|
|
259
|
+
"""
|
|
260
|
+
Generate reference arrays using actual timestamps from the data
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
timestamps: Sorted array of timestamps
|
|
264
|
+
interval_sec: Minimum interval between reference points
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
Tuple of (reference_times, reference_indices)
|
|
268
|
+
"""
|
|
269
|
+
if len(timestamps) == 0:
|
|
270
|
+
return np.array([], dtype=np.float32), np.array([], dtype=np.uint32)
|
|
271
|
+
|
|
272
|
+
reference_times = []
|
|
273
|
+
reference_indices = []
|
|
274
|
+
|
|
275
|
+
current_ref_time = timestamps[0]
|
|
276
|
+
reference_times.append(current_ref_time)
|
|
277
|
+
reference_indices.append(0)
|
|
278
|
+
|
|
279
|
+
# Find the next reference point at least interval_sec later
|
|
280
|
+
for i, timestamp in enumerate(timestamps):
|
|
281
|
+
if timestamp >= current_ref_time + interval_sec:
|
|
282
|
+
reference_times.append(timestamp)
|
|
283
|
+
reference_indices.append(i)
|
|
284
|
+
current_ref_time = timestamp
|
|
75
285
|
|
|
76
|
-
|
|
77
|
-
|
|
286
|
+
return np.array(reference_times, dtype=np.float32), np.array(
|
|
287
|
+
reference_indices, dtype=np.uint32
|
|
288
|
+
)
|