CATSort 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of CATSort might be problematic. Click here for more details.
- catsort-0.1.2/LICENSE +21 -0
- catsort-0.1.2/PKG-INFO +87 -0
- catsort-0.1.2/README.md +56 -0
- catsort-0.1.2/pyproject.toml +52 -0
- catsort-0.1.2/setup.cfg +4 -0
- catsort-0.1.2/src/CATSort.egg-info/PKG-INFO +87 -0
- catsort-0.1.2/src/CATSort.egg-info/SOURCES.txt +14 -0
- catsort-0.1.2/src/CATSort.egg-info/dependency_links.txt +1 -0
- catsort-0.1.2/src/CATSort.egg-info/requires.txt +6 -0
- catsort-0.1.2/src/CATSort.egg-info/top_level.txt +1 -0
- catsort-0.1.2/src/catsort/__init__.py +5 -0
- catsort-0.1.2/src/catsort/core/clustering.py +71 -0
- catsort-0.1.2/src/catsort/core/collision.py +151 -0
- catsort-0.1.2/src/catsort/core/utils.py +70 -0
- catsort-0.1.2/src/catsort/sorter.py +212 -0
- catsort-0.1.2/tests/test_sorter.py +82 -0
catsort-0.1.2/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 lucasbeziers
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
catsort-0.1.2/PKG-INFO
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: CATSort
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: A collision-aware template matching spike sorter.
|
|
5
|
+
Author-email: Lucas Beziers <lucas.beziers.pro@gmail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/lucasbeziers/CATSort
|
|
8
|
+
Project-URL: Repository, https://github.com/lucasbeziers/CATSort
|
|
9
|
+
Project-URL: Issues, https://github.com/lucasbeziers/CATSort/issues
|
|
10
|
+
Keywords: spike sorting,neuroscience,electrophysiology,collisions,template matching
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering
|
|
21
|
+
Requires-Python: >=3.9
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
License-File: LICENSE
|
|
24
|
+
Requires-Dist: isosplit6
|
|
25
|
+
Requires-Dist: numba
|
|
26
|
+
Requires-Dist: numpy
|
|
27
|
+
Requires-Dist: scikit-learn
|
|
28
|
+
Requires-Dist: scipy
|
|
29
|
+
Requires-Dist: spikeinterface>=0.103.1
|
|
30
|
+
Dynamic: license-file
|
|
31
|
+
|
|
32
|
+
# CATSort
|
|
33
|
+
|
|
34
|
+
**CATSort** (Collision-Aware Template-matching Sort) is a robust spike sorter designed to handle overlapping spikes (collisions) with high precision using a specific collision-handling stage before clustering followed by template matching.
|
|
35
|
+
|
|
36
|
+
## Key Features
|
|
37
|
+
|
|
38
|
+
- **Collision Handling**: Automatically identifies and flags collided spikes using multi-criterion feature analysis (amplitude, width, energy).
|
|
39
|
+
- **Template Matching**: Robust spike extraction using template-based matching (including 'wobble' for now).
|
|
40
|
+
- **Flexible Schemes**: Choose between an `adaptive` threshold optimization or an `original` fixed MAD (Median Absolute Deviation) multiplier scheme.
|
|
41
|
+
- **SpikeInterface Integration**: Fully compatible with the [SpikeInterface](https://github.com/SpikeInterface/spikeinterface) ecosystem.
|
|
42
|
+
|
|
43
|
+
## Installation
|
|
44
|
+
|
|
45
|
+
You can install `catsort` via pip:
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
pip install catsort
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
Or from source:
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
git clone https://github.com/lucasbeziers/CATSort.git
|
|
55
|
+
cd CATSort
|
|
56
|
+
pip install -e .
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
## Quick Start
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
import spikeinterface.extractors as se
|
|
63
|
+
from catsort import run_catsort
|
|
64
|
+
|
|
65
|
+
# Load your recording
|
|
66
|
+
recording = se.read_binary("path_to_data.dat", sampling_frequency=30000, num_channels=384, dtype="int16")
|
|
67
|
+
|
|
68
|
+
# Run CATSort
|
|
69
|
+
sorting = run_catsort(recording)
|
|
70
|
+
|
|
71
|
+
# The result is a SpikeInterface Sorting object
|
|
72
|
+
print(sorting)
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
## Parameters
|
|
76
|
+
|
|
77
|
+
CATSort offers several parameters to fine-tune its behavior:
|
|
78
|
+
|
|
79
|
+
- `scheme`: `'original'` or `'adaptive'`.
|
|
80
|
+
- `'adaptive'` uses temporal collisions to optimize thresholds.
|
|
81
|
+
- `'original'` uses fixed MAD multipliers.
|
|
82
|
+
- `mad_multiplier_amplitude`, `mad_multiplier_width`, `mad_multiplier_energy`: (Default: `7.0`, `10.0`, `15.0`) Used when `scheme='original'`.
|
|
83
|
+
- `detect_threshold`: Spike detection threshold in standard deviations (Default: `5`).
|
|
84
|
+
|
|
85
|
+
## License
|
|
86
|
+
|
|
87
|
+
MIT License. See [LICENSE](LICENSE) for details.
|
catsort-0.1.2/README.md
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# CATSort
|
|
2
|
+
|
|
3
|
+
**CATSort** (Collision-Aware Template-matching Sort) is a robust spike sorter designed to handle overlapping spikes (collisions) with high precision using a specific collision-handling stage before clustering followed by template matching.
|
|
4
|
+
|
|
5
|
+
## Key Features
|
|
6
|
+
|
|
7
|
+
- **Collision Handling**: Automatically identifies and flags collided spikes using multi-criterion feature analysis (amplitude, width, energy).
|
|
8
|
+
- **Template Matching**: Robust spike extraction using template-based matching (including 'wobble' for now).
|
|
9
|
+
- **Flexible Schemes**: Choose between an `adaptive` threshold optimization or an `original` fixed MAD (Median Absolute Deviation) multiplier scheme.
|
|
10
|
+
- **SpikeInterface Integration**: Fully compatible with the [SpikeInterface](https://github.com/SpikeInterface/spikeinterface) ecosystem.
|
|
11
|
+
|
|
12
|
+
## Installation
|
|
13
|
+
|
|
14
|
+
You can install `catsort` via pip:
|
|
15
|
+
|
|
16
|
+
```bash
|
|
17
|
+
pip install catsort
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
Or from source:
|
|
21
|
+
|
|
22
|
+
```bash
|
|
23
|
+
git clone https://github.com/lucasbeziers/CATSort.git
|
|
24
|
+
cd CATSort
|
|
25
|
+
pip install -e .
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
## Quick Start
|
|
29
|
+
|
|
30
|
+
```python
|
|
31
|
+
import spikeinterface.extractors as se
|
|
32
|
+
from catsort import run_catsort
|
|
33
|
+
|
|
34
|
+
# Load your recording
|
|
35
|
+
recording = se.read_binary("path_to_data.dat", sampling_frequency=30000, num_channels=384, dtype="int16")
|
|
36
|
+
|
|
37
|
+
# Run CATSort
|
|
38
|
+
sorting = run_catsort(recording)
|
|
39
|
+
|
|
40
|
+
# The result is a SpikeInterface Sorting object
|
|
41
|
+
print(sorting)
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
## Parameters
|
|
45
|
+
|
|
46
|
+
CATSort offers several parameters to fine-tune its behavior:
|
|
47
|
+
|
|
48
|
+
- `scheme`: `'original'` or `'adaptive'`.
|
|
49
|
+
- `'adaptive'` uses temporal collisions to optimize thresholds.
|
|
50
|
+
- `'original'` uses fixed MAD multipliers.
|
|
51
|
+
- `mad_multiplier_amplitude`, `mad_multiplier_width`, `mad_multiplier_energy`: (Default: `7.0`, `10.0`, `15.0`) Used when `scheme='original'`.
|
|
52
|
+
- `detect_threshold`: Spike detection threshold in standard deviations (Default: `5`).
|
|
53
|
+
|
|
54
|
+
## License
|
|
55
|
+
|
|
56
|
+
MIT License. See [LICENSE](LICENSE) for details.
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "CATSort"
|
|
3
|
+
version = "0.1.2"
|
|
4
|
+
description = "A collision-aware template matching spike sorter."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.9"
|
|
7
|
+
license = { text = "MIT" }
|
|
8
|
+
authors = [
|
|
9
|
+
{ name="Lucas Beziers", email="lucas.beziers.pro@gmail.com" },
|
|
10
|
+
]
|
|
11
|
+
keywords = ["spike sorting", "neuroscience", "electrophysiology", "collisions", "template matching"]
|
|
12
|
+
classifiers = [
|
|
13
|
+
"Development Status :: 4 - Beta",
|
|
14
|
+
"Intended Audience :: Science/Research",
|
|
15
|
+
"License :: OSI Approved :: MIT License",
|
|
16
|
+
"Programming Language :: Python :: 3",
|
|
17
|
+
"Programming Language :: Python :: 3.9",
|
|
18
|
+
"Programming Language :: Python :: 3.10",
|
|
19
|
+
"Programming Language :: Python :: 3.11",
|
|
20
|
+
"Programming Language :: Python :: 3.12",
|
|
21
|
+
"Programming Language :: Python :: 3.13",
|
|
22
|
+
"Topic :: Scientific/Engineering",
|
|
23
|
+
]
|
|
24
|
+
dependencies = [
|
|
25
|
+
"isosplit6",
|
|
26
|
+
"numba",
|
|
27
|
+
"numpy",
|
|
28
|
+
"scikit-learn",
|
|
29
|
+
"scipy",
|
|
30
|
+
"spikeinterface>=0.103.1",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
[project.urls]
|
|
34
|
+
Homepage = "https://github.com/lucasbeziers/CATSort"
|
|
35
|
+
Repository = "https://github.com/lucasbeziers/CATSort"
|
|
36
|
+
Issues = "https://github.com/lucasbeziers/CATSort/issues"
|
|
37
|
+
|
|
38
|
+
[build-system]
|
|
39
|
+
requires = ["setuptools>=61.0"]
|
|
40
|
+
build-backend = "setuptools.build_meta"
|
|
41
|
+
|
|
42
|
+
[tool.setuptools.packages.find]
|
|
43
|
+
where = ["src"]
|
|
44
|
+
|
|
45
|
+
[dependency-groups]
|
|
46
|
+
dev = [
|
|
47
|
+
"ipykernel>=6.31.0",
|
|
48
|
+
"mearec",
|
|
49
|
+
"pandas",
|
|
50
|
+
"pytest",
|
|
51
|
+
"twine",
|
|
52
|
+
]
|
catsort-0.1.2/setup.cfg
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: CATSort
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: A collision-aware template matching spike sorter.
|
|
5
|
+
Author-email: Lucas Beziers <lucas.beziers.pro@gmail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/lucasbeziers/CATSort
|
|
8
|
+
Project-URL: Repository, https://github.com/lucasbeziers/CATSort
|
|
9
|
+
Project-URL: Issues, https://github.com/lucasbeziers/CATSort/issues
|
|
10
|
+
Keywords: spike sorting,neuroscience,electrophysiology,collisions,template matching
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering
|
|
21
|
+
Requires-Python: >=3.9
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
License-File: LICENSE
|
|
24
|
+
Requires-Dist: isosplit6
|
|
25
|
+
Requires-Dist: numba
|
|
26
|
+
Requires-Dist: numpy
|
|
27
|
+
Requires-Dist: scikit-learn
|
|
28
|
+
Requires-Dist: scipy
|
|
29
|
+
Requires-Dist: spikeinterface>=0.103.1
|
|
30
|
+
Dynamic: license-file
|
|
31
|
+
|
|
32
|
+
# CATSort
|
|
33
|
+
|
|
34
|
+
**CATSort** (Collision-Aware Template-matching Sort) is a robust spike sorter designed to handle overlapping spikes (collisions) with high precision using a specific collision-handling stage before clustering followed by template matching.
|
|
35
|
+
|
|
36
|
+
## Key Features
|
|
37
|
+
|
|
38
|
+
- **Collision Handling**: Automatically identifies and flags collided spikes using multi-criterion feature analysis (amplitude, width, energy).
|
|
39
|
+
- **Template Matching**: Robust spike extraction using template-based matching (including 'wobble' for now).
|
|
40
|
+
- **Flexible Schemes**: Choose between an `adaptive` threshold optimization or an `original` fixed MAD (Median Absolute Deviation) multiplier scheme.
|
|
41
|
+
- **SpikeInterface Integration**: Fully compatible with the [SpikeInterface](https://github.com/SpikeInterface/spikeinterface) ecosystem.
|
|
42
|
+
|
|
43
|
+
## Installation
|
|
44
|
+
|
|
45
|
+
You can install `catsort` via pip:
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
pip install catsort
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
Or from source:
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
git clone https://github.com/lucasbeziers/CATSort.git
|
|
55
|
+
cd CATSort
|
|
56
|
+
pip install -e .
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
## Quick Start
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
import spikeinterface.extractors as se
|
|
63
|
+
from catsort import run_catsort
|
|
64
|
+
|
|
65
|
+
# Load your recording
|
|
66
|
+
recording = se.read_binary("path_to_data.dat", sampling_frequency=30000, num_channels=384, dtype="int16")
|
|
67
|
+
|
|
68
|
+
# Run CATSort
|
|
69
|
+
sorting = run_catsort(recording)
|
|
70
|
+
|
|
71
|
+
# The result is a SpikeInterface Sorting object
|
|
72
|
+
print(sorting)
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
## Parameters
|
|
76
|
+
|
|
77
|
+
CATSort offers several parameters to fine-tune its behavior:
|
|
78
|
+
|
|
79
|
+
- `scheme`: `'original'` or `'adaptive'`.
|
|
80
|
+
- `'adaptive'` uses temporal collisions to optimize thresholds.
|
|
81
|
+
- `'original'` uses fixed MAD multipliers.
|
|
82
|
+
- `mad_multiplier_amplitude`, `mad_multiplier_width`, `mad_multiplier_energy`: (Default: `7.0`, `10.0`, `15.0`) Used when `scheme='original'`.
|
|
83
|
+
- `detect_threshold`: Spike detection threshold in standard deviations (Default: `5`).
|
|
84
|
+
|
|
85
|
+
## License
|
|
86
|
+
|
|
87
|
+
MIT License. See [LICENSE](LICENSE) for details.
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
src/CATSort.egg-info/PKG-INFO
|
|
5
|
+
src/CATSort.egg-info/SOURCES.txt
|
|
6
|
+
src/CATSort.egg-info/dependency_links.txt
|
|
7
|
+
src/CATSort.egg-info/requires.txt
|
|
8
|
+
src/CATSort.egg-info/top_level.txt
|
|
9
|
+
src/catsort/__init__.py
|
|
10
|
+
src/catsort/sorter.py
|
|
11
|
+
src/catsort/core/clustering.py
|
|
12
|
+
src/catsort/core/collision.py
|
|
13
|
+
src/catsort/core/utils.py
|
|
14
|
+
tests/test_sorter.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
catsort
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# Adapted from MountainSort5 (https://github.com/flatironinstitute/mountainsort5)
|
|
2
|
+
# Licensed under Apache-2.0
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from sklearn import decomposition
|
|
6
|
+
from isosplit6 import isosplit6
|
|
7
|
+
from scipy.spatial.distance import squareform
|
|
8
|
+
from scipy.cluster.hierarchy import linkage, cut_tree
|
|
9
|
+
from numpy.typing import NDArray
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
def compute_pca_features(X: NDArray, npca: int) -> NDArray:
|
|
13
|
+
L = X.shape[0]
|
|
14
|
+
D = X.shape[1]
|
|
15
|
+
npca_2 = np.minimum(np.minimum(npca, L), D)
|
|
16
|
+
if L == 0 or D == 0:
|
|
17
|
+
return np.zeros((0, npca_2), dtype=np.float32)
|
|
18
|
+
pca = decomposition.PCA(n_components=npca_2)
|
|
19
|
+
return pca.fit_transform(X)
|
|
20
|
+
|
|
21
|
+
def isosplit6_subdivision_method(X: NDArray, npca_per_subdivision: int, inds: Optional[NDArray] = None) -> NDArray:
|
|
22
|
+
if inds is not None:
|
|
23
|
+
X_sub = X[inds]
|
|
24
|
+
else:
|
|
25
|
+
X_sub = X
|
|
26
|
+
|
|
27
|
+
L = X_sub.shape[0]
|
|
28
|
+
if L == 0:
|
|
29
|
+
return np.zeros((0,), dtype=np.int32)
|
|
30
|
+
|
|
31
|
+
features = compute_pca_features(X_sub, npca=npca_per_subdivision)
|
|
32
|
+
labels = isosplit6(features)
|
|
33
|
+
|
|
34
|
+
K = int(np.max(labels)) if len(labels) > 0 else 0
|
|
35
|
+
|
|
36
|
+
if K <= 1:
|
|
37
|
+
return labels
|
|
38
|
+
|
|
39
|
+
centroids = np.zeros((K, X.shape[1]), dtype=np.float32)
|
|
40
|
+
for k in range(1, K + 1):
|
|
41
|
+
centroids[k - 1] = np.median(X_sub[labels == k], axis=0)
|
|
42
|
+
|
|
43
|
+
dists = np.sqrt(np.sum((centroids[:, None, :] - centroids[None, :, :]) ** 2, axis=2))
|
|
44
|
+
dists_condensed = squareform(dists)
|
|
45
|
+
|
|
46
|
+
Z = linkage(dists_condensed, method='single', metric='euclidean')
|
|
47
|
+
clusters0 = cut_tree(Z, n_clusters=2)
|
|
48
|
+
|
|
49
|
+
cluster_inds_1 = np.where(clusters0 == 0)[0] + 1
|
|
50
|
+
cluster_inds_2 = np.where(clusters0 == 1)[0] + 1
|
|
51
|
+
|
|
52
|
+
inds1 = np.where(np.isin(labels, cluster_inds_1))[0]
|
|
53
|
+
inds2 = np.where(np.isin(labels, cluster_inds_2))[0]
|
|
54
|
+
|
|
55
|
+
if inds is not None:
|
|
56
|
+
inds1_b = inds[inds1]
|
|
57
|
+
inds2_b = inds[inds2]
|
|
58
|
+
else:
|
|
59
|
+
inds1_b = inds1
|
|
60
|
+
inds2_b = inds2
|
|
61
|
+
|
|
62
|
+
labels1 = isosplit6_subdivision_method(X, npca_per_subdivision=npca_per_subdivision, inds=inds1_b)
|
|
63
|
+
labels2 = isosplit6_subdivision_method(X, npca_per_subdivision=npca_per_subdivision, inds=inds2_b)
|
|
64
|
+
|
|
65
|
+
K1 = int(np.max(labels1))
|
|
66
|
+
ret_labels = np.zeros(L, dtype=np.int32)
|
|
67
|
+
ret_labels[inds1] = labels1
|
|
68
|
+
ret_labels[inds2] = labels2 + K1
|
|
69
|
+
return ret_labels
|
|
70
|
+
|
|
71
|
+
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from scipy.signal import resample
|
|
3
|
+
from scipy.stats import median_abs_deviation
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def compute_fixed_thresholds(
|
|
7
|
+
features: dict,
|
|
8
|
+
mad_multipliers: dict
|
|
9
|
+
) -> dict:
|
|
10
|
+
"""
|
|
11
|
+
Compute thresholds for collision features using fixed MAD multipliers.
|
|
12
|
+
Threshold = median + MAD_multiplier * MAD
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
features: Dictionary with feature values (amplitude, width, energy)
|
|
16
|
+
mad_multipliers: Dictionary with MAD multipliers for each feature
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Dictionary with thresholds for each feature
|
|
20
|
+
"""
|
|
21
|
+
thresholds = {}
|
|
22
|
+
for criterion, values in features.items():
|
|
23
|
+
if criterion in mad_multipliers:
|
|
24
|
+
median_val = np.median(values)
|
|
25
|
+
mad_val = median_abs_deviation(values)
|
|
26
|
+
thresholds[criterion] = median_val + mad_multipliers[criterion] * mad_val
|
|
27
|
+
else:
|
|
28
|
+
# Default to max value (no flagging) if multiplier not specified
|
|
29
|
+
thresholds[criterion] = np.max(values)
|
|
30
|
+
return thresholds
|
|
31
|
+
|
|
32
|
+
def longest_true_runs(arr: np.ndarray) -> np.ndarray:
|
|
33
|
+
"""
|
|
34
|
+
For each row in a 2D boolean array, finds the longest consecutive run of True values
|
|
35
|
+
and returns a boolean array of the same shape with only that run set to True.
|
|
36
|
+
"""
|
|
37
|
+
out = np.zeros_like(arr, dtype=bool)
|
|
38
|
+
for i, row in enumerate(arr):
|
|
39
|
+
# Find start/end indices of True runs
|
|
40
|
+
padded = np.r_[False, row, False]
|
|
41
|
+
edges = np.flatnonzero(padded[1:] != padded[:-1])
|
|
42
|
+
starts, ends = edges[::2], edges[1::2]
|
|
43
|
+
if len(starts) == 0:
|
|
44
|
+
continue
|
|
45
|
+
lengths = ends - starts
|
|
46
|
+
j = np.argmax(lengths)
|
|
47
|
+
out[i, starts[j]:ends[j]] = True
|
|
48
|
+
return out
|
|
49
|
+
|
|
50
|
+
def compute_collision_features(
|
|
51
|
+
peaks_traces: np.ndarray,
|
|
52
|
+
sampling_frequency: float,
|
|
53
|
+
width_threshold_amplitude: float = 0.5
|
|
54
|
+
) -> dict:
|
|
55
|
+
"""
|
|
56
|
+
Compute collision features (amplitude, width, energy) for each peak.
|
|
57
|
+
"""
|
|
58
|
+
amplitudes = np.abs(peaks_traces).max(axis=1)
|
|
59
|
+
energies = np.sum(peaks_traces**2, axis=1)
|
|
60
|
+
|
|
61
|
+
# Resample for width precision if 1D trace per peak
|
|
62
|
+
if peaks_traces.ndim == 2:
|
|
63
|
+
num_samples = peaks_traces.shape[1]
|
|
64
|
+
resample_factor = 8
|
|
65
|
+
peaks_traces_resampled = resample(peaks_traces, num=num_samples * resample_factor, axis=1)
|
|
66
|
+
# Compute widths on resampled traces
|
|
67
|
+
under_width_threshold = peaks_traces_resampled < -width_threshold_amplitude * amplitudes[:, np.newaxis]
|
|
68
|
+
longest_under = longest_true_runs(under_width_threshold)
|
|
69
|
+
widths_samples = np.sum(longest_under, axis=1) / resample_factor
|
|
70
|
+
else:
|
|
71
|
+
under_width_threshold = peaks_traces < -width_threshold_amplitude * amplitudes[:, np.newaxis]
|
|
72
|
+
widths_samples = np.sum(longest_true_runs(under_width_threshold), axis=1)
|
|
73
|
+
|
|
74
|
+
widths_ms = widths_samples * (1000 / sampling_frequency)
|
|
75
|
+
|
|
76
|
+
return {
|
|
77
|
+
"amplitude": amplitudes,
|
|
78
|
+
"width": widths_ms,
|
|
79
|
+
"energy": energies
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
def detect_temporal_collisions(
|
|
83
|
+
sample_indices: np.ndarray,
|
|
84
|
+
channel_indices: np.ndarray,
|
|
85
|
+
sampling_frequency: float,
|
|
86
|
+
refractory_period_ms: float = 2.0
|
|
87
|
+
) -> np.ndarray:
|
|
88
|
+
"""
|
|
89
|
+
Identify spikes that are too close to each other on the same channel.
|
|
90
|
+
"""
|
|
91
|
+
prev_diffs = np.full_like(sample_indices, np.inf, dtype=float)
|
|
92
|
+
next_diffs = np.full_like(sample_indices, np.inf, dtype=float)
|
|
93
|
+
|
|
94
|
+
for ch in np.unique(channel_indices):
|
|
95
|
+
mask = channel_indices == ch
|
|
96
|
+
idx = np.where(mask)[0]
|
|
97
|
+
samples = sample_indices[mask]
|
|
98
|
+
|
|
99
|
+
# Sort by time within this channel
|
|
100
|
+
order = np.argsort(samples)
|
|
101
|
+
sorted_idx = idx[order]
|
|
102
|
+
sorted_samples = samples[order]
|
|
103
|
+
|
|
104
|
+
# Compute diffs
|
|
105
|
+
if len(sorted_samples) > 1:
|
|
106
|
+
diffs = np.diff(sorted_samples)
|
|
107
|
+
prev_diffs[sorted_idx[1:]] = diffs
|
|
108
|
+
next_diffs[sorted_idx[:-1]] = diffs
|
|
109
|
+
|
|
110
|
+
closest_sample_diff = np.minimum(np.abs(prev_diffs), np.abs(next_diffs))
|
|
111
|
+
closest_ms = closest_sample_diff * 1000 / sampling_frequency
|
|
112
|
+
return closest_ms < refractory_period_ms
|
|
113
|
+
|
|
114
|
+
def optimize_collision_thresholds(
|
|
115
|
+
features: dict,
|
|
116
|
+
temporal_collisions: np.ndarray,
|
|
117
|
+
false_positive_tolerance: float = 0.05
|
|
118
|
+
) -> dict:
|
|
119
|
+
"""
|
|
120
|
+
Find optimal thresholds for collision features based on temporal collisions.
|
|
121
|
+
"""
|
|
122
|
+
non_collision_count = (~temporal_collisions).sum()
|
|
123
|
+
max_fp = false_positive_tolerance * non_collision_count
|
|
124
|
+
|
|
125
|
+
optimized_thresholds = {}
|
|
126
|
+
|
|
127
|
+
for criterion, values in features.items():
|
|
128
|
+
unique_vals = np.sort(np.unique(values))
|
|
129
|
+
best_thr = unique_vals[-1] # Default to max (nothing flagged)
|
|
130
|
+
best_tp = -1
|
|
131
|
+
|
|
132
|
+
# Binary search or scan for best threshold
|
|
133
|
+
# For simplicity, we scan a subset of values if too many
|
|
134
|
+
if len(unique_vals) > 1000:
|
|
135
|
+
candidates = unique_vals[::len(unique_vals)//1000]
|
|
136
|
+
else:
|
|
137
|
+
candidates = unique_vals
|
|
138
|
+
|
|
139
|
+
for thr in candidates:
|
|
140
|
+
flagged = values > thr
|
|
141
|
+
tp = np.count_nonzero(flagged & temporal_collisions)
|
|
142
|
+
fp = np.count_nonzero(flagged & ~temporal_collisions)
|
|
143
|
+
|
|
144
|
+
if fp <= max_fp:
|
|
145
|
+
if tp > best_tp:
|
|
146
|
+
best_tp = tp
|
|
147
|
+
best_thr = thr
|
|
148
|
+
|
|
149
|
+
optimized_thresholds[criterion] = best_thr
|
|
150
|
+
|
|
151
|
+
return optimized_thresholds
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
def get_snippet(
|
|
4
|
+
traces: np.ndarray,
|
|
5
|
+
index: int,
|
|
6
|
+
n_before: int, n_after: int
|
|
7
|
+
) -> np.ndarray:
|
|
8
|
+
"""
|
|
9
|
+
Get a snippet of the traces around a specific index.
|
|
10
|
+
Fill the snippet with zeros if the index is out of bounds.
|
|
11
|
+
"""
|
|
12
|
+
n_channels = traces.shape[1]
|
|
13
|
+
snippet = np.zeros((n_before + n_after, n_channels))
|
|
14
|
+
|
|
15
|
+
start = index - n_before
|
|
16
|
+
end = index + n_after
|
|
17
|
+
|
|
18
|
+
# If the snippet is fully within bounds, extract it directly
|
|
19
|
+
if start >= 0 and end <= traces.shape[0]:
|
|
20
|
+
snippet = traces[start:end, :]
|
|
21
|
+
|
|
22
|
+
# If the snippet is partially out of bounds, fill with zeros where necessary
|
|
23
|
+
else:
|
|
24
|
+
valid_start = max(start, 0)
|
|
25
|
+
valid_end = min(end, traces.shape[0])
|
|
26
|
+
insert_start = valid_start - start
|
|
27
|
+
insert_end = insert_start + (valid_end - valid_start)
|
|
28
|
+
snippet[insert_start:insert_end, :] = traces[valid_start:valid_end, :]
|
|
29
|
+
|
|
30
|
+
return snippet # shape (n_before+n_after, n_channels)
|
|
31
|
+
|
|
32
|
+
def get_peaks_traces_all_channels(
|
|
33
|
+
peaks: np.ndarray,
|
|
34
|
+
traces: np.ndarray,
|
|
35
|
+
n_before: int, n_after: int
|
|
36
|
+
) -> np.ndarray:
|
|
37
|
+
"""
|
|
38
|
+
Extract snippets of traces around detected peaks.
|
|
39
|
+
|
|
40
|
+
Output shape: (n_peaks, n_before + n_after, n_channels)
|
|
41
|
+
"""
|
|
42
|
+
n_channels = traces.shape[1]
|
|
43
|
+
complete_peaks = np.zeros((len(peaks), n_before+n_after, n_channels))
|
|
44
|
+
|
|
45
|
+
for i, peak in enumerate(peaks):
|
|
46
|
+
sample_index = peak['sample_index']
|
|
47
|
+
snippet = get_snippet(traces, sample_index, n_before, n_after)
|
|
48
|
+
complete_peaks[i] = snippet
|
|
49
|
+
return complete_peaks
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_peaks_traces_best_channel(
|
|
53
|
+
peaks: np.ndarray,
|
|
54
|
+
traces: np.ndarray,
|
|
55
|
+
n_before: int, n_after: int
|
|
56
|
+
) -> np.ndarray:
|
|
57
|
+
"""
|
|
58
|
+
Extract snippets of traces around detected peaks.
|
|
59
|
+
Keep only the channel with the highest amplitude
|
|
60
|
+
|
|
61
|
+
Output shape: (n_peaks, n_before + n_after)
|
|
62
|
+
"""
|
|
63
|
+
complete_peaks = np.zeros((len(peaks), n_before+n_after))
|
|
64
|
+
|
|
65
|
+
for i, peak in enumerate(peaks):
|
|
66
|
+
sample_index = peak['sample_index']
|
|
67
|
+
snippet = get_snippet(traces, sample_index, n_before, n_after) # shape (n_before+n_after, n_channels)
|
|
68
|
+
best_channel = np.argmax(np.abs(snippet).max(axis=0))
|
|
69
|
+
complete_peaks[i] = snippet[:, best_channel] # shape (n_before+n_after)
|
|
70
|
+
return complete_peaks
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from spikeinterface.core import BaseRecording, NumpySorting, SortingAnalyzer, Templates, ChannelSparsity, create_sorting_analyzer
|
|
4
|
+
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
|
|
5
|
+
import spikeinterface.sortingcomponents.matching as sm
|
|
6
|
+
from sklearn.decomposition import PCA
|
|
7
|
+
|
|
8
|
+
from catsort.core.utils import get_peaks_traces_best_channel, get_peaks_traces_all_channels
|
|
9
|
+
from catsort.core.collision import (
|
|
10
|
+
compute_collision_features,
|
|
11
|
+
detect_temporal_collisions,
|
|
12
|
+
optimize_collision_thresholds,
|
|
13
|
+
compute_fixed_thresholds
|
|
14
|
+
)
|
|
15
|
+
from catsort.core.clustering import isosplit6_subdivision_method
|
|
16
|
+
|
|
17
|
+
DEFAULT_PARAMS = {
|
|
18
|
+
# Detection
|
|
19
|
+
'detect_threshold': 5,
|
|
20
|
+
'exclude_sweep_ms': 0.2,
|
|
21
|
+
'radius_um': 100,
|
|
22
|
+
|
|
23
|
+
# Collision analysis
|
|
24
|
+
'ms_before_spike_detected': 1.0,
|
|
25
|
+
'ms_after_spike_detected': 1.0,
|
|
26
|
+
'refractory_period': 2.0,
|
|
27
|
+
'scheme': 'original', # 'original' or 'adaptive'
|
|
28
|
+
|
|
29
|
+
# Original scheme parameters
|
|
30
|
+
'mad_multiplier_amplitude': 7.0,
|
|
31
|
+
'mad_multiplier_width': 10.0,
|
|
32
|
+
'mad_multiplier_energy': 15.0,
|
|
33
|
+
# Adaptive scheme parameters
|
|
34
|
+
'false_positive_tolerance': 0.05, # Used when scheme='adaptive'
|
|
35
|
+
|
|
36
|
+
# Clustering
|
|
37
|
+
'n_pca_components': 10,
|
|
38
|
+
'npca_per_subdivision': 10,
|
|
39
|
+
|
|
40
|
+
# Template matching
|
|
41
|
+
'tm_method': 'wobble', # 'wobble' only for now
|
|
42
|
+
|
|
43
|
+
# Template matching (Wobble)
|
|
44
|
+
'threshold_wobble': 5000,
|
|
45
|
+
'jitter_factor_wobble': 24,
|
|
46
|
+
'refractory_period_ms_wobble': 2.0,
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
def get_sorting_analyzer_with_computations(
|
|
50
|
+
sorting: NumpySorting,
|
|
51
|
+
recording: BaseRecording,
|
|
52
|
+
ms_before: float, ms_after: float
|
|
53
|
+
) -> SortingAnalyzer:
|
|
54
|
+
sorting_analyzer = create_sorting_analyzer(sorting, recording, return_in_uV=True)
|
|
55
|
+
sorting_analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=np.inf)
|
|
56
|
+
sorting_analyzer.compute("waveforms", ms_before=ms_before, ms_after=ms_after)
|
|
57
|
+
sorting_analyzer.compute("templates", operators=["average", "median", "std"])
|
|
58
|
+
return sorting_analyzer
|
|
59
|
+
|
|
60
|
+
def run_catsort(recording: BaseRecording, params: Optional[dict] = None) -> NumpySorting:
|
|
61
|
+
"""
|
|
62
|
+
Main entry point for Catsort (Collision Aware Template-matching sort).
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
recording: spikeinterface recording object
|
|
66
|
+
params: dictionary of parameters (optional)
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
sorting: spikeinterface sorting object
|
|
70
|
+
"""
|
|
71
|
+
if params is None:
|
|
72
|
+
params = DEFAULT_PARAMS
|
|
73
|
+
else:
|
|
74
|
+
# Merge with defaults
|
|
75
|
+
full_params = DEFAULT_PARAMS.copy()
|
|
76
|
+
full_params.update(params)
|
|
77
|
+
params = full_params
|
|
78
|
+
|
|
79
|
+
print("Step 1: Detecting spikes...")
|
|
80
|
+
peaks_detected = detect_peaks(
|
|
81
|
+
recording=recording,
|
|
82
|
+
method="locally_exclusive",
|
|
83
|
+
method_kwargs={
|
|
84
|
+
"peak_sign": "neg",
|
|
85
|
+
"detect_threshold": params['detect_threshold'],
|
|
86
|
+
"exclude_sweep_ms": params['exclude_sweep_ms'],
|
|
87
|
+
"radius_um": params['radius_um'],
|
|
88
|
+
},
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
sampling_freq = recording.get_sampling_frequency()
|
|
92
|
+
traces = recording.get_traces()
|
|
93
|
+
|
|
94
|
+
n_before = int(sampling_freq * params['ms_before_spike_detected'] * 0.001)
|
|
95
|
+
n_after = int(sampling_freq * params['ms_after_spike_detected'] * 0.001)
|
|
96
|
+
|
|
97
|
+
print("Step 2: Collision handling...")
|
|
98
|
+
# Extract best channel traces for collision feature computation
|
|
99
|
+
traces_best = get_peaks_traces_best_channel(peaks_detected, traces, n_before, n_after)
|
|
100
|
+
|
|
101
|
+
# Temporal collisions
|
|
102
|
+
too_close = detect_temporal_collisions(
|
|
103
|
+
peaks_detected['sample_index'],
|
|
104
|
+
peaks_detected['channel_index'],
|
|
105
|
+
sampling_freq,
|
|
106
|
+
params['refractory_period']
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Compute features and thresholds based on scheme
|
|
110
|
+
collision_features = compute_collision_features(traces_best, sampling_freq)
|
|
111
|
+
|
|
112
|
+
if params['scheme'] == 'adaptive':
|
|
113
|
+
thresholds = optimize_collision_thresholds(
|
|
114
|
+
collision_features,
|
|
115
|
+
too_close,
|
|
116
|
+
params['false_positive_tolerance']
|
|
117
|
+
)
|
|
118
|
+
elif params['scheme'] == 'original':
|
|
119
|
+
mad_multipliers = {
|
|
120
|
+
'amplitude': params['mad_multiplier_amplitude'],
|
|
121
|
+
'width': params['mad_multiplier_width'],
|
|
122
|
+
'energy': params['mad_multiplier_energy']
|
|
123
|
+
}
|
|
124
|
+
thresholds = compute_fixed_thresholds(collision_features, mad_multipliers)
|
|
125
|
+
else:
|
|
126
|
+
raise ValueError(f"Unknown scheme: {params['scheme']}. Must be 'adaptive' or 'original'.")
|
|
127
|
+
|
|
128
|
+
print(f" Scheme: {params['scheme']}")
|
|
129
|
+
|
|
130
|
+
# Flag collisions
|
|
131
|
+
is_collision = too_close.copy()
|
|
132
|
+
print(f" Temporal collisions: {np.sum(too_close)}")
|
|
133
|
+
for crit, val in collision_features.items():
|
|
134
|
+
flagged_by_crit = val > thresholds[crit]
|
|
135
|
+
flagged_by_crit_not_too_close = flagged_by_crit & ~too_close
|
|
136
|
+
print(f" {crit}: {np.sum(flagged_by_crit_not_too_close)} additional collisions")
|
|
137
|
+
is_collision |= flagged_by_crit
|
|
138
|
+
|
|
139
|
+
print(f" Total flagged: {np.sum(is_collision)} collisions out of {len(peaks_detected)} spikes")
|
|
140
|
+
|
|
141
|
+
print("Step 3: Clustering non-collided spikes...")
|
|
142
|
+
mask_not_collided = ~is_collision
|
|
143
|
+
traces_all_not_collided = get_peaks_traces_all_channels(peaks_detected[mask_not_collided], traces, n_before, n_after)
|
|
144
|
+
|
|
145
|
+
# PCA and Clustering
|
|
146
|
+
num_spikes, num_samples, num_channels = traces_all_not_collided.shape
|
|
147
|
+
concatenated = traces_all_not_collided.reshape(num_spikes, -1)
|
|
148
|
+
pca = PCA(n_components=params['n_pca_components'])
|
|
149
|
+
features_not_collided = pca.fit_transform(concatenated)
|
|
150
|
+
|
|
151
|
+
labels_not_collided = isosplit6_subdivision_method(
|
|
152
|
+
features_not_collided,
|
|
153
|
+
npca_per_subdivision=params['npca_per_subdivision']
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Create sorting with clusters for template computation
|
|
157
|
+
samples_clean = peaks_detected[mask_not_collided]['sample_index']
|
|
158
|
+
labels_clean = labels_not_collided
|
|
159
|
+
|
|
160
|
+
sorting_clean = NumpySorting.from_samples_and_labels(
|
|
161
|
+
samples_list=[samples_clean],
|
|
162
|
+
labels_list=[labels_clean],
|
|
163
|
+
sampling_frequency=sampling_freq
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
print("Step 4: Template Matching...")
|
|
167
|
+
# Compute templates from clean clusters
|
|
168
|
+
analyzer = get_sorting_analyzer_with_computations(
|
|
169
|
+
sorting_clean, recording,
|
|
170
|
+
params['ms_before_spike_detected'],
|
|
171
|
+
params['ms_after_spike_detected']
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
templates_ext = analyzer.get_extension('templates')
|
|
175
|
+
sparsity = ChannelSparsity.create_dense(analyzer)
|
|
176
|
+
|
|
177
|
+
templates = Templates(
|
|
178
|
+
templates_array=templates_ext.data['average'],
|
|
179
|
+
sampling_frequency=sampling_freq,
|
|
180
|
+
nbefore=templates_ext.nbefore,
|
|
181
|
+
is_in_uV=True,
|
|
182
|
+
sparsity_mask=sparsity.mask,
|
|
183
|
+
channel_ids=analyzer.channel_ids,
|
|
184
|
+
unit_ids=analyzer.unit_ids,
|
|
185
|
+
probe=analyzer.get_probe()
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
if params['tm_method'] == 'wobble':
|
|
189
|
+
spikes_tm = sm.find_spikes_from_templates(
|
|
190
|
+
recording=recording,
|
|
191
|
+
templates=templates,
|
|
192
|
+
method='wobble',
|
|
193
|
+
|
|
194
|
+
method_kwargs={
|
|
195
|
+
"parameters": {
|
|
196
|
+
"threshold": params['threshold_wobble'],
|
|
197
|
+
"jitter_factor": params['jitter_factor_wobble'],
|
|
198
|
+
"refractory_period_frames": int(sampling_freq * params['refractory_period_ms_wobble'] * 0.001),
|
|
199
|
+
"scale_amplitudes": True
|
|
200
|
+
},
|
|
201
|
+
}
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
raise ValueError(f"Unknown template matching method: {params['tm_method']}")
|
|
205
|
+
|
|
206
|
+
final_sorting = NumpySorting.from_samples_and_labels(
|
|
207
|
+
samples_list=[spikes_tm['sample_index']],
|
|
208
|
+
labels_list=[spikes_tm['cluster_index']],
|
|
209
|
+
sampling_frequency=sampling_freq
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
return final_sorting
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from spikeinterface.core import generate_recording, generate_sorting, generate_snippets
|
|
2
|
+
from spikeinterface.extractors import toy_example
|
|
3
|
+
from spikeinterface.comparison import compare_sorter_to_ground_truth
|
|
4
|
+
|
|
5
|
+
from catsort import sorter
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_tetrode():
|
|
9
|
+
recording, sorting_gt = toy_example(
|
|
10
|
+
duration=10,
|
|
11
|
+
num_channels=4,
|
|
12
|
+
num_units=5,
|
|
13
|
+
sampling_frequency=30000,
|
|
14
|
+
num_segments=1
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
default_parameters = sorter.DEFAULT_PARAMS
|
|
18
|
+
sorting = sorter.run_catsort(recording, params=default_parameters)
|
|
19
|
+
assert True
|
|
20
|
+
|
|
21
|
+
def test_performance_tetrode():
|
|
22
|
+
recording, sorting_gt = toy_example(
|
|
23
|
+
duration=10,
|
|
24
|
+
num_channels=4,
|
|
25
|
+
num_units=5,
|
|
26
|
+
sampling_frequency=30000,
|
|
27
|
+
num_segments=1
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
default_parameters = sorter.DEFAULT_PARAMS
|
|
31
|
+
sorting = sorter.run_catsort(recording, params=default_parameters)
|
|
32
|
+
comparison = compare_sorter_to_ground_truth(sorting, sorting_gt, match_score=0.01)
|
|
33
|
+
perf = comparison.get_performance()
|
|
34
|
+
|
|
35
|
+
assert perf['accuracy'].mean() > 0.5
|
|
36
|
+
assert perf['recall'].mean() > 0.5
|
|
37
|
+
assert perf['precision'].mean() > 0.5
|
|
38
|
+
|
|
39
|
+
def test_monotrode():
|
|
40
|
+
recording, sorting_gt = toy_example(
|
|
41
|
+
duration=10,
|
|
42
|
+
num_channels=1,
|
|
43
|
+
num_units=5,
|
|
44
|
+
sampling_frequency=30000,
|
|
45
|
+
num_segments=1
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
default_parameters = sorter.DEFAULT_PARAMS
|
|
49
|
+
sorting = sorter.run_catsort(recording, params=default_parameters)
|
|
50
|
+
assert True
|
|
51
|
+
|
|
52
|
+
def test_performance_monotrode():
|
|
53
|
+
recording, sorting_gt = toy_example(
|
|
54
|
+
duration=10,
|
|
55
|
+
num_channels=1,
|
|
56
|
+
num_units=5,
|
|
57
|
+
sampling_frequency=30000,
|
|
58
|
+
num_segments=1
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
default_parameters = sorter.DEFAULT_PARAMS
|
|
62
|
+
sorting = sorter.run_catsort(recording, params=default_parameters)
|
|
63
|
+
comparison = compare_sorter_to_ground_truth(sorting, sorting_gt, match_score=0.01)
|
|
64
|
+
perf = comparison.get_performance()
|
|
65
|
+
|
|
66
|
+
assert perf['accuracy'].mean() > 0.25
|
|
67
|
+
assert perf['recall'].mean() > 0.25
|
|
68
|
+
assert perf['precision'].mean() > 0.25
|
|
69
|
+
|
|
70
|
+
def test_scheme_adaptive():
|
|
71
|
+
recording, sorting_gt = toy_example(
|
|
72
|
+
duration=10,
|
|
73
|
+
num_channels=4,
|
|
74
|
+
num_units=5,
|
|
75
|
+
sampling_frequency=30000,
|
|
76
|
+
num_segments=1
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
adaptive_parameters = sorter.DEFAULT_PARAMS.copy()
|
|
80
|
+
adaptive_parameters['scheme'] = 'adaptive'
|
|
81
|
+
sorting = sorter.run_catsort(recording, params=adaptive_parameters)
|
|
82
|
+
assert True
|