figpack 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of figpack might be problematic. Click here for more details.
- figpack/__init__.py +1 -1
- figpack/figpack-figure-dist/assets/{index-HXdk2TtM.js → index-CjiTpC6i.js} +90 -89
- figpack/figpack-figure-dist/index.html +1 -1
- figpack/spike_sorting/views/RasterPlot.py +237 -26
- figpack/spike_sorting/views/SpikeAmplitudes.py +321 -36
- figpack/views/DataFrame.py +109 -0
- figpack/views/Spectrogram.py +223 -0
- figpack/views/__init__.py +2 -0
- {figpack-0.2.6.dist-info → figpack-0.2.8.dist-info}/METADATA +2 -1
- {figpack-0.2.6.dist-info → figpack-0.2.8.dist-info}/RECORD +14 -12
- {figpack-0.2.6.dist-info → figpack-0.2.8.dist-info}/WHEEL +0 -0
- {figpack-0.2.6.dist-info → figpack-0.2.8.dist-info}/entry_points.txt +0 -0
- {figpack-0.2.6.dist-info → figpack-0.2.8.dist-info}/licenses/LICENSE +0 -0
- {figpack-0.2.6.dist-info → figpack-0.2.8.dist-info}/top_level.txt +0 -0
|
@@ -2,13 +2,15 @@
|
|
|
2
2
|
SpikeAmplitudes view for figpack - displays spike amplitudes over time
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import List
|
|
5
|
+
from typing import List
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import zarr
|
|
9
9
|
|
|
10
10
|
from ...core.figpack_view import FigpackView
|
|
11
11
|
from .SpikeAmplitudesItem import SpikeAmplitudesItem
|
|
12
|
+
from .UnitsTable import UnitsTable, UnitsTableColumn, UnitsTableRow
|
|
13
|
+
from ...views.Box import Box, LayoutItem
|
|
12
14
|
|
|
13
15
|
|
|
14
16
|
class SpikeAmplitudes(FigpackView):
|
|
@@ -22,8 +24,6 @@ class SpikeAmplitudes(FigpackView):
|
|
|
22
24
|
start_time_sec: float,
|
|
23
25
|
end_time_sec: float,
|
|
24
26
|
plots: List[SpikeAmplitudesItem],
|
|
25
|
-
hide_unit_selector: bool = False,
|
|
26
|
-
height: int = 500,
|
|
27
27
|
):
|
|
28
28
|
"""
|
|
29
29
|
Initialize a SpikeAmplitudes view
|
|
@@ -32,18 +32,89 @@ class SpikeAmplitudes(FigpackView):
|
|
|
32
32
|
start_time_sec: Start time of the view in seconds
|
|
33
33
|
end_time_sec: End time of the view in seconds
|
|
34
34
|
plots: List of SpikeAmplitudesItem objects
|
|
35
|
-
hide_unit_selector: Whether to hide the unit selector
|
|
36
|
-
height: Height of the view in pixels
|
|
37
35
|
"""
|
|
38
36
|
self.start_time_sec = start_time_sec
|
|
39
37
|
self.end_time_sec = end_time_sec
|
|
40
38
|
self.plots = plots
|
|
41
|
-
|
|
42
|
-
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def from_nwb_units_table(
|
|
42
|
+
nwb_url_or_path_or_h5py,
|
|
43
|
+
*,
|
|
44
|
+
units_path: str,
|
|
45
|
+
include_units_selector: bool = False,
|
|
46
|
+
):
|
|
47
|
+
if isinstance(nwb_url_or_path_or_h5py, str):
|
|
48
|
+
import lindi
|
|
49
|
+
|
|
50
|
+
f = lindi.LindiH5pyFile.from_hdf5_file(nwb_url_or_path_or_h5py)
|
|
51
|
+
else:
|
|
52
|
+
f = nwb_url_or_path_or_h5py
|
|
53
|
+
X = f[units_path]
|
|
54
|
+
spike_amplitudes = X["spike_amplitudes"]
|
|
55
|
+
# spike_amplitudes_index = X["spike_amplitudes_index"] # presumably the same as spike_times_index
|
|
56
|
+
spike_times = X["spike_times"]
|
|
57
|
+
spike_times_index = X["spike_times_index"]
|
|
58
|
+
id = X["id"]
|
|
59
|
+
plots = []
|
|
60
|
+
num_units = len(spike_times_index)
|
|
61
|
+
start_times = []
|
|
62
|
+
end_times = []
|
|
63
|
+
for unit_index in range(num_units):
|
|
64
|
+
unit_id = id[unit_index]
|
|
65
|
+
if unit_index > 0:
|
|
66
|
+
start_index = spike_times_index[unit_index - 1]
|
|
67
|
+
else:
|
|
68
|
+
start_index = 0
|
|
69
|
+
end_index = spike_times_index[unit_index]
|
|
70
|
+
unit_spike_amplitudes = spike_amplitudes[start_index:end_index]
|
|
71
|
+
unit_spike_times = spike_times[start_index:end_index]
|
|
72
|
+
if len(unit_spike_times) == 0:
|
|
73
|
+
continue
|
|
74
|
+
start_times.append(unit_spike_times[0])
|
|
75
|
+
end_times.append(unit_spike_times[-1])
|
|
76
|
+
plots.append(
|
|
77
|
+
SpikeAmplitudesItem(
|
|
78
|
+
unit_id=str(unit_id),
|
|
79
|
+
spike_times_sec=unit_spike_times,
|
|
80
|
+
spike_amplitudes=unit_spike_amplitudes,
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
view = SpikeAmplitudes(
|
|
84
|
+
start_time_sec=min(start_times),
|
|
85
|
+
end_time_sec=max(end_times),
|
|
86
|
+
plots=plots,
|
|
87
|
+
)
|
|
88
|
+
if include_units_selector:
|
|
89
|
+
columns: List[UnitsTableColumn] = [
|
|
90
|
+
UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
|
|
91
|
+
]
|
|
92
|
+
rows: List[UnitsTableRow] = []
|
|
93
|
+
for unit_id in id:
|
|
94
|
+
rows.append(
|
|
95
|
+
UnitsTableRow(
|
|
96
|
+
unit_id=str(unit_id),
|
|
97
|
+
values={},
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
units_table = UnitsTable(
|
|
101
|
+
columns=columns,
|
|
102
|
+
rows=rows,
|
|
103
|
+
)
|
|
104
|
+
layout = Box(
|
|
105
|
+
direction="horizontal",
|
|
106
|
+
items=[
|
|
107
|
+
LayoutItem(view=units_table, max_size=150, title="Units"),
|
|
108
|
+
LayoutItem(view=view, title="Spike Amplitudes"),
|
|
109
|
+
],
|
|
110
|
+
)
|
|
111
|
+
return layout
|
|
112
|
+
else:
|
|
113
|
+
return view
|
|
43
114
|
|
|
44
115
|
def _write_to_zarr_group(self, group: zarr.Group) -> None:
|
|
45
116
|
"""
|
|
46
|
-
Write the SpikeAmplitudes data to a Zarr group
|
|
117
|
+
Write the SpikeAmplitudes data to a Zarr group using unified storage format
|
|
47
118
|
|
|
48
119
|
Args:
|
|
49
120
|
group: Zarr group to write data into
|
|
@@ -54,36 +125,250 @@ class SpikeAmplitudes(FigpackView):
|
|
|
54
125
|
# Store view parameters
|
|
55
126
|
group.attrs["start_time_sec"] = self.start_time_sec
|
|
56
127
|
group.attrs["end_time_sec"] = self.end_time_sec
|
|
57
|
-
group.attrs["hide_unit_selector"] = self.hide_unit_selector
|
|
58
|
-
group.attrs["height"] = self.height
|
|
59
|
-
|
|
60
|
-
# Store the number of plots
|
|
61
|
-
group.attrs["num_plots"] = len(self.plots)
|
|
62
|
-
|
|
63
|
-
# Store metadata for each plot
|
|
64
|
-
plot_metadata = []
|
|
65
|
-
for i, plot in enumerate(self.plots):
|
|
66
|
-
plot_name = f"plot_{i}"
|
|
67
|
-
|
|
68
|
-
# Store metadata
|
|
69
|
-
metadata = {
|
|
70
|
-
"name": plot_name,
|
|
71
|
-
"unit_id": str(plot.unit_id),
|
|
72
|
-
"num_spikes": len(plot.spike_times_sec),
|
|
73
|
-
}
|
|
74
|
-
plot_metadata.append(metadata)
|
|
75
128
|
|
|
76
|
-
|
|
129
|
+
# Prepare unified data arrays
|
|
130
|
+
unified_data = self._prepare_unified_data()
|
|
131
|
+
|
|
132
|
+
if unified_data["total_spikes"] == 0:
|
|
133
|
+
# Handle empty data case
|
|
134
|
+
group.create_dataset("timestamps", data=np.array([], dtype=np.float32))
|
|
135
|
+
group.create_dataset("unit_indices", data=np.array([], dtype=np.uint16))
|
|
136
|
+
group.create_dataset("amplitudes", data=np.array([], dtype=np.float32))
|
|
137
|
+
group.create_dataset("reference_times", data=np.array([], dtype=np.float32))
|
|
77
138
|
group.create_dataset(
|
|
78
|
-
|
|
79
|
-
data=plot.spike_times_sec,
|
|
80
|
-
dtype=np.float32,
|
|
139
|
+
"reference_indices", data=np.array([], dtype=np.uint32)
|
|
81
140
|
)
|
|
82
|
-
group.
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
141
|
+
group.attrs["unit_ids"] = []
|
|
142
|
+
group.attrs["total_spikes"] = 0
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
chunks = (
|
|
146
|
+
(2_000_000,)
|
|
147
|
+
if unified_data["total_spikes"] > 2_000_000
|
|
148
|
+
else (len(unified_data["timestamps"]),)
|
|
149
|
+
)
|
|
150
|
+
# Store main data arrays
|
|
151
|
+
group.create_dataset(
|
|
152
|
+
"timestamps",
|
|
153
|
+
data=unified_data["timestamps"],
|
|
154
|
+
dtype=np.float32,
|
|
155
|
+
chunks=chunks,
|
|
156
|
+
)
|
|
157
|
+
group.create_dataset(
|
|
158
|
+
"unit_indices",
|
|
159
|
+
data=unified_data["unit_indices"],
|
|
160
|
+
dtype=np.uint16,
|
|
161
|
+
chunks=chunks,
|
|
162
|
+
)
|
|
163
|
+
group.create_dataset(
|
|
164
|
+
"amplitudes",
|
|
165
|
+
data=unified_data["amplitudes"],
|
|
166
|
+
dtype=np.float32,
|
|
167
|
+
chunks=chunks,
|
|
168
|
+
)
|
|
169
|
+
group.create_dataset(
|
|
170
|
+
"reference_times",
|
|
171
|
+
data=unified_data["reference_times"],
|
|
172
|
+
dtype=np.float32,
|
|
173
|
+
chunks=(len(unified_data["reference_times"]),),
|
|
174
|
+
)
|
|
175
|
+
group.create_dataset(
|
|
176
|
+
"reference_indices",
|
|
177
|
+
data=unified_data["reference_indices"],
|
|
178
|
+
dtype=np.uint32,
|
|
179
|
+
chunks=(len(unified_data["reference_indices"]),),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Store unit ID mapping
|
|
183
|
+
group.attrs["unit_ids"] = unified_data["unit_ids"]
|
|
184
|
+
group.attrs["total_spikes"] = unified_data["total_spikes"]
|
|
185
|
+
|
|
186
|
+
# Create subsampled data
|
|
187
|
+
subsampled_data = self._create_subsampled_data(
|
|
188
|
+
unified_data["timestamps"],
|
|
189
|
+
unified_data["unit_indices"],
|
|
190
|
+
unified_data["amplitudes"],
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
if subsampled_data:
|
|
194
|
+
subsampled_group = group.create_group("subsampled_data")
|
|
195
|
+
for factor_name, data in subsampled_data.items():
|
|
196
|
+
chunks = (
|
|
197
|
+
(2_000_000,)
|
|
198
|
+
if len(data["timestamps"]) > 2_000_000
|
|
199
|
+
else (len(data["timestamps"]),)
|
|
200
|
+
)
|
|
201
|
+
factor_group = subsampled_group.create_group(factor_name)
|
|
202
|
+
factor_group.create_dataset(
|
|
203
|
+
"timestamps",
|
|
204
|
+
data=data["timestamps"],
|
|
205
|
+
dtype=np.float32,
|
|
206
|
+
chunks=chunks,
|
|
207
|
+
)
|
|
208
|
+
factor_group.create_dataset(
|
|
209
|
+
"unit_indices",
|
|
210
|
+
data=data["unit_indices"],
|
|
211
|
+
dtype=np.uint16,
|
|
212
|
+
chunks=chunks,
|
|
213
|
+
)
|
|
214
|
+
factor_group.create_dataset(
|
|
215
|
+
"amplitudes",
|
|
216
|
+
data=data["amplitudes"],
|
|
217
|
+
dtype=np.float32,
|
|
218
|
+
chunks=chunks,
|
|
219
|
+
)
|
|
220
|
+
factor_group.create_dataset(
|
|
221
|
+
"reference_times",
|
|
222
|
+
data=data["reference_times"],
|
|
223
|
+
dtype=np.float32,
|
|
224
|
+
chunks=(len(data["reference_times"]),),
|
|
225
|
+
)
|
|
226
|
+
factor_group.create_dataset(
|
|
227
|
+
"reference_indices",
|
|
228
|
+
data=data["reference_indices"],
|
|
229
|
+
dtype=np.uint32,
|
|
230
|
+
chunks=(len(data["reference_indices"]),),
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def _prepare_unified_data(self) -> dict:
|
|
234
|
+
"""
|
|
235
|
+
Prepare unified data arrays from all plots
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
Dictionary containing unified arrays and metadata
|
|
239
|
+
"""
|
|
240
|
+
if not self.plots:
|
|
241
|
+
return {
|
|
242
|
+
"timestamps": np.array([], dtype=np.float32),
|
|
243
|
+
"unit_indices": np.array([], dtype=np.uint16),
|
|
244
|
+
"amplitudes": np.array([], dtype=np.float32),
|
|
245
|
+
"reference_times": np.array([], dtype=np.float32),
|
|
246
|
+
"reference_indices": np.array([], dtype=np.uint32),
|
|
247
|
+
"unit_ids": [],
|
|
248
|
+
"total_spikes": 0,
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
# Create unit ID mapping
|
|
252
|
+
unit_ids = [str(plot.unit_id) for plot in self.plots]
|
|
253
|
+
unit_id_to_index = {unit_id: i for i, unit_id in enumerate(unit_ids)}
|
|
254
|
+
|
|
255
|
+
# Collect all spikes with their unit indices
|
|
256
|
+
all_spikes = []
|
|
257
|
+
for plot in self.plots:
|
|
258
|
+
unit_index = unit_id_to_index[str(plot.unit_id)]
|
|
259
|
+
for time, amplitude in zip(plot.spike_times_sec, plot.spike_amplitudes):
|
|
260
|
+
all_spikes.append((float(time), unit_index, float(amplitude)))
|
|
261
|
+
|
|
262
|
+
if not all_spikes:
|
|
263
|
+
return {
|
|
264
|
+
"timestamps": np.array([], dtype=np.float32),
|
|
265
|
+
"unit_indices": np.array([], dtype=np.uint16),
|
|
266
|
+
"amplitudes": np.array([], dtype=np.float32),
|
|
267
|
+
"reference_times": np.array([], dtype=np.float32),
|
|
268
|
+
"reference_indices": np.array([], dtype=np.uint32),
|
|
269
|
+
"unit_ids": unit_ids,
|
|
270
|
+
"total_spikes": 0,
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
# Sort by timestamp
|
|
274
|
+
all_spikes.sort(key=lambda x: x[0])
|
|
275
|
+
|
|
276
|
+
# Extract sorted arrays
|
|
277
|
+
timestamps = np.array([spike[0] for spike in all_spikes], dtype=np.float32)
|
|
278
|
+
unit_indices = np.array([spike[1] for spike in all_spikes], dtype=np.uint16)
|
|
279
|
+
amplitudes = np.array([spike[2] for spike in all_spikes], dtype=np.float32)
|
|
280
|
+
|
|
281
|
+
# Generate reference arrays
|
|
282
|
+
reference_times, reference_indices = self._generate_reference_arrays(timestamps)
|
|
283
|
+
|
|
284
|
+
return {
|
|
285
|
+
"timestamps": timestamps,
|
|
286
|
+
"unit_indices": unit_indices,
|
|
287
|
+
"amplitudes": amplitudes,
|
|
288
|
+
"reference_times": reference_times,
|
|
289
|
+
"reference_indices": reference_indices,
|
|
290
|
+
"unit_ids": unit_ids,
|
|
291
|
+
"total_spikes": len(all_spikes),
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
def _generate_reference_arrays(
|
|
295
|
+
self, timestamps: np.ndarray, interval_sec: float = 1.0
|
|
296
|
+
) -> tuple:
|
|
297
|
+
"""
|
|
298
|
+
Generate reference arrays using actual timestamps from the data
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
timestamps: Sorted array of timestamps
|
|
302
|
+
interval_sec: Minimum interval between reference points
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
Tuple of (reference_times, reference_indices)
|
|
306
|
+
"""
|
|
307
|
+
if len(timestamps) == 0:
|
|
308
|
+
return np.array([], dtype=np.float32), np.array([], dtype=np.uint32)
|
|
309
|
+
|
|
310
|
+
reference_times = []
|
|
311
|
+
reference_indices = []
|
|
312
|
+
|
|
313
|
+
current_ref_time = timestamps[0]
|
|
314
|
+
reference_times.append(current_ref_time)
|
|
315
|
+
reference_indices.append(0)
|
|
316
|
+
|
|
317
|
+
# Find the next reference point at least interval_sec later
|
|
318
|
+
for i, timestamp in enumerate(timestamps):
|
|
319
|
+
if timestamp >= current_ref_time + interval_sec:
|
|
320
|
+
reference_times.append(timestamp)
|
|
321
|
+
reference_indices.append(i)
|
|
322
|
+
current_ref_time = timestamp
|
|
323
|
+
|
|
324
|
+
return np.array(reference_times, dtype=np.float32), np.array(
|
|
325
|
+
reference_indices, dtype=np.uint32
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
def _create_subsampled_data(
|
|
329
|
+
self, timestamps: np.ndarray, unit_indices: np.ndarray, amplitudes: np.ndarray
|
|
330
|
+
) -> dict:
|
|
331
|
+
"""
|
|
332
|
+
Create subsampled data with geometric progression factors
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
timestamps: Original timestamps array
|
|
336
|
+
unit_indices: Original unit indices array
|
|
337
|
+
amplitudes: Original amplitudes array
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
Dictionary of subsampled data by factor
|
|
341
|
+
"""
|
|
342
|
+
subsampled_data = {}
|
|
343
|
+
factor = 4
|
|
344
|
+
current_timestamps = timestamps
|
|
345
|
+
current_unit_indices = unit_indices
|
|
346
|
+
current_amplitudes = amplitudes
|
|
347
|
+
|
|
348
|
+
while len(current_timestamps) >= 500000:
|
|
349
|
+
# Create subsampled version by taking every Nth spike
|
|
350
|
+
subsampled_indices = np.arange(0, len(current_timestamps), factor)
|
|
351
|
+
subsampled_timestamps = current_timestamps[subsampled_indices]
|
|
352
|
+
subsampled_unit_indices = current_unit_indices[subsampled_indices]
|
|
353
|
+
subsampled_amplitudes = current_amplitudes[subsampled_indices]
|
|
354
|
+
|
|
355
|
+
# Generate reference arrays for this subsampled level
|
|
356
|
+
ref_times, ref_indices = self._generate_reference_arrays(
|
|
357
|
+
subsampled_timestamps
|
|
86
358
|
)
|
|
87
359
|
|
|
88
|
-
|
|
89
|
-
|
|
360
|
+
subsampled_data[f"factor_{factor}"] = {
|
|
361
|
+
"timestamps": subsampled_timestamps,
|
|
362
|
+
"unit_indices": subsampled_unit_indices,
|
|
363
|
+
"amplitudes": subsampled_amplitudes,
|
|
364
|
+
"reference_times": ref_times,
|
|
365
|
+
"reference_indices": ref_indices,
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
# Prepare for next iteration
|
|
369
|
+
current_timestamps = subsampled_timestamps
|
|
370
|
+
current_unit_indices = subsampled_unit_indices
|
|
371
|
+
current_amplitudes = subsampled_amplitudes
|
|
372
|
+
factor *= 4 # Geometric progression: 4, 16, 64, 256, ...
|
|
373
|
+
|
|
374
|
+
return subsampled_data
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DataFrame view for figpack - displays pandas DataFrames as interactive tables
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from typing import Any, Dict, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import zarr
|
|
11
|
+
|
|
12
|
+
from ..core.figpack_view import FigpackView
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DataFrame(FigpackView):
|
|
16
|
+
"""
|
|
17
|
+
A DataFrame visualization component for displaying pandas DataFrames as interactive tables
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, df: pd.DataFrame):
|
|
21
|
+
"""
|
|
22
|
+
Initialize a DataFrame view
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
df: The pandas DataFrame to display
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
ValueError: If df is not a pandas DataFrame
|
|
29
|
+
"""
|
|
30
|
+
if not isinstance(df, pd.DataFrame):
|
|
31
|
+
raise ValueError("df must be a pandas DataFrame")
|
|
32
|
+
|
|
33
|
+
self.df = df
|
|
34
|
+
|
|
35
|
+
def _write_to_zarr_group(self, group: zarr.Group) -> None:
|
|
36
|
+
"""
|
|
37
|
+
Write the DataFrame data to a Zarr group
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
group: Zarr group to write data into
|
|
41
|
+
"""
|
|
42
|
+
# Set the view type
|
|
43
|
+
group.attrs["view_type"] = "DataFrame"
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
# Convert DataFrame to CSV string
|
|
47
|
+
csv_string = self.df.to_csv(index=False)
|
|
48
|
+
|
|
49
|
+
# Convert CSV string to bytes and store in numpy array
|
|
50
|
+
csv_bytes = csv_string.encode("utf-8")
|
|
51
|
+
csv_array = np.frombuffer(csv_bytes, dtype=np.uint8)
|
|
52
|
+
|
|
53
|
+
# Store the CSV data as compressed array
|
|
54
|
+
group.create_dataset(
|
|
55
|
+
"csv_data",
|
|
56
|
+
data=csv_array,
|
|
57
|
+
dtype=np.uint8,
|
|
58
|
+
chunks=True,
|
|
59
|
+
compressor=zarr.Blosc(
|
|
60
|
+
cname="zstd", clevel=3, shuffle=zarr.Blosc.SHUFFLE
|
|
61
|
+
),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Store metadata about the DataFrame
|
|
65
|
+
group.attrs["data_size"] = len(csv_bytes)
|
|
66
|
+
group.attrs["row_count"] = len(self.df)
|
|
67
|
+
group.attrs["column_count"] = len(self.df.columns)
|
|
68
|
+
|
|
69
|
+
# Store column information
|
|
70
|
+
column_info = []
|
|
71
|
+
for col in self.df.columns:
|
|
72
|
+
dtype_str = str(self.df[col].dtype)
|
|
73
|
+
# Simplify dtype names for frontend
|
|
74
|
+
if dtype_str.startswith("int"):
|
|
75
|
+
simple_dtype = "integer"
|
|
76
|
+
elif dtype_str.startswith("float"):
|
|
77
|
+
simple_dtype = "float"
|
|
78
|
+
elif dtype_str.startswith("bool"):
|
|
79
|
+
simple_dtype = "boolean"
|
|
80
|
+
elif dtype_str.startswith("datetime"):
|
|
81
|
+
simple_dtype = "datetime"
|
|
82
|
+
elif dtype_str == "object":
|
|
83
|
+
# Check if it's actually strings
|
|
84
|
+
if self.df[col].dtype == "object":
|
|
85
|
+
simple_dtype = "string"
|
|
86
|
+
else:
|
|
87
|
+
simple_dtype = "object"
|
|
88
|
+
else:
|
|
89
|
+
simple_dtype = "string"
|
|
90
|
+
|
|
91
|
+
column_info.append(
|
|
92
|
+
{"name": str(col), "dtype": dtype_str, "simple_dtype": simple_dtype}
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Store column info as JSON string
|
|
96
|
+
column_info_json = json.dumps(column_info)
|
|
97
|
+
group.attrs["column_info"] = column_info_json
|
|
98
|
+
|
|
99
|
+
except Exception as e:
|
|
100
|
+
# If DataFrame processing fails, store error information
|
|
101
|
+
group.attrs["error"] = f"Failed to process DataFrame: {str(e)}"
|
|
102
|
+
group.attrs["row_count"] = 0
|
|
103
|
+
group.attrs["column_count"] = 0
|
|
104
|
+
group.attrs["data_size"] = 0
|
|
105
|
+
group.attrs["column_info"] = "[]"
|
|
106
|
+
# Create empty array as placeholder
|
|
107
|
+
group.create_dataset(
|
|
108
|
+
"csv_data", data=np.array([], dtype=np.uint8), dtype=np.uint8
|
|
109
|
+
)
|