figpack 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of figpack might be problematic. Click here for more details.

@@ -2,13 +2,15 @@
2
2
  SpikeAmplitudes view for figpack - displays spike amplitudes over time
3
3
  """
4
4
 
5
- from typing import List, Optional
5
+ from typing import List
6
6
 
7
7
  import numpy as np
8
8
  import zarr
9
9
 
10
10
  from ...core.figpack_view import FigpackView
11
11
  from .SpikeAmplitudesItem import SpikeAmplitudesItem
12
+ from .UnitsTable import UnitsTable, UnitsTableColumn, UnitsTableRow
13
+ from ...views.Box import Box, LayoutItem
12
14
 
13
15
 
14
16
  class SpikeAmplitudes(FigpackView):
@@ -22,8 +24,6 @@ class SpikeAmplitudes(FigpackView):
22
24
  start_time_sec: float,
23
25
  end_time_sec: float,
24
26
  plots: List[SpikeAmplitudesItem],
25
- hide_unit_selector: bool = False,
26
- height: int = 500,
27
27
  ):
28
28
  """
29
29
  Initialize a SpikeAmplitudes view
@@ -32,18 +32,89 @@ class SpikeAmplitudes(FigpackView):
32
32
  start_time_sec: Start time of the view in seconds
33
33
  end_time_sec: End time of the view in seconds
34
34
  plots: List of SpikeAmplitudesItem objects
35
- hide_unit_selector: Whether to hide the unit selector
36
- height: Height of the view in pixels
37
35
  """
38
36
  self.start_time_sec = start_time_sec
39
37
  self.end_time_sec = end_time_sec
40
38
  self.plots = plots
41
- self.hide_unit_selector = hide_unit_selector
42
- self.height = height
39
+
40
+ @staticmethod
41
+ def from_nwb_units_table(
42
+ nwb_url_or_path_or_h5py,
43
+ *,
44
+ units_path: str,
45
+ include_units_selector: bool = False,
46
+ ):
47
+ if isinstance(nwb_url_or_path_or_h5py, str):
48
+ import lindi
49
+
50
+ f = lindi.LindiH5pyFile.from_hdf5_file(nwb_url_or_path_or_h5py)
51
+ else:
52
+ f = nwb_url_or_path_or_h5py
53
+ X = f[units_path]
54
+ spike_amplitudes = X["spike_amplitudes"]
55
+ # spike_amplitudes_index = X["spike_amplitudes_index"] # presumably the same as spike_times_index
56
+ spike_times = X["spike_times"]
57
+ spike_times_index = X["spike_times_index"]
58
+ id = X["id"]
59
+ plots = []
60
+ num_units = len(spike_times_index)
61
+ start_times = []
62
+ end_times = []
63
+ for unit_index in range(num_units):
64
+ unit_id = id[unit_index]
65
+ if unit_index > 0:
66
+ start_index = spike_times_index[unit_index - 1]
67
+ else:
68
+ start_index = 0
69
+ end_index = spike_times_index[unit_index]
70
+ unit_spike_amplitudes = spike_amplitudes[start_index:end_index]
71
+ unit_spike_times = spike_times[start_index:end_index]
72
+ if len(unit_spike_times) == 0:
73
+ continue
74
+ start_times.append(unit_spike_times[0])
75
+ end_times.append(unit_spike_times[-1])
76
+ plots.append(
77
+ SpikeAmplitudesItem(
78
+ unit_id=str(unit_id),
79
+ spike_times_sec=unit_spike_times,
80
+ spike_amplitudes=unit_spike_amplitudes,
81
+ )
82
+ )
83
+ view = SpikeAmplitudes(
84
+ start_time_sec=min(start_times),
85
+ end_time_sec=max(end_times),
86
+ plots=plots,
87
+ )
88
+ if include_units_selector:
89
+ columns: List[UnitsTableColumn] = [
90
+ UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
91
+ ]
92
+ rows: List[UnitsTableRow] = []
93
+ for unit_id in id:
94
+ rows.append(
95
+ UnitsTableRow(
96
+ unit_id=str(unit_id),
97
+ values={},
98
+ )
99
+ )
100
+ units_table = UnitsTable(
101
+ columns=columns,
102
+ rows=rows,
103
+ )
104
+ layout = Box(
105
+ direction="horizontal",
106
+ items=[
107
+ LayoutItem(view=units_table, max_size=150, title="Units"),
108
+ LayoutItem(view=view, title="Spike Amplitudes"),
109
+ ],
110
+ )
111
+ return layout
112
+ else:
113
+ return view
43
114
 
44
115
  def _write_to_zarr_group(self, group: zarr.Group) -> None:
45
116
  """
46
- Write the SpikeAmplitudes data to a Zarr group
117
+ Write the SpikeAmplitudes data to a Zarr group using unified storage format
47
118
 
48
119
  Args:
49
120
  group: Zarr group to write data into
@@ -54,36 +125,250 @@ class SpikeAmplitudes(FigpackView):
54
125
  # Store view parameters
55
126
  group.attrs["start_time_sec"] = self.start_time_sec
56
127
  group.attrs["end_time_sec"] = self.end_time_sec
57
- group.attrs["hide_unit_selector"] = self.hide_unit_selector
58
- group.attrs["height"] = self.height
59
-
60
- # Store the number of plots
61
- group.attrs["num_plots"] = len(self.plots)
62
-
63
- # Store metadata for each plot
64
- plot_metadata = []
65
- for i, plot in enumerate(self.plots):
66
- plot_name = f"plot_{i}"
67
-
68
- # Store metadata
69
- metadata = {
70
- "name": plot_name,
71
- "unit_id": str(plot.unit_id),
72
- "num_spikes": len(plot.spike_times_sec),
73
- }
74
- plot_metadata.append(metadata)
75
128
 
76
- # Create arrays for this plot
129
+ # Prepare unified data arrays
130
+ unified_data = self._prepare_unified_data()
131
+
132
+ if unified_data["total_spikes"] == 0:
133
+ # Handle empty data case
134
+ group.create_dataset("timestamps", data=np.array([], dtype=np.float32))
135
+ group.create_dataset("unit_indices", data=np.array([], dtype=np.uint16))
136
+ group.create_dataset("amplitudes", data=np.array([], dtype=np.float32))
137
+ group.create_dataset("reference_times", data=np.array([], dtype=np.float32))
77
138
  group.create_dataset(
78
- f"{plot_name}/spike_times_sec",
79
- data=plot.spike_times_sec,
80
- dtype=np.float32,
139
+ "reference_indices", data=np.array([], dtype=np.uint32)
81
140
  )
82
- group.create_dataset(
83
- f"{plot_name}/spike_amplitudes",
84
- data=plot.spike_amplitudes,
85
- dtype=np.float32,
141
+ group.attrs["unit_ids"] = []
142
+ group.attrs["total_spikes"] = 0
143
+ return
144
+
145
+ chunks = (
146
+ (2_000_000,)
147
+ if unified_data["total_spikes"] > 2_000_000
148
+ else (len(unified_data["timestamps"]),)
149
+ )
150
+ # Store main data arrays
151
+ group.create_dataset(
152
+ "timestamps",
153
+ data=unified_data["timestamps"],
154
+ dtype=np.float32,
155
+ chunks=chunks,
156
+ )
157
+ group.create_dataset(
158
+ "unit_indices",
159
+ data=unified_data["unit_indices"],
160
+ dtype=np.uint16,
161
+ chunks=chunks,
162
+ )
163
+ group.create_dataset(
164
+ "amplitudes",
165
+ data=unified_data["amplitudes"],
166
+ dtype=np.float32,
167
+ chunks=chunks,
168
+ )
169
+ group.create_dataset(
170
+ "reference_times",
171
+ data=unified_data["reference_times"],
172
+ dtype=np.float32,
173
+ chunks=(len(unified_data["reference_times"]),),
174
+ )
175
+ group.create_dataset(
176
+ "reference_indices",
177
+ data=unified_data["reference_indices"],
178
+ dtype=np.uint32,
179
+ chunks=(len(unified_data["reference_indices"]),),
180
+ )
181
+
182
+ # Store unit ID mapping
183
+ group.attrs["unit_ids"] = unified_data["unit_ids"]
184
+ group.attrs["total_spikes"] = unified_data["total_spikes"]
185
+
186
+ # Create subsampled data
187
+ subsampled_data = self._create_subsampled_data(
188
+ unified_data["timestamps"],
189
+ unified_data["unit_indices"],
190
+ unified_data["amplitudes"],
191
+ )
192
+
193
+ if subsampled_data:
194
+ subsampled_group = group.create_group("subsampled_data")
195
+ for factor_name, data in subsampled_data.items():
196
+ chunks = (
197
+ (2_000_000,)
198
+ if len(data["timestamps"]) > 2_000_000
199
+ else (len(data["timestamps"]),)
200
+ )
201
+ factor_group = subsampled_group.create_group(factor_name)
202
+ factor_group.create_dataset(
203
+ "timestamps",
204
+ data=data["timestamps"],
205
+ dtype=np.float32,
206
+ chunks=chunks,
207
+ )
208
+ factor_group.create_dataset(
209
+ "unit_indices",
210
+ data=data["unit_indices"],
211
+ dtype=np.uint16,
212
+ chunks=chunks,
213
+ )
214
+ factor_group.create_dataset(
215
+ "amplitudes",
216
+ data=data["amplitudes"],
217
+ dtype=np.float32,
218
+ chunks=chunks,
219
+ )
220
+ factor_group.create_dataset(
221
+ "reference_times",
222
+ data=data["reference_times"],
223
+ dtype=np.float32,
224
+ chunks=(len(data["reference_times"]),),
225
+ )
226
+ factor_group.create_dataset(
227
+ "reference_indices",
228
+ data=data["reference_indices"],
229
+ dtype=np.uint32,
230
+ chunks=(len(data["reference_indices"]),),
231
+ )
232
+
233
+ def _prepare_unified_data(self) -> dict:
234
+ """
235
+ Prepare unified data arrays from all plots
236
+
237
+ Returns:
238
+ Dictionary containing unified arrays and metadata
239
+ """
240
+ if not self.plots:
241
+ return {
242
+ "timestamps": np.array([], dtype=np.float32),
243
+ "unit_indices": np.array([], dtype=np.uint16),
244
+ "amplitudes": np.array([], dtype=np.float32),
245
+ "reference_times": np.array([], dtype=np.float32),
246
+ "reference_indices": np.array([], dtype=np.uint32),
247
+ "unit_ids": [],
248
+ "total_spikes": 0,
249
+ }
250
+
251
+ # Create unit ID mapping
252
+ unit_ids = [str(plot.unit_id) for plot in self.plots]
253
+ unit_id_to_index = {unit_id: i for i, unit_id in enumerate(unit_ids)}
254
+
255
+ # Collect all spikes with their unit indices
256
+ all_spikes = []
257
+ for plot in self.plots:
258
+ unit_index = unit_id_to_index[str(plot.unit_id)]
259
+ for time, amplitude in zip(plot.spike_times_sec, plot.spike_amplitudes):
260
+ all_spikes.append((float(time), unit_index, float(amplitude)))
261
+
262
+ if not all_spikes:
263
+ return {
264
+ "timestamps": np.array([], dtype=np.float32),
265
+ "unit_indices": np.array([], dtype=np.uint16),
266
+ "amplitudes": np.array([], dtype=np.float32),
267
+ "reference_times": np.array([], dtype=np.float32),
268
+ "reference_indices": np.array([], dtype=np.uint32),
269
+ "unit_ids": unit_ids,
270
+ "total_spikes": 0,
271
+ }
272
+
273
+ # Sort by timestamp
274
+ all_spikes.sort(key=lambda x: x[0])
275
+
276
+ # Extract sorted arrays
277
+ timestamps = np.array([spike[0] for spike in all_spikes], dtype=np.float32)
278
+ unit_indices = np.array([spike[1] for spike in all_spikes], dtype=np.uint16)
279
+ amplitudes = np.array([spike[2] for spike in all_spikes], dtype=np.float32)
280
+
281
+ # Generate reference arrays
282
+ reference_times, reference_indices = self._generate_reference_arrays(timestamps)
283
+
284
+ return {
285
+ "timestamps": timestamps,
286
+ "unit_indices": unit_indices,
287
+ "amplitudes": amplitudes,
288
+ "reference_times": reference_times,
289
+ "reference_indices": reference_indices,
290
+ "unit_ids": unit_ids,
291
+ "total_spikes": len(all_spikes),
292
+ }
293
+
294
+ def _generate_reference_arrays(
295
+ self, timestamps: np.ndarray, interval_sec: float = 1.0
296
+ ) -> tuple:
297
+ """
298
+ Generate reference arrays using actual timestamps from the data
299
+
300
+ Args:
301
+ timestamps: Sorted array of timestamps
302
+ interval_sec: Minimum interval between reference points
303
+
304
+ Returns:
305
+ Tuple of (reference_times, reference_indices)
306
+ """
307
+ if len(timestamps) == 0:
308
+ return np.array([], dtype=np.float32), np.array([], dtype=np.uint32)
309
+
310
+ reference_times = []
311
+ reference_indices = []
312
+
313
+ current_ref_time = timestamps[0]
314
+ reference_times.append(current_ref_time)
315
+ reference_indices.append(0)
316
+
317
+ # Find the next reference point at least interval_sec later
318
+ for i, timestamp in enumerate(timestamps):
319
+ if timestamp >= current_ref_time + interval_sec:
320
+ reference_times.append(timestamp)
321
+ reference_indices.append(i)
322
+ current_ref_time = timestamp
323
+
324
+ return np.array(reference_times, dtype=np.float32), np.array(
325
+ reference_indices, dtype=np.uint32
326
+ )
327
+
328
+ def _create_subsampled_data(
329
+ self, timestamps: np.ndarray, unit_indices: np.ndarray, amplitudes: np.ndarray
330
+ ) -> dict:
331
+ """
332
+ Create subsampled data with geometric progression factors
333
+
334
+ Args:
335
+ timestamps: Original timestamps array
336
+ unit_indices: Original unit indices array
337
+ amplitudes: Original amplitudes array
338
+
339
+ Returns:
340
+ Dictionary of subsampled data by factor
341
+ """
342
+ subsampled_data = {}
343
+ factor = 4
344
+ current_timestamps = timestamps
345
+ current_unit_indices = unit_indices
346
+ current_amplitudes = amplitudes
347
+
348
+ while len(current_timestamps) >= 500000:
349
+ # Create subsampled version by taking every Nth spike
350
+ subsampled_indices = np.arange(0, len(current_timestamps), factor)
351
+ subsampled_timestamps = current_timestamps[subsampled_indices]
352
+ subsampled_unit_indices = current_unit_indices[subsampled_indices]
353
+ subsampled_amplitudes = current_amplitudes[subsampled_indices]
354
+
355
+ # Generate reference arrays for this subsampled level
356
+ ref_times, ref_indices = self._generate_reference_arrays(
357
+ subsampled_timestamps
86
358
  )
87
359
 
88
- # Store the plot metadata
89
- group.attrs["plots"] = plot_metadata
360
+ subsampled_data[f"factor_{factor}"] = {
361
+ "timestamps": subsampled_timestamps,
362
+ "unit_indices": subsampled_unit_indices,
363
+ "amplitudes": subsampled_amplitudes,
364
+ "reference_times": ref_times,
365
+ "reference_indices": ref_indices,
366
+ }
367
+
368
+ # Prepare for next iteration
369
+ current_timestamps = subsampled_timestamps
370
+ current_unit_indices = subsampled_unit_indices
371
+ current_amplitudes = subsampled_amplitudes
372
+ factor *= 4 # Geometric progression: 4, 16, 64, 256, ...
373
+
374
+ return subsampled_data
@@ -0,0 +1,109 @@
1
+ """
2
+ DataFrame view for figpack - displays pandas DataFrames as interactive tables
3
+ """
4
+
5
+ import json
6
+ from typing import Any, Dict, Union
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import zarr
11
+
12
+ from ..core.figpack_view import FigpackView
13
+
14
+
15
+ class DataFrame(FigpackView):
16
+ """
17
+ A DataFrame visualization component for displaying pandas DataFrames as interactive tables
18
+ """
19
+
20
+ def __init__(self, df: pd.DataFrame):
21
+ """
22
+ Initialize a DataFrame view
23
+
24
+ Args:
25
+ df: The pandas DataFrame to display
26
+
27
+ Raises:
28
+ ValueError: If df is not a pandas DataFrame
29
+ """
30
+ if not isinstance(df, pd.DataFrame):
31
+ raise ValueError("df must be a pandas DataFrame")
32
+
33
+ self.df = df
34
+
35
+ def _write_to_zarr_group(self, group: zarr.Group) -> None:
36
+ """
37
+ Write the DataFrame data to a Zarr group
38
+
39
+ Args:
40
+ group: Zarr group to write data into
41
+ """
42
+ # Set the view type
43
+ group.attrs["view_type"] = "DataFrame"
44
+
45
+ try:
46
+ # Convert DataFrame to CSV string
47
+ csv_string = self.df.to_csv(index=False)
48
+
49
+ # Convert CSV string to bytes and store in numpy array
50
+ csv_bytes = csv_string.encode("utf-8")
51
+ csv_array = np.frombuffer(csv_bytes, dtype=np.uint8)
52
+
53
+ # Store the CSV data as compressed array
54
+ group.create_dataset(
55
+ "csv_data",
56
+ data=csv_array,
57
+ dtype=np.uint8,
58
+ chunks=True,
59
+ compressor=zarr.Blosc(
60
+ cname="zstd", clevel=3, shuffle=zarr.Blosc.SHUFFLE
61
+ ),
62
+ )
63
+
64
+ # Store metadata about the DataFrame
65
+ group.attrs["data_size"] = len(csv_bytes)
66
+ group.attrs["row_count"] = len(self.df)
67
+ group.attrs["column_count"] = len(self.df.columns)
68
+
69
+ # Store column information
70
+ column_info = []
71
+ for col in self.df.columns:
72
+ dtype_str = str(self.df[col].dtype)
73
+ # Simplify dtype names for frontend
74
+ if dtype_str.startswith("int"):
75
+ simple_dtype = "integer"
76
+ elif dtype_str.startswith("float"):
77
+ simple_dtype = "float"
78
+ elif dtype_str.startswith("bool"):
79
+ simple_dtype = "boolean"
80
+ elif dtype_str.startswith("datetime"):
81
+ simple_dtype = "datetime"
82
+ elif dtype_str == "object":
83
+ # Check if it's actually strings
84
+ if self.df[col].dtype == "object":
85
+ simple_dtype = "string"
86
+ else:
87
+ simple_dtype = "object"
88
+ else:
89
+ simple_dtype = "string"
90
+
91
+ column_info.append(
92
+ {"name": str(col), "dtype": dtype_str, "simple_dtype": simple_dtype}
93
+ )
94
+
95
+ # Store column info as JSON string
96
+ column_info_json = json.dumps(column_info)
97
+ group.attrs["column_info"] = column_info_json
98
+
99
+ except Exception as e:
100
+ # If DataFrame processing fails, store error information
101
+ group.attrs["error"] = f"Failed to process DataFrame: {str(e)}"
102
+ group.attrs["row_count"] = 0
103
+ group.attrs["column_count"] = 0
104
+ group.attrs["data_size"] = 0
105
+ group.attrs["column_info"] = "[]"
106
+ # Create empty array as placeholder
107
+ group.create_dataset(
108
+ "csv_data", data=np.array([], dtype=np.uint8), dtype=np.uint8
109
+ )