canns 0.14.3__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/asa/__init__.py +56 -21
- canns/analyzer/data/asa/coho.py +21 -0
- canns/analyzer/data/asa/cohomap.py +453 -0
- canns/analyzer/data/asa/cohomap_vectors.py +365 -0
- canns/analyzer/data/asa/cohospace.py +155 -1165
- canns/analyzer/data/asa/cohospace_phase_centers.py +119 -0
- canns/analyzer/data/asa/cohospace_scatter.py +1115 -0
- canns/analyzer/data/asa/embedding.py +5 -7
- canns/analyzer/data/asa/fr.py +1 -8
- canns/analyzer/data/asa/path.py +70 -0
- canns/analyzer/data/asa/plotting.py +5 -30
- canns/analyzer/data/asa/utils.py +160 -0
- canns/analyzer/data/cell_classification/__init__.py +10 -0
- canns/analyzer/data/cell_classification/core/__init__.py +4 -0
- canns/analyzer/data/cell_classification/core/btn.py +272 -0
- canns/analyzer/data/cell_classification/visualization/__init__.py +3 -0
- canns/analyzer/data/cell_classification/visualization/btn_plots.py +241 -0
- canns/analyzer/visualization/__init__.py +2 -0
- canns/analyzer/visualization/core/config.py +20 -0
- canns/analyzer/visualization/theta_sweep_plots.py +142 -0
- canns/pipeline/asa/runner.py +19 -19
- canns/pipeline/asa_gui/__init__.py +5 -3
- canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +3 -1
- canns/pipeline/asa_gui/core/runner.py +23 -23
- canns/pipeline/asa_gui/views/pages/preprocess_page.py +7 -12
- {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/METADATA +1 -1
- {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/RECORD +30 -23
- canns/analyzer/data/asa/filters.py +0 -208
- {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/WHEEL +0 -0
- {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/entry_points.txt +0 -0
- {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,7 +6,6 @@ import numpy as np
|
|
|
6
6
|
from scipy.ndimage import gaussian_filter1d
|
|
7
7
|
|
|
8
8
|
from .config import DataLoadError, ProcessingError, SpikeEmbeddingConfig
|
|
9
|
-
from .filters import _gaussian_filter1d
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None, **kwargs):
|
|
@@ -31,10 +30,9 @@ def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None,
|
|
|
31
30
|
Returns
|
|
32
31
|
-------
|
|
33
32
|
tuple
|
|
34
|
-
``(spikes_bin, xx, yy, tt)``
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
otherwise ``None``.
|
|
33
|
+
``(spikes_bin, xx, yy, tt)``. ``spikes_bin`` is a (T, N) binned spike matrix.
|
|
34
|
+
``xx``, ``yy``, ``tt`` are position/time arrays when ``speed_filter=True``,
|
|
35
|
+
otherwise ``None``.
|
|
38
36
|
|
|
39
37
|
Examples
|
|
40
38
|
--------
|
|
@@ -263,8 +261,8 @@ def _load_pos(t, x, y, res=100000, dt=1000):
|
|
|
263
261
|
xx = np.interp(tt, t, x)
|
|
264
262
|
yy = np.interp(tt, t, y)
|
|
265
263
|
|
|
266
|
-
xxs =
|
|
267
|
-
yys =
|
|
264
|
+
xxs = gaussian_filter1d(xx - np.min(xx), sigma=100)
|
|
265
|
+
yys = gaussian_filter1d(yy - np.min(yy), sigma=100)
|
|
268
266
|
dx = (xxs[1:] - xxs[:-1]) * 100
|
|
269
267
|
dy = (yys[1:] - yys[:-1]) * 100
|
|
270
268
|
speed = np.sqrt(dx**2 + dy**2) / 0.01
|
canns/analyzer/data/asa/fr.py
CHANGED
|
@@ -1,18 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from dataclasses import dataclass
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
|
|
8
7
|
from ...visualization.core import PlotConfig, finalize_figure
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
12
|
-
if save_path:
|
|
13
|
-
parent = os.path.dirname(save_path)
|
|
14
|
-
if parent:
|
|
15
|
-
os.makedirs(parent, exist_ok=True)
|
|
8
|
+
from .utils import _ensure_parent_dir
|
|
16
9
|
|
|
17
10
|
|
|
18
11
|
def _slice_range(r: tuple[int, int] | None, length: int) -> slice:
|
canns/analyzer/data/asa/path.py
CHANGED
|
@@ -240,6 +240,76 @@ def interp_coords_to_full_1d(idx_map: np.ndarray, coords1: np.ndarray, T_full: i
|
|
|
240
240
|
return np.mod(out, 2 * np.pi)[:, None]
|
|
241
241
|
|
|
242
242
|
|
|
243
|
+
def _align_activity_to_coords(
|
|
244
|
+
coords: np.ndarray,
|
|
245
|
+
activity: np.ndarray,
|
|
246
|
+
times: np.ndarray | None = None,
|
|
247
|
+
*,
|
|
248
|
+
label: str = "activity",
|
|
249
|
+
auto_filter: bool = True,
|
|
250
|
+
) -> np.ndarray:
|
|
251
|
+
"""
|
|
252
|
+
Align activity to coords by optional time indices and validate lengths.
|
|
253
|
+
|
|
254
|
+
Parameters
|
|
255
|
+
----------
|
|
256
|
+
coords : ndarray
|
|
257
|
+
Decoded coordinates array.
|
|
258
|
+
activity : ndarray
|
|
259
|
+
Activity matrix (firing rate or spikes).
|
|
260
|
+
times : ndarray, optional
|
|
261
|
+
Optional time indices to align activity to coords when coords are computed
|
|
262
|
+
on a subset of timepoints.
|
|
263
|
+
label : str
|
|
264
|
+
Label for error messages (default: "activity").
|
|
265
|
+
auto_filter : bool
|
|
266
|
+
If True and lengths mismatch, auto-filter activity with activity>0 to mimic
|
|
267
|
+
decode filtering.
|
|
268
|
+
|
|
269
|
+
Returns
|
|
270
|
+
-------
|
|
271
|
+
ndarray
|
|
272
|
+
Aligned activity array.
|
|
273
|
+
|
|
274
|
+
Raises
|
|
275
|
+
------
|
|
276
|
+
ValueError
|
|
277
|
+
If activity length doesn't match coords length after alignment attempts.
|
|
278
|
+
"""
|
|
279
|
+
coords = np.asarray(coords)
|
|
280
|
+
activity = np.asarray(activity)
|
|
281
|
+
|
|
282
|
+
if times is not None:
|
|
283
|
+
times = np.asarray(times)
|
|
284
|
+
try:
|
|
285
|
+
activity = activity[times]
|
|
286
|
+
except Exception as exc:
|
|
287
|
+
raise ValueError(
|
|
288
|
+
f"Failed to index {label} with `times`. Ensure `times` indexes the original time axis."
|
|
289
|
+
) from exc
|
|
290
|
+
|
|
291
|
+
if activity.shape[0] != coords.shape[0]:
|
|
292
|
+
# Try to reproduce decode's zero-spike filtering if lengths mismatch.
|
|
293
|
+
if auto_filter and times is None and activity.ndim == 2:
|
|
294
|
+
mask = np.sum(activity > 0, axis=1) >= 1
|
|
295
|
+
if mask.sum() == coords.shape[0]:
|
|
296
|
+
activity = activity[mask]
|
|
297
|
+
else:
|
|
298
|
+
raise ValueError(
|
|
299
|
+
f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
|
|
300
|
+
"If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
|
|
301
|
+
"`times=decoding['times']` or slice the activity accordingly."
|
|
302
|
+
)
|
|
303
|
+
else:
|
|
304
|
+
raise ValueError(
|
|
305
|
+
f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
|
|
306
|
+
"If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
|
|
307
|
+
"`times=decoding['times']` or slice the activity accordingly."
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
return activity
|
|
311
|
+
|
|
312
|
+
|
|
243
313
|
def align_coords_to_position_2d(
|
|
244
314
|
t_full: np.ndarray,
|
|
245
315
|
x_full: np.ndarray,
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from typing import Any
|
|
5
4
|
|
|
6
5
|
import matplotlib.pyplot as plt
|
|
@@ -24,31 +23,7 @@ from ...visualization.core import (
|
|
|
24
23
|
from ...visualization.core.jupyter_utils import display_animation_in_jupyter, is_jupyter_environment
|
|
25
24
|
from .config import CANN2DPlotConfig, ProcessingError, SpikeEmbeddingConfig
|
|
26
25
|
from .embedding import embed_spike_trains
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def _ensure_plot_config(
|
|
30
|
-
config: PlotConfig | None,
|
|
31
|
-
factory,
|
|
32
|
-
*,
|
|
33
|
-
kwargs: dict[str, Any] | None = None,
|
|
34
|
-
**defaults: Any,
|
|
35
|
-
) -> PlotConfig:
|
|
36
|
-
if config is None:
|
|
37
|
-
defaults.update({"kwargs": kwargs or {}})
|
|
38
|
-
return factory(**defaults)
|
|
39
|
-
|
|
40
|
-
if kwargs:
|
|
41
|
-
config_kwargs = config.kwargs or {}
|
|
42
|
-
config_kwargs.update(kwargs)
|
|
43
|
-
config.kwargs = config_kwargs
|
|
44
|
-
return config
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
48
|
-
if save_path:
|
|
49
|
-
parent = os.path.dirname(save_path)
|
|
50
|
-
if parent:
|
|
51
|
-
os.makedirs(parent, exist_ok=True)
|
|
26
|
+
from .utils import _ensure_parent_dir, _ensure_plot_config
|
|
52
27
|
|
|
53
28
|
|
|
54
29
|
def _render_torus_frame(frame_index: int, frame_data: dict[str, Any]) -> np.ndarray:
|
|
@@ -403,7 +378,7 @@ def plot_path_compare_1d(
|
|
|
403
378
|
return fig, axes
|
|
404
379
|
|
|
405
380
|
|
|
406
|
-
def
|
|
381
|
+
def plot_cohomap_scatter(
|
|
407
382
|
decoding_result: dict[str, Any],
|
|
408
383
|
position_data: dict[str, Any],
|
|
409
384
|
config: PlotConfig | None = None,
|
|
@@ -453,7 +428,7 @@ def plot_cohomap(
|
|
|
453
428
|
>>> # Decode coordinates
|
|
454
429
|
>>> decoding = decode_circular_coordinates(persistence_result, spike_data)
|
|
455
430
|
>>> # Visualize with trajectory data
|
|
456
|
-
>>> fig =
|
|
431
|
+
>>> fig = plot_cohomap_scatter(
|
|
457
432
|
... decoding,
|
|
458
433
|
... position_data={'x': xx, 'y': yy},
|
|
459
434
|
... save_path='cohomap.png',
|
|
@@ -518,7 +493,7 @@ def plot_cohomap(
|
|
|
518
493
|
return fig
|
|
519
494
|
|
|
520
495
|
|
|
521
|
-
def
|
|
496
|
+
def plot_cohomap_scatter_multi(
|
|
522
497
|
decoding_result: dict,
|
|
523
498
|
position_data: dict,
|
|
524
499
|
config: PlotConfig | None = None,
|
|
@@ -560,7 +535,7 @@ def plot_cohomap_multi(
|
|
|
560
535
|
|
|
561
536
|
Examples
|
|
562
537
|
--------
|
|
563
|
-
>>> fig =
|
|
538
|
+
>>> fig = plot_cohomap_scatter_multi(decoding, {"x": xx, "y": yy}, show=False) # doctest: +SKIP
|
|
564
539
|
"""
|
|
565
540
|
config = _ensure_plot_config(
|
|
566
541
|
config,
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""Shared utility functions for ASA (Attractor State Analysis) modules."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from scipy.ndimage import gaussian_filter
|
|
10
|
+
|
|
11
|
+
from ...visualization.core import PlotConfig
|
|
12
|
+
from .path import find_coords_matrix, find_times_box
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _ensure_plot_config(
|
|
16
|
+
config: PlotConfig | None,
|
|
17
|
+
factory,
|
|
18
|
+
*args,
|
|
19
|
+
kwargs: dict | None = None,
|
|
20
|
+
**defaults,
|
|
21
|
+
) -> PlotConfig:
|
|
22
|
+
"""Ensure a PlotConfig exists, creating one from factory if needed.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
config: Optional existing PlotConfig.
|
|
26
|
+
factory: Factory function to create PlotConfig if config is None.
|
|
27
|
+
*args: Positional arguments for factory.
|
|
28
|
+
kwargs: Optional dict to merge into config.kwargs.
|
|
29
|
+
**defaults: Keyword arguments for factory.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
PlotConfig instance.
|
|
33
|
+
"""
|
|
34
|
+
if config is None:
|
|
35
|
+
if kwargs:
|
|
36
|
+
defaults.update({"kwargs": kwargs})
|
|
37
|
+
return factory(*args, **defaults)
|
|
38
|
+
|
|
39
|
+
# If config exists and kwargs provided, merge them
|
|
40
|
+
if kwargs:
|
|
41
|
+
config_kwargs = config.kwargs or {}
|
|
42
|
+
config_kwargs.update(kwargs)
|
|
43
|
+
config.kwargs = config_kwargs
|
|
44
|
+
return config
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
48
|
+
"""Create parent directory for save_path if it doesn't exist.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
save_path: Optional file path. If provided, creates parent directory.
|
|
52
|
+
"""
|
|
53
|
+
if save_path:
|
|
54
|
+
parent = os.path.dirname(save_path)
|
|
55
|
+
if parent:
|
|
56
|
+
os.makedirs(parent, exist_ok=True)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _circmean(x: np.ndarray) -> float:
|
|
60
|
+
"""Compute circular mean of angles.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
x: Array of angles in radians.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Circular mean in radians.
|
|
67
|
+
"""
|
|
68
|
+
return float(np.arctan2(np.mean(np.sin(x)), np.mean(np.cos(x))))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _smooth_circular_map(
|
|
72
|
+
mtot: np.ndarray,
|
|
73
|
+
smooth_sigma: float,
|
|
74
|
+
*,
|
|
75
|
+
fill_nan: bool = False,
|
|
76
|
+
fill_sigma: float | None = None,
|
|
77
|
+
fill_min_weight: float = 1e-3,
|
|
78
|
+
) -> np.ndarray:
|
|
79
|
+
"""Smooth a circular phase map using Gaussian filtering in sin/cos space.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
mtot: Phase map array (angles in radians).
|
|
83
|
+
smooth_sigma: Gaussian smoothing sigma.
|
|
84
|
+
fill_nan: Whether to fill NaN values using weighted interpolation.
|
|
85
|
+
fill_sigma: Sigma for NaN filling (defaults to smooth_sigma).
|
|
86
|
+
fill_min_weight: Minimum weight threshold for valid interpolation.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Smoothed phase map.
|
|
90
|
+
"""
|
|
91
|
+
mtot = np.asarray(mtot, dtype=float)
|
|
92
|
+
nans = np.isnan(mtot)
|
|
93
|
+
mask = (~nans).astype(float)
|
|
94
|
+
sintot = np.sin(mtot)
|
|
95
|
+
costot = np.cos(mtot)
|
|
96
|
+
sintot[nans] = 0.0
|
|
97
|
+
costot[nans] = 0.0
|
|
98
|
+
|
|
99
|
+
if fill_nan:
|
|
100
|
+
if fill_sigma is None:
|
|
101
|
+
fill_sigma = smooth_sigma if smooth_sigma and smooth_sigma > 0 else 1.0
|
|
102
|
+
weight = gaussian_filter(mask, fill_sigma)
|
|
103
|
+
sintot = gaussian_filter(sintot * mask, fill_sigma)
|
|
104
|
+
costot = gaussian_filter(costot * mask, fill_sigma)
|
|
105
|
+
min_weight = max(float(fill_min_weight), 0.0)
|
|
106
|
+
valid = weight > min_weight
|
|
107
|
+
sintot = np.divide(sintot, weight, out=np.zeros_like(sintot), where=valid)
|
|
108
|
+
costot = np.divide(costot, weight, out=np.zeros_like(costot), where=valid)
|
|
109
|
+
mtot = np.arctan2(sintot, costot)
|
|
110
|
+
if fill_min_weight > 0:
|
|
111
|
+
mtot[~valid] = np.nan
|
|
112
|
+
return mtot
|
|
113
|
+
|
|
114
|
+
if smooth_sigma and smooth_sigma > 0:
|
|
115
|
+
sintot = gaussian_filter(sintot, smooth_sigma)
|
|
116
|
+
costot = gaussian_filter(costot, smooth_sigma)
|
|
117
|
+
mtot = np.arctan2(sintot, costot)
|
|
118
|
+
mtot[nans] = np.nan
|
|
119
|
+
return mtot
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _extract_coords_and_times(
|
|
123
|
+
decoding_result: dict[str, Any],
|
|
124
|
+
coords_key: str | None = None,
|
|
125
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
126
|
+
"""Extract coordinates and time indices from decoding result.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
decoding_result: Dictionary containing decoding results.
|
|
130
|
+
coords_key: Optional key for coordinates (defaults to 'coordsbox' or auto-detect).
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
Tuple of (coords, times_box) where coords is (T, 2) and times_box is optional.
|
|
134
|
+
"""
|
|
135
|
+
if coords_key is not None:
|
|
136
|
+
if coords_key not in decoding_result:
|
|
137
|
+
raise KeyError(f"coords_key '{coords_key}' not found in decoding_result.")
|
|
138
|
+
coords = np.asarray(decoding_result[coords_key])
|
|
139
|
+
elif "coordsbox" in decoding_result:
|
|
140
|
+
coords = np.asarray(decoding_result["coordsbox"])
|
|
141
|
+
else:
|
|
142
|
+
coords, _ = find_coords_matrix(decoding_result)
|
|
143
|
+
|
|
144
|
+
times_box, _ = find_times_box(decoding_result)
|
|
145
|
+
return coords, times_box
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _phase_map_valid_fraction(phase_map: np.ndarray) -> float:
|
|
149
|
+
"""Calculate fraction of valid (non-NaN) values in phase map.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
phase_map: Phase map array.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Fraction of valid values (0.0 to 1.0).
|
|
156
|
+
"""
|
|
157
|
+
valid = np.isfinite(phase_map)
|
|
158
|
+
if valid.size == 0:
|
|
159
|
+
return 0.0
|
|
160
|
+
return float(np.mean(valid))
|
|
@@ -11,6 +11,9 @@ Vollan, Gardner, Moser & Moser (Nature, 2025)
|
|
|
11
11
|
__version__ = "0.1.0"
|
|
12
12
|
|
|
13
13
|
from .core import ( # noqa: F401
|
|
14
|
+
BTNAnalyzer,
|
|
15
|
+
BTNConfig,
|
|
16
|
+
BTNResult,
|
|
14
17
|
GridnessAnalyzer,
|
|
15
18
|
GridnessResult,
|
|
16
19
|
HDCellResult,
|
|
@@ -46,6 +49,8 @@ from .utils import ( # noqa: F401
|
|
|
46
49
|
)
|
|
47
50
|
from .visualization import ( # noqa: F401
|
|
48
51
|
plot_autocorrelogram,
|
|
52
|
+
plot_btn_autocorr_summary,
|
|
53
|
+
plot_btn_distance_matrix,
|
|
49
54
|
plot_grid_score_histogram,
|
|
50
55
|
plot_gridness_analysis,
|
|
51
56
|
plot_hd_analysis,
|
|
@@ -57,6 +62,9 @@ from .visualization import ( # noqa: F401
|
|
|
57
62
|
__all__ = [
|
|
58
63
|
"GridnessAnalyzer",
|
|
59
64
|
"GridnessResult",
|
|
65
|
+
"BTNAnalyzer",
|
|
66
|
+
"BTNConfig",
|
|
67
|
+
"BTNResult",
|
|
60
68
|
"HeadDirectionAnalyzer",
|
|
61
69
|
"HDCellResult",
|
|
62
70
|
"compute_2d_autocorrelation",
|
|
@@ -94,4 +102,6 @@ __all__ = [
|
|
|
94
102
|
"plot_polar_tuning",
|
|
95
103
|
"plot_temporal_autocorr",
|
|
96
104
|
"plot_hd_analysis",
|
|
105
|
+
"plot_btn_autocorr_summary",
|
|
106
|
+
"plot_btn_distance_matrix",
|
|
97
107
|
]
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Core analysis modules."""
|
|
2
2
|
|
|
3
|
+
from .btn import BTNAnalyzer, BTNConfig, BTNResult
|
|
3
4
|
from .grid_cells import GridnessAnalyzer, GridnessResult, compute_2d_autocorrelation
|
|
4
5
|
from .grid_modules_leiden import identify_grid_modules_and_stats
|
|
5
6
|
from .head_direction import HDCellResult, HeadDirectionAnalyzer
|
|
@@ -15,6 +16,9 @@ __all__ = [
|
|
|
15
16
|
"GridnessAnalyzer",
|
|
16
17
|
"compute_2d_autocorrelation",
|
|
17
18
|
"GridnessResult",
|
|
19
|
+
"BTNAnalyzer",
|
|
20
|
+
"BTNConfig",
|
|
21
|
+
"BTNResult",
|
|
18
22
|
"HeadDirectionAnalyzer",
|
|
19
23
|
"HDCellResult",
|
|
20
24
|
"compute_rate_map",
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
"""
|
|
2
|
+
BTN (Bursty/Theta/Non-bursty) classification.
|
|
3
|
+
|
|
4
|
+
Workflow:
|
|
5
|
+
1) Compute ISI autocorrelograms from spike times.
|
|
6
|
+
2) Normalize and smooth each autocorr curve.
|
|
7
|
+
3) Compute cosine distance between curves.
|
|
8
|
+
4) Cluster with Tomato using a manual kNN graph and density weights.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
from scipy.ndimage import gaussian_filter1d
|
|
18
|
+
from scipy.spatial.distance import pdist, squareform
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from gudhi.clustering.tomato import Tomato
|
|
22
|
+
except Exception as exc: # pragma: no cover - optional dependency
|
|
23
|
+
Tomato = None
|
|
24
|
+
_TOMATO_IMPORT_ERROR = exc
|
|
25
|
+
else:
|
|
26
|
+
_TOMATO_IMPORT_ERROR = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class BTNConfig:
|
|
31
|
+
"""Configuration for BTN clustering."""
|
|
32
|
+
|
|
33
|
+
maxt: float = 0.2
|
|
34
|
+
res: float = 1e-3
|
|
35
|
+
smooth_sigma: float = 4.0
|
|
36
|
+
nbs: int = 80
|
|
37
|
+
n_clusters: int = 4
|
|
38
|
+
metric: str = "cosine"
|
|
39
|
+
b_one: bool = True
|
|
40
|
+
b_log: bool = False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class BTNResult:
|
|
45
|
+
"""Result container for BTN clustering."""
|
|
46
|
+
|
|
47
|
+
labels: np.ndarray
|
|
48
|
+
n_clusters: int
|
|
49
|
+
cluster_sizes: np.ndarray
|
|
50
|
+
btn_labels: np.ndarray | None
|
|
51
|
+
mapping: dict[int, str] | None
|
|
52
|
+
intermediates: dict[str, np.ndarray] | None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class BTNAnalyzer:
|
|
56
|
+
"""Analyzer that clusters neurons into BTN groups using Tomato."""
|
|
57
|
+
|
|
58
|
+
def __init__(self, config: BTNConfig | None = None):
|
|
59
|
+
self.config = config or BTNConfig()
|
|
60
|
+
|
|
61
|
+
def classify_btn(
|
|
62
|
+
self,
|
|
63
|
+
spike_data: dict[str, Any],
|
|
64
|
+
*,
|
|
65
|
+
mapping: dict[int, str] | None = None,
|
|
66
|
+
return_intermediates: bool = False,
|
|
67
|
+
plot_diagram: bool = False,
|
|
68
|
+
) -> BTNResult:
|
|
69
|
+
"""Cluster neurons into BTN classes using ISI autocorr + Tomato.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
spike_data : dict
|
|
74
|
+
ASA-style dict with keys ``spike`` and ``t`` (and optionally x/y).
|
|
75
|
+
mapping : dict, optional
|
|
76
|
+
Optional mapping from cluster id to BTN label string.
|
|
77
|
+
return_intermediates : bool
|
|
78
|
+
If True, include intermediate arrays in the result.
|
|
79
|
+
plot_diagram : bool
|
|
80
|
+
If True, call Tomato.plot_diagram() for visual inspection.
|
|
81
|
+
"""
|
|
82
|
+
_require_tomato()
|
|
83
|
+
spikes = _extract_spike_times(spike_data)
|
|
84
|
+
acorr, bin_times = _isi_autocorr(
|
|
85
|
+
spikes,
|
|
86
|
+
maxt=self.config.maxt,
|
|
87
|
+
res=self.config.res,
|
|
88
|
+
b_one=self.config.b_one,
|
|
89
|
+
b_log=self.config.b_log,
|
|
90
|
+
)
|
|
91
|
+
acorr_norm = _normalize_autocorr(acorr)
|
|
92
|
+
acorr_smooth = gaussian_filter1d(acorr_norm, sigma=self.config.smooth_sigma, axis=1)
|
|
93
|
+
|
|
94
|
+
dist = squareform(pdist(acorr_smooth, metric=self.config.metric))
|
|
95
|
+
num_nodes = dist.shape[0]
|
|
96
|
+
order = np.argsort(dist, axis=1)
|
|
97
|
+
if num_nodes > 1:
|
|
98
|
+
nbs_max = num_nodes - 1
|
|
99
|
+
nbs = max(1, min(int(self.config.nbs), nbs_max))
|
|
100
|
+
knn_indices = order[:, 1 : nbs + 1]
|
|
101
|
+
else:
|
|
102
|
+
nbs = 1
|
|
103
|
+
knn_indices = order[:, :1]
|
|
104
|
+
knn_dists = dist[np.arange(dist.shape[0])[:, None], knn_indices]
|
|
105
|
+
weights = np.sum(np.exp(-knn_dists), axis=1)
|
|
106
|
+
|
|
107
|
+
t = Tomato(graph_type="manual", density_type="manual", metric="precomputed")
|
|
108
|
+
t.fit(knn_indices, weights=weights)
|
|
109
|
+
if plot_diagram:
|
|
110
|
+
t.plot_diagram()
|
|
111
|
+
if self.config.n_clusters is not None:
|
|
112
|
+
t.n_clusters_ = int(self.config.n_clusters)
|
|
113
|
+
labels = np.asarray(t.labels_, dtype=int)
|
|
114
|
+
if labels.size:
|
|
115
|
+
if np.any(labels < 0):
|
|
116
|
+
valid = labels[labels >= 0]
|
|
117
|
+
cluster_sizes = np.bincount(valid) if valid.size else np.array([])
|
|
118
|
+
else:
|
|
119
|
+
cluster_sizes = np.bincount(labels)
|
|
120
|
+
else:
|
|
121
|
+
cluster_sizes = np.array([])
|
|
122
|
+
|
|
123
|
+
btn_labels = None
|
|
124
|
+
if mapping is not None:
|
|
125
|
+
btn_labels = np.array([mapping.get(int(c), "unknown") for c in labels], dtype=object)
|
|
126
|
+
|
|
127
|
+
intermediates = None
|
|
128
|
+
if return_intermediates:
|
|
129
|
+
intermediates = {
|
|
130
|
+
"acorr": acorr,
|
|
131
|
+
"acorr_norm": acorr_norm,
|
|
132
|
+
"acorr_smooth": acorr_smooth,
|
|
133
|
+
"distance_matrix": dist,
|
|
134
|
+
"knn_indices": knn_indices,
|
|
135
|
+
"knn_dists": knn_dists,
|
|
136
|
+
"density_weights": weights,
|
|
137
|
+
"bin_times": bin_times,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
if self.config.n_clusters is not None:
|
|
141
|
+
n_clusters = int(self.config.n_clusters)
|
|
142
|
+
else:
|
|
143
|
+
if labels.size:
|
|
144
|
+
valid_labels = labels[labels >= 0]
|
|
145
|
+
n_clusters = int(np.unique(valid_labels).size)
|
|
146
|
+
else:
|
|
147
|
+
n_clusters = 0
|
|
148
|
+
|
|
149
|
+
return BTNResult(
|
|
150
|
+
labels=labels,
|
|
151
|
+
n_clusters=n_clusters,
|
|
152
|
+
cluster_sizes=cluster_sizes,
|
|
153
|
+
btn_labels=btn_labels,
|
|
154
|
+
mapping=mapping,
|
|
155
|
+
intermediates=intermediates,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _require_tomato() -> None:
|
|
160
|
+
if Tomato is None:
|
|
161
|
+
raise ImportError(
|
|
162
|
+
"Tomato clustering requires gudhi. Install with: pip install gudhi"
|
|
163
|
+
) from _TOMATO_IMPORT_ERROR
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _extract_spike_times(spike_data: dict[str, Any]) -> dict[int, np.ndarray]:
|
|
167
|
+
if not isinstance(spike_data, dict) or "spike" not in spike_data or "t" not in spike_data:
|
|
168
|
+
raise ValueError("spike_data must be a dict with keys 'spike' and 't'")
|
|
169
|
+
|
|
170
|
+
spike_raw = spike_data["spike"]
|
|
171
|
+
t = np.asarray(spike_data["t"])
|
|
172
|
+
if (
|
|
173
|
+
isinstance(spike_raw, np.ndarray)
|
|
174
|
+
and spike_raw.ndim == 2
|
|
175
|
+
and spike_raw.shape[0] == t.shape[0]
|
|
176
|
+
):
|
|
177
|
+
raise ValueError(
|
|
178
|
+
"BTN expects spike times per neuron, but got a binned spike matrix (T x N)."
|
|
179
|
+
)
|
|
180
|
+
if (
|
|
181
|
+
hasattr(spike_raw, "item")
|
|
182
|
+
and callable(spike_raw.item)
|
|
183
|
+
and np.asarray(spike_raw).shape == ()
|
|
184
|
+
):
|
|
185
|
+
spikes_all = spike_raw[()]
|
|
186
|
+
elif isinstance(spike_raw, dict):
|
|
187
|
+
spikes_all = spike_raw
|
|
188
|
+
elif isinstance(spike_raw, (list, np.ndarray)):
|
|
189
|
+
spikes_all = spike_raw
|
|
190
|
+
else:
|
|
191
|
+
spikes_all = spike_raw
|
|
192
|
+
|
|
193
|
+
min_time0 = np.min(t)
|
|
194
|
+
max_time0 = np.max(t)
|
|
195
|
+
|
|
196
|
+
spikes: dict[int, np.ndarray] = {}
|
|
197
|
+
if isinstance(spikes_all, dict):
|
|
198
|
+
for i, key in enumerate(spikes_all.keys()):
|
|
199
|
+
s = np.asarray(spikes_all[key])
|
|
200
|
+
spikes[i] = s[(s >= min_time0) & (s < max_time0)]
|
|
201
|
+
else:
|
|
202
|
+
cell_inds = np.arange(len(spikes_all))
|
|
203
|
+
for i, m in enumerate(cell_inds):
|
|
204
|
+
s = np.asarray(spikes_all[m]) if len(spikes_all[m]) > 0 else np.array([])
|
|
205
|
+
if s.size > 0:
|
|
206
|
+
spikes[i] = s[(s >= min_time0) & (s < max_time0)]
|
|
207
|
+
else:
|
|
208
|
+
spikes[i] = np.array([])
|
|
209
|
+
|
|
210
|
+
return spikes
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _isi_autocorr(
|
|
214
|
+
spk: dict[int, np.ndarray],
|
|
215
|
+
*,
|
|
216
|
+
maxt: float = 0.2,
|
|
217
|
+
res: float = 1e-3,
|
|
218
|
+
b_one: bool = True,
|
|
219
|
+
b_log: bool = False,
|
|
220
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
221
|
+
"""Compute ISI autocorrelogram for each neuron.
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
acorr : ndarray
|
|
226
|
+
Shape (N, num_bins).
|
|
227
|
+
bin_times : ndarray
|
|
228
|
+
Bin edges used for histogramming.
|
|
229
|
+
"""
|
|
230
|
+
if b_log:
|
|
231
|
+
num_bins = 100
|
|
232
|
+
bin_times = np.ones(num_bins + 1) * 10
|
|
233
|
+
bin_times = np.power(bin_times, np.linspace(np.log10(0.005), np.log10(maxt), num_bins + 1))
|
|
234
|
+
bin_times = np.unique(np.concatenate((-bin_times, bin_times)))
|
|
235
|
+
num_bins = len(bin_times)
|
|
236
|
+
elif b_one:
|
|
237
|
+
num_bins = int(maxt / res) + 1
|
|
238
|
+
bin_times = np.linspace(0, maxt, num_bins)
|
|
239
|
+
else:
|
|
240
|
+
num_bins = int(2 * maxt / res) + 1
|
|
241
|
+
bin_times = np.linspace(-maxt, maxt, num_bins)
|
|
242
|
+
|
|
243
|
+
num_neurons = len(spk)
|
|
244
|
+
acorr = np.zeros((num_neurons, len(bin_times) - 1), dtype=int)
|
|
245
|
+
|
|
246
|
+
maxt = maxt - 1e-5
|
|
247
|
+
mint = -maxt
|
|
248
|
+
if b_one:
|
|
249
|
+
mint = -1e-5
|
|
250
|
+
|
|
251
|
+
for i, n in enumerate(spk):
|
|
252
|
+
spike_times = np.asarray(spk[n])
|
|
253
|
+
for ss in spike_times:
|
|
254
|
+
stemp = spike_times[(spike_times < ss + maxt) & (spike_times > ss + mint)]
|
|
255
|
+
dd = stemp - ss
|
|
256
|
+
if b_one:
|
|
257
|
+
dd = dd[dd >= 0]
|
|
258
|
+
bins = np.digitize(dd, bin_times) - 1
|
|
259
|
+
bins = bins[(bins >= 0) & (bins < num_bins - 1)]
|
|
260
|
+
if bins.size:
|
|
261
|
+
acorr[i, :] += np.bincount(bins, minlength=num_bins)[:-1]
|
|
262
|
+
|
|
263
|
+
return acorr, bin_times
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _normalize_autocorr(acorr: np.ndarray) -> np.ndarray:
|
|
267
|
+
acorr = acorr.astype(float, copy=True)
|
|
268
|
+
denom = acorr[:, 0].copy()
|
|
269
|
+
valid = denom > 0
|
|
270
|
+
acorr[valid, :] = acorr[valid, :] / denom[valid, None]
|
|
271
|
+
acorr[:, 0] = 0.0
|
|
272
|
+
return acorr
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Visualization modules."""
|
|
2
2
|
|
|
3
|
+
from .btn_plots import plot_btn_autocorr_summary, plot_btn_distance_matrix
|
|
3
4
|
from .grid_plots import (
|
|
4
5
|
plot_autocorrelogram,
|
|
5
6
|
plot_grid_score_histogram,
|
|
@@ -16,4 +17,6 @@ __all__ = [
|
|
|
16
17
|
"plot_polar_tuning",
|
|
17
18
|
"plot_temporal_autocorr",
|
|
18
19
|
"plot_hd_analysis",
|
|
20
|
+
"plot_btn_distance_matrix",
|
|
21
|
+
"plot_btn_autocorr_summary",
|
|
19
22
|
]
|