canns 0.14.2__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/asa/__init__.py +77 -21
- canns/analyzer/data/asa/coho.py +97 -0
- canns/analyzer/data/asa/cohomap.py +408 -0
- canns/analyzer/data/asa/cohomap_scatter.py +10 -0
- canns/analyzer/data/asa/cohomap_vectors.py +311 -0
- canns/analyzer/data/asa/cohospace.py +173 -1153
- canns/analyzer/data/asa/cohospace_phase_centers.py +137 -0
- canns/analyzer/data/asa/cohospace_scatter.py +1220 -0
- canns/analyzer/data/asa/embedding.py +3 -4
- canns/analyzer/data/asa/plotting.py +4 -4
- canns/analyzer/data/cell_classification/__init__.py +10 -0
- canns/analyzer/data/cell_classification/core/__init__.py +4 -0
- canns/analyzer/data/cell_classification/core/btn.py +272 -0
- canns/analyzer/data/cell_classification/visualization/__init__.py +3 -0
- canns/analyzer/data/cell_classification/visualization/btn_plots.py +258 -0
- canns/analyzer/visualization/__init__.py +2 -0
- canns/analyzer/visualization/core/config.py +20 -0
- canns/analyzer/visualization/theta_sweep_plots.py +142 -0
- canns/pipeline/asa/runner.py +19 -19
- canns/pipeline/asa_gui/__init__.py +5 -3
- canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +32 -4
- canns/pipeline/asa_gui/core/runner.py +23 -23
- canns/pipeline/asa_gui/views/pages/preprocess_page.py +250 -8
- {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/METADATA +2 -1
- {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/RECORD +28 -20
- {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/WHEEL +0 -0
- {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/entry_points.txt +0 -0
- {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -31,10 +31,9 @@ def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None,
|
|
|
31
31
|
Returns
|
|
32
32
|
-------
|
|
33
33
|
tuple
|
|
34
|
-
``(spikes_bin, xx, yy, tt)``
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
otherwise ``None``.
|
|
34
|
+
``(spikes_bin, xx, yy, tt)``. ``spikes_bin`` is a (T, N) binned spike matrix.
|
|
35
|
+
``xx``, ``yy``, ``tt`` are position/time arrays when ``speed_filter=True``,
|
|
36
|
+
otherwise ``None``.
|
|
38
37
|
|
|
39
38
|
Examples
|
|
40
39
|
--------
|
|
@@ -403,7 +403,7 @@ def plot_path_compare_1d(
|
|
|
403
403
|
return fig, axes
|
|
404
404
|
|
|
405
405
|
|
|
406
|
-
def
|
|
406
|
+
def plot_cohomap_scatter(
|
|
407
407
|
decoding_result: dict[str, Any],
|
|
408
408
|
position_data: dict[str, Any],
|
|
409
409
|
config: PlotConfig | None = None,
|
|
@@ -453,7 +453,7 @@ def plot_cohomap(
|
|
|
453
453
|
>>> # Decode coordinates
|
|
454
454
|
>>> decoding = decode_circular_coordinates(persistence_result, spike_data)
|
|
455
455
|
>>> # Visualize with trajectory data
|
|
456
|
-
>>> fig =
|
|
456
|
+
>>> fig = plot_cohomap_scatter(
|
|
457
457
|
... decoding,
|
|
458
458
|
... position_data={'x': xx, 'y': yy},
|
|
459
459
|
... save_path='cohomap.png',
|
|
@@ -518,7 +518,7 @@ def plot_cohomap(
|
|
|
518
518
|
return fig
|
|
519
519
|
|
|
520
520
|
|
|
521
|
-
def
|
|
521
|
+
def plot_cohomap_scatter_multi(
|
|
522
522
|
decoding_result: dict,
|
|
523
523
|
position_data: dict,
|
|
524
524
|
config: PlotConfig | None = None,
|
|
@@ -560,7 +560,7 @@ def plot_cohomap_multi(
|
|
|
560
560
|
|
|
561
561
|
Examples
|
|
562
562
|
--------
|
|
563
|
-
>>> fig =
|
|
563
|
+
>>> fig = plot_cohomap_scatter_multi(decoding, {"x": xx, "y": yy}, show=False) # doctest: +SKIP
|
|
564
564
|
"""
|
|
565
565
|
config = _ensure_plot_config(
|
|
566
566
|
config,
|
|
@@ -11,6 +11,9 @@ Vollan, Gardner, Moser & Moser (Nature, 2025)
|
|
|
11
11
|
__version__ = "0.1.0"
|
|
12
12
|
|
|
13
13
|
from .core import ( # noqa: F401
|
|
14
|
+
BTNAnalyzer,
|
|
15
|
+
BTNConfig,
|
|
16
|
+
BTNResult,
|
|
14
17
|
GridnessAnalyzer,
|
|
15
18
|
GridnessResult,
|
|
16
19
|
HDCellResult,
|
|
@@ -46,6 +49,8 @@ from .utils import ( # noqa: F401
|
|
|
46
49
|
)
|
|
47
50
|
from .visualization import ( # noqa: F401
|
|
48
51
|
plot_autocorrelogram,
|
|
52
|
+
plot_btn_autocorr_summary,
|
|
53
|
+
plot_btn_distance_matrix,
|
|
49
54
|
plot_grid_score_histogram,
|
|
50
55
|
plot_gridness_analysis,
|
|
51
56
|
plot_hd_analysis,
|
|
@@ -57,6 +62,9 @@ from .visualization import ( # noqa: F401
|
|
|
57
62
|
__all__ = [
|
|
58
63
|
"GridnessAnalyzer",
|
|
59
64
|
"GridnessResult",
|
|
65
|
+
"BTNAnalyzer",
|
|
66
|
+
"BTNConfig",
|
|
67
|
+
"BTNResult",
|
|
60
68
|
"HeadDirectionAnalyzer",
|
|
61
69
|
"HDCellResult",
|
|
62
70
|
"compute_2d_autocorrelation",
|
|
@@ -94,4 +102,6 @@ __all__ = [
|
|
|
94
102
|
"plot_polar_tuning",
|
|
95
103
|
"plot_temporal_autocorr",
|
|
96
104
|
"plot_hd_analysis",
|
|
105
|
+
"plot_btn_autocorr_summary",
|
|
106
|
+
"plot_btn_distance_matrix",
|
|
97
107
|
]
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Core analysis modules."""
|
|
2
2
|
|
|
3
|
+
from .btn import BTNAnalyzer, BTNConfig, BTNResult
|
|
3
4
|
from .grid_cells import GridnessAnalyzer, GridnessResult, compute_2d_autocorrelation
|
|
4
5
|
from .grid_modules_leiden import identify_grid_modules_and_stats
|
|
5
6
|
from .head_direction import HDCellResult, HeadDirectionAnalyzer
|
|
@@ -15,6 +16,9 @@ __all__ = [
|
|
|
15
16
|
"GridnessAnalyzer",
|
|
16
17
|
"compute_2d_autocorrelation",
|
|
17
18
|
"GridnessResult",
|
|
19
|
+
"BTNAnalyzer",
|
|
20
|
+
"BTNConfig",
|
|
21
|
+
"BTNResult",
|
|
18
22
|
"HeadDirectionAnalyzer",
|
|
19
23
|
"HDCellResult",
|
|
20
24
|
"compute_rate_map",
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
"""
|
|
2
|
+
BTN (Bursty/Theta/Non-bursty) classification.
|
|
3
|
+
|
|
4
|
+
Workflow:
|
|
5
|
+
1) Compute ISI autocorrelograms from spike times.
|
|
6
|
+
2) Normalize and smooth each autocorr curve.
|
|
7
|
+
3) Compute cosine distance between curves.
|
|
8
|
+
4) Cluster with Tomato using a manual kNN graph and density weights.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
from scipy.ndimage import gaussian_filter1d
|
|
18
|
+
from scipy.spatial.distance import pdist, squareform
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
from gudhi.clustering.tomato import Tomato
|
|
22
|
+
except Exception as exc: # pragma: no cover - optional dependency
|
|
23
|
+
Tomato = None
|
|
24
|
+
_TOMATO_IMPORT_ERROR = exc
|
|
25
|
+
else:
|
|
26
|
+
_TOMATO_IMPORT_ERROR = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class BTNConfig:
|
|
31
|
+
"""Configuration for BTN clustering."""
|
|
32
|
+
|
|
33
|
+
maxt: float = 0.2
|
|
34
|
+
res: float = 1e-3
|
|
35
|
+
smooth_sigma: float = 4.0
|
|
36
|
+
nbs: int = 80
|
|
37
|
+
n_clusters: int = 4
|
|
38
|
+
metric: str = "cosine"
|
|
39
|
+
b_one: bool = True
|
|
40
|
+
b_log: bool = False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class BTNResult:
|
|
45
|
+
"""Result container for BTN clustering."""
|
|
46
|
+
|
|
47
|
+
labels: np.ndarray
|
|
48
|
+
n_clusters: int
|
|
49
|
+
cluster_sizes: np.ndarray
|
|
50
|
+
btn_labels: np.ndarray | None
|
|
51
|
+
mapping: dict[int, str] | None
|
|
52
|
+
intermediates: dict[str, np.ndarray] | None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class BTNAnalyzer:
|
|
56
|
+
"""Analyzer that clusters neurons into BTN groups using Tomato."""
|
|
57
|
+
|
|
58
|
+
def __init__(self, config: BTNConfig | None = None):
|
|
59
|
+
self.config = config or BTNConfig()
|
|
60
|
+
|
|
61
|
+
def classify_btn(
|
|
62
|
+
self,
|
|
63
|
+
spike_data: dict[str, Any],
|
|
64
|
+
*,
|
|
65
|
+
mapping: dict[int, str] | None = None,
|
|
66
|
+
return_intermediates: bool = False,
|
|
67
|
+
plot_diagram: bool = False,
|
|
68
|
+
) -> BTNResult:
|
|
69
|
+
"""Cluster neurons into BTN classes using ISI autocorr + Tomato.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
spike_data : dict
|
|
74
|
+
ASA-style dict with keys ``spike`` and ``t`` (and optionally x/y).
|
|
75
|
+
mapping : dict, optional
|
|
76
|
+
Optional mapping from cluster id to BTN label string.
|
|
77
|
+
return_intermediates : bool
|
|
78
|
+
If True, include intermediate arrays in the result.
|
|
79
|
+
plot_diagram : bool
|
|
80
|
+
If True, call Tomato.plot_diagram() for visual inspection.
|
|
81
|
+
"""
|
|
82
|
+
_require_tomato()
|
|
83
|
+
spikes = _extract_spike_times(spike_data)
|
|
84
|
+
acorr, bin_times = _isi_autocorr(
|
|
85
|
+
spikes,
|
|
86
|
+
maxt=self.config.maxt,
|
|
87
|
+
res=self.config.res,
|
|
88
|
+
b_one=self.config.b_one,
|
|
89
|
+
b_log=self.config.b_log,
|
|
90
|
+
)
|
|
91
|
+
acorr_norm = _normalize_autocorr(acorr)
|
|
92
|
+
acorr_smooth = gaussian_filter1d(acorr_norm, sigma=self.config.smooth_sigma, axis=1)
|
|
93
|
+
|
|
94
|
+
dist = squareform(pdist(acorr_smooth, metric=self.config.metric))
|
|
95
|
+
num_nodes = dist.shape[0]
|
|
96
|
+
order = np.argsort(dist, axis=1)
|
|
97
|
+
if num_nodes > 1:
|
|
98
|
+
nbs_max = num_nodes - 1
|
|
99
|
+
nbs = max(1, min(int(self.config.nbs), nbs_max))
|
|
100
|
+
knn_indices = order[:, 1 : nbs + 1]
|
|
101
|
+
else:
|
|
102
|
+
nbs = 1
|
|
103
|
+
knn_indices = order[:, :1]
|
|
104
|
+
knn_dists = dist[np.arange(dist.shape[0])[:, None], knn_indices]
|
|
105
|
+
weights = np.sum(np.exp(-knn_dists), axis=1)
|
|
106
|
+
|
|
107
|
+
t = Tomato(graph_type="manual", density_type="manual", metric="precomputed")
|
|
108
|
+
t.fit(knn_indices, weights=weights)
|
|
109
|
+
if plot_diagram:
|
|
110
|
+
t.plot_diagram()
|
|
111
|
+
if self.config.n_clusters is not None:
|
|
112
|
+
t.n_clusters_ = int(self.config.n_clusters)
|
|
113
|
+
labels = np.asarray(t.labels_, dtype=int)
|
|
114
|
+
if labels.size:
|
|
115
|
+
if np.any(labels < 0):
|
|
116
|
+
valid = labels[labels >= 0]
|
|
117
|
+
cluster_sizes = np.bincount(valid) if valid.size else np.array([])
|
|
118
|
+
else:
|
|
119
|
+
cluster_sizes = np.bincount(labels)
|
|
120
|
+
else:
|
|
121
|
+
cluster_sizes = np.array([])
|
|
122
|
+
|
|
123
|
+
btn_labels = None
|
|
124
|
+
if mapping is not None:
|
|
125
|
+
btn_labels = np.array([mapping.get(int(c), "unknown") for c in labels], dtype=object)
|
|
126
|
+
|
|
127
|
+
intermediates = None
|
|
128
|
+
if return_intermediates:
|
|
129
|
+
intermediates = {
|
|
130
|
+
"acorr": acorr,
|
|
131
|
+
"acorr_norm": acorr_norm,
|
|
132
|
+
"acorr_smooth": acorr_smooth,
|
|
133
|
+
"distance_matrix": dist,
|
|
134
|
+
"knn_indices": knn_indices,
|
|
135
|
+
"knn_dists": knn_dists,
|
|
136
|
+
"density_weights": weights,
|
|
137
|
+
"bin_times": bin_times,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
if self.config.n_clusters is not None:
|
|
141
|
+
n_clusters = int(self.config.n_clusters)
|
|
142
|
+
else:
|
|
143
|
+
if labels.size:
|
|
144
|
+
valid_labels = labels[labels >= 0]
|
|
145
|
+
n_clusters = int(np.unique(valid_labels).size)
|
|
146
|
+
else:
|
|
147
|
+
n_clusters = 0
|
|
148
|
+
|
|
149
|
+
return BTNResult(
|
|
150
|
+
labels=labels,
|
|
151
|
+
n_clusters=n_clusters,
|
|
152
|
+
cluster_sizes=cluster_sizes,
|
|
153
|
+
btn_labels=btn_labels,
|
|
154
|
+
mapping=mapping,
|
|
155
|
+
intermediates=intermediates,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _require_tomato() -> None:
|
|
160
|
+
if Tomato is None:
|
|
161
|
+
raise ImportError(
|
|
162
|
+
"Tomato clustering requires gudhi. Install with: pip install gudhi"
|
|
163
|
+
) from _TOMATO_IMPORT_ERROR
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _extract_spike_times(spike_data: dict[str, Any]) -> dict[int, np.ndarray]:
|
|
167
|
+
if not isinstance(spike_data, dict) or "spike" not in spike_data or "t" not in spike_data:
|
|
168
|
+
raise ValueError("spike_data must be a dict with keys 'spike' and 't'")
|
|
169
|
+
|
|
170
|
+
spike_raw = spike_data["spike"]
|
|
171
|
+
t = np.asarray(spike_data["t"])
|
|
172
|
+
if (
|
|
173
|
+
isinstance(spike_raw, np.ndarray)
|
|
174
|
+
and spike_raw.ndim == 2
|
|
175
|
+
and spike_raw.shape[0] == t.shape[0]
|
|
176
|
+
):
|
|
177
|
+
raise ValueError(
|
|
178
|
+
"BTN expects spike times per neuron, but got a binned spike matrix (T x N)."
|
|
179
|
+
)
|
|
180
|
+
if (
|
|
181
|
+
hasattr(spike_raw, "item")
|
|
182
|
+
and callable(spike_raw.item)
|
|
183
|
+
and np.asarray(spike_raw).shape == ()
|
|
184
|
+
):
|
|
185
|
+
spikes_all = spike_raw[()]
|
|
186
|
+
elif isinstance(spike_raw, dict):
|
|
187
|
+
spikes_all = spike_raw
|
|
188
|
+
elif isinstance(spike_raw, (list, np.ndarray)):
|
|
189
|
+
spikes_all = spike_raw
|
|
190
|
+
else:
|
|
191
|
+
spikes_all = spike_raw
|
|
192
|
+
|
|
193
|
+
min_time0 = np.min(t)
|
|
194
|
+
max_time0 = np.max(t)
|
|
195
|
+
|
|
196
|
+
spikes: dict[int, np.ndarray] = {}
|
|
197
|
+
if isinstance(spikes_all, dict):
|
|
198
|
+
for i, key in enumerate(spikes_all.keys()):
|
|
199
|
+
s = np.asarray(spikes_all[key])
|
|
200
|
+
spikes[i] = s[(s >= min_time0) & (s < max_time0)]
|
|
201
|
+
else:
|
|
202
|
+
cell_inds = np.arange(len(spikes_all))
|
|
203
|
+
for i, m in enumerate(cell_inds):
|
|
204
|
+
s = np.asarray(spikes_all[m]) if len(spikes_all[m]) > 0 else np.array([])
|
|
205
|
+
if s.size > 0:
|
|
206
|
+
spikes[i] = s[(s >= min_time0) & (s < max_time0)]
|
|
207
|
+
else:
|
|
208
|
+
spikes[i] = np.array([])
|
|
209
|
+
|
|
210
|
+
return spikes
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _isi_autocorr(
|
|
214
|
+
spk: dict[int, np.ndarray],
|
|
215
|
+
*,
|
|
216
|
+
maxt: float = 0.2,
|
|
217
|
+
res: float = 1e-3,
|
|
218
|
+
b_one: bool = True,
|
|
219
|
+
b_log: bool = False,
|
|
220
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
221
|
+
"""Compute ISI autocorrelogram for each neuron.
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
acorr : ndarray
|
|
226
|
+
Shape (N, num_bins).
|
|
227
|
+
bin_times : ndarray
|
|
228
|
+
Bin edges used for histogramming.
|
|
229
|
+
"""
|
|
230
|
+
if b_log:
|
|
231
|
+
num_bins = 100
|
|
232
|
+
bin_times = np.ones(num_bins + 1) * 10
|
|
233
|
+
bin_times = np.power(bin_times, np.linspace(np.log10(0.005), np.log10(maxt), num_bins + 1))
|
|
234
|
+
bin_times = np.unique(np.concatenate((-bin_times, bin_times)))
|
|
235
|
+
num_bins = len(bin_times)
|
|
236
|
+
elif b_one:
|
|
237
|
+
num_bins = int(maxt / res) + 1
|
|
238
|
+
bin_times = np.linspace(0, maxt, num_bins)
|
|
239
|
+
else:
|
|
240
|
+
num_bins = int(2 * maxt / res) + 1
|
|
241
|
+
bin_times = np.linspace(-maxt, maxt, num_bins)
|
|
242
|
+
|
|
243
|
+
num_neurons = len(spk)
|
|
244
|
+
acorr = np.zeros((num_neurons, len(bin_times) - 1), dtype=int)
|
|
245
|
+
|
|
246
|
+
maxt = maxt - 1e-5
|
|
247
|
+
mint = -maxt
|
|
248
|
+
if b_one:
|
|
249
|
+
mint = -1e-5
|
|
250
|
+
|
|
251
|
+
for i, n in enumerate(spk):
|
|
252
|
+
spike_times = np.asarray(spk[n])
|
|
253
|
+
for ss in spike_times:
|
|
254
|
+
stemp = spike_times[(spike_times < ss + maxt) & (spike_times > ss + mint)]
|
|
255
|
+
dd = stemp - ss
|
|
256
|
+
if b_one:
|
|
257
|
+
dd = dd[dd >= 0]
|
|
258
|
+
bins = np.digitize(dd, bin_times) - 1
|
|
259
|
+
bins = bins[(bins >= 0) & (bins < num_bins - 1)]
|
|
260
|
+
if bins.size:
|
|
261
|
+
acorr[i, :] += np.bincount(bins, minlength=num_bins)[:-1]
|
|
262
|
+
|
|
263
|
+
return acorr, bin_times
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _normalize_autocorr(acorr: np.ndarray) -> np.ndarray:
|
|
267
|
+
acorr = acorr.astype(float, copy=True)
|
|
268
|
+
denom = acorr[:, 0].copy()
|
|
269
|
+
valid = denom > 0
|
|
270
|
+
acorr[valid, :] = acorr[valid, :] / denom[valid, None]
|
|
271
|
+
acorr[:, 0] = 0.0
|
|
272
|
+
return acorr
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Visualization modules."""
|
|
2
2
|
|
|
3
|
+
from .btn_plots import plot_btn_autocorr_summary, plot_btn_distance_matrix
|
|
3
4
|
from .grid_plots import (
|
|
4
5
|
plot_autocorrelogram,
|
|
5
6
|
plot_grid_score_histogram,
|
|
@@ -16,4 +17,6 @@ __all__ = [
|
|
|
16
17
|
"plot_polar_tuning",
|
|
17
18
|
"plot_temporal_autocorr",
|
|
18
19
|
"plot_hd_analysis",
|
|
20
|
+
"plot_btn_distance_matrix",
|
|
21
|
+
"plot_btn_autocorr_summary",
|
|
19
22
|
]
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
"""BTN visualization utilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from matplotlib import pyplot as plt
|
|
9
|
+
from scipy.ndimage import gaussian_filter1d
|
|
10
|
+
|
|
11
|
+
from canns.analyzer.visualization.core.config import PlotConfig, finalize_figure
|
|
12
|
+
|
|
13
|
+
_DEFAULT_BTN_COLORS = {
|
|
14
|
+
"B": "#1f77b4",
|
|
15
|
+
"T": "#000000",
|
|
16
|
+
"N": "#2ca02c",
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _ensure_plot_config(
|
|
21
|
+
config: PlotConfig | None,
|
|
22
|
+
factory,
|
|
23
|
+
*,
|
|
24
|
+
kwargs: dict[str, Any] | None = None,
|
|
25
|
+
**defaults: Any,
|
|
26
|
+
) -> PlotConfig:
|
|
27
|
+
if config is None:
|
|
28
|
+
defaults.update({"kwargs": kwargs or {}})
|
|
29
|
+
return factory(**defaults)
|
|
30
|
+
|
|
31
|
+
if kwargs:
|
|
32
|
+
config_kwargs = config.kwargs or {}
|
|
33
|
+
config_kwargs.update(kwargs)
|
|
34
|
+
config.kwargs = config_kwargs
|
|
35
|
+
return config
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _canonical_label(label: str) -> str:
|
|
39
|
+
lab = label.strip().lower()
|
|
40
|
+
if lab in ("b", "bursty"):
|
|
41
|
+
return "B"
|
|
42
|
+
if lab in ("t", "theta", "theta-modulated", "theta_modulated", "theta modulated"):
|
|
43
|
+
return "T"
|
|
44
|
+
if lab in ("n", "nonbursty", "non-bursty", "non_bursty"):
|
|
45
|
+
return "N"
|
|
46
|
+
return label
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _cluster_order(labels: np.ndarray, mapping: dict[int, str] | None) -> list[int]:
|
|
50
|
+
cids = [int(c) for c in np.unique(labels)]
|
|
51
|
+
if mapping is None:
|
|
52
|
+
return sorted(cids)
|
|
53
|
+
|
|
54
|
+
def _key(cid: int) -> tuple[int, str]:
|
|
55
|
+
lab = _canonical_label(mapping.get(int(cid), str(cid)))
|
|
56
|
+
order = {"B": 0, "T": 1, "N": 2}.get(lab, 999)
|
|
57
|
+
return (order, str(lab))
|
|
58
|
+
|
|
59
|
+
return sorted(cids, key=_key)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _label_color(
|
|
63
|
+
label: str,
|
|
64
|
+
colors: dict[str, str] | None,
|
|
65
|
+
fallback_idx: int,
|
|
66
|
+
) -> str:
|
|
67
|
+
if colors and label in colors:
|
|
68
|
+
return colors[label]
|
|
69
|
+
if label in _DEFAULT_BTN_COLORS:
|
|
70
|
+
return _DEFAULT_BTN_COLORS[label]
|
|
71
|
+
cmap = plt.get_cmap("tab10")
|
|
72
|
+
return cmap(fallback_idx % 10)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _normalize_rows(acorr: np.ndarray, mode: str | None) -> np.ndarray:
|
|
76
|
+
if mode is None or mode == "none":
|
|
77
|
+
return acorr
|
|
78
|
+
if mode == "probability":
|
|
79
|
+
denom = acorr.sum(axis=1, keepdims=True)
|
|
80
|
+
elif mode == "peak":
|
|
81
|
+
denom = acorr.max(axis=1, keepdims=True)
|
|
82
|
+
elif mode == "first":
|
|
83
|
+
denom = acorr[:, :1]
|
|
84
|
+
else:
|
|
85
|
+
raise ValueError(f"Unknown normalize mode: {mode!r}")
|
|
86
|
+
denom = np.where(denom == 0, 1.0, denom)
|
|
87
|
+
return acorr / denom
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def plot_btn_distance_matrix(
|
|
91
|
+
*,
|
|
92
|
+
dist: np.ndarray | None = None,
|
|
93
|
+
labels: np.ndarray | None = None,
|
|
94
|
+
mapping: dict[int, str] | None = None,
|
|
95
|
+
sort_by_label: bool = True,
|
|
96
|
+
title: str = "BTN distance matrix",
|
|
97
|
+
cmap: str = "afmhot",
|
|
98
|
+
figsize: tuple[int, int] = (5, 5),
|
|
99
|
+
save_path: str | None = None,
|
|
100
|
+
show: bool = True,
|
|
101
|
+
ax: plt.Axes | None = None,
|
|
102
|
+
config: PlotConfig | None = None,
|
|
103
|
+
) -> tuple[plt.Figure, plt.Axes, np.ndarray]:
|
|
104
|
+
"""Plot a distance matrix heatmap sorted by BTN cluster labels."""
|
|
105
|
+
if dist is None or labels is None:
|
|
106
|
+
raise ValueError("dist and labels are required.")
|
|
107
|
+
|
|
108
|
+
labels = np.asarray(labels).astype(int)
|
|
109
|
+
|
|
110
|
+
if sort_by_label:
|
|
111
|
+
cids = _cluster_order(labels, mapping)
|
|
112
|
+
order = np.concatenate([np.where(labels == c)[0] for c in cids])
|
|
113
|
+
else:
|
|
114
|
+
order = np.arange(len(labels))
|
|
115
|
+
|
|
116
|
+
dist_sorted = dist[np.ix_(order, order)]
|
|
117
|
+
|
|
118
|
+
config = _ensure_plot_config(
|
|
119
|
+
config,
|
|
120
|
+
PlotConfig.for_static_plot,
|
|
121
|
+
title=title,
|
|
122
|
+
figsize=figsize,
|
|
123
|
+
save_path=save_path,
|
|
124
|
+
show=show,
|
|
125
|
+
kwargs={},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
created_fig = False
|
|
129
|
+
if ax is None:
|
|
130
|
+
fig, ax = plt.subplots(1, 1, figsize=config.figsize)
|
|
131
|
+
created_fig = True
|
|
132
|
+
else:
|
|
133
|
+
fig = ax.figure
|
|
134
|
+
|
|
135
|
+
im = ax.imshow(dist_sorted, cmap=cmap, origin="lower", interpolation="nearest")
|
|
136
|
+
fig.colorbar(im, ax=ax, label="Cosine distance")
|
|
137
|
+
ax.set_title(config.title)
|
|
138
|
+
ax.set_xlabel("Neuron")
|
|
139
|
+
ax.set_ylabel("Neuron")
|
|
140
|
+
|
|
141
|
+
if sort_by_label:
|
|
142
|
+
sizes = [np.sum(labels == c) for c in cids]
|
|
143
|
+
boundaries = np.cumsum(sizes)[:-1]
|
|
144
|
+
for b in boundaries:
|
|
145
|
+
ax.axhline(b - 0.5, color="w", linewidth=0.6, alpha=0.7)
|
|
146
|
+
ax.axvline(b - 0.5, color="w", linewidth=0.6, alpha=0.7)
|
|
147
|
+
|
|
148
|
+
if created_fig:
|
|
149
|
+
fig.tight_layout()
|
|
150
|
+
finalize_figure(fig, config, rasterize_artists=[im] if config.rasterized else None)
|
|
151
|
+
|
|
152
|
+
return fig, ax, order
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def plot_btn_autocorr_summary(
|
|
156
|
+
*,
|
|
157
|
+
acorr: np.ndarray | None = None,
|
|
158
|
+
labels: np.ndarray | None = None,
|
|
159
|
+
bin_times: np.ndarray | None = None,
|
|
160
|
+
res: float | None = None,
|
|
161
|
+
mapping: dict[int, str] | None = None,
|
|
162
|
+
colors: dict[str, str] | None = None,
|
|
163
|
+
normalize: str | None = "probability",
|
|
164
|
+
smooth_sigma: float | None = None,
|
|
165
|
+
long_max_ms: float | None = 200.0,
|
|
166
|
+
short_max_ms: float | None = None,
|
|
167
|
+
title: str = "BTN temporal autocorr",
|
|
168
|
+
figsize: tuple[int, int] = (8, 3),
|
|
169
|
+
save_path: str | None = None,
|
|
170
|
+
show: bool = True,
|
|
171
|
+
config: PlotConfig | None = None,
|
|
172
|
+
) -> plt.Figure:
|
|
173
|
+
"""Plot class-averaged ISI autocorr curves (mean +/- SEM)."""
|
|
174
|
+
if acorr is None or labels is None:
|
|
175
|
+
raise ValueError("acorr and labels are required.")
|
|
176
|
+
|
|
177
|
+
labels = np.asarray(labels).astype(int)
|
|
178
|
+
acorr = np.asarray(acorr)
|
|
179
|
+
acorr_plot = _normalize_rows(acorr.astype(float, copy=False), normalize)
|
|
180
|
+
if smooth_sigma is not None:
|
|
181
|
+
acorr_plot = gaussian_filter1d(acorr_plot, sigma=float(smooth_sigma), axis=1)
|
|
182
|
+
|
|
183
|
+
if bin_times is not None:
|
|
184
|
+
bin_times = np.asarray(bin_times)
|
|
185
|
+
x = 0.5 * (bin_times[:-1] + bin_times[1:])
|
|
186
|
+
elif res is not None:
|
|
187
|
+
x = np.arange(acorr.shape[1]) * float(res)
|
|
188
|
+
else:
|
|
189
|
+
raise ValueError("Provide bin_times or res to define lag axis.")
|
|
190
|
+
|
|
191
|
+
x_ms = x * 1000.0
|
|
192
|
+
|
|
193
|
+
cids = _cluster_order(labels, mapping)
|
|
194
|
+
label_strings = []
|
|
195
|
+
for c in cids:
|
|
196
|
+
if mapping is not None:
|
|
197
|
+
label_strings.append(_canonical_label(mapping.get(int(c), str(c))))
|
|
198
|
+
else:
|
|
199
|
+
label_strings.append(str(c))
|
|
200
|
+
|
|
201
|
+
show_short = short_max_ms is not None
|
|
202
|
+
ncols = 2 if show_short else 1
|
|
203
|
+
|
|
204
|
+
config = _ensure_plot_config(
|
|
205
|
+
config,
|
|
206
|
+
PlotConfig.for_static_plot,
|
|
207
|
+
title=title,
|
|
208
|
+
figsize=figsize if ncols == 1 else (figsize[0] * 1.6, figsize[1]),
|
|
209
|
+
save_path=save_path,
|
|
210
|
+
show=show,
|
|
211
|
+
kwargs={},
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
fig, axes = plt.subplots(1, ncols, figsize=config.figsize)
|
|
215
|
+
if ncols == 1:
|
|
216
|
+
axes = [axes]
|
|
217
|
+
|
|
218
|
+
def _plot_panel(ax: plt.Axes, max_ms: float | None, panel_title: str):
|
|
219
|
+
if max_ms is None:
|
|
220
|
+
mask = np.ones_like(x_ms, dtype=bool)
|
|
221
|
+
else:
|
|
222
|
+
mask = x_ms <= max_ms
|
|
223
|
+
|
|
224
|
+
for idx, (cid, label_str) in enumerate(zip(cids, label_strings, strict=False)):
|
|
225
|
+
rows = acorr_plot[labels == cid]
|
|
226
|
+
if rows.size == 0:
|
|
227
|
+
continue
|
|
228
|
+
mean = rows.mean(axis=0)
|
|
229
|
+
sem = rows.std(axis=0) / np.sqrt(rows.shape[0])
|
|
230
|
+
color = _label_color(label_str, colors, idx)
|
|
231
|
+
ax.plot(x_ms[mask], mean[mask], color=color, lw=2, label=label_str)
|
|
232
|
+
ax.fill_between(
|
|
233
|
+
x_ms[mask],
|
|
234
|
+
(mean - sem)[mask],
|
|
235
|
+
(mean + sem)[mask],
|
|
236
|
+
color=color,
|
|
237
|
+
alpha=0.25,
|
|
238
|
+
linewidth=0,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
ax.set_xlabel("Lag (ms)")
|
|
242
|
+
ax.set_title(panel_title)
|
|
243
|
+
ax.grid(False)
|
|
244
|
+
|
|
245
|
+
_plot_panel(axes[0], long_max_ms, "Long lag")
|
|
246
|
+
ylabel = "Probability" if normalize == "probability" else "Autocorr (norm)"
|
|
247
|
+
axes[0].set_ylabel(ylabel)
|
|
248
|
+
|
|
249
|
+
if show_short:
|
|
250
|
+
_plot_panel(axes[1], float(short_max_ms), "Short lag")
|
|
251
|
+
|
|
252
|
+
for ax in axes:
|
|
253
|
+
ax.legend(frameon=False)
|
|
254
|
+
|
|
255
|
+
fig.suptitle(config.title)
|
|
256
|
+
fig.tight_layout()
|
|
257
|
+
finalize_figure(fig, config)
|
|
258
|
+
return fig
|
|
@@ -34,6 +34,7 @@ from .theta_sweep_plots import (
|
|
|
34
34
|
create_theta_sweep_grid_cell_animation,
|
|
35
35
|
create_theta_sweep_place_cell_animation,
|
|
36
36
|
plot_grid_cell_manifold,
|
|
37
|
+
plot_internal_position_trajectory,
|
|
37
38
|
plot_population_activity_with_theta,
|
|
38
39
|
)
|
|
39
40
|
from .tuning_plots import tuning_curve
|
|
@@ -70,5 +71,6 @@ __all__ = [
|
|
|
70
71
|
"create_theta_sweep_grid_cell_animation",
|
|
71
72
|
"create_theta_sweep_place_cell_animation",
|
|
72
73
|
"plot_grid_cell_manifold",
|
|
74
|
+
"plot_internal_position_trajectory",
|
|
73
75
|
"plot_population_activity_with_theta",
|
|
74
76
|
]
|
|
@@ -469,6 +469,26 @@ class PlotConfigs:
|
|
|
469
469
|
defaults.update(kwargs)
|
|
470
470
|
return PlotConfig.for_static_plot(**defaults)
|
|
471
471
|
|
|
472
|
+
@staticmethod
|
|
473
|
+
def internal_position_trajectory_static(**kwargs: Any) -> PlotConfig:
|
|
474
|
+
defaults: dict[str, Any] = {
|
|
475
|
+
"title": "Internal Position vs. Real Trajectory",
|
|
476
|
+
"figsize": (6, 4),
|
|
477
|
+
}
|
|
478
|
+
plot_kwargs: dict[str, Any] = {
|
|
479
|
+
"cmap": "cool",
|
|
480
|
+
"add_colorbar": True,
|
|
481
|
+
"colorbar": {"label": "Max GC activity"},
|
|
482
|
+
"trajectory_color": "black",
|
|
483
|
+
"trajectory_linewidth": 1.0,
|
|
484
|
+
"scatter_size": 4,
|
|
485
|
+
"scatter_alpha": 0.9,
|
|
486
|
+
}
|
|
487
|
+
plot_kwargs.update(kwargs.pop("kwargs", {}))
|
|
488
|
+
defaults["kwargs"] = plot_kwargs
|
|
489
|
+
defaults.update(kwargs)
|
|
490
|
+
return PlotConfig.for_static_plot(**defaults)
|
|
491
|
+
|
|
472
492
|
@staticmethod
|
|
473
493
|
def direction_cell_polar(**kwargs: Any) -> PlotConfig:
|
|
474
494
|
"""Configuration for direction cell polar plot visualization.
|