canns 0.14.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. canns/analyzer/data/asa/__init__.py +77 -21
  2. canns/analyzer/data/asa/coho.py +97 -0
  3. canns/analyzer/data/asa/cohomap.py +408 -0
  4. canns/analyzer/data/asa/cohomap_scatter.py +10 -0
  5. canns/analyzer/data/asa/cohomap_vectors.py +311 -0
  6. canns/analyzer/data/asa/cohospace.py +173 -1153
  7. canns/analyzer/data/asa/cohospace_phase_centers.py +137 -0
  8. canns/analyzer/data/asa/cohospace_scatter.py +1220 -0
  9. canns/analyzer/data/asa/embedding.py +3 -4
  10. canns/analyzer/data/asa/plotting.py +4 -4
  11. canns/analyzer/data/cell_classification/__init__.py +10 -0
  12. canns/analyzer/data/cell_classification/core/__init__.py +4 -0
  13. canns/analyzer/data/cell_classification/core/btn.py +272 -0
  14. canns/analyzer/data/cell_classification/visualization/__init__.py +3 -0
  15. canns/analyzer/data/cell_classification/visualization/btn_plots.py +258 -0
  16. canns/analyzer/visualization/__init__.py +2 -0
  17. canns/analyzer/visualization/core/config.py +20 -0
  18. canns/analyzer/visualization/theta_sweep_plots.py +142 -0
  19. canns/pipeline/asa/runner.py +19 -19
  20. canns/pipeline/asa_gui/__init__.py +5 -3
  21. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +32 -4
  22. canns/pipeline/asa_gui/core/runner.py +23 -23
  23. canns/pipeline/asa_gui/views/pages/preprocess_page.py +250 -8
  24. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/METADATA +2 -1
  25. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/RECORD +28 -20
  26. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/WHEEL +0 -0
  27. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/entry_points.txt +0 -0
  28. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/licenses/LICENSE +0 -0
@@ -31,10 +31,9 @@ def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None,
31
31
  Returns
32
32
  -------
33
33
  tuple
34
- ``(spikes_bin, xx, yy, tt)`` where:
35
- - ``spikes_bin`` is a (T, N) binned spike matrix.
36
- - ``xx``, ``yy``, ``tt`` are position/time arrays when ``speed_filter=True``,
37
- otherwise ``None``.
34
+ ``(spikes_bin, xx, yy, tt)``. ``spikes_bin`` is a (T, N) binned spike matrix.
35
+ ``xx``, ``yy``, ``tt`` are position/time arrays when ``speed_filter=True``,
36
+ otherwise ``None``.
38
37
 
39
38
  Examples
40
39
  --------
@@ -403,7 +403,7 @@ def plot_path_compare_1d(
403
403
  return fig, axes
404
404
 
405
405
 
406
- def plot_cohomap(
406
+ def plot_cohomap_scatter(
407
407
  decoding_result: dict[str, Any],
408
408
  position_data: dict[str, Any],
409
409
  config: PlotConfig | None = None,
@@ -453,7 +453,7 @@ def plot_cohomap(
453
453
  >>> # Decode coordinates
454
454
  >>> decoding = decode_circular_coordinates(persistence_result, spike_data)
455
455
  >>> # Visualize with trajectory data
456
- >>> fig = plot_cohomap(
456
+ >>> fig = plot_cohomap_scatter(
457
457
  ... decoding,
458
458
  ... position_data={'x': xx, 'y': yy},
459
459
  ... save_path='cohomap.png',
@@ -518,7 +518,7 @@ def plot_cohomap(
518
518
  return fig
519
519
 
520
520
 
521
- def plot_cohomap_multi(
521
+ def plot_cohomap_scatter_multi(
522
522
  decoding_result: dict,
523
523
  position_data: dict,
524
524
  config: PlotConfig | None = None,
@@ -560,7 +560,7 @@ def plot_cohomap_multi(
560
560
 
561
561
  Examples
562
562
  --------
563
- >>> fig = plot_cohomap_multi(decoding, {"x": xx, "y": yy}, show=False) # doctest: +SKIP
563
+ >>> fig = plot_cohomap_scatter_multi(decoding, {"x": xx, "y": yy}, show=False) # doctest: +SKIP
564
564
  """
565
565
  config = _ensure_plot_config(
566
566
  config,
@@ -11,6 +11,9 @@ Vollan, Gardner, Moser & Moser (Nature, 2025)
11
11
  __version__ = "0.1.0"
12
12
 
13
13
  from .core import ( # noqa: F401
14
+ BTNAnalyzer,
15
+ BTNConfig,
16
+ BTNResult,
14
17
  GridnessAnalyzer,
15
18
  GridnessResult,
16
19
  HDCellResult,
@@ -46,6 +49,8 @@ from .utils import ( # noqa: F401
46
49
  )
47
50
  from .visualization import ( # noqa: F401
48
51
  plot_autocorrelogram,
52
+ plot_btn_autocorr_summary,
53
+ plot_btn_distance_matrix,
49
54
  plot_grid_score_histogram,
50
55
  plot_gridness_analysis,
51
56
  plot_hd_analysis,
@@ -57,6 +62,9 @@ from .visualization import ( # noqa: F401
57
62
  __all__ = [
58
63
  "GridnessAnalyzer",
59
64
  "GridnessResult",
65
+ "BTNAnalyzer",
66
+ "BTNConfig",
67
+ "BTNResult",
60
68
  "HeadDirectionAnalyzer",
61
69
  "HDCellResult",
62
70
  "compute_2d_autocorrelation",
@@ -94,4 +102,6 @@ __all__ = [
94
102
  "plot_polar_tuning",
95
103
  "plot_temporal_autocorr",
96
104
  "plot_hd_analysis",
105
+ "plot_btn_autocorr_summary",
106
+ "plot_btn_distance_matrix",
97
107
  ]
@@ -1,5 +1,6 @@
1
1
  """Core analysis modules."""
2
2
 
3
+ from .btn import BTNAnalyzer, BTNConfig, BTNResult
3
4
  from .grid_cells import GridnessAnalyzer, GridnessResult, compute_2d_autocorrelation
4
5
  from .grid_modules_leiden import identify_grid_modules_and_stats
5
6
  from .head_direction import HDCellResult, HeadDirectionAnalyzer
@@ -15,6 +16,9 @@ __all__ = [
15
16
  "GridnessAnalyzer",
16
17
  "compute_2d_autocorrelation",
17
18
  "GridnessResult",
19
+ "BTNAnalyzer",
20
+ "BTNConfig",
21
+ "BTNResult",
18
22
  "HeadDirectionAnalyzer",
19
23
  "HDCellResult",
20
24
  "compute_rate_map",
@@ -0,0 +1,272 @@
1
+ """
2
+ BTN (Bursty/Theta/Non-bursty) classification.
3
+
4
+ Workflow:
5
+ 1) Compute ISI autocorrelograms from spike times.
6
+ 2) Normalize and smooth each autocorr curve.
7
+ 3) Compute cosine distance between curves.
8
+ 4) Cluster with Tomato using a manual kNN graph and density weights.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from dataclasses import dataclass
14
+ from typing import Any
15
+
16
+ import numpy as np
17
+ from scipy.ndimage import gaussian_filter1d
18
+ from scipy.spatial.distance import pdist, squareform
19
+
20
+ try:
21
+ from gudhi.clustering.tomato import Tomato
22
+ except Exception as exc: # pragma: no cover - optional dependency
23
+ Tomato = None
24
+ _TOMATO_IMPORT_ERROR = exc
25
+ else:
26
+ _TOMATO_IMPORT_ERROR = None
27
+
28
+
29
+ @dataclass
30
+ class BTNConfig:
31
+ """Configuration for BTN clustering."""
32
+
33
+ maxt: float = 0.2
34
+ res: float = 1e-3
35
+ smooth_sigma: float = 4.0
36
+ nbs: int = 80
37
+ n_clusters: int = 4
38
+ metric: str = "cosine"
39
+ b_one: bool = True
40
+ b_log: bool = False
41
+
42
+
43
+ @dataclass
44
+ class BTNResult:
45
+ """Result container for BTN clustering."""
46
+
47
+ labels: np.ndarray
48
+ n_clusters: int
49
+ cluster_sizes: np.ndarray
50
+ btn_labels: np.ndarray | None
51
+ mapping: dict[int, str] | None
52
+ intermediates: dict[str, np.ndarray] | None
53
+
54
+
55
+ class BTNAnalyzer:
56
+ """Analyzer that clusters neurons into BTN groups using Tomato."""
57
+
58
+ def __init__(self, config: BTNConfig | None = None):
59
+ self.config = config or BTNConfig()
60
+
61
+ def classify_btn(
62
+ self,
63
+ spike_data: dict[str, Any],
64
+ *,
65
+ mapping: dict[int, str] | None = None,
66
+ return_intermediates: bool = False,
67
+ plot_diagram: bool = False,
68
+ ) -> BTNResult:
69
+ """Cluster neurons into BTN classes using ISI autocorr + Tomato.
70
+
71
+ Parameters
72
+ ----------
73
+ spike_data : dict
74
+ ASA-style dict with keys ``spike`` and ``t`` (and optionally x/y).
75
+ mapping : dict, optional
76
+ Optional mapping from cluster id to BTN label string.
77
+ return_intermediates : bool
78
+ If True, include intermediate arrays in the result.
79
+ plot_diagram : bool
80
+ If True, call Tomato.plot_diagram() for visual inspection.
81
+ """
82
+ _require_tomato()
83
+ spikes = _extract_spike_times(spike_data)
84
+ acorr, bin_times = _isi_autocorr(
85
+ spikes,
86
+ maxt=self.config.maxt,
87
+ res=self.config.res,
88
+ b_one=self.config.b_one,
89
+ b_log=self.config.b_log,
90
+ )
91
+ acorr_norm = _normalize_autocorr(acorr)
92
+ acorr_smooth = gaussian_filter1d(acorr_norm, sigma=self.config.smooth_sigma, axis=1)
93
+
94
+ dist = squareform(pdist(acorr_smooth, metric=self.config.metric))
95
+ num_nodes = dist.shape[0]
96
+ order = np.argsort(dist, axis=1)
97
+ if num_nodes > 1:
98
+ nbs_max = num_nodes - 1
99
+ nbs = max(1, min(int(self.config.nbs), nbs_max))
100
+ knn_indices = order[:, 1 : nbs + 1]
101
+ else:
102
+ nbs = 1
103
+ knn_indices = order[:, :1]
104
+ knn_dists = dist[np.arange(dist.shape[0])[:, None], knn_indices]
105
+ weights = np.sum(np.exp(-knn_dists), axis=1)
106
+
107
+ t = Tomato(graph_type="manual", density_type="manual", metric="precomputed")
108
+ t.fit(knn_indices, weights=weights)
109
+ if plot_diagram:
110
+ t.plot_diagram()
111
+ if self.config.n_clusters is not None:
112
+ t.n_clusters_ = int(self.config.n_clusters)
113
+ labels = np.asarray(t.labels_, dtype=int)
114
+ if labels.size:
115
+ if np.any(labels < 0):
116
+ valid = labels[labels >= 0]
117
+ cluster_sizes = np.bincount(valid) if valid.size else np.array([])
118
+ else:
119
+ cluster_sizes = np.bincount(labels)
120
+ else:
121
+ cluster_sizes = np.array([])
122
+
123
+ btn_labels = None
124
+ if mapping is not None:
125
+ btn_labels = np.array([mapping.get(int(c), "unknown") for c in labels], dtype=object)
126
+
127
+ intermediates = None
128
+ if return_intermediates:
129
+ intermediates = {
130
+ "acorr": acorr,
131
+ "acorr_norm": acorr_norm,
132
+ "acorr_smooth": acorr_smooth,
133
+ "distance_matrix": dist,
134
+ "knn_indices": knn_indices,
135
+ "knn_dists": knn_dists,
136
+ "density_weights": weights,
137
+ "bin_times": bin_times,
138
+ }
139
+
140
+ if self.config.n_clusters is not None:
141
+ n_clusters = int(self.config.n_clusters)
142
+ else:
143
+ if labels.size:
144
+ valid_labels = labels[labels >= 0]
145
+ n_clusters = int(np.unique(valid_labels).size)
146
+ else:
147
+ n_clusters = 0
148
+
149
+ return BTNResult(
150
+ labels=labels,
151
+ n_clusters=n_clusters,
152
+ cluster_sizes=cluster_sizes,
153
+ btn_labels=btn_labels,
154
+ mapping=mapping,
155
+ intermediates=intermediates,
156
+ )
157
+
158
+
159
+ def _require_tomato() -> None:
160
+ if Tomato is None:
161
+ raise ImportError(
162
+ "Tomato clustering requires gudhi. Install with: pip install gudhi"
163
+ ) from _TOMATO_IMPORT_ERROR
164
+
165
+
166
+ def _extract_spike_times(spike_data: dict[str, Any]) -> dict[int, np.ndarray]:
167
+ if not isinstance(spike_data, dict) or "spike" not in spike_data or "t" not in spike_data:
168
+ raise ValueError("spike_data must be a dict with keys 'spike' and 't'")
169
+
170
+ spike_raw = spike_data["spike"]
171
+ t = np.asarray(spike_data["t"])
172
+ if (
173
+ isinstance(spike_raw, np.ndarray)
174
+ and spike_raw.ndim == 2
175
+ and spike_raw.shape[0] == t.shape[0]
176
+ ):
177
+ raise ValueError(
178
+ "BTN expects spike times per neuron, but got a binned spike matrix (T x N)."
179
+ )
180
+ if (
181
+ hasattr(spike_raw, "item")
182
+ and callable(spike_raw.item)
183
+ and np.asarray(spike_raw).shape == ()
184
+ ):
185
+ spikes_all = spike_raw[()]
186
+ elif isinstance(spike_raw, dict):
187
+ spikes_all = spike_raw
188
+ elif isinstance(spike_raw, (list, np.ndarray)):
189
+ spikes_all = spike_raw
190
+ else:
191
+ spikes_all = spike_raw
192
+
193
+ min_time0 = np.min(t)
194
+ max_time0 = np.max(t)
195
+
196
+ spikes: dict[int, np.ndarray] = {}
197
+ if isinstance(spikes_all, dict):
198
+ for i, key in enumerate(spikes_all.keys()):
199
+ s = np.asarray(spikes_all[key])
200
+ spikes[i] = s[(s >= min_time0) & (s < max_time0)]
201
+ else:
202
+ cell_inds = np.arange(len(spikes_all))
203
+ for i, m in enumerate(cell_inds):
204
+ s = np.asarray(spikes_all[m]) if len(spikes_all[m]) > 0 else np.array([])
205
+ if s.size > 0:
206
+ spikes[i] = s[(s >= min_time0) & (s < max_time0)]
207
+ else:
208
+ spikes[i] = np.array([])
209
+
210
+ return spikes
211
+
212
+
213
+ def _isi_autocorr(
214
+ spk: dict[int, np.ndarray],
215
+ *,
216
+ maxt: float = 0.2,
217
+ res: float = 1e-3,
218
+ b_one: bool = True,
219
+ b_log: bool = False,
220
+ ) -> tuple[np.ndarray, np.ndarray]:
221
+ """Compute ISI autocorrelogram for each neuron.
222
+
223
+ Returns
224
+ -------
225
+ acorr : ndarray
226
+ Shape (N, num_bins).
227
+ bin_times : ndarray
228
+ Bin edges used for histogramming.
229
+ """
230
+ if b_log:
231
+ num_bins = 100
232
+ bin_times = np.ones(num_bins + 1) * 10
233
+ bin_times = np.power(bin_times, np.linspace(np.log10(0.005), np.log10(maxt), num_bins + 1))
234
+ bin_times = np.unique(np.concatenate((-bin_times, bin_times)))
235
+ num_bins = len(bin_times)
236
+ elif b_one:
237
+ num_bins = int(maxt / res) + 1
238
+ bin_times = np.linspace(0, maxt, num_bins)
239
+ else:
240
+ num_bins = int(2 * maxt / res) + 1
241
+ bin_times = np.linspace(-maxt, maxt, num_bins)
242
+
243
+ num_neurons = len(spk)
244
+ acorr = np.zeros((num_neurons, len(bin_times) - 1), dtype=int)
245
+
246
+ maxt = maxt - 1e-5
247
+ mint = -maxt
248
+ if b_one:
249
+ mint = -1e-5
250
+
251
+ for i, n in enumerate(spk):
252
+ spike_times = np.asarray(spk[n])
253
+ for ss in spike_times:
254
+ stemp = spike_times[(spike_times < ss + maxt) & (spike_times > ss + mint)]
255
+ dd = stemp - ss
256
+ if b_one:
257
+ dd = dd[dd >= 0]
258
+ bins = np.digitize(dd, bin_times) - 1
259
+ bins = bins[(bins >= 0) & (bins < num_bins - 1)]
260
+ if bins.size:
261
+ acorr[i, :] += np.bincount(bins, minlength=num_bins)[:-1]
262
+
263
+ return acorr, bin_times
264
+
265
+
266
+ def _normalize_autocorr(acorr: np.ndarray) -> np.ndarray:
267
+ acorr = acorr.astype(float, copy=True)
268
+ denom = acorr[:, 0].copy()
269
+ valid = denom > 0
270
+ acorr[valid, :] = acorr[valid, :] / denom[valid, None]
271
+ acorr[:, 0] = 0.0
272
+ return acorr
@@ -1,5 +1,6 @@
1
1
  """Visualization modules."""
2
2
 
3
+ from .btn_plots import plot_btn_autocorr_summary, plot_btn_distance_matrix
3
4
  from .grid_plots import (
4
5
  plot_autocorrelogram,
5
6
  plot_grid_score_histogram,
@@ -16,4 +17,6 @@ __all__ = [
16
17
  "plot_polar_tuning",
17
18
  "plot_temporal_autocorr",
18
19
  "plot_hd_analysis",
20
+ "plot_btn_distance_matrix",
21
+ "plot_btn_autocorr_summary",
19
22
  ]
@@ -0,0 +1,258 @@
1
+ """BTN visualization utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ from matplotlib import pyplot as plt
9
+ from scipy.ndimage import gaussian_filter1d
10
+
11
+ from canns.analyzer.visualization.core.config import PlotConfig, finalize_figure
12
+
13
+ _DEFAULT_BTN_COLORS = {
14
+ "B": "#1f77b4",
15
+ "T": "#000000",
16
+ "N": "#2ca02c",
17
+ }
18
+
19
+
20
+ def _ensure_plot_config(
21
+ config: PlotConfig | None,
22
+ factory,
23
+ *,
24
+ kwargs: dict[str, Any] | None = None,
25
+ **defaults: Any,
26
+ ) -> PlotConfig:
27
+ if config is None:
28
+ defaults.update({"kwargs": kwargs or {}})
29
+ return factory(**defaults)
30
+
31
+ if kwargs:
32
+ config_kwargs = config.kwargs or {}
33
+ config_kwargs.update(kwargs)
34
+ config.kwargs = config_kwargs
35
+ return config
36
+
37
+
38
+ def _canonical_label(label: str) -> str:
39
+ lab = label.strip().lower()
40
+ if lab in ("b", "bursty"):
41
+ return "B"
42
+ if lab in ("t", "theta", "theta-modulated", "theta_modulated", "theta modulated"):
43
+ return "T"
44
+ if lab in ("n", "nonbursty", "non-bursty", "non_bursty"):
45
+ return "N"
46
+ return label
47
+
48
+
49
+ def _cluster_order(labels: np.ndarray, mapping: dict[int, str] | None) -> list[int]:
50
+ cids = [int(c) for c in np.unique(labels)]
51
+ if mapping is None:
52
+ return sorted(cids)
53
+
54
+ def _key(cid: int) -> tuple[int, str]:
55
+ lab = _canonical_label(mapping.get(int(cid), str(cid)))
56
+ order = {"B": 0, "T": 1, "N": 2}.get(lab, 999)
57
+ return (order, str(lab))
58
+
59
+ return sorted(cids, key=_key)
60
+
61
+
62
+ def _label_color(
63
+ label: str,
64
+ colors: dict[str, str] | None,
65
+ fallback_idx: int,
66
+ ) -> str:
67
+ if colors and label in colors:
68
+ return colors[label]
69
+ if label in _DEFAULT_BTN_COLORS:
70
+ return _DEFAULT_BTN_COLORS[label]
71
+ cmap = plt.get_cmap("tab10")
72
+ return cmap(fallback_idx % 10)
73
+
74
+
75
+ def _normalize_rows(acorr: np.ndarray, mode: str | None) -> np.ndarray:
76
+ if mode is None or mode == "none":
77
+ return acorr
78
+ if mode == "probability":
79
+ denom = acorr.sum(axis=1, keepdims=True)
80
+ elif mode == "peak":
81
+ denom = acorr.max(axis=1, keepdims=True)
82
+ elif mode == "first":
83
+ denom = acorr[:, :1]
84
+ else:
85
+ raise ValueError(f"Unknown normalize mode: {mode!r}")
86
+ denom = np.where(denom == 0, 1.0, denom)
87
+ return acorr / denom
88
+
89
+
90
+ def plot_btn_distance_matrix(
91
+ *,
92
+ dist: np.ndarray | None = None,
93
+ labels: np.ndarray | None = None,
94
+ mapping: dict[int, str] | None = None,
95
+ sort_by_label: bool = True,
96
+ title: str = "BTN distance matrix",
97
+ cmap: str = "afmhot",
98
+ figsize: tuple[int, int] = (5, 5),
99
+ save_path: str | None = None,
100
+ show: bool = True,
101
+ ax: plt.Axes | None = None,
102
+ config: PlotConfig | None = None,
103
+ ) -> tuple[plt.Figure, plt.Axes, np.ndarray]:
104
+ """Plot a distance matrix heatmap sorted by BTN cluster labels."""
105
+ if dist is None or labels is None:
106
+ raise ValueError("dist and labels are required.")
107
+
108
+ labels = np.asarray(labels).astype(int)
109
+
110
+ if sort_by_label:
111
+ cids = _cluster_order(labels, mapping)
112
+ order = np.concatenate([np.where(labels == c)[0] for c in cids])
113
+ else:
114
+ order = np.arange(len(labels))
115
+
116
+ dist_sorted = dist[np.ix_(order, order)]
117
+
118
+ config = _ensure_plot_config(
119
+ config,
120
+ PlotConfig.for_static_plot,
121
+ title=title,
122
+ figsize=figsize,
123
+ save_path=save_path,
124
+ show=show,
125
+ kwargs={},
126
+ )
127
+
128
+ created_fig = False
129
+ if ax is None:
130
+ fig, ax = plt.subplots(1, 1, figsize=config.figsize)
131
+ created_fig = True
132
+ else:
133
+ fig = ax.figure
134
+
135
+ im = ax.imshow(dist_sorted, cmap=cmap, origin="lower", interpolation="nearest")
136
+ fig.colorbar(im, ax=ax, label="Cosine distance")
137
+ ax.set_title(config.title)
138
+ ax.set_xlabel("Neuron")
139
+ ax.set_ylabel("Neuron")
140
+
141
+ if sort_by_label:
142
+ sizes = [np.sum(labels == c) for c in cids]
143
+ boundaries = np.cumsum(sizes)[:-1]
144
+ for b in boundaries:
145
+ ax.axhline(b - 0.5, color="w", linewidth=0.6, alpha=0.7)
146
+ ax.axvline(b - 0.5, color="w", linewidth=0.6, alpha=0.7)
147
+
148
+ if created_fig:
149
+ fig.tight_layout()
150
+ finalize_figure(fig, config, rasterize_artists=[im] if config.rasterized else None)
151
+
152
+ return fig, ax, order
153
+
154
+
155
+ def plot_btn_autocorr_summary(
156
+ *,
157
+ acorr: np.ndarray | None = None,
158
+ labels: np.ndarray | None = None,
159
+ bin_times: np.ndarray | None = None,
160
+ res: float | None = None,
161
+ mapping: dict[int, str] | None = None,
162
+ colors: dict[str, str] | None = None,
163
+ normalize: str | None = "probability",
164
+ smooth_sigma: float | None = None,
165
+ long_max_ms: float | None = 200.0,
166
+ short_max_ms: float | None = None,
167
+ title: str = "BTN temporal autocorr",
168
+ figsize: tuple[int, int] = (8, 3),
169
+ save_path: str | None = None,
170
+ show: bool = True,
171
+ config: PlotConfig | None = None,
172
+ ) -> plt.Figure:
173
+ """Plot class-averaged ISI autocorr curves (mean +/- SEM)."""
174
+ if acorr is None or labels is None:
175
+ raise ValueError("acorr and labels are required.")
176
+
177
+ labels = np.asarray(labels).astype(int)
178
+ acorr = np.asarray(acorr)
179
+ acorr_plot = _normalize_rows(acorr.astype(float, copy=False), normalize)
180
+ if smooth_sigma is not None:
181
+ acorr_plot = gaussian_filter1d(acorr_plot, sigma=float(smooth_sigma), axis=1)
182
+
183
+ if bin_times is not None:
184
+ bin_times = np.asarray(bin_times)
185
+ x = 0.5 * (bin_times[:-1] + bin_times[1:])
186
+ elif res is not None:
187
+ x = np.arange(acorr.shape[1]) * float(res)
188
+ else:
189
+ raise ValueError("Provide bin_times or res to define lag axis.")
190
+
191
+ x_ms = x * 1000.0
192
+
193
+ cids = _cluster_order(labels, mapping)
194
+ label_strings = []
195
+ for c in cids:
196
+ if mapping is not None:
197
+ label_strings.append(_canonical_label(mapping.get(int(c), str(c))))
198
+ else:
199
+ label_strings.append(str(c))
200
+
201
+ show_short = short_max_ms is not None
202
+ ncols = 2 if show_short else 1
203
+
204
+ config = _ensure_plot_config(
205
+ config,
206
+ PlotConfig.for_static_plot,
207
+ title=title,
208
+ figsize=figsize if ncols == 1 else (figsize[0] * 1.6, figsize[1]),
209
+ save_path=save_path,
210
+ show=show,
211
+ kwargs={},
212
+ )
213
+
214
+ fig, axes = plt.subplots(1, ncols, figsize=config.figsize)
215
+ if ncols == 1:
216
+ axes = [axes]
217
+
218
+ def _plot_panel(ax: plt.Axes, max_ms: float | None, panel_title: str):
219
+ if max_ms is None:
220
+ mask = np.ones_like(x_ms, dtype=bool)
221
+ else:
222
+ mask = x_ms <= max_ms
223
+
224
+ for idx, (cid, label_str) in enumerate(zip(cids, label_strings, strict=False)):
225
+ rows = acorr_plot[labels == cid]
226
+ if rows.size == 0:
227
+ continue
228
+ mean = rows.mean(axis=0)
229
+ sem = rows.std(axis=0) / np.sqrt(rows.shape[0])
230
+ color = _label_color(label_str, colors, idx)
231
+ ax.plot(x_ms[mask], mean[mask], color=color, lw=2, label=label_str)
232
+ ax.fill_between(
233
+ x_ms[mask],
234
+ (mean - sem)[mask],
235
+ (mean + sem)[mask],
236
+ color=color,
237
+ alpha=0.25,
238
+ linewidth=0,
239
+ )
240
+
241
+ ax.set_xlabel("Lag (ms)")
242
+ ax.set_title(panel_title)
243
+ ax.grid(False)
244
+
245
+ _plot_panel(axes[0], long_max_ms, "Long lag")
246
+ ylabel = "Probability" if normalize == "probability" else "Autocorr (norm)"
247
+ axes[0].set_ylabel(ylabel)
248
+
249
+ if show_short:
250
+ _plot_panel(axes[1], float(short_max_ms), "Short lag")
251
+
252
+ for ax in axes:
253
+ ax.legend(frameon=False)
254
+
255
+ fig.suptitle(config.title)
256
+ fig.tight_layout()
257
+ finalize_figure(fig, config)
258
+ return fig
@@ -34,6 +34,7 @@ from .theta_sweep_plots import (
34
34
  create_theta_sweep_grid_cell_animation,
35
35
  create_theta_sweep_place_cell_animation,
36
36
  plot_grid_cell_manifold,
37
+ plot_internal_position_trajectory,
37
38
  plot_population_activity_with_theta,
38
39
  )
39
40
  from .tuning_plots import tuning_curve
@@ -70,5 +71,6 @@ __all__ = [
70
71
  "create_theta_sweep_grid_cell_animation",
71
72
  "create_theta_sweep_place_cell_animation",
72
73
  "plot_grid_cell_manifold",
74
+ "plot_internal_position_trajectory",
73
75
  "plot_population_activity_with_theta",
74
76
  ]
@@ -469,6 +469,26 @@ class PlotConfigs:
469
469
  defaults.update(kwargs)
470
470
  return PlotConfig.for_static_plot(**defaults)
471
471
 
472
+ @staticmethod
473
+ def internal_position_trajectory_static(**kwargs: Any) -> PlotConfig:
474
+ defaults: dict[str, Any] = {
475
+ "title": "Internal Position vs. Real Trajectory",
476
+ "figsize": (6, 4),
477
+ }
478
+ plot_kwargs: dict[str, Any] = {
479
+ "cmap": "cool",
480
+ "add_colorbar": True,
481
+ "colorbar": {"label": "Max GC activity"},
482
+ "trajectory_color": "black",
483
+ "trajectory_linewidth": 1.0,
484
+ "scatter_size": 4,
485
+ "scatter_alpha": 0.9,
486
+ }
487
+ plot_kwargs.update(kwargs.pop("kwargs", {}))
488
+ defaults["kwargs"] = plot_kwargs
489
+ defaults.update(kwargs)
490
+ return PlotConfig.for_static_plot(**defaults)
491
+
472
492
  @staticmethod
473
493
  def direction_cell_polar(**kwargs: Any) -> PlotConfig:
474
494
  """Configuration for direction cell polar plot visualization.