canns 0.14.3__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/asa/__init__.py +56 -21
- canns/analyzer/data/asa/coho.py +21 -0
- canns/analyzer/data/asa/cohomap.py +453 -0
- canns/analyzer/data/asa/cohomap_vectors.py +365 -0
- canns/analyzer/data/asa/cohospace.py +155 -1165
- canns/analyzer/data/asa/cohospace_phase_centers.py +119 -0
- canns/analyzer/data/asa/cohospace_scatter.py +1115 -0
- canns/analyzer/data/asa/embedding.py +5 -7
- canns/analyzer/data/asa/fr.py +1 -8
- canns/analyzer/data/asa/path.py +70 -0
- canns/analyzer/data/asa/plotting.py +5 -30
- canns/analyzer/data/asa/utils.py +160 -0
- canns/analyzer/data/cell_classification/__init__.py +10 -0
- canns/analyzer/data/cell_classification/core/__init__.py +4 -0
- canns/analyzer/data/cell_classification/core/btn.py +272 -0
- canns/analyzer/data/cell_classification/visualization/__init__.py +3 -0
- canns/analyzer/data/cell_classification/visualization/btn_plots.py +241 -0
- canns/analyzer/visualization/__init__.py +2 -0
- canns/analyzer/visualization/core/config.py +20 -0
- canns/analyzer/visualization/theta_sweep_plots.py +142 -0
- canns/pipeline/asa/runner.py +19 -19
- canns/pipeline/asa_gui/__init__.py +5 -3
- canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +3 -1
- canns/pipeline/asa_gui/core/runner.py +23 -23
- canns/pipeline/asa_gui/views/pages/preprocess_page.py +7 -12
- {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/METADATA +1 -1
- {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/RECORD +30 -23
- canns/analyzer/data/asa/filters.py +0 -208
- {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/WHEEL +0 -0
- {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/entry_points.txt +0 -0
- {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
"""BTN visualization utilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from matplotlib import pyplot as plt
|
|
9
|
+
from scipy.ndimage import gaussian_filter1d
|
|
10
|
+
|
|
11
|
+
from canns.analyzer.data.asa.utils import _ensure_plot_config
|
|
12
|
+
from canns.analyzer.visualization.core.config import PlotConfig, finalize_figure
|
|
13
|
+
|
|
14
|
+
_DEFAULT_BTN_COLORS = {
|
|
15
|
+
"B": "#1f77b4",
|
|
16
|
+
"T": "#000000",
|
|
17
|
+
"N": "#2ca02c",
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _canonical_label(label: str) -> str:
|
|
22
|
+
lab = label.strip().lower()
|
|
23
|
+
if lab in ("b", "bursty"):
|
|
24
|
+
return "B"
|
|
25
|
+
if lab in ("t", "theta", "theta-modulated", "theta_modulated", "theta modulated"):
|
|
26
|
+
return "T"
|
|
27
|
+
if lab in ("n", "nonbursty", "non-bursty", "non_bursty"):
|
|
28
|
+
return "N"
|
|
29
|
+
return label
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _cluster_order(labels: np.ndarray, mapping: dict[int, str] | None) -> list[int]:
|
|
33
|
+
cids = [int(c) for c in np.unique(labels)]
|
|
34
|
+
if mapping is None:
|
|
35
|
+
return sorted(cids)
|
|
36
|
+
|
|
37
|
+
def _key(cid: int) -> tuple[int, str]:
|
|
38
|
+
lab = _canonical_label(mapping.get(int(cid), str(cid)))
|
|
39
|
+
order = {"B": 0, "T": 1, "N": 2}.get(lab, 999)
|
|
40
|
+
return (order, str(lab))
|
|
41
|
+
|
|
42
|
+
return sorted(cids, key=_key)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _label_color(
|
|
46
|
+
label: str,
|
|
47
|
+
colors: dict[str, str] | None,
|
|
48
|
+
fallback_idx: int,
|
|
49
|
+
) -> str:
|
|
50
|
+
if colors and label in colors:
|
|
51
|
+
return colors[label]
|
|
52
|
+
if label in _DEFAULT_BTN_COLORS:
|
|
53
|
+
return _DEFAULT_BTN_COLORS[label]
|
|
54
|
+
cmap = plt.get_cmap("tab10")
|
|
55
|
+
return cmap(fallback_idx % 10)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _normalize_rows(acorr: np.ndarray, mode: str | None) -> np.ndarray:
|
|
59
|
+
if mode is None or mode == "none":
|
|
60
|
+
return acorr
|
|
61
|
+
if mode == "probability":
|
|
62
|
+
denom = acorr.sum(axis=1, keepdims=True)
|
|
63
|
+
elif mode == "peak":
|
|
64
|
+
denom = acorr.max(axis=1, keepdims=True)
|
|
65
|
+
elif mode == "first":
|
|
66
|
+
denom = acorr[:, :1]
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError(f"Unknown normalize mode: {mode!r}")
|
|
69
|
+
denom = np.where(denom == 0, 1.0, denom)
|
|
70
|
+
return acorr / denom
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def plot_btn_distance_matrix(
|
|
74
|
+
*,
|
|
75
|
+
dist: np.ndarray | None = None,
|
|
76
|
+
labels: np.ndarray | None = None,
|
|
77
|
+
mapping: dict[int, str] | None = None,
|
|
78
|
+
sort_by_label: bool = True,
|
|
79
|
+
title: str = "BTN distance matrix",
|
|
80
|
+
cmap: str = "afmhot",
|
|
81
|
+
figsize: tuple[int, int] = (5, 5),
|
|
82
|
+
save_path: str | None = None,
|
|
83
|
+
show: bool = True,
|
|
84
|
+
ax: plt.Axes | None = None,
|
|
85
|
+
config: PlotConfig | None = None,
|
|
86
|
+
) -> tuple[plt.Figure, plt.Axes, np.ndarray]:
|
|
87
|
+
"""Plot a distance matrix heatmap sorted by BTN cluster labels."""
|
|
88
|
+
if dist is None or labels is None:
|
|
89
|
+
raise ValueError("dist and labels are required.")
|
|
90
|
+
|
|
91
|
+
labels = np.asarray(labels).astype(int)
|
|
92
|
+
|
|
93
|
+
if sort_by_label:
|
|
94
|
+
cids = _cluster_order(labels, mapping)
|
|
95
|
+
order = np.concatenate([np.where(labels == c)[0] for c in cids])
|
|
96
|
+
else:
|
|
97
|
+
order = np.arange(len(labels))
|
|
98
|
+
|
|
99
|
+
dist_sorted = dist[np.ix_(order, order)]
|
|
100
|
+
|
|
101
|
+
config = _ensure_plot_config(
|
|
102
|
+
config,
|
|
103
|
+
PlotConfig.for_static_plot,
|
|
104
|
+
title=title,
|
|
105
|
+
figsize=figsize,
|
|
106
|
+
save_path=save_path,
|
|
107
|
+
show=show,
|
|
108
|
+
kwargs={},
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
created_fig = False
|
|
112
|
+
if ax is None:
|
|
113
|
+
fig, ax = plt.subplots(1, 1, figsize=config.figsize)
|
|
114
|
+
created_fig = True
|
|
115
|
+
else:
|
|
116
|
+
fig = ax.figure
|
|
117
|
+
|
|
118
|
+
im = ax.imshow(dist_sorted, cmap=cmap, origin="lower", interpolation="nearest")
|
|
119
|
+
fig.colorbar(im, ax=ax, label="Cosine distance")
|
|
120
|
+
ax.set_title(config.title)
|
|
121
|
+
ax.set_xlabel("Neuron")
|
|
122
|
+
ax.set_ylabel("Neuron")
|
|
123
|
+
|
|
124
|
+
if sort_by_label:
|
|
125
|
+
sizes = [np.sum(labels == c) for c in cids]
|
|
126
|
+
boundaries = np.cumsum(sizes)[:-1]
|
|
127
|
+
for b in boundaries:
|
|
128
|
+
ax.axhline(b - 0.5, color="w", linewidth=0.6, alpha=0.7)
|
|
129
|
+
ax.axvline(b - 0.5, color="w", linewidth=0.6, alpha=0.7)
|
|
130
|
+
|
|
131
|
+
if created_fig:
|
|
132
|
+
fig.tight_layout()
|
|
133
|
+
finalize_figure(fig, config, rasterize_artists=[im] if config.rasterized else None)
|
|
134
|
+
|
|
135
|
+
return fig, ax, order
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def plot_btn_autocorr_summary(
|
|
139
|
+
*,
|
|
140
|
+
acorr: np.ndarray | None = None,
|
|
141
|
+
labels: np.ndarray | None = None,
|
|
142
|
+
bin_times: np.ndarray | None = None,
|
|
143
|
+
res: float | None = None,
|
|
144
|
+
mapping: dict[int, str] | None = None,
|
|
145
|
+
colors: dict[str, str] | None = None,
|
|
146
|
+
normalize: str | None = "probability",
|
|
147
|
+
smooth_sigma: float | None = None,
|
|
148
|
+
long_max_ms: float | None = 200.0,
|
|
149
|
+
short_max_ms: float | None = None,
|
|
150
|
+
title: str = "BTN temporal autocorr",
|
|
151
|
+
figsize: tuple[int, int] = (8, 3),
|
|
152
|
+
save_path: str | None = None,
|
|
153
|
+
show: bool = True,
|
|
154
|
+
config: PlotConfig | None = None,
|
|
155
|
+
) -> plt.Figure:
|
|
156
|
+
"""Plot class-averaged ISI autocorr curves (mean +/- SEM)."""
|
|
157
|
+
if acorr is None or labels is None:
|
|
158
|
+
raise ValueError("acorr and labels are required.")
|
|
159
|
+
|
|
160
|
+
labels = np.asarray(labels).astype(int)
|
|
161
|
+
acorr = np.asarray(acorr)
|
|
162
|
+
acorr_plot = _normalize_rows(acorr.astype(float, copy=False), normalize)
|
|
163
|
+
if smooth_sigma is not None:
|
|
164
|
+
acorr_plot = gaussian_filter1d(acorr_plot, sigma=float(smooth_sigma), axis=1)
|
|
165
|
+
|
|
166
|
+
if bin_times is not None:
|
|
167
|
+
bin_times = np.asarray(bin_times)
|
|
168
|
+
x = 0.5 * (bin_times[:-1] + bin_times[1:])
|
|
169
|
+
elif res is not None:
|
|
170
|
+
x = np.arange(acorr.shape[1]) * float(res)
|
|
171
|
+
else:
|
|
172
|
+
raise ValueError("Provide bin_times or res to define lag axis.")
|
|
173
|
+
|
|
174
|
+
x_ms = x * 1000.0
|
|
175
|
+
|
|
176
|
+
cids = _cluster_order(labels, mapping)
|
|
177
|
+
label_strings = []
|
|
178
|
+
for c in cids:
|
|
179
|
+
if mapping is not None:
|
|
180
|
+
label_strings.append(_canonical_label(mapping.get(int(c), str(c))))
|
|
181
|
+
else:
|
|
182
|
+
label_strings.append(str(c))
|
|
183
|
+
|
|
184
|
+
show_short = short_max_ms is not None
|
|
185
|
+
ncols = 2 if show_short else 1
|
|
186
|
+
|
|
187
|
+
config = _ensure_plot_config(
|
|
188
|
+
config,
|
|
189
|
+
PlotConfig.for_static_plot,
|
|
190
|
+
title=title,
|
|
191
|
+
figsize=figsize if ncols == 1 else (figsize[0] * 1.6, figsize[1]),
|
|
192
|
+
save_path=save_path,
|
|
193
|
+
show=show,
|
|
194
|
+
kwargs={},
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
fig, axes = plt.subplots(1, ncols, figsize=config.figsize)
|
|
198
|
+
if ncols == 1:
|
|
199
|
+
axes = [axes]
|
|
200
|
+
|
|
201
|
+
def _plot_panel(ax: plt.Axes, max_ms: float | None, panel_title: str):
|
|
202
|
+
if max_ms is None:
|
|
203
|
+
mask = np.ones_like(x_ms, dtype=bool)
|
|
204
|
+
else:
|
|
205
|
+
mask = x_ms <= max_ms
|
|
206
|
+
|
|
207
|
+
for idx, (cid, label_str) in enumerate(zip(cids, label_strings, strict=False)):
|
|
208
|
+
rows = acorr_plot[labels == cid]
|
|
209
|
+
if rows.size == 0:
|
|
210
|
+
continue
|
|
211
|
+
mean = rows.mean(axis=0)
|
|
212
|
+
sem = rows.std(axis=0) / np.sqrt(rows.shape[0])
|
|
213
|
+
color = _label_color(label_str, colors, idx)
|
|
214
|
+
ax.plot(x_ms[mask], mean[mask], color=color, lw=2, label=label_str)
|
|
215
|
+
ax.fill_between(
|
|
216
|
+
x_ms[mask],
|
|
217
|
+
(mean - sem)[mask],
|
|
218
|
+
(mean + sem)[mask],
|
|
219
|
+
color=color,
|
|
220
|
+
alpha=0.25,
|
|
221
|
+
linewidth=0,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
ax.set_xlabel("Lag (ms)")
|
|
225
|
+
ax.set_title(panel_title)
|
|
226
|
+
ax.grid(False)
|
|
227
|
+
|
|
228
|
+
_plot_panel(axes[0], long_max_ms, "Long lag")
|
|
229
|
+
ylabel = "Probability" if normalize == "probability" else "Autocorr (norm)"
|
|
230
|
+
axes[0].set_ylabel(ylabel)
|
|
231
|
+
|
|
232
|
+
if show_short:
|
|
233
|
+
_plot_panel(axes[1], float(short_max_ms), "Short lag")
|
|
234
|
+
|
|
235
|
+
for ax in axes:
|
|
236
|
+
ax.legend(frameon=False)
|
|
237
|
+
|
|
238
|
+
fig.suptitle(config.title)
|
|
239
|
+
fig.tight_layout()
|
|
240
|
+
finalize_figure(fig, config)
|
|
241
|
+
return fig
|
|
@@ -34,6 +34,7 @@ from .theta_sweep_plots import (
|
|
|
34
34
|
create_theta_sweep_grid_cell_animation,
|
|
35
35
|
create_theta_sweep_place_cell_animation,
|
|
36
36
|
plot_grid_cell_manifold,
|
|
37
|
+
plot_internal_position_trajectory,
|
|
37
38
|
plot_population_activity_with_theta,
|
|
38
39
|
)
|
|
39
40
|
from .tuning_plots import tuning_curve
|
|
@@ -70,5 +71,6 @@ __all__ = [
|
|
|
70
71
|
"create_theta_sweep_grid_cell_animation",
|
|
71
72
|
"create_theta_sweep_place_cell_animation",
|
|
72
73
|
"plot_grid_cell_manifold",
|
|
74
|
+
"plot_internal_position_trajectory",
|
|
73
75
|
"plot_population_activity_with_theta",
|
|
74
76
|
]
|
|
@@ -469,6 +469,26 @@ class PlotConfigs:
|
|
|
469
469
|
defaults.update(kwargs)
|
|
470
470
|
return PlotConfig.for_static_plot(**defaults)
|
|
471
471
|
|
|
472
|
+
@staticmethod
|
|
473
|
+
def internal_position_trajectory_static(**kwargs: Any) -> PlotConfig:
|
|
474
|
+
defaults: dict[str, Any] = {
|
|
475
|
+
"title": "Internal Position vs. Real Trajectory",
|
|
476
|
+
"figsize": (6, 4),
|
|
477
|
+
}
|
|
478
|
+
plot_kwargs: dict[str, Any] = {
|
|
479
|
+
"cmap": "cool",
|
|
480
|
+
"add_colorbar": True,
|
|
481
|
+
"colorbar": {"label": "Max GC activity"},
|
|
482
|
+
"trajectory_color": "black",
|
|
483
|
+
"trajectory_linewidth": 1.0,
|
|
484
|
+
"scatter_size": 4,
|
|
485
|
+
"scatter_alpha": 0.9,
|
|
486
|
+
}
|
|
487
|
+
plot_kwargs.update(kwargs.pop("kwargs", {}))
|
|
488
|
+
defaults["kwargs"] = plot_kwargs
|
|
489
|
+
defaults.update(kwargs)
|
|
490
|
+
return PlotConfig.for_static_plot(**defaults)
|
|
491
|
+
|
|
472
492
|
@staticmethod
|
|
473
493
|
def direction_cell_polar(**kwargs: Any) -> PlotConfig:
|
|
474
494
|
"""Configuration for direction cell polar plot visualization.
|
|
@@ -760,6 +760,148 @@ def plot_grid_cell_manifold(
|
|
|
760
760
|
raise e
|
|
761
761
|
|
|
762
762
|
|
|
763
|
+
def plot_internal_position_trajectory(
|
|
764
|
+
internal_position: np.ndarray,
|
|
765
|
+
position: np.ndarray,
|
|
766
|
+
max_activity: np.ndarray | None = None,
|
|
767
|
+
env_size: float | tuple[float, float] | tuple[float, float, float, float] | None = None,
|
|
768
|
+
config: PlotConfig | None = None,
|
|
769
|
+
ax: plt.Axes | None = None,
|
|
770
|
+
# Backward compatibility parameters
|
|
771
|
+
title: str = "Internal Position (GC bump) vs. Real Trajectory",
|
|
772
|
+
figsize: tuple[int, int] = (6, 4),
|
|
773
|
+
cmap: str = "cool",
|
|
774
|
+
show: bool = True,
|
|
775
|
+
save_path: str | None = None,
|
|
776
|
+
**kwargs,
|
|
777
|
+
) -> tuple[plt.Figure, plt.Axes]:
|
|
778
|
+
"""Plot internal position (GC bump) against the real trajectory.
|
|
779
|
+
|
|
780
|
+
Args:
|
|
781
|
+
internal_position: Internal decoded positions ``(T, 2)``.
|
|
782
|
+
position: Real positions ``(T, 2)``.
|
|
783
|
+
max_activity: Optional per-time max activity to color the internal position.
|
|
784
|
+
env_size: Environment size. If float, uses ``[0, env_size]`` for both axes.
|
|
785
|
+
If a tuple of 2, treats as ``(width, height)``. If a tuple of 4, treats as
|
|
786
|
+
``(xmin, xmax, ymin, ymax)``.
|
|
787
|
+
config: PlotConfig object for unified configuration.
|
|
788
|
+
ax: Optional axis to draw on instead of creating a new figure.
|
|
789
|
+
**kwargs: Additional parameters for backward compatibility.
|
|
790
|
+
|
|
791
|
+
Returns:
|
|
792
|
+
tuple: ``(figure, axis)`` objects.
|
|
793
|
+
"""
|
|
794
|
+
if internal_position.ndim != 2 or internal_position.shape[1] != 2:
|
|
795
|
+
raise ValueError(
|
|
796
|
+
f"internal_position must be (T, 2), got shape {internal_position.shape}"
|
|
797
|
+
)
|
|
798
|
+
if position.ndim != 2 or position.shape[1] != 2:
|
|
799
|
+
raise ValueError(f"position must be (T, 2), got shape {position.shape}")
|
|
800
|
+
if internal_position.shape[0] != position.shape[0]:
|
|
801
|
+
raise ValueError(
|
|
802
|
+
"internal_position and position must have the same length: "
|
|
803
|
+
f"{internal_position.shape[0]} != {position.shape[0]}"
|
|
804
|
+
)
|
|
805
|
+
if max_activity is not None and max_activity.shape[0] != internal_position.shape[0]:
|
|
806
|
+
raise ValueError(
|
|
807
|
+
"max_activity must have same length as internal_position: "
|
|
808
|
+
f"{max_activity.shape[0]} != {internal_position.shape[0]}"
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
if config is None:
|
|
812
|
+
config = PlotConfig(
|
|
813
|
+
title=title,
|
|
814
|
+
figsize=figsize,
|
|
815
|
+
show=show,
|
|
816
|
+
save_path=save_path,
|
|
817
|
+
kwargs={"cmap": cmap, **kwargs},
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
plot_kwargs = config.to_matplotlib_kwargs()
|
|
821
|
+
trajectory_color = plot_kwargs.pop("trajectory_color", "black")
|
|
822
|
+
trajectory_linewidth = plot_kwargs.pop("trajectory_linewidth", 1.0)
|
|
823
|
+
scatter_size = plot_kwargs.pop("scatter_size", 4)
|
|
824
|
+
scatter_alpha = plot_kwargs.pop("scatter_alpha", 0.9)
|
|
825
|
+
add_colorbar = bool(plot_kwargs.pop("add_colorbar", True))
|
|
826
|
+
colorbar_options = plot_kwargs.pop("colorbar", {}) if add_colorbar else {}
|
|
827
|
+
if "cmap" not in plot_kwargs:
|
|
828
|
+
plot_kwargs["cmap"] = "cool"
|
|
829
|
+
|
|
830
|
+
if max_activity is None:
|
|
831
|
+
add_colorbar = False
|
|
832
|
+
|
|
833
|
+
axis_provided = ax is not None
|
|
834
|
+
if not axis_provided:
|
|
835
|
+
fig, ax = plt.subplots(figsize=config.figsize)
|
|
836
|
+
else:
|
|
837
|
+
fig = ax.figure
|
|
838
|
+
|
|
839
|
+
try:
|
|
840
|
+
scatter = ax.scatter(
|
|
841
|
+
internal_position[:, 0],
|
|
842
|
+
internal_position[:, 1],
|
|
843
|
+
c=max_activity,
|
|
844
|
+
s=scatter_size,
|
|
845
|
+
alpha=scatter_alpha,
|
|
846
|
+
**plot_kwargs,
|
|
847
|
+
)
|
|
848
|
+
ax.plot(position[:, 0], position[:, 1], color=trajectory_color, lw=trajectory_linewidth)
|
|
849
|
+
|
|
850
|
+
ax.set_aspect("equal", adjustable="box")
|
|
851
|
+
if config.title:
|
|
852
|
+
ax.set_title(config.title, fontsize=14, fontweight="bold")
|
|
853
|
+
|
|
854
|
+
if env_size is not None:
|
|
855
|
+
if isinstance(env_size, (tuple, list, np.ndarray)):
|
|
856
|
+
if len(env_size) == 2:
|
|
857
|
+
ax.set_xlim(0, env_size[0])
|
|
858
|
+
ax.set_ylim(0, env_size[1])
|
|
859
|
+
elif len(env_size) == 4:
|
|
860
|
+
ax.set_xlim(env_size[0], env_size[1])
|
|
861
|
+
ax.set_ylim(env_size[2], env_size[3])
|
|
862
|
+
else:
|
|
863
|
+
raise ValueError("env_size tuple must be length 2 or 4.")
|
|
864
|
+
else:
|
|
865
|
+
ax.set_xlim(0, env_size)
|
|
866
|
+
ax.set_ylim(0, env_size)
|
|
867
|
+
|
|
868
|
+
sns.despine(ax=ax)
|
|
869
|
+
|
|
870
|
+
if add_colorbar:
|
|
871
|
+
default_cbar_opts = {"pad": 0.15, "size": "5%", "label": "Max GC activity"}
|
|
872
|
+
if isinstance(colorbar_options, dict):
|
|
873
|
+
extra_cbar_kwargs = colorbar_options.get("kwargs", {})
|
|
874
|
+
for key in ("pad", "size", "label"):
|
|
875
|
+
if key in colorbar_options:
|
|
876
|
+
default_cbar_opts[key] = colorbar_options[key]
|
|
877
|
+
else:
|
|
878
|
+
extra_cbar_kwargs = {}
|
|
879
|
+
|
|
880
|
+
divider = make_axes_locatable(ax)
|
|
881
|
+
cax = divider.append_axes(
|
|
882
|
+
"right", size=default_cbar_opts["size"], pad=default_cbar_opts["pad"]
|
|
883
|
+
)
|
|
884
|
+
cbar = fig.colorbar(scatter, cax=cax, **extra_cbar_kwargs)
|
|
885
|
+
if default_cbar_opts["label"]:
|
|
886
|
+
cbar.set_label(default_cbar_opts["label"], fontsize=12)
|
|
887
|
+
|
|
888
|
+
if not axis_provided:
|
|
889
|
+
fig.tight_layout()
|
|
890
|
+
|
|
891
|
+
show_flag = config.show and not axis_provided
|
|
892
|
+
finalize_figure(
|
|
893
|
+
fig,
|
|
894
|
+
replace(config, show=show_flag),
|
|
895
|
+
rasterize_artists=[scatter] if config.rasterized else None,
|
|
896
|
+
always_close=not show_flag,
|
|
897
|
+
)
|
|
898
|
+
return fig, ax
|
|
899
|
+
|
|
900
|
+
except Exception as e:
|
|
901
|
+
plt.close(fig)
|
|
902
|
+
raise e
|
|
903
|
+
|
|
904
|
+
|
|
763
905
|
@dataclass(slots=True)
|
|
764
906
|
class _PlaceCellAnimationData:
|
|
765
907
|
"""Immutable container for place cell animation arrays."""
|
canns/pipeline/asa/runner.py
CHANGED
|
@@ -628,7 +628,7 @@ class PipelineRunner:
|
|
|
628
628
|
self, asa_data: dict[str, Any], state: WorkflowState, log_callback
|
|
629
629
|
) -> dict[str, Path]:
|
|
630
630
|
"""Run CohoMap analysis (TDA + decode + plotting)."""
|
|
631
|
-
from canns.analyzer.data.asa import
|
|
631
|
+
from canns.analyzer.data.asa import plot_cohomap_scatter_multi
|
|
632
632
|
from canns.analyzer.visualization import PlotConfigs
|
|
633
633
|
|
|
634
634
|
tda_dir = self._results_dir(state) / "TDA"
|
|
@@ -661,7 +661,7 @@ class PipelineRunner:
|
|
|
661
661
|
log_callback("Generating cohomology map...")
|
|
662
662
|
pos = self._aligned_pos if self._aligned_pos is not None else asa_data
|
|
663
663
|
config = PlotConfigs.cohomap(show=False, save_path=str(cohomap_path))
|
|
664
|
-
|
|
664
|
+
plot_cohomap_scatter_multi(
|
|
665
665
|
decoding_result=decode_result,
|
|
666
666
|
position_data={"x": pos["x"], "y": pos["y"]},
|
|
667
667
|
config=config,
|
|
@@ -838,16 +838,16 @@ class PipelineRunner:
|
|
|
838
838
|
) -> dict[str, Path]:
|
|
839
839
|
"""Run cohomology space visualization."""
|
|
840
840
|
from canns.analyzer.data.asa import (
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
841
|
+
plot_cohospace_scatter_neuron_1d,
|
|
842
|
+
plot_cohospace_scatter_neuron_2d,
|
|
843
|
+
plot_cohospace_scatter_population_1d,
|
|
844
|
+
plot_cohospace_scatter_population_2d,
|
|
845
|
+
plot_cohospace_scatter_trajectory_1d,
|
|
846
|
+
plot_cohospace_scatter_trajectory_2d,
|
|
847
847
|
)
|
|
848
|
-
from canns.analyzer.data.asa.
|
|
849
|
-
|
|
850
|
-
|
|
848
|
+
from canns.analyzer.data.asa.cohospace_scatter import (
|
|
849
|
+
plot_cohospace_scatter_neuron_skewed,
|
|
850
|
+
plot_cohospace_scatter_population_skewed,
|
|
851
851
|
)
|
|
852
852
|
from canns.analyzer.visualization import PlotConfigs
|
|
853
853
|
|
|
@@ -937,7 +937,7 @@ class PipelineRunner:
|
|
|
937
937
|
traj_path = out_dir / "cohospace_trajectory.png"
|
|
938
938
|
if dim_mode == "1d":
|
|
939
939
|
traj_cfg = PlotConfigs.cohospace_trajectory_1d(show=False, save_path=str(traj_path))
|
|
940
|
-
|
|
940
|
+
plot_cohospace_scatter_trajectory_1d(
|
|
941
941
|
coords=coords2,
|
|
942
942
|
times=None,
|
|
943
943
|
subsample=subsample,
|
|
@@ -945,7 +945,7 @@ class PipelineRunner:
|
|
|
945
945
|
)
|
|
946
946
|
else:
|
|
947
947
|
traj_cfg = PlotConfigs.cohospace_trajectory_2d(show=False, save_path=str(traj_path))
|
|
948
|
-
|
|
948
|
+
plot_cohospace_scatter_trajectory_2d(
|
|
949
949
|
coords=coords2,
|
|
950
950
|
times=None,
|
|
951
951
|
subsample=subsample,
|
|
@@ -958,7 +958,7 @@ class PipelineRunner:
|
|
|
958
958
|
log_callback(f"Plotting neuron {neuron_id}...")
|
|
959
959
|
neuron_path = out_dir / f"cohospace_neuron_{neuron_id}.png"
|
|
960
960
|
if unfold == "skew" and dim_mode != "1d":
|
|
961
|
-
|
|
961
|
+
plot_cohospace_scatter_neuron_skewed(
|
|
962
962
|
coords=coordsbox2,
|
|
963
963
|
activity=activity,
|
|
964
964
|
neuron_id=int(neuron_id),
|
|
@@ -974,7 +974,7 @@ class PipelineRunner:
|
|
|
974
974
|
neuron_cfg = PlotConfigs.cohospace_neuron_1d(
|
|
975
975
|
show=False, save_path=str(neuron_path)
|
|
976
976
|
)
|
|
977
|
-
|
|
977
|
+
plot_cohospace_scatter_neuron_1d(
|
|
978
978
|
coords=coordsbox2,
|
|
979
979
|
activity=activity,
|
|
980
980
|
neuron_id=int(neuron_id),
|
|
@@ -986,7 +986,7 @@ class PipelineRunner:
|
|
|
986
986
|
neuron_cfg = PlotConfigs.cohospace_neuron_2d(
|
|
987
987
|
show=False, save_path=str(neuron_path)
|
|
988
988
|
)
|
|
989
|
-
|
|
989
|
+
plot_cohospace_scatter_neuron_2d(
|
|
990
990
|
coords=coordsbox2,
|
|
991
991
|
activity=activity,
|
|
992
992
|
neuron_id=int(neuron_id),
|
|
@@ -1001,7 +1001,7 @@ class PipelineRunner:
|
|
|
1001
1001
|
pop_path = out_dir / "cohospace_population.png"
|
|
1002
1002
|
neuron_ids = list(range(activity.shape[1]))
|
|
1003
1003
|
if unfold == "skew" and dim_mode != "1d":
|
|
1004
|
-
|
|
1004
|
+
plot_cohospace_scatter_population_skewed(
|
|
1005
1005
|
coords=coords2,
|
|
1006
1006
|
activity=activity,
|
|
1007
1007
|
neuron_ids=neuron_ids,
|
|
@@ -1017,7 +1017,7 @@ class PipelineRunner:
|
|
|
1017
1017
|
pop_cfg = PlotConfigs.cohospace_population_1d(
|
|
1018
1018
|
show=False, save_path=str(pop_path)
|
|
1019
1019
|
)
|
|
1020
|
-
|
|
1020
|
+
plot_cohospace_scatter_population_1d(
|
|
1021
1021
|
coords=coords2,
|
|
1022
1022
|
activity=activity,
|
|
1023
1023
|
neuron_ids=neuron_ids,
|
|
@@ -1029,7 +1029,7 @@ class PipelineRunner:
|
|
|
1029
1029
|
pop_cfg = PlotConfigs.cohospace_population_2d(
|
|
1030
1030
|
show=False, save_path=str(pop_path)
|
|
1031
1031
|
)
|
|
1032
|
-
|
|
1032
|
+
plot_cohospace_scatter_population_2d(
|
|
1033
1033
|
coords=coords2,
|
|
1034
1034
|
activity=activity,
|
|
1035
1035
|
neuron_ids=neuron_ids,
|
|
@@ -2,9 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
import sys
|
|
6
|
-
import os
|
|
7
5
|
import importlib.util
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
8
|
|
|
9
9
|
__all__ = ["main", "ASAGuiApp"]
|
|
10
10
|
|
|
@@ -14,7 +14,9 @@ if _pyside6_missing:
|
|
|
14
14
|
try: # pragma: no cover - only used in CI/test runs
|
|
15
15
|
import pytest
|
|
16
16
|
|
|
17
|
-
pytest.skip(
|
|
17
|
+
pytest.skip(
|
|
18
|
+
"PySide6 is not installed; skipping asa_gui module.", allow_module_level=True
|
|
19
|
+
)
|
|
18
20
|
except Exception:
|
|
19
21
|
pass
|
|
20
22
|
|
|
@@ -250,7 +250,9 @@ class PathCompareMode(AbstractAnalysisMode):
|
|
|
250
250
|
self.dim.setToolTip("1D decoded dimension index.")
|
|
251
251
|
self.dim1.setToolTip("2D decoded dimension 1.")
|
|
252
252
|
self.dim2.setToolTip("2D decoded dimension 2.")
|
|
253
|
-
self.use_box.setToolTip(
|
|
253
|
+
self.use_box.setToolTip(
|
|
254
|
+
"Use coordsbox/times_box alignment (recommended with speed_filter)."
|
|
255
|
+
)
|
|
254
256
|
self.interp_full.setToolTip("Interpolate back to full trajectory.")
|
|
255
257
|
self.coords_key.setToolTip("Optional decode coords key (default coords/coordsbox).")
|
|
256
258
|
self.times_key.setToolTip("Optional times_box key.")
|