figpack 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- figpack/__init__.py +4 -3
- figpack/core/__init__.py +2 -2
- figpack/core/_bundle_utils.py +80 -33
- figpack/core/_view_figure.py +1 -1
- figpack/core/extension_view.py +9 -27
- figpack/core/figpack_extension.py +0 -71
- figpack/core/figpack_view.py +2 -2
- figpack/core/zarr.py +61 -0
- figpack/figpack-figure-dist/assets/index-DBwmtEpB.js +91 -0
- figpack/figpack-figure-dist/assets/{index-D9a3K6eW.css → index-DHWczh-Q.css} +1 -1
- figpack/figpack-figure-dist/index.html +2 -2
- figpack/views/Box.py +2 -3
- figpack/views/DataFrame.py +6 -12
- figpack/views/Gallery.py +2 -3
- figpack/views/Image.py +3 -9
- figpack/views/Markdown.py +3 -4
- figpack/views/MatplotlibFigure.py +2 -11
- figpack/views/MultiChannelTimeseries.py +2 -6
- figpack/views/PlotlyExtension/PlotlyExtension.py +8 -60
- figpack/views/PlotlyExtension/_plotly_extension.py +46 -0
- figpack/views/PlotlyExtension/plotly_view.js +84 -80
- figpack/views/Spectrogram.py +2 -7
- figpack/views/Splitter.py +2 -3
- figpack/views/TabLayout.py +2 -3
- figpack/views/TimeseriesGraph.py +6 -10
- figpack/views/__init__.py +1 -0
- {figpack-0.2.15.dist-info → figpack-0.2.17.dist-info}/METADATA +21 -2
- figpack-0.2.17.dist-info/RECORD +43 -0
- figpack/figpack-figure-dist/assets/index-DtOnN02w.js +0 -846
- figpack/franklab/__init__.py +0 -5
- figpack/franklab/views/TrackAnimation.py +0 -153
- figpack/franklab/views/__init__.py +0 -9
- figpack/spike_sorting/__init__.py +0 -5
- figpack/spike_sorting/views/AutocorrelogramItem.py +0 -32
- figpack/spike_sorting/views/Autocorrelograms.py +0 -118
- figpack/spike_sorting/views/AverageWaveforms.py +0 -147
- figpack/spike_sorting/views/CrossCorrelogramItem.py +0 -35
- figpack/spike_sorting/views/CrossCorrelograms.py +0 -132
- figpack/spike_sorting/views/RasterPlot.py +0 -288
- figpack/spike_sorting/views/RasterPlotItem.py +0 -28
- figpack/spike_sorting/views/SpikeAmplitudes.py +0 -374
- figpack/spike_sorting/views/SpikeAmplitudesItem.py +0 -38
- figpack/spike_sorting/views/UnitMetricsGraph.py +0 -129
- figpack/spike_sorting/views/UnitSimilarityScore.py +0 -40
- figpack/spike_sorting/views/UnitsTable.py +0 -89
- figpack/spike_sorting/views/UnitsTableColumn.py +0 -40
- figpack/spike_sorting/views/UnitsTableRow.py +0 -36
- figpack/spike_sorting/views/__init__.py +0 -41
- figpack-0.2.15.dist-info/RECORD +0 -60
- {figpack-0.2.15.dist-info → figpack-0.2.17.dist-info}/WHEEL +0 -0
- {figpack-0.2.15.dist-info → figpack-0.2.17.dist-info}/entry_points.txt +0 -0
- {figpack-0.2.15.dist-info → figpack-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {figpack-0.2.15.dist-info → figpack-0.2.17.dist-info}/top_level.txt +0 -0
figpack/franklab/__init__.py
DELETED
|
@@ -1,153 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
TrackAnimation view for figpack - displays animated tracking data
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import Optional
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
import zarr
|
|
9
|
-
|
|
10
|
-
from ...core.figpack_view import FigpackView
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class TrackAnimation(FigpackView):
|
|
14
|
-
"""
|
|
15
|
-
A track animation visualization component for displaying animal tracking data
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
def __init__(
|
|
19
|
-
self,
|
|
20
|
-
*,
|
|
21
|
-
bin_height: float,
|
|
22
|
-
bin_width: float,
|
|
23
|
-
frame_bounds: np.ndarray,
|
|
24
|
-
locations: np.ndarray,
|
|
25
|
-
values: np.ndarray,
|
|
26
|
-
xcount: int,
|
|
27
|
-
ycount: int,
|
|
28
|
-
xmin: float,
|
|
29
|
-
ymin: float,
|
|
30
|
-
head_direction: np.ndarray,
|
|
31
|
-
positions: np.ndarray,
|
|
32
|
-
timestamps: np.ndarray,
|
|
33
|
-
track_bin_corners: np.ndarray,
|
|
34
|
-
sampling_frequency_hz: float,
|
|
35
|
-
timestamp_start: float,
|
|
36
|
-
total_recording_frame_length: int,
|
|
37
|
-
track_bin_height: float,
|
|
38
|
-
track_bin_width: float,
|
|
39
|
-
xmax: float,
|
|
40
|
-
ymax: float,
|
|
41
|
-
):
|
|
42
|
-
"""
|
|
43
|
-
Initialize a TrackAnimation view
|
|
44
|
-
|
|
45
|
-
Args:
|
|
46
|
-
bin_height: Height of spatial bins
|
|
47
|
-
bin_width: Width of spatial bins
|
|
48
|
-
frame_bounds: Array of frame boundaries
|
|
49
|
-
locations: Array of spatial locations
|
|
50
|
-
values: Array of values at each location
|
|
51
|
-
xcount: Number of bins in x direction
|
|
52
|
-
ycount: Number of bins in y direction
|
|
53
|
-
xmin: Minimum x coordinate
|
|
54
|
-
ymin: Minimum y coordinate
|
|
55
|
-
head_direction: Array of head direction angles
|
|
56
|
-
positions: Array of position coordinates (2D)
|
|
57
|
-
timestamps: Array of timestamps
|
|
58
|
-
track_bin_corners: Array of track bin corner coordinates
|
|
59
|
-
sampling_frequency_hz: Sampling frequency in Hz
|
|
60
|
-
timestamp_start: Start timestamp
|
|
61
|
-
total_recording_frame_length: Total number of frames
|
|
62
|
-
track_bin_height: Height of track bins
|
|
63
|
-
track_bin_width: Width of track bins
|
|
64
|
-
xmax: Maximum x coordinate
|
|
65
|
-
ymax: Maximum y coordinate
|
|
66
|
-
"""
|
|
67
|
-
# Validate input arrays
|
|
68
|
-
assert isinstance(
|
|
69
|
-
frame_bounds, np.ndarray
|
|
70
|
-
), "frame_bounds must be a numpy array"
|
|
71
|
-
assert isinstance(locations, np.ndarray), "locations must be a numpy array"
|
|
72
|
-
assert isinstance(values, np.ndarray), "values must be a numpy array"
|
|
73
|
-
assert isinstance(
|
|
74
|
-
head_direction, np.ndarray
|
|
75
|
-
), "head_direction must be a numpy array"
|
|
76
|
-
assert isinstance(positions, np.ndarray), "positions must be a numpy array"
|
|
77
|
-
assert isinstance(timestamps, np.ndarray), "timestamps must be a numpy array"
|
|
78
|
-
assert isinstance(
|
|
79
|
-
track_bin_corners, np.ndarray
|
|
80
|
-
), "track_bin_corners must be a numpy array"
|
|
81
|
-
|
|
82
|
-
assert len(locations) == len(
|
|
83
|
-
values
|
|
84
|
-
), "locations and values must have same length"
|
|
85
|
-
assert len(head_direction) == len(
|
|
86
|
-
timestamps
|
|
87
|
-
), "head_direction and timestamps must have same length"
|
|
88
|
-
assert positions.shape[1] == len(
|
|
89
|
-
timestamps
|
|
90
|
-
), "positions second dimension must match timestamps length"
|
|
91
|
-
assert positions.shape[0] == 2, "positions must have shape (2, N)"
|
|
92
|
-
|
|
93
|
-
# Store spatial binning parameters
|
|
94
|
-
self.bin_height = bin_height
|
|
95
|
-
self.bin_width = bin_width
|
|
96
|
-
self.xcount = xcount
|
|
97
|
-
self.ycount = ycount
|
|
98
|
-
self.xmin = xmin
|
|
99
|
-
self.ymin = ymin
|
|
100
|
-
self.xmax = xmax
|
|
101
|
-
self.ymax = ymax
|
|
102
|
-
|
|
103
|
-
# Store arrays
|
|
104
|
-
self.frame_bounds = frame_bounds
|
|
105
|
-
self.locations = locations
|
|
106
|
-
self.values = values
|
|
107
|
-
self.head_direction = head_direction
|
|
108
|
-
self.positions = positions
|
|
109
|
-
self.timestamps = timestamps
|
|
110
|
-
self.track_bin_corners = track_bin_corners
|
|
111
|
-
|
|
112
|
-
# Store metadata
|
|
113
|
-
self.sampling_frequency_hz = sampling_frequency_hz
|
|
114
|
-
self.timestamp_start = timestamp_start
|
|
115
|
-
self.total_recording_frame_length = total_recording_frame_length
|
|
116
|
-
self.track_bin_height = track_bin_height
|
|
117
|
-
self.track_bin_width = track_bin_width
|
|
118
|
-
|
|
119
|
-
def _write_to_zarr_group(self, group: zarr.Group) -> None:
|
|
120
|
-
"""
|
|
121
|
-
Write the track animation data to a Zarr group
|
|
122
|
-
|
|
123
|
-
Args:
|
|
124
|
-
group: Zarr group to write data into
|
|
125
|
-
"""
|
|
126
|
-
# Set view type
|
|
127
|
-
group.attrs["view_type"] = "TrackAnimation"
|
|
128
|
-
|
|
129
|
-
# Store spatial binning parameters
|
|
130
|
-
group.attrs["bin_height"] = self.bin_height
|
|
131
|
-
group.attrs["bin_width"] = self.bin_width
|
|
132
|
-
group.attrs["xcount"] = self.xcount
|
|
133
|
-
group.attrs["ycount"] = self.ycount
|
|
134
|
-
group.attrs["xmin"] = self.xmin
|
|
135
|
-
group.attrs["ymin"] = self.ymin
|
|
136
|
-
group.attrs["xmax"] = self.xmax
|
|
137
|
-
group.attrs["ymax"] = self.ymax
|
|
138
|
-
|
|
139
|
-
# Store metadata
|
|
140
|
-
group.attrs["sampling_frequency_hz"] = self.sampling_frequency_hz
|
|
141
|
-
group.attrs["timestamp_start"] = self.timestamp_start
|
|
142
|
-
group.attrs["total_recording_frame_length"] = self.total_recording_frame_length
|
|
143
|
-
group.attrs["track_bin_height"] = self.track_bin_height
|
|
144
|
-
group.attrs["track_bin_width"] = self.track_bin_width
|
|
145
|
-
|
|
146
|
-
# Store arrays as datasets
|
|
147
|
-
group.create_dataset("frame_bounds", data=self.frame_bounds)
|
|
148
|
-
group.create_dataset("locations", data=self.locations)
|
|
149
|
-
group.create_dataset("values", data=self.values)
|
|
150
|
-
group.create_dataset("head_direction", data=self.head_direction)
|
|
151
|
-
group.create_dataset("positions", data=self.positions)
|
|
152
|
-
group.create_dataset("timestamps", data=self.timestamps)
|
|
153
|
-
group.create_dataset("track_bin_corners", data=self.track_bin_corners)
|
|
@@ -1,32 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
AutocorrelogramItem for spike sorting views
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import Union
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class AutocorrelogramItem:
|
|
11
|
-
"""
|
|
12
|
-
Represents a single autocorrelogram for a unit
|
|
13
|
-
"""
|
|
14
|
-
|
|
15
|
-
def __init__(
|
|
16
|
-
self,
|
|
17
|
-
*,
|
|
18
|
-
unit_id: Union[str, int],
|
|
19
|
-
bin_edges_sec: np.ndarray,
|
|
20
|
-
bin_counts: np.ndarray,
|
|
21
|
-
):
|
|
22
|
-
"""
|
|
23
|
-
Initialize an AutocorrelogramItem
|
|
24
|
-
|
|
25
|
-
Args:
|
|
26
|
-
unit_id: Identifier for the unit
|
|
27
|
-
bin_edges_sec: Array of bin edges in seconds
|
|
28
|
-
bin_counts: Array of bin counts
|
|
29
|
-
"""
|
|
30
|
-
self.unit_id = unit_id
|
|
31
|
-
self.bin_edges_sec = np.array(bin_edges_sec, dtype=np.float32)
|
|
32
|
-
self.bin_counts = np.array(bin_counts, dtype=np.int32)
|
|
@@ -1,118 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Autocorrelograms view for figpack - displays multiple autocorrelograms
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import List, Optional
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
import zarr
|
|
9
|
-
|
|
10
|
-
from ...core.figpack_view import FigpackView
|
|
11
|
-
from .AutocorrelogramItem import AutocorrelogramItem
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class Autocorrelograms(FigpackView):
|
|
15
|
-
"""
|
|
16
|
-
A view that displays multiple autocorrelograms for spike sorting analysis
|
|
17
|
-
"""
|
|
18
|
-
|
|
19
|
-
def __init__(
|
|
20
|
-
self,
|
|
21
|
-
*,
|
|
22
|
-
autocorrelograms: List[AutocorrelogramItem],
|
|
23
|
-
):
|
|
24
|
-
"""
|
|
25
|
-
Initialize an Autocorrelograms view
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
autocorrelograms: List of AutocorrelogramItem objects
|
|
29
|
-
"""
|
|
30
|
-
self.autocorrelograms = autocorrelograms
|
|
31
|
-
|
|
32
|
-
@staticmethod
|
|
33
|
-
def from_sorting(sorting):
|
|
34
|
-
import spikeinterface as si
|
|
35
|
-
import spikeinterface.widgets as sw
|
|
36
|
-
|
|
37
|
-
assert isinstance(sorting, si.BaseSorting), "Input must be a BaseSorting object"
|
|
38
|
-
W = sw.plot_autocorrelograms(sorting)
|
|
39
|
-
return Autocorrelograms.from_spikeinterface_widget(W)
|
|
40
|
-
|
|
41
|
-
@staticmethod
|
|
42
|
-
def from_spikeinterface_widget(W):
|
|
43
|
-
from spikeinterface.widgets.base import to_attr
|
|
44
|
-
from spikeinterface.widgets.utils_sortingview import make_serializable
|
|
45
|
-
|
|
46
|
-
from .AutocorrelogramItem import AutocorrelogramItem
|
|
47
|
-
|
|
48
|
-
data_plot = W.data_plot
|
|
49
|
-
|
|
50
|
-
dp = to_attr(data_plot)
|
|
51
|
-
|
|
52
|
-
unit_ids = make_serializable(dp.unit_ids)
|
|
53
|
-
|
|
54
|
-
ac_items = []
|
|
55
|
-
for i in range(len(unit_ids)):
|
|
56
|
-
for j in range(i, len(unit_ids)):
|
|
57
|
-
if i == j:
|
|
58
|
-
ac_items.append(
|
|
59
|
-
AutocorrelogramItem(
|
|
60
|
-
unit_id=unit_ids[i],
|
|
61
|
-
bin_edges_sec=(dp.bins / 1000.0).astype("float32"),
|
|
62
|
-
bin_counts=dp.correlograms[i, j].astype("int32"),
|
|
63
|
-
)
|
|
64
|
-
)
|
|
65
|
-
|
|
66
|
-
view = Autocorrelograms(autocorrelograms=ac_items)
|
|
67
|
-
return view
|
|
68
|
-
|
|
69
|
-
def _write_to_zarr_group(self, group: zarr.Group) -> None:
|
|
70
|
-
"""
|
|
71
|
-
Write the Autocorrelograms data to a Zarr group
|
|
72
|
-
|
|
73
|
-
Args:
|
|
74
|
-
group: Zarr group to write data into
|
|
75
|
-
"""
|
|
76
|
-
# Set the view type
|
|
77
|
-
group.attrs["view_type"] = "Autocorrelograms"
|
|
78
|
-
|
|
79
|
-
# Store the number of autocorrelograms
|
|
80
|
-
num_autocorrelograms = len(self.autocorrelograms)
|
|
81
|
-
group.attrs["num_autocorrelograms"] = num_autocorrelograms
|
|
82
|
-
|
|
83
|
-
if num_autocorrelograms == 0:
|
|
84
|
-
return
|
|
85
|
-
|
|
86
|
-
# Get dimensions from first autocorrelogram
|
|
87
|
-
num_bins = len(self.autocorrelograms[0].bin_counts)
|
|
88
|
-
|
|
89
|
-
# Store bin edges (same for all autocorrelograms)
|
|
90
|
-
group.create_dataset(
|
|
91
|
-
"bin_edges_sec",
|
|
92
|
-
data=self.autocorrelograms[0].bin_edges_sec,
|
|
93
|
-
dtype=np.float32,
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
# Create 2D array for all bin counts
|
|
97
|
-
bin_counts = np.zeros((num_autocorrelograms, num_bins), dtype=np.int32)
|
|
98
|
-
|
|
99
|
-
# Store metadata for each autocorrelogram and populate bin counts array
|
|
100
|
-
autocorrelogram_metadata = []
|
|
101
|
-
for i, autocorr in enumerate(self.autocorrelograms):
|
|
102
|
-
metadata = {
|
|
103
|
-
"unit_id": str(autocorr.unit_id),
|
|
104
|
-
"index": i, # Store index to map to bin_counts array
|
|
105
|
-
"num_bins": num_bins,
|
|
106
|
-
}
|
|
107
|
-
autocorrelogram_metadata.append(metadata)
|
|
108
|
-
bin_counts[i] = autocorr.bin_counts
|
|
109
|
-
|
|
110
|
-
# Store the bin counts as a single 2D dataset
|
|
111
|
-
group.create_dataset(
|
|
112
|
-
"bin_counts",
|
|
113
|
-
data=bin_counts,
|
|
114
|
-
dtype=np.int32,
|
|
115
|
-
)
|
|
116
|
-
|
|
117
|
-
# Store the autocorrelogram metadata
|
|
118
|
-
group.attrs["autocorrelograms"] = autocorrelogram_metadata
|
|
@@ -1,147 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
AverageWaveforms view for figpack - displays multiple average waveforms
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import List, Optional, Union
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
import zarr
|
|
9
|
-
|
|
10
|
-
from ...core.figpack_view import FigpackView
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class AverageWaveformItem:
|
|
14
|
-
"""
|
|
15
|
-
Represents a single average waveform for a unit
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
def __init__(
|
|
19
|
-
self,
|
|
20
|
-
*,
|
|
21
|
-
unit_id: Union[str, int],
|
|
22
|
-
channel_ids: List[Union[str, int]],
|
|
23
|
-
waveform: np.ndarray,
|
|
24
|
-
waveform_std_dev: Optional[np.ndarray] = None,
|
|
25
|
-
waveform_percentiles: Optional[List[np.ndarray]] = None,
|
|
26
|
-
):
|
|
27
|
-
"""
|
|
28
|
-
Initialize an AverageWaveformItem
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
unit_id: Identifier for the unit
|
|
32
|
-
channel_ids: List of channel identifiers
|
|
33
|
-
waveform: 2D numpy array representing the average waveform (num_samples x num_channels)
|
|
34
|
-
waveform_std_dev: Optional 2D numpy array representing the standard deviation of the waveform
|
|
35
|
-
waveform_percentiles: Optional list of 2D numpy arrays representing percentiles of the waveform
|
|
36
|
-
"""
|
|
37
|
-
self.unit_id = unit_id
|
|
38
|
-
self.channel_ids = channel_ids
|
|
39
|
-
self.waveform = np.array(waveform, dtype=np.float32)
|
|
40
|
-
self.waveform_std_dev = (
|
|
41
|
-
np.array(waveform_std_dev, dtype=np.float32)
|
|
42
|
-
if waveform_std_dev is not None
|
|
43
|
-
else None
|
|
44
|
-
)
|
|
45
|
-
if waveform_percentiles is not None:
|
|
46
|
-
self.waveform_percentiles = [
|
|
47
|
-
np.array(p, dtype=np.float32) for p in waveform_percentiles
|
|
48
|
-
]
|
|
49
|
-
else:
|
|
50
|
-
self.waveform_percentiles = None
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class AverageWaveforms(FigpackView):
|
|
54
|
-
"""
|
|
55
|
-
A view that displays multiple average waveforms for spike sorting analysis
|
|
56
|
-
"""
|
|
57
|
-
|
|
58
|
-
def __init__(self, *, average_waveforms: List[AverageWaveformItem]):
|
|
59
|
-
"""
|
|
60
|
-
Initialize an AverageWaveforms view
|
|
61
|
-
|
|
62
|
-
Args:
|
|
63
|
-
average_waveforms: List of AverageWaveformItem objects
|
|
64
|
-
"""
|
|
65
|
-
self.average_waveforms = average_waveforms
|
|
66
|
-
|
|
67
|
-
@staticmethod
|
|
68
|
-
def from_sorting_analyzer(sorting_analyzer):
|
|
69
|
-
sorting_analyzer.compute(
|
|
70
|
-
["random_spikes", "waveforms", "templates", "noise_levels"]
|
|
71
|
-
)
|
|
72
|
-
ext_templates = sorting_analyzer.get_extension("templates")
|
|
73
|
-
# shape is num_units, num_samples, num_channels
|
|
74
|
-
av_templates = ext_templates.get_data(operator="average")
|
|
75
|
-
|
|
76
|
-
ext_noise_levels = sorting_analyzer.get_extension("noise_levels")
|
|
77
|
-
noise_levels = ext_noise_levels.get_data()
|
|
78
|
-
|
|
79
|
-
waveform_std_dev = np.zeros(
|
|
80
|
-
(av_templates.shape[1], av_templates.shape[2]), dtype=np.float32
|
|
81
|
-
)
|
|
82
|
-
for i in range(av_templates.shape[2]):
|
|
83
|
-
waveform_std_dev[:, i] = noise_levels[i]
|
|
84
|
-
|
|
85
|
-
average_waveform_items = []
|
|
86
|
-
for i, unit_id in enumerate(sorting_analyzer.unit_ids):
|
|
87
|
-
waveform = av_templates[i]
|
|
88
|
-
channel_ids = list(sorting_analyzer.recording.get_channel_ids())
|
|
89
|
-
average_waveform_items.append(
|
|
90
|
-
AverageWaveformItem(
|
|
91
|
-
unit_id=unit_id,
|
|
92
|
-
waveform=waveform,
|
|
93
|
-
channel_ids=channel_ids,
|
|
94
|
-
waveform_std_dev=waveform_std_dev,
|
|
95
|
-
)
|
|
96
|
-
)
|
|
97
|
-
view = AverageWaveforms(average_waveforms=average_waveform_items)
|
|
98
|
-
return view
|
|
99
|
-
|
|
100
|
-
def _write_to_zarr_group(self, group: zarr.Group) -> None:
|
|
101
|
-
"""
|
|
102
|
-
Write the AverageWaveforms data to a Zarr group
|
|
103
|
-
|
|
104
|
-
Args:
|
|
105
|
-
group: Zarr group to write data into
|
|
106
|
-
"""
|
|
107
|
-
# Set the view type
|
|
108
|
-
group.attrs["view_type"] = "AverageWaveforms"
|
|
109
|
-
|
|
110
|
-
# Store the number of average waveforms
|
|
111
|
-
group.attrs["num_average_waveforms"] = len(self.average_waveforms)
|
|
112
|
-
|
|
113
|
-
# Store metadata for each average waveform
|
|
114
|
-
average_waveform_metadata = []
|
|
115
|
-
for i, waveform in enumerate(self.average_waveforms):
|
|
116
|
-
waveform_name = f"waveform_{i}"
|
|
117
|
-
|
|
118
|
-
# Store metadata
|
|
119
|
-
metadata = {
|
|
120
|
-
"name": waveform_name,
|
|
121
|
-
"unit_id": str(waveform.unit_id),
|
|
122
|
-
"channel_ids": [str(ch) for ch in waveform.channel_ids],
|
|
123
|
-
}
|
|
124
|
-
average_waveform_metadata.append(metadata)
|
|
125
|
-
|
|
126
|
-
# Create arrays for this average waveform
|
|
127
|
-
group.create_dataset(
|
|
128
|
-
f"{waveform_name}/waveform",
|
|
129
|
-
data=waveform.waveform,
|
|
130
|
-
dtype=waveform.waveform.dtype,
|
|
131
|
-
)
|
|
132
|
-
if waveform.waveform_std_dev is not None:
|
|
133
|
-
group.create_dataset(
|
|
134
|
-
f"{waveform_name}/waveform_std_dev",
|
|
135
|
-
data=waveform.waveform_std_dev,
|
|
136
|
-
dtype=waveform.waveform_std_dev.dtype,
|
|
137
|
-
)
|
|
138
|
-
if waveform.waveform_percentiles is not None:
|
|
139
|
-
for j, p in enumerate(waveform.waveform_percentiles):
|
|
140
|
-
group.create_dataset(
|
|
141
|
-
f"{waveform_name}/waveform_percentile_{j}",
|
|
142
|
-
data=p,
|
|
143
|
-
dtype=p.dtype,
|
|
144
|
-
)
|
|
145
|
-
|
|
146
|
-
# Store the average waveform metadata
|
|
147
|
-
group.attrs["average_waveforms"] = average_waveform_metadata
|
|
@@ -1,35 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
CrossCorrelogramItem for spike sorting views
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import Union
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class CrossCorrelogramItem:
|
|
11
|
-
"""
|
|
12
|
-
Represents a single cross-correlogram between two units
|
|
13
|
-
"""
|
|
14
|
-
|
|
15
|
-
def __init__(
|
|
16
|
-
self,
|
|
17
|
-
*,
|
|
18
|
-
unit_id1: Union[str, int],
|
|
19
|
-
unit_id2: Union[str, int],
|
|
20
|
-
bin_edges_sec: np.ndarray,
|
|
21
|
-
bin_counts: np.ndarray,
|
|
22
|
-
):
|
|
23
|
-
"""
|
|
24
|
-
Initialize a CrossCorrelogramItem
|
|
25
|
-
|
|
26
|
-
Args:
|
|
27
|
-
unit_id1: Identifier for the first unit
|
|
28
|
-
unit_id2: Identifier for the second unit
|
|
29
|
-
bin_edges_sec: Array of bin edges in seconds
|
|
30
|
-
bin_counts: Array of bin counts
|
|
31
|
-
"""
|
|
32
|
-
self.unit_id1 = unit_id1
|
|
33
|
-
self.unit_id2 = unit_id2
|
|
34
|
-
self.bin_edges_sec = np.array(bin_edges_sec, dtype=np.float32)
|
|
35
|
-
self.bin_counts = np.array(bin_counts, dtype=np.int32)
|
|
@@ -1,132 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
CrossCorrelograms view for figpack - displays multiple cross-correlograms
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from typing import List, Optional
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
import zarr
|
|
9
|
-
|
|
10
|
-
from ...core.figpack_view import FigpackView
|
|
11
|
-
from .CrossCorrelogramItem import CrossCorrelogramItem
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class CrossCorrelograms(FigpackView):
|
|
15
|
-
"""
|
|
16
|
-
A view that displays multiple cross-correlograms for spike sorting analysis
|
|
17
|
-
"""
|
|
18
|
-
|
|
19
|
-
def __init__(
|
|
20
|
-
self,
|
|
21
|
-
*,
|
|
22
|
-
cross_correlograms: List[CrossCorrelogramItem],
|
|
23
|
-
hide_unit_selector: Optional[bool] = False,
|
|
24
|
-
):
|
|
25
|
-
"""
|
|
26
|
-
Initialize a CrossCorrelograms view
|
|
27
|
-
|
|
28
|
-
Args:
|
|
29
|
-
cross_correlograms: List of CrossCorrelogramItem objects
|
|
30
|
-
hide_unit_selector: Whether to hide the unit selector widget
|
|
31
|
-
"""
|
|
32
|
-
self.cross_correlograms = cross_correlograms
|
|
33
|
-
self.hide_unit_selector = hide_unit_selector
|
|
34
|
-
|
|
35
|
-
@staticmethod
|
|
36
|
-
def from_sorting(sorting):
|
|
37
|
-
import spikeinterface as si
|
|
38
|
-
import spikeinterface.widgets as sw
|
|
39
|
-
|
|
40
|
-
assert isinstance(sorting, si.BaseSorting), "Input must be a BaseSorting object"
|
|
41
|
-
W = sw.CrossCorrelogramsWidget(sorting)
|
|
42
|
-
return CrossCorrelograms.from_spikeinterface_widget(W)
|
|
43
|
-
|
|
44
|
-
@staticmethod
|
|
45
|
-
def from_spikeinterface_widget(W):
|
|
46
|
-
from spikeinterface.widgets.base import to_attr
|
|
47
|
-
from spikeinterface.widgets.utils_sortingview import make_serializable
|
|
48
|
-
|
|
49
|
-
from .CrossCorrelogramItem import CrossCorrelogramItem
|
|
50
|
-
|
|
51
|
-
data_plot = W.data_plot
|
|
52
|
-
|
|
53
|
-
dp = to_attr(data_plot)
|
|
54
|
-
|
|
55
|
-
unit_ids = make_serializable(dp.unit_ids)
|
|
56
|
-
|
|
57
|
-
if dp.similarity is not None:
|
|
58
|
-
similarity = dp.similarity
|
|
59
|
-
else:
|
|
60
|
-
similarity = np.ones((len(unit_ids), len(unit_ids)))
|
|
61
|
-
|
|
62
|
-
cc_items = []
|
|
63
|
-
for i in range(len(unit_ids)):
|
|
64
|
-
for j in range(i, len(unit_ids)):
|
|
65
|
-
if similarity[i, j] >= dp.min_similarity_for_correlograms:
|
|
66
|
-
cc_items.append(
|
|
67
|
-
CrossCorrelogramItem(
|
|
68
|
-
unit_id1=unit_ids[i],
|
|
69
|
-
unit_id2=unit_ids[j],
|
|
70
|
-
bin_edges_sec=(dp.bins / 1000.0).astype("float32"),
|
|
71
|
-
bin_counts=dp.correlograms[i, j].astype("int32"),
|
|
72
|
-
)
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
view = CrossCorrelograms(cross_correlograms=cc_items, hide_unit_selector=False)
|
|
76
|
-
return view
|
|
77
|
-
|
|
78
|
-
def _write_to_zarr_group(self, group: zarr.Group) -> None:
|
|
79
|
-
"""
|
|
80
|
-
Write the CrossCorrelograms data to a Zarr group
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
group: Zarr group to write data into
|
|
84
|
-
"""
|
|
85
|
-
# Set the view type
|
|
86
|
-
group.attrs["view_type"] = "CrossCorrelograms"
|
|
87
|
-
|
|
88
|
-
# Set view properties
|
|
89
|
-
if self.hide_unit_selector is not None:
|
|
90
|
-
group.attrs["hide_unit_selector"] = self.hide_unit_selector
|
|
91
|
-
|
|
92
|
-
# Store the number of cross-correlograms
|
|
93
|
-
num_cross_correlograms = len(self.cross_correlograms)
|
|
94
|
-
group.attrs["num_cross_correlograms"] = num_cross_correlograms
|
|
95
|
-
|
|
96
|
-
if num_cross_correlograms == 0:
|
|
97
|
-
return
|
|
98
|
-
|
|
99
|
-
# Get dimensions from first cross-correlogram
|
|
100
|
-
num_bins = len(self.cross_correlograms[0].bin_counts)
|
|
101
|
-
|
|
102
|
-
# Store bin edges (same for all cross-correlograms)
|
|
103
|
-
group.create_dataset(
|
|
104
|
-
"bin_edges_sec",
|
|
105
|
-
data=self.cross_correlograms[0].bin_edges_sec,
|
|
106
|
-
dtype=np.float32,
|
|
107
|
-
)
|
|
108
|
-
|
|
109
|
-
# Create 2D array for all bin counts
|
|
110
|
-
bin_counts = np.zeros((num_cross_correlograms, num_bins), dtype=np.int32)
|
|
111
|
-
|
|
112
|
-
# Store metadata for each cross-correlogram and populate bin counts array
|
|
113
|
-
cross_correlogram_metadata = []
|
|
114
|
-
for i, cross_corr in enumerate(self.cross_correlograms):
|
|
115
|
-
metadata = {
|
|
116
|
-
"unit_id1": str(cross_corr.unit_id1),
|
|
117
|
-
"unit_id2": str(cross_corr.unit_id2),
|
|
118
|
-
"index": i, # Store index to map to bin_counts array
|
|
119
|
-
"num_bins": num_bins,
|
|
120
|
-
}
|
|
121
|
-
cross_correlogram_metadata.append(metadata)
|
|
122
|
-
bin_counts[i] = cross_corr.bin_counts
|
|
123
|
-
|
|
124
|
-
# Store the bin counts as a single 2D dataset
|
|
125
|
-
group.create_dataset(
|
|
126
|
-
"bin_counts",
|
|
127
|
-
data=bin_counts,
|
|
128
|
-
dtype=np.int32,
|
|
129
|
-
)
|
|
130
|
-
|
|
131
|
-
# Store the cross-correlogram metadata
|
|
132
|
-
group.attrs["cross_correlograms"] = cross_correlogram_metadata
|