canns 0.14.3__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. canns/analyzer/data/asa/__init__.py +56 -21
  2. canns/analyzer/data/asa/coho.py +21 -0
  3. canns/analyzer/data/asa/cohomap.py +453 -0
  4. canns/analyzer/data/asa/cohomap_vectors.py +365 -0
  5. canns/analyzer/data/asa/cohospace.py +155 -1165
  6. canns/analyzer/data/asa/cohospace_phase_centers.py +119 -0
  7. canns/analyzer/data/asa/cohospace_scatter.py +1115 -0
  8. canns/analyzer/data/asa/embedding.py +5 -7
  9. canns/analyzer/data/asa/fr.py +1 -8
  10. canns/analyzer/data/asa/path.py +70 -0
  11. canns/analyzer/data/asa/plotting.py +5 -30
  12. canns/analyzer/data/asa/utils.py +160 -0
  13. canns/analyzer/data/cell_classification/__init__.py +10 -0
  14. canns/analyzer/data/cell_classification/core/__init__.py +4 -0
  15. canns/analyzer/data/cell_classification/core/btn.py +272 -0
  16. canns/analyzer/data/cell_classification/visualization/__init__.py +3 -0
  17. canns/analyzer/data/cell_classification/visualization/btn_plots.py +241 -0
  18. canns/analyzer/visualization/__init__.py +2 -0
  19. canns/analyzer/visualization/core/config.py +20 -0
  20. canns/analyzer/visualization/theta_sweep_plots.py +142 -0
  21. canns/pipeline/asa/runner.py +19 -19
  22. canns/pipeline/asa_gui/__init__.py +5 -3
  23. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +3 -1
  24. canns/pipeline/asa_gui/core/runner.py +23 -23
  25. canns/pipeline/asa_gui/views/pages/preprocess_page.py +7 -12
  26. {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/METADATA +1 -1
  27. {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/RECORD +30 -23
  28. canns/analyzer/data/asa/filters.py +0 -208
  29. {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/WHEEL +0 -0
  30. {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/entry_points.txt +0 -0
  31. {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/licenses/LICENSE +0 -0
@@ -6,7 +6,6 @@ import numpy as np
6
6
  from scipy.ndimage import gaussian_filter1d
7
7
 
8
8
  from .config import DataLoadError, ProcessingError, SpikeEmbeddingConfig
9
- from .filters import _gaussian_filter1d
10
9
 
11
10
 
12
11
  def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None, **kwargs):
@@ -31,10 +30,9 @@ def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None,
31
30
  Returns
32
31
  -------
33
32
  tuple
34
- ``(spikes_bin, xx, yy, tt)`` where:
35
- - ``spikes_bin`` is a (T, N) binned spike matrix.
36
- - ``xx``, ``yy``, ``tt`` are position/time arrays when ``speed_filter=True``,
37
- otherwise ``None``.
33
+ ``(spikes_bin, xx, yy, tt)``. ``spikes_bin`` is a (T, N) binned spike matrix.
34
+ ``xx``, ``yy``, ``tt`` are position/time arrays when ``speed_filter=True``,
35
+ otherwise ``None``.
38
36
 
39
37
  Examples
40
38
  --------
@@ -263,8 +261,8 @@ def _load_pos(t, x, y, res=100000, dt=1000):
263
261
  xx = np.interp(tt, t, x)
264
262
  yy = np.interp(tt, t, y)
265
263
 
266
- xxs = _gaussian_filter1d(xx - np.min(xx), sigma=100)
267
- yys = _gaussian_filter1d(yy - np.min(yy), sigma=100)
264
+ xxs = gaussian_filter1d(xx - np.min(xx), sigma=100)
265
+ yys = gaussian_filter1d(yy - np.min(yy), sigma=100)
268
266
  dx = (xxs[1:] - xxs[:-1]) * 100
269
267
  dy = (yys[1:] - yys[:-1]) * 100
270
268
  speed = np.sqrt(dx**2 + dy**2) / 0.01
@@ -1,18 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
- import os
4
3
  from dataclasses import dataclass
5
4
 
6
5
  import numpy as np
7
6
 
8
7
  from ...visualization.core import PlotConfig, finalize_figure
9
-
10
-
11
- def _ensure_parent_dir(save_path: str | None) -> None:
12
- if save_path:
13
- parent = os.path.dirname(save_path)
14
- if parent:
15
- os.makedirs(parent, exist_ok=True)
8
+ from .utils import _ensure_parent_dir
16
9
 
17
10
 
18
11
  def _slice_range(r: tuple[int, int] | None, length: int) -> slice:
@@ -240,6 +240,76 @@ def interp_coords_to_full_1d(idx_map: np.ndarray, coords1: np.ndarray, T_full: i
240
240
  return np.mod(out, 2 * np.pi)[:, None]
241
241
 
242
242
 
243
+ def _align_activity_to_coords(
244
+ coords: np.ndarray,
245
+ activity: np.ndarray,
246
+ times: np.ndarray | None = None,
247
+ *,
248
+ label: str = "activity",
249
+ auto_filter: bool = True,
250
+ ) -> np.ndarray:
251
+ """
252
+ Align activity to coords by optional time indices and validate lengths.
253
+
254
+ Parameters
255
+ ----------
256
+ coords : ndarray
257
+ Decoded coordinates array.
258
+ activity : ndarray
259
+ Activity matrix (firing rate or spikes).
260
+ times : ndarray, optional
261
+ Optional time indices to align activity to coords when coords are computed
262
+ on a subset of timepoints.
263
+ label : str
264
+ Label for error messages (default: "activity").
265
+ auto_filter : bool
266
+ If True and lengths mismatch, auto-filter activity with activity>0 to mimic
267
+ decode filtering.
268
+
269
+ Returns
270
+ -------
271
+ ndarray
272
+ Aligned activity array.
273
+
274
+ Raises
275
+ ------
276
+ ValueError
277
+ If activity length doesn't match coords length after alignment attempts.
278
+ """
279
+ coords = np.asarray(coords)
280
+ activity = np.asarray(activity)
281
+
282
+ if times is not None:
283
+ times = np.asarray(times)
284
+ try:
285
+ activity = activity[times]
286
+ except Exception as exc:
287
+ raise ValueError(
288
+ f"Failed to index {label} with `times`. Ensure `times` indexes the original time axis."
289
+ ) from exc
290
+
291
+ if activity.shape[0] != coords.shape[0]:
292
+ # Try to reproduce decode's zero-spike filtering if lengths mismatch.
293
+ if auto_filter and times is None and activity.ndim == 2:
294
+ mask = np.sum(activity > 0, axis=1) >= 1
295
+ if mask.sum() == coords.shape[0]:
296
+ activity = activity[mask]
297
+ else:
298
+ raise ValueError(
299
+ f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
300
+ "If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
301
+ "`times=decoding['times']` or slice the activity accordingly."
302
+ )
303
+ else:
304
+ raise ValueError(
305
+ f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
306
+ "If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
307
+ "`times=decoding['times']` or slice the activity accordingly."
308
+ )
309
+
310
+ return activity
311
+
312
+
243
313
  def align_coords_to_position_2d(
244
314
  t_full: np.ndarray,
245
315
  x_full: np.ndarray,
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import os
4
3
  from typing import Any
5
4
 
6
5
  import matplotlib.pyplot as plt
@@ -24,31 +23,7 @@ from ...visualization.core import (
24
23
  from ...visualization.core.jupyter_utils import display_animation_in_jupyter, is_jupyter_environment
25
24
  from .config import CANN2DPlotConfig, ProcessingError, SpikeEmbeddingConfig
26
25
  from .embedding import embed_spike_trains
27
-
28
-
29
- def _ensure_plot_config(
30
- config: PlotConfig | None,
31
- factory,
32
- *,
33
- kwargs: dict[str, Any] | None = None,
34
- **defaults: Any,
35
- ) -> PlotConfig:
36
- if config is None:
37
- defaults.update({"kwargs": kwargs or {}})
38
- return factory(**defaults)
39
-
40
- if kwargs:
41
- config_kwargs = config.kwargs or {}
42
- config_kwargs.update(kwargs)
43
- config.kwargs = config_kwargs
44
- return config
45
-
46
-
47
- def _ensure_parent_dir(save_path: str | None) -> None:
48
- if save_path:
49
- parent = os.path.dirname(save_path)
50
- if parent:
51
- os.makedirs(parent, exist_ok=True)
26
+ from .utils import _ensure_parent_dir, _ensure_plot_config
52
27
 
53
28
 
54
29
  def _render_torus_frame(frame_index: int, frame_data: dict[str, Any]) -> np.ndarray:
@@ -403,7 +378,7 @@ def plot_path_compare_1d(
403
378
  return fig, axes
404
379
 
405
380
 
406
- def plot_cohomap(
381
+ def plot_cohomap_scatter(
407
382
  decoding_result: dict[str, Any],
408
383
  position_data: dict[str, Any],
409
384
  config: PlotConfig | None = None,
@@ -453,7 +428,7 @@ def plot_cohomap(
453
428
  >>> # Decode coordinates
454
429
  >>> decoding = decode_circular_coordinates(persistence_result, spike_data)
455
430
  >>> # Visualize with trajectory data
456
- >>> fig = plot_cohomap(
431
+ >>> fig = plot_cohomap_scatter(
457
432
  ... decoding,
458
433
  ... position_data={'x': xx, 'y': yy},
459
434
  ... save_path='cohomap.png',
@@ -518,7 +493,7 @@ def plot_cohomap(
518
493
  return fig
519
494
 
520
495
 
521
- def plot_cohomap_multi(
496
+ def plot_cohomap_scatter_multi(
522
497
  decoding_result: dict,
523
498
  position_data: dict,
524
499
  config: PlotConfig | None = None,
@@ -560,7 +535,7 @@ def plot_cohomap_multi(
560
535
 
561
536
  Examples
562
537
  --------
563
- >>> fig = plot_cohomap_multi(decoding, {"x": xx, "y": yy}, show=False) # doctest: +SKIP
538
+ >>> fig = plot_cohomap_scatter_multi(decoding, {"x": xx, "y": yy}, show=False) # doctest: +SKIP
564
539
  """
565
540
  config = _ensure_plot_config(
566
541
  config,
@@ -0,0 +1,160 @@
1
+ """Shared utility functions for ASA (Attractor State Analysis) modules."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ from scipy.ndimage import gaussian_filter
10
+
11
+ from ...visualization.core import PlotConfig
12
+ from .path import find_coords_matrix, find_times_box
13
+
14
+
15
+ def _ensure_plot_config(
16
+ config: PlotConfig | None,
17
+ factory,
18
+ *args,
19
+ kwargs: dict | None = None,
20
+ **defaults,
21
+ ) -> PlotConfig:
22
+ """Ensure a PlotConfig exists, creating one from factory if needed.
23
+
24
+ Args:
25
+ config: Optional existing PlotConfig.
26
+ factory: Factory function to create PlotConfig if config is None.
27
+ *args: Positional arguments for factory.
28
+ kwargs: Optional dict to merge into config.kwargs.
29
+ **defaults: Keyword arguments for factory.
30
+
31
+ Returns:
32
+ PlotConfig instance.
33
+ """
34
+ if config is None:
35
+ if kwargs:
36
+ defaults.update({"kwargs": kwargs})
37
+ return factory(*args, **defaults)
38
+
39
+ # If config exists and kwargs provided, merge them
40
+ if kwargs:
41
+ config_kwargs = config.kwargs or {}
42
+ config_kwargs.update(kwargs)
43
+ config.kwargs = config_kwargs
44
+ return config
45
+
46
+
47
+ def _ensure_parent_dir(save_path: str | None) -> None:
48
+ """Create parent directory for save_path if it doesn't exist.
49
+
50
+ Args:
51
+ save_path: Optional file path. If provided, creates parent directory.
52
+ """
53
+ if save_path:
54
+ parent = os.path.dirname(save_path)
55
+ if parent:
56
+ os.makedirs(parent, exist_ok=True)
57
+
58
+
59
+ def _circmean(x: np.ndarray) -> float:
60
+ """Compute circular mean of angles.
61
+
62
+ Args:
63
+ x: Array of angles in radians.
64
+
65
+ Returns:
66
+ Circular mean in radians.
67
+ """
68
+ return float(np.arctan2(np.mean(np.sin(x)), np.mean(np.cos(x))))
69
+
70
+
71
+ def _smooth_circular_map(
72
+ mtot: np.ndarray,
73
+ smooth_sigma: float,
74
+ *,
75
+ fill_nan: bool = False,
76
+ fill_sigma: float | None = None,
77
+ fill_min_weight: float = 1e-3,
78
+ ) -> np.ndarray:
79
+ """Smooth a circular phase map using Gaussian filtering in sin/cos space.
80
+
81
+ Args:
82
+ mtot: Phase map array (angles in radians).
83
+ smooth_sigma: Gaussian smoothing sigma.
84
+ fill_nan: Whether to fill NaN values using weighted interpolation.
85
+ fill_sigma: Sigma for NaN filling (defaults to smooth_sigma).
86
+ fill_min_weight: Minimum weight threshold for valid interpolation.
87
+
88
+ Returns:
89
+ Smoothed phase map.
90
+ """
91
+ mtot = np.asarray(mtot, dtype=float)
92
+ nans = np.isnan(mtot)
93
+ mask = (~nans).astype(float)
94
+ sintot = np.sin(mtot)
95
+ costot = np.cos(mtot)
96
+ sintot[nans] = 0.0
97
+ costot[nans] = 0.0
98
+
99
+ if fill_nan:
100
+ if fill_sigma is None:
101
+ fill_sigma = smooth_sigma if smooth_sigma and smooth_sigma > 0 else 1.0
102
+ weight = gaussian_filter(mask, fill_sigma)
103
+ sintot = gaussian_filter(sintot * mask, fill_sigma)
104
+ costot = gaussian_filter(costot * mask, fill_sigma)
105
+ min_weight = max(float(fill_min_weight), 0.0)
106
+ valid = weight > min_weight
107
+ sintot = np.divide(sintot, weight, out=np.zeros_like(sintot), where=valid)
108
+ costot = np.divide(costot, weight, out=np.zeros_like(costot), where=valid)
109
+ mtot = np.arctan2(sintot, costot)
110
+ if fill_min_weight > 0:
111
+ mtot[~valid] = np.nan
112
+ return mtot
113
+
114
+ if smooth_sigma and smooth_sigma > 0:
115
+ sintot = gaussian_filter(sintot, smooth_sigma)
116
+ costot = gaussian_filter(costot, smooth_sigma)
117
+ mtot = np.arctan2(sintot, costot)
118
+ mtot[nans] = np.nan
119
+ return mtot
120
+
121
+
122
+ def _extract_coords_and_times(
123
+ decoding_result: dict[str, Any],
124
+ coords_key: str | None = None,
125
+ ) -> tuple[np.ndarray, np.ndarray | None]:
126
+ """Extract coordinates and time indices from decoding result.
127
+
128
+ Args:
129
+ decoding_result: Dictionary containing decoding results.
130
+ coords_key: Optional key for coordinates (defaults to 'coordsbox' or auto-detect).
131
+
132
+ Returns:
133
+ Tuple of (coords, times_box) where coords is (T, 2) and times_box is optional.
134
+ """
135
+ if coords_key is not None:
136
+ if coords_key not in decoding_result:
137
+ raise KeyError(f"coords_key '{coords_key}' not found in decoding_result.")
138
+ coords = np.asarray(decoding_result[coords_key])
139
+ elif "coordsbox" in decoding_result:
140
+ coords = np.asarray(decoding_result["coordsbox"])
141
+ else:
142
+ coords, _ = find_coords_matrix(decoding_result)
143
+
144
+ times_box, _ = find_times_box(decoding_result)
145
+ return coords, times_box
146
+
147
+
148
+ def _phase_map_valid_fraction(phase_map: np.ndarray) -> float:
149
+ """Calculate fraction of valid (non-NaN) values in phase map.
150
+
151
+ Args:
152
+ phase_map: Phase map array.
153
+
154
+ Returns:
155
+ Fraction of valid values (0.0 to 1.0).
156
+ """
157
+ valid = np.isfinite(phase_map)
158
+ if valid.size == 0:
159
+ return 0.0
160
+ return float(np.mean(valid))
@@ -11,6 +11,9 @@ Vollan, Gardner, Moser & Moser (Nature, 2025)
11
11
  __version__ = "0.1.0"
12
12
 
13
13
  from .core import ( # noqa: F401
14
+ BTNAnalyzer,
15
+ BTNConfig,
16
+ BTNResult,
14
17
  GridnessAnalyzer,
15
18
  GridnessResult,
16
19
  HDCellResult,
@@ -46,6 +49,8 @@ from .utils import ( # noqa: F401
46
49
  )
47
50
  from .visualization import ( # noqa: F401
48
51
  plot_autocorrelogram,
52
+ plot_btn_autocorr_summary,
53
+ plot_btn_distance_matrix,
49
54
  plot_grid_score_histogram,
50
55
  plot_gridness_analysis,
51
56
  plot_hd_analysis,
@@ -57,6 +62,9 @@ from .visualization import ( # noqa: F401
57
62
  __all__ = [
58
63
  "GridnessAnalyzer",
59
64
  "GridnessResult",
65
+ "BTNAnalyzer",
66
+ "BTNConfig",
67
+ "BTNResult",
60
68
  "HeadDirectionAnalyzer",
61
69
  "HDCellResult",
62
70
  "compute_2d_autocorrelation",
@@ -94,4 +102,6 @@ __all__ = [
94
102
  "plot_polar_tuning",
95
103
  "plot_temporal_autocorr",
96
104
  "plot_hd_analysis",
105
+ "plot_btn_autocorr_summary",
106
+ "plot_btn_distance_matrix",
97
107
  ]
@@ -1,5 +1,6 @@
1
1
  """Core analysis modules."""
2
2
 
3
+ from .btn import BTNAnalyzer, BTNConfig, BTNResult
3
4
  from .grid_cells import GridnessAnalyzer, GridnessResult, compute_2d_autocorrelation
4
5
  from .grid_modules_leiden import identify_grid_modules_and_stats
5
6
  from .head_direction import HDCellResult, HeadDirectionAnalyzer
@@ -15,6 +16,9 @@ __all__ = [
15
16
  "GridnessAnalyzer",
16
17
  "compute_2d_autocorrelation",
17
18
  "GridnessResult",
19
+ "BTNAnalyzer",
20
+ "BTNConfig",
21
+ "BTNResult",
18
22
  "HeadDirectionAnalyzer",
19
23
  "HDCellResult",
20
24
  "compute_rate_map",
@@ -0,0 +1,272 @@
1
+ """
2
+ BTN (Bursty/Theta/Non-bursty) classification.
3
+
4
+ Workflow:
5
+ 1) Compute ISI autocorrelograms from spike times.
6
+ 2) Normalize and smooth each autocorr curve.
7
+ 3) Compute cosine distance between curves.
8
+ 4) Cluster with Tomato using a manual kNN graph and density weights.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from dataclasses import dataclass
14
+ from typing import Any
15
+
16
+ import numpy as np
17
+ from scipy.ndimage import gaussian_filter1d
18
+ from scipy.spatial.distance import pdist, squareform
19
+
20
+ try:
21
+ from gudhi.clustering.tomato import Tomato
22
+ except Exception as exc: # pragma: no cover - optional dependency
23
+ Tomato = None
24
+ _TOMATO_IMPORT_ERROR = exc
25
+ else:
26
+ _TOMATO_IMPORT_ERROR = None
27
+
28
+
29
+ @dataclass
30
+ class BTNConfig:
31
+ """Configuration for BTN clustering."""
32
+
33
+ maxt: float = 0.2
34
+ res: float = 1e-3
35
+ smooth_sigma: float = 4.0
36
+ nbs: int = 80
37
+ n_clusters: int = 4
38
+ metric: str = "cosine"
39
+ b_one: bool = True
40
+ b_log: bool = False
41
+
42
+
43
+ @dataclass
44
+ class BTNResult:
45
+ """Result container for BTN clustering."""
46
+
47
+ labels: np.ndarray
48
+ n_clusters: int
49
+ cluster_sizes: np.ndarray
50
+ btn_labels: np.ndarray | None
51
+ mapping: dict[int, str] | None
52
+ intermediates: dict[str, np.ndarray] | None
53
+
54
+
55
+ class BTNAnalyzer:
56
+ """Analyzer that clusters neurons into BTN groups using Tomato."""
57
+
58
+ def __init__(self, config: BTNConfig | None = None):
59
+ self.config = config or BTNConfig()
60
+
61
+ def classify_btn(
62
+ self,
63
+ spike_data: dict[str, Any],
64
+ *,
65
+ mapping: dict[int, str] | None = None,
66
+ return_intermediates: bool = False,
67
+ plot_diagram: bool = False,
68
+ ) -> BTNResult:
69
+ """Cluster neurons into BTN classes using ISI autocorr + Tomato.
70
+
71
+ Parameters
72
+ ----------
73
+ spike_data : dict
74
+ ASA-style dict with keys ``spike`` and ``t`` (and optionally x/y).
75
+ mapping : dict, optional
76
+ Optional mapping from cluster id to BTN label string.
77
+ return_intermediates : bool
78
+ If True, include intermediate arrays in the result.
79
+ plot_diagram : bool
80
+ If True, call Tomato.plot_diagram() for visual inspection.
81
+ """
82
+ _require_tomato()
83
+ spikes = _extract_spike_times(spike_data)
84
+ acorr, bin_times = _isi_autocorr(
85
+ spikes,
86
+ maxt=self.config.maxt,
87
+ res=self.config.res,
88
+ b_one=self.config.b_one,
89
+ b_log=self.config.b_log,
90
+ )
91
+ acorr_norm = _normalize_autocorr(acorr)
92
+ acorr_smooth = gaussian_filter1d(acorr_norm, sigma=self.config.smooth_sigma, axis=1)
93
+
94
+ dist = squareform(pdist(acorr_smooth, metric=self.config.metric))
95
+ num_nodes = dist.shape[0]
96
+ order = np.argsort(dist, axis=1)
97
+ if num_nodes > 1:
98
+ nbs_max = num_nodes - 1
99
+ nbs = max(1, min(int(self.config.nbs), nbs_max))
100
+ knn_indices = order[:, 1 : nbs + 1]
101
+ else:
102
+ nbs = 1
103
+ knn_indices = order[:, :1]
104
+ knn_dists = dist[np.arange(dist.shape[0])[:, None], knn_indices]
105
+ weights = np.sum(np.exp(-knn_dists), axis=1)
106
+
107
+ t = Tomato(graph_type="manual", density_type="manual", metric="precomputed")
108
+ t.fit(knn_indices, weights=weights)
109
+ if plot_diagram:
110
+ t.plot_diagram()
111
+ if self.config.n_clusters is not None:
112
+ t.n_clusters_ = int(self.config.n_clusters)
113
+ labels = np.asarray(t.labels_, dtype=int)
114
+ if labels.size:
115
+ if np.any(labels < 0):
116
+ valid = labels[labels >= 0]
117
+ cluster_sizes = np.bincount(valid) if valid.size else np.array([])
118
+ else:
119
+ cluster_sizes = np.bincount(labels)
120
+ else:
121
+ cluster_sizes = np.array([])
122
+
123
+ btn_labels = None
124
+ if mapping is not None:
125
+ btn_labels = np.array([mapping.get(int(c), "unknown") for c in labels], dtype=object)
126
+
127
+ intermediates = None
128
+ if return_intermediates:
129
+ intermediates = {
130
+ "acorr": acorr,
131
+ "acorr_norm": acorr_norm,
132
+ "acorr_smooth": acorr_smooth,
133
+ "distance_matrix": dist,
134
+ "knn_indices": knn_indices,
135
+ "knn_dists": knn_dists,
136
+ "density_weights": weights,
137
+ "bin_times": bin_times,
138
+ }
139
+
140
+ if self.config.n_clusters is not None:
141
+ n_clusters = int(self.config.n_clusters)
142
+ else:
143
+ if labels.size:
144
+ valid_labels = labels[labels >= 0]
145
+ n_clusters = int(np.unique(valid_labels).size)
146
+ else:
147
+ n_clusters = 0
148
+
149
+ return BTNResult(
150
+ labels=labels,
151
+ n_clusters=n_clusters,
152
+ cluster_sizes=cluster_sizes,
153
+ btn_labels=btn_labels,
154
+ mapping=mapping,
155
+ intermediates=intermediates,
156
+ )
157
+
158
+
159
+ def _require_tomato() -> None:
160
+ if Tomato is None:
161
+ raise ImportError(
162
+ "Tomato clustering requires gudhi. Install with: pip install gudhi"
163
+ ) from _TOMATO_IMPORT_ERROR
164
+
165
+
166
+ def _extract_spike_times(spike_data: dict[str, Any]) -> dict[int, np.ndarray]:
167
+ if not isinstance(spike_data, dict) or "spike" not in spike_data or "t" not in spike_data:
168
+ raise ValueError("spike_data must be a dict with keys 'spike' and 't'")
169
+
170
+ spike_raw = spike_data["spike"]
171
+ t = np.asarray(spike_data["t"])
172
+ if (
173
+ isinstance(spike_raw, np.ndarray)
174
+ and spike_raw.ndim == 2
175
+ and spike_raw.shape[0] == t.shape[0]
176
+ ):
177
+ raise ValueError(
178
+ "BTN expects spike times per neuron, but got a binned spike matrix (T x N)."
179
+ )
180
+ if (
181
+ hasattr(spike_raw, "item")
182
+ and callable(spike_raw.item)
183
+ and np.asarray(spike_raw).shape == ()
184
+ ):
185
+ spikes_all = spike_raw[()]
186
+ elif isinstance(spike_raw, dict):
187
+ spikes_all = spike_raw
188
+ elif isinstance(spike_raw, (list, np.ndarray)):
189
+ spikes_all = spike_raw
190
+ else:
191
+ spikes_all = spike_raw
192
+
193
+ min_time0 = np.min(t)
194
+ max_time0 = np.max(t)
195
+
196
+ spikes: dict[int, np.ndarray] = {}
197
+ if isinstance(spikes_all, dict):
198
+ for i, key in enumerate(spikes_all.keys()):
199
+ s = np.asarray(spikes_all[key])
200
+ spikes[i] = s[(s >= min_time0) & (s < max_time0)]
201
+ else:
202
+ cell_inds = np.arange(len(spikes_all))
203
+ for i, m in enumerate(cell_inds):
204
+ s = np.asarray(spikes_all[m]) if len(spikes_all[m]) > 0 else np.array([])
205
+ if s.size > 0:
206
+ spikes[i] = s[(s >= min_time0) & (s < max_time0)]
207
+ else:
208
+ spikes[i] = np.array([])
209
+
210
+ return spikes
211
+
212
+
213
+ def _isi_autocorr(
214
+ spk: dict[int, np.ndarray],
215
+ *,
216
+ maxt: float = 0.2,
217
+ res: float = 1e-3,
218
+ b_one: bool = True,
219
+ b_log: bool = False,
220
+ ) -> tuple[np.ndarray, np.ndarray]:
221
+ """Compute ISI autocorrelogram for each neuron.
222
+
223
+ Returns
224
+ -------
225
+ acorr : ndarray
226
+ Shape (N, num_bins).
227
+ bin_times : ndarray
228
+ Bin edges used for histogramming.
229
+ """
230
+ if b_log:
231
+ num_bins = 100
232
+ bin_times = np.ones(num_bins + 1) * 10
233
+ bin_times = np.power(bin_times, np.linspace(np.log10(0.005), np.log10(maxt), num_bins + 1))
234
+ bin_times = np.unique(np.concatenate((-bin_times, bin_times)))
235
+ num_bins = len(bin_times)
236
+ elif b_one:
237
+ num_bins = int(maxt / res) + 1
238
+ bin_times = np.linspace(0, maxt, num_bins)
239
+ else:
240
+ num_bins = int(2 * maxt / res) + 1
241
+ bin_times = np.linspace(-maxt, maxt, num_bins)
242
+
243
+ num_neurons = len(spk)
244
+ acorr = np.zeros((num_neurons, len(bin_times) - 1), dtype=int)
245
+
246
+ maxt = maxt - 1e-5
247
+ mint = -maxt
248
+ if b_one:
249
+ mint = -1e-5
250
+
251
+ for i, n in enumerate(spk):
252
+ spike_times = np.asarray(spk[n])
253
+ for ss in spike_times:
254
+ stemp = spike_times[(spike_times < ss + maxt) & (spike_times > ss + mint)]
255
+ dd = stemp - ss
256
+ if b_one:
257
+ dd = dd[dd >= 0]
258
+ bins = np.digitize(dd, bin_times) - 1
259
+ bins = bins[(bins >= 0) & (bins < num_bins - 1)]
260
+ if bins.size:
261
+ acorr[i, :] += np.bincount(bins, minlength=num_bins)[:-1]
262
+
263
+ return acorr, bin_times
264
+
265
+
266
+ def _normalize_autocorr(acorr: np.ndarray) -> np.ndarray:
267
+ acorr = acorr.astype(float, copy=True)
268
+ denom = acorr[:, 0].copy()
269
+ valid = denom > 0
270
+ acorr[valid, :] = acorr[valid, :] / denom[valid, None]
271
+ acorr[:, 0] = 0.0
272
+ return acorr
@@ -1,5 +1,6 @@
1
1
  """Visualization modules."""
2
2
 
3
+ from .btn_plots import plot_btn_autocorr_summary, plot_btn_distance_matrix
3
4
  from .grid_plots import (
4
5
  plot_autocorrelogram,
5
6
  plot_grid_score_histogram,
@@ -16,4 +17,6 @@ __all__ = [
16
17
  "plot_polar_tuning",
17
18
  "plot_temporal_autocorr",
18
19
  "plot_hd_analysis",
20
+ "plot_btn_distance_matrix",
21
+ "plot_btn_autocorr_summary",
19
22
  ]