figpack-spike-sorting 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,24 @@
1
+ import figpack
2
+
3
+
4
+ def _load_javascript_code():
5
+ """Load the JavaScript code from the built figpack_spike_sorting.js file"""
6
+ import os
7
+
8
+ js_path = os.path.join(os.path.dirname(__file__), "figpack_spike_sorting.js")
9
+ try:
10
+ with open(js_path, "r", encoding="utf-8") as f:
11
+ return f.read()
12
+ except FileNotFoundError:
13
+ raise FileNotFoundError(
14
+ f"Could not find figpack_spike_sorting.js at {js_path}. "
15
+ "Make sure to run 'npm run build' to generate the JavaScript bundle."
16
+ )
17
+
18
+
19
+ # Create and register the figpack_spike_sorting extension
20
+ spike_sorting_extension = figpack.FigpackExtension(
21
+ name="figpack-spike-sorting",
22
+ javascript_code=_load_javascript_code(),
23
+ version="1.0.0",
24
+ )
@@ -0,0 +1,32 @@
1
+ """
2
+ AutocorrelogramItem for spike sorting views
3
+ """
4
+
5
+ from typing import Union
6
+
7
+ import numpy as np
8
+
9
+
10
+ class AutocorrelogramItem:
11
+ """
12
+ Represents a single autocorrelogram for a unit
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ *,
18
+ unit_id: Union[str, int],
19
+ bin_edges_sec: np.ndarray,
20
+ bin_counts: np.ndarray,
21
+ ):
22
+ """
23
+ Initialize an AutocorrelogramItem
24
+
25
+ Args:
26
+ unit_id: Identifier for the unit
27
+ bin_edges_sec: Array of bin edges in seconds
28
+ bin_counts: Array of bin counts
29
+ """
30
+ self.unit_id = unit_id
31
+ self.bin_edges_sec = np.array(bin_edges_sec, dtype=np.float32)
32
+ self.bin_counts = np.array(bin_counts, dtype=np.int32)
@@ -0,0 +1,120 @@
1
+ """
2
+ Autocorrelograms view for figpack - displays multiple autocorrelograms
3
+ """
4
+
5
+ from typing import List
6
+
7
+ import numpy as np
8
+
9
+ import figpack
10
+ from .AutocorrelogramItem import AutocorrelogramItem
11
+ from ..spike_sorting_extension import spike_sorting_extension
12
+
13
+
14
+ class Autocorrelograms(figpack.ExtensionView):
15
+ """
16
+ A view that displays multiple autocorrelograms for spike sorting analysis
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ *,
22
+ autocorrelograms: List[AutocorrelogramItem],
23
+ ):
24
+ """
25
+ Initialize an Autocorrelograms view
26
+
27
+ Args:
28
+ autocorrelograms: List of AutocorrelogramItem objects
29
+ """
30
+ super().__init__(
31
+ extension=spike_sorting_extension,
32
+ view_type="spike_sorting.Autocorrelograms",
33
+ )
34
+
35
+ self.autocorrelograms = autocorrelograms
36
+
37
+ @staticmethod
38
+ def from_sorting(sorting):
39
+ import spikeinterface as si
40
+ import spikeinterface.widgets as sw
41
+
42
+ assert isinstance(sorting, si.BaseSorting), "Input must be a BaseSorting object"
43
+ W = sw.plot_autocorrelograms(sorting)
44
+ return Autocorrelograms.from_spikeinterface_widget(W)
45
+
46
+ @staticmethod
47
+ def from_spikeinterface_widget(W):
48
+ from spikeinterface.widgets.base import to_attr
49
+ from spikeinterface.widgets.utils_sortingview import make_serializable
50
+
51
+ from .AutocorrelogramItem import AutocorrelogramItem
52
+
53
+ data_plot = W.data_plot
54
+
55
+ dp = to_attr(data_plot)
56
+
57
+ unit_ids = make_serializable(dp.unit_ids)
58
+
59
+ ac_items = []
60
+ for i in range(len(unit_ids)):
61
+ for j in range(i, len(unit_ids)):
62
+ if i == j:
63
+ ac_items.append(
64
+ AutocorrelogramItem(
65
+ unit_id=unit_ids[i],
66
+ bin_edges_sec=(dp.bins / 1000.0).astype("float32"),
67
+ bin_counts=dp.correlograms[i, j].astype("int32"),
68
+ )
69
+ )
70
+
71
+ view = Autocorrelograms(autocorrelograms=ac_items)
72
+ return view
73
+
74
+ def _write_to_zarr_group(self, group: figpack.Group) -> None:
75
+ """
76
+ Write the Autocorrelograms data to a Zarr group
77
+
78
+ Args:
79
+ group: Zarr group to write data into
80
+ """
81
+ super()._write_to_zarr_group(group)
82
+
83
+ # Store the number of autocorrelograms
84
+ num_autocorrelograms = len(self.autocorrelograms)
85
+ group.attrs["num_autocorrelograms"] = num_autocorrelograms
86
+
87
+ if num_autocorrelograms == 0:
88
+ return
89
+
90
+ # Get dimensions from first autocorrelogram
91
+ num_bins = len(self.autocorrelograms[0].bin_counts)
92
+
93
+ # Store bin edges (same for all autocorrelograms)
94
+ group.create_dataset(
95
+ "bin_edges_sec",
96
+ data=self.autocorrelograms[0].bin_edges_sec,
97
+ )
98
+
99
+ # Create 2D array for all bin counts
100
+ bin_counts = np.zeros((num_autocorrelograms, num_bins), dtype=np.int32)
101
+
102
+ # Store metadata for each autocorrelogram and populate bin counts array
103
+ autocorrelogram_metadata = []
104
+ for i, autocorr in enumerate(self.autocorrelograms):
105
+ metadata = {
106
+ "unit_id": str(autocorr.unit_id),
107
+ "index": i, # Store index to map to bin_counts array
108
+ "num_bins": num_bins,
109
+ }
110
+ autocorrelogram_metadata.append(metadata)
111
+ bin_counts[i] = autocorr.bin_counts
112
+
113
+ # Store the bin counts as a single 2D dataset
114
+ group.create_dataset(
115
+ "bin_counts",
116
+ data=bin_counts,
117
+ )
118
+
119
+ # Store the autocorrelogram metadata
120
+ group.attrs["autocorrelograms"] = autocorrelogram_metadata
@@ -0,0 +1,147 @@
1
+ """
2
+ AverageWaveforms view for figpack - displays multiple average waveforms
3
+ """
4
+
5
+ from typing import List, Optional, Union
6
+
7
+ import numpy as np
8
+ import figpack
9
+ from ..spike_sorting_extension import spike_sorting_extension
10
+
11
+
12
+ class AverageWaveformItem:
13
+ """
14
+ Represents a single average waveform for a unit
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ *,
20
+ unit_id: Union[str, int],
21
+ channel_ids: List[Union[str, int]],
22
+ waveform: np.ndarray,
23
+ waveform_std_dev: Optional[np.ndarray] = None,
24
+ waveform_percentiles: Optional[List[np.ndarray]] = None,
25
+ ):
26
+ """
27
+ Initialize an AverageWaveformItem
28
+
29
+ Args:
30
+ unit_id: Identifier for the unit
31
+ channel_ids: List of channel identifiers
32
+ waveform: 2D numpy array representing the average waveform (num_samples x num_channels)
33
+ waveform_std_dev: Optional 2D numpy array representing the standard deviation of the waveform
34
+ waveform_percentiles: Optional list of 2D numpy arrays representing percentiles of the waveform
35
+ """
36
+ self.unit_id = unit_id
37
+ self.channel_ids = channel_ids
38
+ self.waveform = np.array(waveform, dtype=np.float32)
39
+ self.waveform_std_dev = (
40
+ np.array(waveform_std_dev, dtype=np.float32)
41
+ if waveform_std_dev is not None
42
+ else None
43
+ )
44
+ if waveform_percentiles is not None:
45
+ self.waveform_percentiles = [
46
+ np.array(p, dtype=np.float32) for p in waveform_percentiles
47
+ ]
48
+ else:
49
+ self.waveform_percentiles = None
50
+
51
+
52
+ class AverageWaveforms(figpack.ExtensionView):
53
+ """
54
+ A view that displays multiple average waveforms for spike sorting analysis
55
+ """
56
+
57
+ def __init__(self, *, average_waveforms: List[AverageWaveformItem]):
58
+ """
59
+ Initialize an AverageWaveforms view
60
+
61
+ Args:
62
+ average_waveforms: List of AverageWaveformItem objects
63
+ """
64
+ super().__init__(
65
+ extension=spike_sorting_extension,
66
+ view_type="spike_sorting.AverageWaveforms",
67
+ )
68
+ self.average_waveforms = average_waveforms
69
+
70
+ @staticmethod
71
+ def from_sorting_analyzer(sorting_analyzer):
72
+ sorting_analyzer.compute(
73
+ ["random_spikes", "waveforms", "templates", "noise_levels"]
74
+ )
75
+ ext_templates = sorting_analyzer.get_extension("templates")
76
+ # shape is num_units, num_samples, num_channels
77
+ av_templates = ext_templates.get_data(operator="average")
78
+
79
+ ext_noise_levels = sorting_analyzer.get_extension("noise_levels")
80
+ noise_levels = ext_noise_levels.get_data()
81
+
82
+ waveform_std_dev = np.zeros(
83
+ (av_templates.shape[1], av_templates.shape[2]), dtype=np.float32
84
+ )
85
+ for i in range(av_templates.shape[2]):
86
+ waveform_std_dev[:, i] = noise_levels[i]
87
+
88
+ average_waveform_items = []
89
+ for i, unit_id in enumerate(sorting_analyzer.unit_ids):
90
+ waveform = av_templates[i]
91
+ channel_ids = list(sorting_analyzer.recording.get_channel_ids())
92
+ average_waveform_items.append(
93
+ AverageWaveformItem(
94
+ unit_id=unit_id,
95
+ waveform=waveform,
96
+ channel_ids=channel_ids,
97
+ waveform_std_dev=waveform_std_dev,
98
+ )
99
+ )
100
+ view = AverageWaveforms(average_waveforms=average_waveform_items)
101
+ return view
102
+
103
+ def _write_to_zarr_group(self, group: figpack.Group) -> None:
104
+ """
105
+ Write the AverageWaveforms data to a Zarr group
106
+
107
+ Args:
108
+ group: Zarr group to write data into
109
+ """
110
+ super()._write_to_zarr_group(group)
111
+
112
+ # Store the number of average waveforms
113
+ group.attrs["num_average_waveforms"] = len(self.average_waveforms)
114
+
115
+ # Store metadata for each average waveform
116
+ average_waveform_metadata = []
117
+ for i, waveform in enumerate(self.average_waveforms):
118
+ waveform_name = f"waveform_{i}"
119
+
120
+ # Store metadata
121
+ metadata = {
122
+ "name": waveform_name,
123
+ "unit_id": str(waveform.unit_id),
124
+ "channel_ids": [str(ch) for ch in waveform.channel_ids],
125
+ }
126
+ average_waveform_metadata.append(metadata)
127
+
128
+ # Create arrays for this average waveform
129
+ group.create_dataset(
130
+ f"{waveform_name}/waveform",
131
+ data=waveform.waveform,
132
+ )
133
+ if waveform.waveform_std_dev is not None:
134
+ group.create_dataset(
135
+ f"{waveform_name}/waveform_std_dev",
136
+ data=waveform.waveform_std_dev,
137
+ )
138
+ if waveform.waveform_percentiles is not None:
139
+ for j, p in enumerate(waveform.waveform_percentiles):
140
+ group.create_dataset(
141
+ f"{waveform_name}/waveform_percentile_{j}",
142
+ data=p,
143
+ dtype=p.dtype,
144
+ )
145
+
146
+ # Store the average waveform metadata
147
+ group.attrs["average_waveforms"] = average_waveform_metadata
@@ -0,0 +1,35 @@
1
+ """
2
+ CrossCorrelogramItem for spike sorting views
3
+ """
4
+
5
+ from typing import Union
6
+
7
+ import numpy as np
8
+
9
+
10
+ class CrossCorrelogramItem:
11
+ """
12
+ Represents a single cross-correlogram between two units
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ *,
18
+ unit_id1: Union[str, int],
19
+ unit_id2: Union[str, int],
20
+ bin_edges_sec: np.ndarray,
21
+ bin_counts: np.ndarray,
22
+ ):
23
+ """
24
+ Initialize a CrossCorrelogramItem
25
+
26
+ Args:
27
+ unit_id1: Identifier for the first unit
28
+ unit_id2: Identifier for the second unit
29
+ bin_edges_sec: Array of bin edges in seconds
30
+ bin_counts: Array of bin counts
31
+ """
32
+ self.unit_id1 = unit_id1
33
+ self.unit_id2 = unit_id2
34
+ self.bin_edges_sec = np.array(bin_edges_sec, dtype=np.float32)
35
+ self.bin_counts = np.array(bin_counts, dtype=np.int32)
@@ -0,0 +1,134 @@
1
+ """
2
+ CrossCorrelograms view for figpack - displays multiple cross-correlograms
3
+ """
4
+
5
+ from typing import List, Optional
6
+
7
+ import numpy as np
8
+
9
+ import figpack
10
+ from ..spike_sorting_extension import spike_sorting_extension
11
+
12
+ from .CrossCorrelogramItem import CrossCorrelogramItem
13
+
14
+
15
+ class CrossCorrelograms(figpack.ExtensionView):
16
+ """
17
+ A view that displays multiple cross-correlograms for spike sorting analysis
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ *,
23
+ cross_correlograms: List[CrossCorrelogramItem],
24
+ hide_unit_selector: Optional[bool] = False,
25
+ ):
26
+ """
27
+ Initialize a CrossCorrelograms view
28
+
29
+ Args:
30
+ cross_correlograms: List of CrossCorrelogramItem objects
31
+ hide_unit_selector: Whether to hide the unit selector widget
32
+ """
33
+ super().__init__(
34
+ extension=spike_sorting_extension,
35
+ view_type="spike_sorting.CrossCorrelograms",
36
+ )
37
+ self.cross_correlograms = cross_correlograms
38
+ self.hide_unit_selector = hide_unit_selector
39
+
40
+ @staticmethod
41
+ def from_sorting(sorting):
42
+ import spikeinterface as si
43
+ import spikeinterface.widgets as sw
44
+
45
+ assert isinstance(sorting, si.BaseSorting), "Input must be a BaseSorting object"
46
+ W = sw.CrossCorrelogramsWidget(sorting)
47
+ return CrossCorrelograms.from_spikeinterface_widget(W)
48
+
49
+ @staticmethod
50
+ def from_spikeinterface_widget(W):
51
+ from spikeinterface.widgets.base import to_attr
52
+ from spikeinterface.widgets.utils_sortingview import make_serializable
53
+
54
+ from .CrossCorrelogramItem import CrossCorrelogramItem
55
+
56
+ data_plot = W.data_plot
57
+
58
+ dp = to_attr(data_plot)
59
+
60
+ unit_ids = make_serializable(dp.unit_ids)
61
+
62
+ if dp.similarity is not None:
63
+ similarity = dp.similarity
64
+ else:
65
+ similarity = np.ones((len(unit_ids), len(unit_ids)))
66
+
67
+ cc_items = []
68
+ for i in range(len(unit_ids)):
69
+ for j in range(i, len(unit_ids)):
70
+ if similarity[i, j] >= dp.min_similarity_for_correlograms:
71
+ cc_items.append(
72
+ CrossCorrelogramItem(
73
+ unit_id1=unit_ids[i],
74
+ unit_id2=unit_ids[j],
75
+ bin_edges_sec=(dp.bins / 1000.0).astype("float32"),
76
+ bin_counts=dp.correlograms[i, j].astype("int32"),
77
+ )
78
+ )
79
+
80
+ view = CrossCorrelograms(cross_correlograms=cc_items, hide_unit_selector=False)
81
+ return view
82
+
83
+ def _write_to_zarr_group(self, group: figpack.Group) -> None:
84
+ """
85
+ Write the CrossCorrelograms data to a Zarr group
86
+
87
+ Args:
88
+ group: Zarr group to write data into
89
+ """
90
+ super()._write_to_zarr_group(group)
91
+
92
+ # Set view properties
93
+ if self.hide_unit_selector is not None:
94
+ group.attrs["hide_unit_selector"] = self.hide_unit_selector
95
+
96
+ # Store the number of cross-correlograms
97
+ num_cross_correlograms = len(self.cross_correlograms)
98
+ group.attrs["num_cross_correlograms"] = num_cross_correlograms
99
+
100
+ if num_cross_correlograms == 0:
101
+ return
102
+
103
+ # Get dimensions from first cross-correlogram
104
+ num_bins = len(self.cross_correlograms[0].bin_counts)
105
+
106
+ # Store bin edges (same for all cross-correlograms)
107
+ group.create_dataset(
108
+ "bin_edges_sec",
109
+ data=self.cross_correlograms[0].bin_edges_sec,
110
+ )
111
+
112
+ # Create 2D array for all bin counts
113
+ bin_counts = np.zeros((num_cross_correlograms, num_bins), dtype=np.int32)
114
+
115
+ # Store metadata for each cross-correlogram and populate bin counts array
116
+ cross_correlogram_metadata = []
117
+ for i, cross_corr in enumerate(self.cross_correlograms):
118
+ metadata = {
119
+ "unit_id1": str(cross_corr.unit_id1),
120
+ "unit_id2": str(cross_corr.unit_id2),
121
+ "index": i, # Store index to map to bin_counts array
122
+ "num_bins": num_bins,
123
+ }
124
+ cross_correlogram_metadata.append(metadata)
125
+ bin_counts[i] = cross_corr.bin_counts
126
+
127
+ # Store the bin counts as a single 2D dataset
128
+ group.create_dataset(
129
+ "bin_counts",
130
+ data=bin_counts,
131
+ )
132
+
133
+ # Store the cross-correlogram metadata
134
+ group.attrs["cross_correlograms"] = cross_correlogram_metadata