figpack 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of figpack might be problematic. Click here for more details.

@@ -5,7 +5,7 @@
5
5
  <link rel="icon" type="image/png" href="./assets/neurosift-logo-CLsuwLMO.png" />
6
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
7
  <title>figpack figure</title>
8
- <script type="module" crossorigin src="./assets/index-HXdk2TtM.js"></script>
8
+ <script type="module" crossorigin src="./assets/index-CjiTpC6i.js"></script>
9
9
  <link rel="stylesheet" crossorigin href="./assets/index-Cmae55E4.css">
10
10
  </head>
11
11
  <body>
@@ -8,6 +8,8 @@ import zarr
8
8
 
9
9
  from ...core.figpack_view import FigpackView
10
10
  from .RasterPlotItem import RasterPlotItem
11
+ from .UnitsTable import UnitsTable, UnitsTableColumn, UnitsTableRow
12
+ from ...views.Box import Box, LayoutItem
11
13
 
12
14
 
13
15
  class RasterPlot(FigpackView):
@@ -21,7 +23,6 @@ class RasterPlot(FigpackView):
21
23
  start_time_sec: float,
22
24
  end_time_sec: float,
23
25
  plots: List[RasterPlotItem],
24
- height: int = 500,
25
26
  ):
26
27
  """
27
28
  Initialize a RasterPlot view
@@ -35,43 +36,253 @@ class RasterPlot(FigpackView):
35
36
  self.start_time_sec = float(start_time_sec)
36
37
  self.end_time_sec = float(end_time_sec)
37
38
  self.plots = plots
38
- self.height = height
39
+
40
+ @staticmethod
41
+ def from_nwb_units_table(
42
+ nwb_url_or_path_or_h5py,
43
+ *,
44
+ units_path: str,
45
+ include_units_selector: bool = False,
46
+ ):
47
+ if isinstance(nwb_url_or_path_or_h5py, str):
48
+ import lindi
49
+
50
+ f = lindi.LindiH5pyFile.from_hdf5_file(nwb_url_or_path_or_h5py)
51
+ else:
52
+ f = nwb_url_or_path_or_h5py
53
+ X = f[units_path]
54
+ spike_times = X["spike_times"]
55
+ spike_times_index = X["spike_times_index"]
56
+ id = X["id"]
57
+ plots = []
58
+ num_units = len(spike_times_index)
59
+ start_times = []
60
+ end_times = []
61
+ for unit_index in range(num_units):
62
+ unit_id = id[unit_index]
63
+ if unit_index > 0:
64
+ start_index = spike_times_index[unit_index - 1]
65
+ else:
66
+ start_index = 0
67
+ end_index = spike_times_index[unit_index]
68
+ unit_spike_times = spike_times[start_index:end_index]
69
+ if len(unit_spike_times) == 0:
70
+ continue
71
+ start_times.append(unit_spike_times[0])
72
+ end_times.append(unit_spike_times[-1])
73
+ plots.append(
74
+ RasterPlotItem(unit_id=str(unit_id), spike_times_sec=unit_spike_times)
75
+ )
76
+ view = RasterPlot(
77
+ start_time_sec=min(start_times),
78
+ end_time_sec=max(end_times),
79
+ plots=plots,
80
+ )
81
+ if include_units_selector:
82
+ columns: List[UnitsTableColumn] = [
83
+ UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
84
+ ]
85
+ rows: List[UnitsTableRow] = []
86
+ for unit_id in id:
87
+ rows.append(
88
+ UnitsTableRow(
89
+ unit_id=str(unit_id),
90
+ values={},
91
+ )
92
+ )
93
+ units_table = UnitsTable(
94
+ columns=columns,
95
+ rows=rows,
96
+ )
97
+ layout = Box(
98
+ direction="horizontal",
99
+ items=[
100
+ LayoutItem(view=units_table, max_size=150, title="Units"),
101
+ LayoutItem(view=view, title="Spike Amplitudes"),
102
+ ],
103
+ )
104
+ return layout
105
+ else:
106
+ return view
39
107
 
40
108
  def _write_to_zarr_group(self, group: zarr.Group) -> None:
41
109
  """
42
- Write the RasterPlot data to a Zarr group
43
-
44
110
  Args:
45
111
  group: Zarr group to write data into
46
112
  """
47
113
  # Set the view type
48
114
  group.attrs["view_type"] = "RasterPlot"
49
115
 
50
- # Store view-level attributes
116
+ # Store view parameters
51
117
  group.attrs["start_time_sec"] = self.start_time_sec
52
118
  group.attrs["end_time_sec"] = self.end_time_sec
53
- group.attrs["height"] = self.height
54
- group.attrs["num_plots"] = len(self.plots)
55
-
56
- # Store metadata for each plot
57
- plot_metadata = []
58
- for i, plot in enumerate(self.plots):
59
- plot_name = f"plot_{i}"
60
-
61
- # Store metadata
62
- metadata = {
63
- "name": plot_name,
64
- "unit_id": str(plot.unit_id),
65
- "num_spikes": len(plot.spike_times_sec),
66
- }
67
- plot_metadata.append(metadata)
68
119
 
69
- # Create arrays for this plot
120
+ # Prepare unified data arrays
121
+ unified_data = self._prepare_unified_data()
122
+
123
+ if unified_data["total_spikes"] == 0:
124
+ # Handle empty data case
125
+ group.create_dataset("timestamps", data=np.array([], dtype=np.float32))
126
+ group.create_dataset("unit_indices", data=np.array([], dtype=np.uint16))
127
+ group.create_dataset("reference_times", data=np.array([], dtype=np.float32))
70
128
  group.create_dataset(
71
- f"{plot_name}/spike_times_sec",
72
- data=plot.spike_times_sec,
73
- dtype=np.float32,
129
+ "reference_indices", data=np.array([], dtype=np.uint32)
74
130
  )
131
+ group.attrs["unit_ids"] = []
132
+ group.attrs["total_spikes"] = 0
133
+ return
134
+
135
+ chunks = (
136
+ (2_000_000,)
137
+ if unified_data["total_spikes"] > 2_000_000
138
+ else (len(unified_data["timestamps"]),)
139
+ )
140
+ # Store main data arrays
141
+ group.create_dataset(
142
+ "timestamps",
143
+ data=unified_data["timestamps"],
144
+ dtype=np.float32,
145
+ chunks=chunks,
146
+ )
147
+ group.create_dataset(
148
+ "unit_indices",
149
+ data=unified_data["unit_indices"],
150
+ dtype=np.uint16,
151
+ chunks=chunks,
152
+ )
153
+ group.create_dataset(
154
+ "reference_times",
155
+ data=unified_data["reference_times"],
156
+ dtype=np.float32,
157
+ chunks=(len(unified_data["reference_times"]),),
158
+ )
159
+ group.create_dataset(
160
+ "reference_indices",
161
+ data=unified_data["reference_indices"],
162
+ dtype=np.uint32,
163
+ chunks=(len(unified_data["reference_indices"]),),
164
+ )
165
+
166
+ # Create spike counts array with 1-second bins
167
+ duration = self.end_time_sec - self.start_time_sec
168
+ num_bins = int(np.ceil(duration))
169
+ num_units = len(self.plots)
170
+ spike_counts = np.zeros((num_bins, num_units), dtype=np.uint16)
171
+
172
+ # Efficiently compute spike counts for each unit
173
+ for unit_idx, plot in enumerate(self.plots):
174
+ # Convert spike times to bin indices
175
+ bin_indices = (
176
+ (np.array(plot.spike_times_sec) - self.start_time_sec)
177
+ ).astype(int)
178
+ # Count spikes in valid bins
179
+ valid_indices = (bin_indices >= 0) & (bin_indices < num_bins)
180
+ unique_bins, counts = np.unique(
181
+ bin_indices[valid_indices], return_counts=True
182
+ )
183
+ spike_counts[unique_bins, unit_idx] = counts.clip(
184
+ max=65535
185
+ ) # Clip to uint16 max
186
+
187
+ # Store spike counts array
188
+ group.create_dataset(
189
+ "spike_counts_1sec",
190
+ data=spike_counts,
191
+ dtype=np.uint16,
192
+ chunks=(min(num_bins, 10000), min(num_units, 500)),
193
+ )
194
+
195
+ # Store unit ID mapping
196
+ group.attrs["unit_ids"] = unified_data["unit_ids"]
197
+ group.attrs["total_spikes"] = unified_data["total_spikes"]
198
+
199
+ def _prepare_unified_data(self) -> dict:
200
+ """
201
+ Prepare unified data arrays from all plots
202
+
203
+ Returns:
204
+ Dictionary containing unified arrays and metadata
205
+ """
206
+ if not self.plots:
207
+ return {
208
+ "timestamps": np.array([], dtype=np.float32),
209
+ "unit_indices": np.array([], dtype=np.uint16),
210
+ "reference_times": np.array([], dtype=np.float32),
211
+ "reference_indices": np.array([], dtype=np.uint32),
212
+ "unit_ids": [],
213
+ "total_spikes": 0,
214
+ }
215
+
216
+ # Create unit ID mapping
217
+ unit_ids = [str(plot.unit_id) for plot in self.plots]
218
+ unit_id_to_index = {unit_id: i for i, unit_id in enumerate(unit_ids)}
219
+
220
+ # Collect all spikes with their unit indices
221
+ all_spikes = []
222
+ for plot in self.plots:
223
+ unit_index = unit_id_to_index[str(plot.unit_id)]
224
+ for time in plot.spike_times_sec:
225
+ all_spikes.append((float(time), unit_index))
226
+
227
+ if not all_spikes:
228
+ return {
229
+ "timestamps": np.array([], dtype=np.float32),
230
+ "unit_indices": np.array([], dtype=np.uint16),
231
+ "reference_times": np.array([], dtype=np.float32),
232
+ "reference_indices": np.array([], dtype=np.uint32),
233
+ "unit_ids": unit_ids,
234
+ "total_spikes": 0,
235
+ }
236
+
237
+ # Sort by timestamp
238
+ all_spikes.sort(key=lambda x: x[0])
239
+
240
+ # Extract sorted arrays
241
+ timestamps = np.array([spike[0] for spike in all_spikes], dtype=np.float32)
242
+ unit_indices = np.array([spike[1] for spike in all_spikes], dtype=np.uint16)
243
+
244
+ # Generate reference arrays
245
+ reference_times, reference_indices = self._generate_reference_arrays(timestamps)
246
+
247
+ return {
248
+ "timestamps": timestamps,
249
+ "unit_indices": unit_indices,
250
+ "reference_times": reference_times,
251
+ "reference_indices": reference_indices,
252
+ "unit_ids": unit_ids,
253
+ "total_spikes": len(all_spikes),
254
+ }
255
+
256
+ def _generate_reference_arrays(
257
+ self, timestamps: np.ndarray, interval_sec: float = 1.0
258
+ ) -> tuple:
259
+ """
260
+ Generate reference arrays using actual timestamps from the data
261
+
262
+ Args:
263
+ timestamps: Sorted array of timestamps
264
+ interval_sec: Minimum interval between reference points
265
+
266
+ Returns:
267
+ Tuple of (reference_times, reference_indices)
268
+ """
269
+ if len(timestamps) == 0:
270
+ return np.array([], dtype=np.float32), np.array([], dtype=np.uint32)
271
+
272
+ reference_times = []
273
+ reference_indices = []
274
+
275
+ current_ref_time = timestamps[0]
276
+ reference_times.append(current_ref_time)
277
+ reference_indices.append(0)
278
+
279
+ # Find the next reference point at least interval_sec later
280
+ for i, timestamp in enumerate(timestamps):
281
+ if timestamp >= current_ref_time + interval_sec:
282
+ reference_times.append(timestamp)
283
+ reference_indices.append(i)
284
+ current_ref_time = timestamp
75
285
 
76
- # Store the plot metadata
77
- group.attrs["plots"] = plot_metadata
286
+ return np.array(reference_times, dtype=np.float32), np.array(
287
+ reference_indices, dtype=np.uint32
288
+ )