CATSort 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
catsort-0.1.3/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 lucasbeziers
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
catsort-0.1.3/PKG-INFO ADDED
@@ -0,0 +1,87 @@
1
+ Metadata-Version: 2.4
2
+ Name: CATSort
3
+ Version: 0.1.3
4
+ Summary: A collision-aware template matching spike sorter.
5
+ Author-email: Lucas Beziers <lucas.beziers.pro@gmail.com>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/lucasbeziers/CATSort
8
+ Project-URL: Repository, https://github.com/lucasbeziers/CATSort
9
+ Project-URL: Issues, https://github.com/lucasbeziers/CATSort/issues
10
+ Keywords: spike sorting,neuroscience,electrophysiology,collisions,template matching
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Programming Language :: Python :: 3.13
20
+ Classifier: Topic :: Scientific/Engineering
21
+ Requires-Python: >=3.9
22
+ Description-Content-Type: text/markdown
23
+ License-File: LICENSE
24
+ Requires-Dist: isosplit6
25
+ Requires-Dist: numba
26
+ Requires-Dist: numpy
27
+ Requires-Dist: scikit-learn
28
+ Requires-Dist: scipy
29
+ Requires-Dist: spikeinterface>=0.103.1
30
+ Dynamic: license-file
31
+
32
+ # CATSort
33
+
34
+ **CATSort** (Collision-Aware Template-matching Sort) is a robust spike sorter designed to handle overlapping spikes (collisions) with high precision using a specific collision-handling stage before clustering followed by template matching.
35
+
36
+ ## Key Features
37
+
38
+ - **Collision Handling**: Automatically identifies and flags collided spikes using multi-criterion feature analysis (amplitude, width, energy).
39
+ - **Template Matching**: Robust spike extraction using template-based matching (including 'wobble' for now).
40
+ - **Flexible Schemes**: Choose between an `adaptive` threshold optimization or an `original` fixed MAD (Median Absolute Deviation) multiplier scheme.
41
+ - **SpikeInterface Integration**: Fully compatible with the [SpikeInterface](https://github.com/SpikeInterface/spikeinterface) ecosystem.
42
+
43
+ ## Installation
44
+
45
+ You can install `catsort` via pip:
46
+
47
+ ```bash
48
+ pip install catsort
49
+ ```
50
+
51
+ Or from source:
52
+
53
+ ```bash
54
+ git clone https://github.com/lucasbeziers/CATSort.git
55
+ cd CATSort
56
+ pip install -e .
57
+ ```
58
+
59
+ ## Quick Start
60
+
61
+ ```python
62
+ import spikeinterface.extractors as se
63
+ from catsort import run_catsort
64
+
65
+ # Load your recording
66
+ recording = se.read_binary("path_to_data.dat", sampling_frequency=30000, num_channels=384, dtype="int16")
67
+
68
+ # Run CATSort
69
+ sorting = run_catsort(recording)
70
+
71
+ # The result is a SpikeInterface Sorting object
72
+ print(sorting)
73
+ ```
74
+
75
+ ## Parameters
76
+
77
+ CATSort offers several parameters to fine-tune its behavior:
78
+
79
+ - `scheme`: `'original'` or `'adaptive'`.
80
+ - `'adaptive'` uses temporal collisions to optimize thresholds.
81
+ - `'original'` uses fixed MAD multipliers.
82
+ - `mad_multiplier_amplitude`, `mad_multiplier_width`, `mad_multiplier_energy`: (Default: `7.0`, `10.0`, `15.0`) Used when `scheme='original'`.
83
+ - `detect_threshold`: Spike detection threshold in standard deviations (Default: `5`).
84
+
85
+ ## License
86
+
87
+ MIT License. See [LICENSE](LICENSE) for details.
@@ -0,0 +1,56 @@
1
+ # CATSort
2
+
3
+ **CATSort** (Collision-Aware Template-matching Sort) is a robust spike sorter designed to handle overlapping spikes (collisions) with high precision using a specific collision-handling stage before clustering followed by template matching.
4
+
5
+ ## Key Features
6
+
7
+ - **Collision Handling**: Automatically identifies and flags collided spikes using multi-criterion feature analysis (amplitude, width, energy).
8
+ - **Template Matching**: Robust spike extraction using template-based matching (including 'wobble' for now).
9
+ - **Flexible Schemes**: Choose between an `adaptive` threshold optimization or an `original` fixed MAD (Median Absolute Deviation) multiplier scheme.
10
+ - **SpikeInterface Integration**: Fully compatible with the [SpikeInterface](https://github.com/SpikeInterface/spikeinterface) ecosystem.
11
+
12
+ ## Installation
13
+
14
+ You can install `catsort` via pip:
15
+
16
+ ```bash
17
+ pip install catsort
18
+ ```
19
+
20
+ Or from source:
21
+
22
+ ```bash
23
+ git clone https://github.com/lucasbeziers/CATSort.git
24
+ cd CATSort
25
+ pip install -e .
26
+ ```
27
+
28
+ ## Quick Start
29
+
30
+ ```python
31
+ import spikeinterface.extractors as se
32
+ from catsort import run_catsort
33
+
34
+ # Load your recording
35
+ recording = se.read_binary("path_to_data.dat", sampling_frequency=30000, num_channels=384, dtype="int16")
36
+
37
+ # Run CATSort
38
+ sorting = run_catsort(recording)
39
+
40
+ # The result is a SpikeInterface Sorting object
41
+ print(sorting)
42
+ ```
43
+
44
+ ## Parameters
45
+
46
+ CATSort offers several parameters to fine-tune its behavior:
47
+
48
+ - `scheme`: `'original'` or `'adaptive'`.
49
+ - `'adaptive'` uses temporal collisions to optimize thresholds.
50
+ - `'original'` uses fixed MAD multipliers.
51
+ - `mad_multiplier_amplitude`, `mad_multiplier_width`, `mad_multiplier_energy`: (Default: `7.0`, `10.0`, `15.0`) Used when `scheme='original'`.
52
+ - `detect_threshold`: Spike detection threshold in standard deviations (Default: `5`).
53
+
54
+ ## License
55
+
56
+ MIT License. See [LICENSE](LICENSE) for details.
@@ -0,0 +1,52 @@
1
+ [project]
2
+ name = "CATSort"
3
+ version = "0.1.3"
4
+ description = "A collision-aware template matching spike sorter."
5
+ readme = "README.md"
6
+ requires-python = ">=3.9"
7
+ license = { text = "MIT" }
8
+ authors = [
9
+ { name="Lucas Beziers", email="lucas.beziers.pro@gmail.com" },
10
+ ]
11
+ keywords = ["spike sorting", "neuroscience", "electrophysiology", "collisions", "template matching"]
12
+ classifiers = [
13
+ "Development Status :: 4 - Beta",
14
+ "Intended Audience :: Science/Research",
15
+ "License :: OSI Approved :: MIT License",
16
+ "Programming Language :: Python :: 3",
17
+ "Programming Language :: Python :: 3.9",
18
+ "Programming Language :: Python :: 3.10",
19
+ "Programming Language :: Python :: 3.11",
20
+ "Programming Language :: Python :: 3.12",
21
+ "Programming Language :: Python :: 3.13",
22
+ "Topic :: Scientific/Engineering",
23
+ ]
24
+ dependencies = [
25
+ "isosplit6",
26
+ "numba",
27
+ "numpy",
28
+ "scikit-learn",
29
+ "scipy",
30
+ "spikeinterface>=0.103.1",
31
+ ]
32
+
33
+ [project.urls]
34
+ Homepage = "https://github.com/lucasbeziers/CATSort"
35
+ Repository = "https://github.com/lucasbeziers/CATSort"
36
+ Issues = "https://github.com/lucasbeziers/CATSort/issues"
37
+
38
+ [build-system]
39
+ requires = ["setuptools>=61.0"]
40
+ build-backend = "setuptools.build_meta"
41
+
42
+ [tool.setuptools.packages.find]
43
+ where = ["src"]
44
+
45
+ [dependency-groups]
46
+ dev = [
47
+ "ipykernel>=6.31.0",
48
+ "mearec",
49
+ "pandas",
50
+ "pytest",
51
+ "twine",
52
+ ]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,87 @@
1
+ Metadata-Version: 2.4
2
+ Name: CATSort
3
+ Version: 0.1.3
4
+ Summary: A collision-aware template matching spike sorter.
5
+ Author-email: Lucas Beziers <lucas.beziers.pro@gmail.com>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/lucasbeziers/CATSort
8
+ Project-URL: Repository, https://github.com/lucasbeziers/CATSort
9
+ Project-URL: Issues, https://github.com/lucasbeziers/CATSort/issues
10
+ Keywords: spike sorting,neuroscience,electrophysiology,collisions,template matching
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Programming Language :: Python :: 3.13
20
+ Classifier: Topic :: Scientific/Engineering
21
+ Requires-Python: >=3.9
22
+ Description-Content-Type: text/markdown
23
+ License-File: LICENSE
24
+ Requires-Dist: isosplit6
25
+ Requires-Dist: numba
26
+ Requires-Dist: numpy
27
+ Requires-Dist: scikit-learn
28
+ Requires-Dist: scipy
29
+ Requires-Dist: spikeinterface>=0.103.1
30
+ Dynamic: license-file
31
+
32
+ # CATSort
33
+
34
+ **CATSort** (Collision-Aware Template-matching Sort) is a robust spike sorter designed to handle overlapping spikes (collisions) with high precision using a specific collision-handling stage before clustering followed by template matching.
35
+
36
+ ## Key Features
37
+
38
+ - **Collision Handling**: Automatically identifies and flags collided spikes using multi-criterion feature analysis (amplitude, width, energy).
39
+ - **Template Matching**: Robust spike extraction using template-based matching (including 'wobble' for now).
40
+ - **Flexible Schemes**: Choose between an `adaptive` threshold optimization or an `original` fixed MAD (Median Absolute Deviation) multiplier scheme.
41
+ - **SpikeInterface Integration**: Fully compatible with the [SpikeInterface](https://github.com/SpikeInterface/spikeinterface) ecosystem.
42
+
43
+ ## Installation
44
+
45
+ You can install `catsort` via pip:
46
+
47
+ ```bash
48
+ pip install catsort
49
+ ```
50
+
51
+ Or from source:
52
+
53
+ ```bash
54
+ git clone https://github.com/lucasbeziers/CATSort.git
55
+ cd CATSort
56
+ pip install -e .
57
+ ```
58
+
59
+ ## Quick Start
60
+
61
+ ```python
62
+ import spikeinterface.extractors as se
63
+ from catsort import run_catsort
64
+
65
+ # Load your recording
66
+ recording = se.read_binary("path_to_data.dat", sampling_frequency=30000, num_channels=384, dtype="int16")
67
+
68
+ # Run CATSort
69
+ sorting = run_catsort(recording)
70
+
71
+ # The result is a SpikeInterface Sorting object
72
+ print(sorting)
73
+ ```
74
+
75
+ ## Parameters
76
+
77
+ CATSort offers several parameters to fine-tune its behavior:
78
+
79
+ - `scheme`: `'original'` or `'adaptive'`.
80
+ - `'adaptive'` uses temporal collisions to optimize thresholds.
81
+ - `'original'` uses fixed MAD multipliers.
82
+ - `mad_multiplier_amplitude`, `mad_multiplier_width`, `mad_multiplier_energy`: (Default: `7.0`, `10.0`, `15.0`) Used when `scheme='original'`.
83
+ - `detect_threshold`: Spike detection threshold in standard deviations (Default: `5`).
84
+
85
+ ## License
86
+
87
+ MIT License. See [LICENSE](LICENSE) for details.
@@ -0,0 +1,14 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ src/CATSort.egg-info/PKG-INFO
5
+ src/CATSort.egg-info/SOURCES.txt
6
+ src/CATSort.egg-info/dependency_links.txt
7
+ src/CATSort.egg-info/requires.txt
8
+ src/CATSort.egg-info/top_level.txt
9
+ src/catsort/__init__.py
10
+ src/catsort/sorter.py
11
+ src/catsort/core/clustering.py
12
+ src/catsort/core/collision.py
13
+ src/catsort/core/utils.py
14
+ tests/test_sorter.py
@@ -0,0 +1,6 @@
1
+ isosplit6
2
+ numba
3
+ numpy
4
+ scikit-learn
5
+ scipy
6
+ spikeinterface>=0.103.1
@@ -0,0 +1 @@
1
+ catsort
@@ -0,0 +1,5 @@
1
+ # CATSort
2
+
3
+ __version__ = "0.1.3"
4
+
5
+ from .sorter import run_catsort, DEFAULT_PARAMS
@@ -0,0 +1,71 @@
1
+ # Adapted from MountainSort5 (https://github.com/flatironinstitute/mountainsort5)
2
+ # Licensed under Apache-2.0
3
+
4
+ import numpy as np
5
+ from sklearn import decomposition
6
+ from isosplit6 import isosplit6
7
+ from scipy.spatial.distance import squareform
8
+ from scipy.cluster.hierarchy import linkage, cut_tree
9
+ from numpy.typing import NDArray
10
+ from typing import Optional
11
+
12
+ def compute_pca_features(X: NDArray, npca: int) -> NDArray:
13
+ L = X.shape[0]
14
+ D = X.shape[1]
15
+ npca_2 = np.minimum(np.minimum(npca, L), D)
16
+ if L == 0 or D == 0:
17
+ return np.zeros((0, npca_2), dtype=np.float32)
18
+ pca = decomposition.PCA(n_components=npca_2)
19
+ return pca.fit_transform(X)
20
+
21
+ def isosplit6_subdivision_method(X: NDArray, npca_per_subdivision: int, inds: Optional[NDArray] = None) -> NDArray:
22
+ if inds is not None:
23
+ X_sub = X[inds]
24
+ else:
25
+ X_sub = X
26
+
27
+ L = X_sub.shape[0]
28
+ if L == 0:
29
+ return np.zeros((0,), dtype=np.int32)
30
+
31
+ features = compute_pca_features(X_sub, npca=npca_per_subdivision)
32
+ labels = isosplit6(features)
33
+
34
+ K = int(np.max(labels)) if len(labels) > 0 else 0
35
+
36
+ if K <= 1:
37
+ return labels
38
+
39
+ centroids = np.zeros((K, X.shape[1]), dtype=np.float32)
40
+ for k in range(1, K + 1):
41
+ centroids[k - 1] = np.median(X_sub[labels == k], axis=0)
42
+
43
+ dists = np.sqrt(np.sum((centroids[:, None, :] - centroids[None, :, :]) ** 2, axis=2))
44
+ dists_condensed = squareform(dists)
45
+
46
+ Z = linkage(dists_condensed, method='single', metric='euclidean')
47
+ clusters0 = cut_tree(Z, n_clusters=2)
48
+
49
+ cluster_inds_1 = np.where(clusters0 == 0)[0] + 1
50
+ cluster_inds_2 = np.where(clusters0 == 1)[0] + 1
51
+
52
+ inds1 = np.where(np.isin(labels, cluster_inds_1))[0]
53
+ inds2 = np.where(np.isin(labels, cluster_inds_2))[0]
54
+
55
+ if inds is not None:
56
+ inds1_b = inds[inds1]
57
+ inds2_b = inds[inds2]
58
+ else:
59
+ inds1_b = inds1
60
+ inds2_b = inds2
61
+
62
+ labels1 = isosplit6_subdivision_method(X, npca_per_subdivision=npca_per_subdivision, inds=inds1_b)
63
+ labels2 = isosplit6_subdivision_method(X, npca_per_subdivision=npca_per_subdivision, inds=inds2_b)
64
+
65
+ K1 = int(np.max(labels1))
66
+ ret_labels = np.zeros(L, dtype=np.int32)
67
+ ret_labels[inds1] = labels1
68
+ ret_labels[inds2] = labels2 + K1
69
+ return ret_labels
70
+
71
+
@@ -0,0 +1,151 @@
1
+ import numpy as np
2
+ from scipy.signal import resample
3
+ from scipy.stats import median_abs_deviation
4
+
5
+
6
+ def compute_fixed_thresholds(
7
+ features: dict,
8
+ mad_multipliers: dict
9
+ ) -> dict:
10
+ """
11
+ Compute thresholds for collision features using fixed MAD multipliers.
12
+ Threshold = median + MAD_multiplier * MAD
13
+
14
+ Args:
15
+ features: Dictionary with feature values (amplitude, width, energy)
16
+ mad_multipliers: Dictionary with MAD multipliers for each feature
17
+
18
+ Returns:
19
+ Dictionary with thresholds for each feature
20
+ """
21
+ thresholds = {}
22
+ for criterion, values in features.items():
23
+ if criterion in mad_multipliers:
24
+ median_val = np.median(values)
25
+ mad_val = median_abs_deviation(values)
26
+ thresholds[criterion] = median_val + mad_multipliers[criterion] * mad_val
27
+ else:
28
+ # Default to max value (no flagging) if multiplier not specified
29
+ thresholds[criterion] = np.max(values)
30
+ return thresholds
31
+
32
+ def longest_true_runs(arr: np.ndarray) -> np.ndarray:
33
+ """
34
+ For each row in a 2D boolean array, finds the longest consecutive run of True values
35
+ and returns a boolean array of the same shape with only that run set to True.
36
+ """
37
+ out = np.zeros_like(arr, dtype=bool)
38
+ for i, row in enumerate(arr):
39
+ # Find start/end indices of True runs
40
+ padded = np.r_[False, row, False]
41
+ edges = np.flatnonzero(padded[1:] != padded[:-1])
42
+ starts, ends = edges[::2], edges[1::2]
43
+ if len(starts) == 0:
44
+ continue
45
+ lengths = ends - starts
46
+ j = np.argmax(lengths)
47
+ out[i, starts[j]:ends[j]] = True
48
+ return out
49
+
50
+ def compute_collision_features(
51
+ peaks_traces: np.ndarray,
52
+ sampling_frequency: float,
53
+ width_threshold_amplitude: float = 0.5
54
+ ) -> dict:
55
+ """
56
+ Compute collision features (amplitude, width, energy) for each peak.
57
+ """
58
+ amplitudes = np.abs(peaks_traces).max(axis=1)
59
+ energies = np.sum(peaks_traces**2, axis=1)
60
+
61
+ # Resample for width precision if 1D trace per peak
62
+ if peaks_traces.ndim == 2:
63
+ num_samples = peaks_traces.shape[1]
64
+ resample_factor = 8
65
+ peaks_traces_resampled = resample(peaks_traces, num=num_samples * resample_factor, axis=1)
66
+ # Compute widths on resampled traces
67
+ under_width_threshold = peaks_traces_resampled < -width_threshold_amplitude * amplitudes[:, np.newaxis]
68
+ longest_under = longest_true_runs(under_width_threshold)
69
+ widths_samples = np.sum(longest_under, axis=1) / resample_factor
70
+ else:
71
+ under_width_threshold = peaks_traces < -width_threshold_amplitude * amplitudes[:, np.newaxis]
72
+ widths_samples = np.sum(longest_true_runs(under_width_threshold), axis=1)
73
+
74
+ widths_ms = widths_samples * (1000 / sampling_frequency)
75
+
76
+ return {
77
+ "amplitude": amplitudes,
78
+ "width": widths_ms,
79
+ "energy": energies
80
+ }
81
+
82
+ def detect_temporal_collisions(
83
+ sample_indices: np.ndarray,
84
+ channel_indices: np.ndarray,
85
+ sampling_frequency: float,
86
+ refractory_period_ms: float = 2.0
87
+ ) -> np.ndarray:
88
+ """
89
+ Identify spikes that are too close to each other on the same channel.
90
+ """
91
+ prev_diffs = np.full_like(sample_indices, np.inf, dtype=float)
92
+ next_diffs = np.full_like(sample_indices, np.inf, dtype=float)
93
+
94
+ for ch in np.unique(channel_indices):
95
+ mask = channel_indices == ch
96
+ idx = np.where(mask)[0]
97
+ samples = sample_indices[mask]
98
+
99
+ # Sort by time within this channel
100
+ order = np.argsort(samples)
101
+ sorted_idx = idx[order]
102
+ sorted_samples = samples[order]
103
+
104
+ # Compute diffs
105
+ if len(sorted_samples) > 1:
106
+ diffs = np.diff(sorted_samples)
107
+ prev_diffs[sorted_idx[1:]] = diffs
108
+ next_diffs[sorted_idx[:-1]] = diffs
109
+
110
+ closest_sample_diff = np.minimum(np.abs(prev_diffs), np.abs(next_diffs))
111
+ closest_ms = closest_sample_diff * 1000 / sampling_frequency
112
+ return closest_ms < refractory_period_ms
113
+
114
+ def optimize_collision_thresholds(
115
+ features: dict,
116
+ temporal_collisions: np.ndarray,
117
+ false_positive_tolerance: float = 0.05
118
+ ) -> dict:
119
+ """
120
+ Find optimal thresholds for collision features based on temporal collisions.
121
+ """
122
+ non_collision_count = (~temporal_collisions).sum()
123
+ max_fp = false_positive_tolerance * non_collision_count
124
+
125
+ optimized_thresholds = {}
126
+
127
+ for criterion, values in features.items():
128
+ unique_vals = np.sort(np.unique(values))
129
+ best_thr = unique_vals[-1] # Default to max (nothing flagged)
130
+ best_tp = -1
131
+
132
+ # Binary search or scan for best threshold
133
+ # For simplicity, we scan a subset of values if too many
134
+ if len(unique_vals) > 1000:
135
+ candidates = unique_vals[::len(unique_vals)//1000]
136
+ else:
137
+ candidates = unique_vals
138
+
139
+ for thr in candidates:
140
+ flagged = values > thr
141
+ tp = np.count_nonzero(flagged & temporal_collisions)
142
+ fp = np.count_nonzero(flagged & ~temporal_collisions)
143
+
144
+ if fp <= max_fp:
145
+ if tp > best_tp:
146
+ best_tp = tp
147
+ best_thr = thr
148
+
149
+ optimized_thresholds[criterion] = best_thr
150
+
151
+ return optimized_thresholds
@@ -0,0 +1,70 @@
1
+ import numpy as np
2
+
3
+ def get_snippet(
4
+ traces: np.ndarray,
5
+ index: int,
6
+ n_before: int, n_after: int
7
+ ) -> np.ndarray:
8
+ """
9
+ Get a snippet of the traces around a specific index.
10
+ Fill the snippet with zeros if the index is out of bounds.
11
+ """
12
+ n_channels = traces.shape[1]
13
+ snippet = np.zeros((n_before + n_after, n_channels))
14
+
15
+ start = index - n_before
16
+ end = index + n_after
17
+
18
+ # If the snippet is fully within bounds, extract it directly
19
+ if start >= 0 and end <= traces.shape[0]:
20
+ snippet = traces[start:end, :]
21
+
22
+ # If the snippet is partially out of bounds, fill with zeros where necessary
23
+ else:
24
+ valid_start = max(start, 0)
25
+ valid_end = min(end, traces.shape[0])
26
+ insert_start = valid_start - start
27
+ insert_end = insert_start + (valid_end - valid_start)
28
+ snippet[insert_start:insert_end, :] = traces[valid_start:valid_end, :]
29
+
30
+ return snippet # shape (n_before+n_after, n_channels)
31
+
32
+ def get_peaks_traces_all_channels(
33
+ peaks: np.ndarray,
34
+ traces: np.ndarray,
35
+ n_before: int, n_after: int
36
+ ) -> np.ndarray:
37
+ """
38
+ Extract snippets of traces around detected peaks.
39
+
40
+ Output shape: (n_peaks, n_before + n_after, n_channels)
41
+ """
42
+ n_channels = traces.shape[1]
43
+ complete_peaks = np.zeros((len(peaks), n_before+n_after, n_channels))
44
+
45
+ for i, peak in enumerate(peaks):
46
+ sample_index = peak['sample_index']
47
+ snippet = get_snippet(traces, sample_index, n_before, n_after)
48
+ complete_peaks[i] = snippet
49
+ return complete_peaks
50
+
51
+
52
+ def get_peaks_traces_best_channel(
53
+ peaks: np.ndarray,
54
+ traces: np.ndarray,
55
+ n_before: int, n_after: int
56
+ ) -> np.ndarray:
57
+ """
58
+ Extract snippets of traces around detected peaks.
59
+ Keep only the channel with the highest amplitude
60
+
61
+ Output shape: (n_peaks, n_before + n_after)
62
+ """
63
+ complete_peaks = np.zeros((len(peaks), n_before+n_after))
64
+
65
+ for i, peak in enumerate(peaks):
66
+ sample_index = peak['sample_index']
67
+ snippet = get_snippet(traces, sample_index, n_before, n_after) # shape (n_before+n_after, n_channels)
68
+ best_channel = np.argmax(np.abs(snippet).max(axis=0))
69
+ complete_peaks[i] = snippet[:, best_channel] # shape (n_before+n_after)
70
+ return complete_peaks
@@ -0,0 +1,212 @@
1
+ import numpy as np
2
+ from typing import Optional
3
+ from spikeinterface.core import BaseRecording, NumpySorting, SortingAnalyzer, Templates, ChannelSparsity, create_sorting_analyzer
4
+ from spikeinterface.sortingcomponents.peak_detection import detect_peaks
5
+ import spikeinterface.sortingcomponents.matching as sm
6
+ from sklearn.decomposition import PCA
7
+
8
+ from catsort.core.utils import get_peaks_traces_best_channel, get_peaks_traces_all_channels
9
+ from catsort.core.collision import (
10
+ compute_collision_features,
11
+ detect_temporal_collisions,
12
+ optimize_collision_thresholds,
13
+ compute_fixed_thresholds
14
+ )
15
+ from catsort.core.clustering import isosplit6_subdivision_method
16
+
17
+ DEFAULT_PARAMS = {
18
+ # Detection
19
+ 'detect_threshold': 5,
20
+ 'exclude_sweep_ms': 0.2,
21
+ 'radius_um': 100,
22
+
23
+ # Collision analysis
24
+ 'ms_before_spike_detected': 1.0,
25
+ 'ms_after_spike_detected': 1.0,
26
+ 'refractory_period': 2.0,
27
+ 'scheme': 'original', # 'original' or 'adaptive'
28
+
29
+ # Original scheme parameters
30
+ 'mad_multiplier_amplitude': 7.0,
31
+ 'mad_multiplier_width': 10.0,
32
+ 'mad_multiplier_energy': 15.0,
33
+ # Adaptive scheme parameters
34
+ 'false_positive_tolerance': 0.05, # Used when scheme='adaptive'
35
+
36
+ # Clustering
37
+ 'n_pca_components': 10,
38
+ 'npca_per_subdivision': 10,
39
+
40
+ # Template matching
41
+ 'tm_method': 'wobble', # 'wobble' only for now
42
+
43
+ # Template matching (Wobble)
44
+ 'threshold_wobble': 5000,
45
+ 'jitter_factor_wobble': 24,
46
+ 'refractory_period_ms_wobble': 2.0,
47
+ }
48
+
49
+ def get_sorting_analyzer_with_computations(
50
+ sorting: NumpySorting,
51
+ recording: BaseRecording,
52
+ ms_before: float, ms_after: float
53
+ ) -> SortingAnalyzer:
54
+ sorting_analyzer = create_sorting_analyzer(sorting, recording, return_in_uV=True)
55
+ sorting_analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=np.inf)
56
+ sorting_analyzer.compute("waveforms", ms_before=ms_before, ms_after=ms_after)
57
+ sorting_analyzer.compute("templates", operators=["average", "median", "std"])
58
+ return sorting_analyzer
59
+
60
+ def run_catsort(recording: BaseRecording, params: Optional[dict] = None) -> NumpySorting:
61
+ """
62
+ Main entry point for Catsort (Collision Aware Template-matching sort).
63
+
64
+ Args:
65
+ recording: spikeinterface recording object
66
+ params: dictionary of parameters (optional)
67
+
68
+ Returns:
69
+ sorting: spikeinterface sorting object
70
+ """
71
+ if params is None:
72
+ params = DEFAULT_PARAMS
73
+ else:
74
+ # Merge with defaults
75
+ full_params = DEFAULT_PARAMS.copy()
76
+ full_params.update(params)
77
+ params = full_params
78
+
79
+ print("Step 1: Detecting spikes...")
80
+ peaks_detected = detect_peaks(
81
+ recording=recording,
82
+ method="locally_exclusive",
83
+ method_kwargs={
84
+ "peak_sign": "neg",
85
+ "detect_threshold": params['detect_threshold'],
86
+ "exclude_sweep_ms": params['exclude_sweep_ms'],
87
+ "radius_um": params['radius_um'],
88
+ },
89
+ )
90
+
91
+ sampling_freq = recording.get_sampling_frequency()
92
+ traces = recording.get_traces()
93
+
94
+ n_before = int(sampling_freq * params['ms_before_spike_detected'] * 0.001)
95
+ n_after = int(sampling_freq * params['ms_after_spike_detected'] * 0.001)
96
+
97
+ print("Step 2: Collision handling...")
98
+ # Extract best channel traces for collision feature computation
99
+ traces_best = get_peaks_traces_best_channel(peaks_detected, traces, n_before, n_after)
100
+
101
+ # Temporal collisions
102
+ too_close = detect_temporal_collisions(
103
+ peaks_detected['sample_index'],
104
+ peaks_detected['channel_index'],
105
+ sampling_freq,
106
+ params['refractory_period']
107
+ )
108
+
109
+ # Compute features and thresholds based on scheme
110
+ collision_features = compute_collision_features(traces_best, sampling_freq)
111
+
112
+ if params['scheme'] == 'adaptive':
113
+ thresholds = optimize_collision_thresholds(
114
+ collision_features,
115
+ too_close,
116
+ params['false_positive_tolerance']
117
+ )
118
+ elif params['scheme'] == 'original':
119
+ mad_multipliers = {
120
+ 'amplitude': params['mad_multiplier_amplitude'],
121
+ 'width': params['mad_multiplier_width'],
122
+ 'energy': params['mad_multiplier_energy']
123
+ }
124
+ thresholds = compute_fixed_thresholds(collision_features, mad_multipliers)
125
+ else:
126
+ raise ValueError(f"Unknown scheme: {params['scheme']}. Must be 'adaptive' or 'original'.")
127
+
128
+ print(f" Scheme: {params['scheme']}")
129
+
130
+ # Flag collisions
131
+ is_collision = too_close.copy()
132
+ print(f" Temporal collisions: {np.sum(too_close)}")
133
+ for crit, val in collision_features.items():
134
+ flagged_by_crit = val > thresholds[crit]
135
+ flagged_by_crit_not_too_close = flagged_by_crit & ~too_close
136
+ print(f" {crit}: {np.sum(flagged_by_crit_not_too_close)} additional collisions")
137
+ is_collision |= flagged_by_crit
138
+
139
+ print(f" Total flagged: {np.sum(is_collision)} collisions out of {len(peaks_detected)} spikes")
140
+
141
+ print("Step 3: Clustering non-collided spikes...")
142
+ mask_not_collided = ~is_collision
143
+ traces_all_not_collided = get_peaks_traces_all_channels(peaks_detected[mask_not_collided], traces, n_before, n_after)
144
+
145
+ # PCA and Clustering
146
+ num_spikes, num_samples, num_channels = traces_all_not_collided.shape
147
+ concatenated = traces_all_not_collided.reshape(num_spikes, -1)
148
+ pca = PCA(n_components=params['n_pca_components'])
149
+ features_not_collided = pca.fit_transform(concatenated)
150
+
151
+ labels_not_collided = isosplit6_subdivision_method(
152
+ features_not_collided,
153
+ npca_per_subdivision=params['npca_per_subdivision']
154
+ )
155
+
156
+ # Create sorting with clusters for template computation
157
+ samples_clean = peaks_detected[mask_not_collided]['sample_index']
158
+ labels_clean = labels_not_collided
159
+
160
+ sorting_clean = NumpySorting.from_samples_and_labels(
161
+ samples_list=[samples_clean],
162
+ labels_list=[labels_clean],
163
+ sampling_frequency=sampling_freq
164
+ )
165
+
166
+ print("Step 4: Template Matching...")
167
+ # Compute templates from clean clusters
168
+ analyzer = get_sorting_analyzer_with_computations(
169
+ sorting_clean, recording,
170
+ params['ms_before_spike_detected'],
171
+ params['ms_after_spike_detected']
172
+ )
173
+
174
+ templates_ext = analyzer.get_extension('templates')
175
+ sparsity = ChannelSparsity.create_dense(analyzer)
176
+
177
+ templates = Templates(
178
+ templates_array=templates_ext.data['average'],
179
+ sampling_frequency=sampling_freq,
180
+ nbefore=templates_ext.nbefore,
181
+ is_in_uV=True,
182
+ sparsity_mask=sparsity.mask,
183
+ channel_ids=analyzer.channel_ids,
184
+ unit_ids=analyzer.unit_ids,
185
+ probe=analyzer.get_probe()
186
+ )
187
+
188
+ if params['tm_method'] == 'wobble':
189
+ spikes_tm = sm.find_spikes_from_templates(
190
+ recording=recording,
191
+ templates=templates,
192
+ method='wobble',
193
+
194
+ method_kwargs={
195
+ "parameters": {
196
+ "threshold": params['threshold_wobble'],
197
+ "jitter_factor": params['jitter_factor_wobble'],
198
+ "refractory_period_frames": int(sampling_freq * params['refractory_period_ms_wobble'] * 0.001),
199
+ "scale_amplitudes": True
200
+ },
201
+ }
202
+ )
203
+ else:
204
+ raise ValueError(f"Unknown template matching method: {params['tm_method']}")
205
+
206
+ final_sorting = NumpySorting.from_samples_and_labels(
207
+ samples_list=[spikes_tm['sample_index']],
208
+ labels_list=[spikes_tm['cluster_index']],
209
+ sampling_frequency=sampling_freq
210
+ )
211
+
212
+ return final_sorting
@@ -0,0 +1,82 @@
1
+ from spikeinterface.core import generate_recording, generate_sorting, generate_snippets
2
+ from spikeinterface.extractors import toy_example
3
+ from spikeinterface.comparison import compare_sorter_to_ground_truth
4
+
5
+ from catsort import sorter
6
+
7
+
8
+ def test_tetrode():
9
+ recording, sorting_gt = toy_example(
10
+ duration=10,
11
+ num_channels=4,
12
+ num_units=5,
13
+ sampling_frequency=30000,
14
+ num_segments=1
15
+ )
16
+
17
+ default_parameters = sorter.DEFAULT_PARAMS
18
+ sorting = sorter.run_catsort(recording, params=default_parameters)
19
+ assert True
20
+
21
+ def test_performance_tetrode():
22
+ recording, sorting_gt = toy_example(
23
+ duration=10,
24
+ num_channels=4,
25
+ num_units=5,
26
+ sampling_frequency=30000,
27
+ num_segments=1
28
+ )
29
+
30
+ default_parameters = sorter.DEFAULT_PARAMS
31
+ sorting = sorter.run_catsort(recording, params=default_parameters)
32
+ comparison = compare_sorter_to_ground_truth(sorting, sorting_gt, match_score=0.01)
33
+ perf = comparison.get_performance()
34
+
35
+ assert perf['accuracy'].mean() > 0.5
36
+ assert perf['recall'].mean() > 0.5
37
+ assert perf['precision'].mean() > 0.5
38
+
39
+ def test_monotrode():
40
+ recording, sorting_gt = toy_example(
41
+ duration=10,
42
+ num_channels=1,
43
+ num_units=5,
44
+ sampling_frequency=30000,
45
+ num_segments=1
46
+ )
47
+
48
+ default_parameters = sorter.DEFAULT_PARAMS
49
+ sorting = sorter.run_catsort(recording, params=default_parameters)
50
+ assert True
51
+
52
+ def test_performance_monotrode():
53
+ recording, sorting_gt = toy_example(
54
+ duration=10,
55
+ num_channels=1,
56
+ num_units=5,
57
+ sampling_frequency=30000,
58
+ num_segments=1
59
+ )
60
+
61
+ default_parameters = sorter.DEFAULT_PARAMS
62
+ sorting = sorter.run_catsort(recording, params=default_parameters)
63
+ comparison = compare_sorter_to_ground_truth(sorting, sorting_gt, match_score=0.01)
64
+ perf = comparison.get_performance()
65
+
66
+ assert perf['accuracy'].mean() > 0.25
67
+ assert perf['recall'].mean() > 0.25
68
+ assert perf['precision'].mean() > 0.25
69
+
70
+ def test_scheme_adaptive():
71
+ recording, sorting_gt = toy_example(
72
+ duration=10,
73
+ num_channels=4,
74
+ num_units=5,
75
+ sampling_frequency=30000,
76
+ num_segments=1
77
+ )
78
+
79
+ adaptive_parameters = sorter.DEFAULT_PARAMS.copy()
80
+ adaptive_parameters['scheme'] = 'adaptive'
81
+ sorting = sorter.run_catsort(recording, params=adaptive_parameters)
82
+ assert True