canns 0.13.1__py3-none-any.whl → 0.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/__init__.py +5 -1
- canns/analyzer/data/asa/__init__.py +27 -12
- canns/analyzer/data/asa/cohospace.py +336 -10
- canns/analyzer/data/asa/config.py +3 -0
- canns/analyzer/data/asa/embedding.py +48 -45
- canns/analyzer/data/asa/path.py +104 -2
- canns/analyzer/data/asa/plotting.py +88 -19
- canns/analyzer/data/asa/tda.py +11 -4
- canns/analyzer/data/cell_classification/__init__.py +97 -0
- canns/analyzer/data/cell_classification/core/__init__.py +26 -0
- canns/analyzer/data/cell_classification/core/grid_cells.py +633 -0
- canns/analyzer/data/cell_classification/core/grid_modules_leiden.py +288 -0
- canns/analyzer/data/cell_classification/core/head_direction.py +347 -0
- canns/analyzer/data/cell_classification/core/spatial_analysis.py +431 -0
- canns/analyzer/data/cell_classification/io/__init__.py +5 -0
- canns/analyzer/data/cell_classification/io/matlab_loader.py +417 -0
- canns/analyzer/data/cell_classification/utils/__init__.py +39 -0
- canns/analyzer/data/cell_classification/utils/circular_stats.py +383 -0
- canns/analyzer/data/cell_classification/utils/correlation.py +318 -0
- canns/analyzer/data/cell_classification/utils/geometry.py +442 -0
- canns/analyzer/data/cell_classification/utils/image_processing.py +416 -0
- canns/analyzer/data/cell_classification/visualization/__init__.py +19 -0
- canns/analyzer/data/cell_classification/visualization/grid_plots.py +292 -0
- canns/analyzer/data/cell_classification/visualization/hd_plots.py +200 -0
- canns/analyzer/metrics/__init__.py +2 -1
- canns/analyzer/visualization/core/config.py +46 -4
- canns/data/__init__.py +6 -1
- canns/data/datasets.py +154 -1
- canns/data/loaders.py +37 -0
- canns/pipeline/__init__.py +13 -9
- canns/pipeline/__main__.py +6 -0
- canns/pipeline/asa/runner.py +105 -41
- canns/pipeline/asa_gui/__init__.py +68 -0
- canns/pipeline/asa_gui/__main__.py +6 -0
- canns/pipeline/asa_gui/analysis_modes/__init__.py +42 -0
- canns/pipeline/asa_gui/analysis_modes/base.py +39 -0
- canns/pipeline/asa_gui/analysis_modes/batch_mode.py +21 -0
- canns/pipeline/asa_gui/analysis_modes/cohomap_mode.py +56 -0
- canns/pipeline/asa_gui/analysis_modes/cohospace_mode.py +194 -0
- canns/pipeline/asa_gui/analysis_modes/decode_mode.py +52 -0
- canns/pipeline/asa_gui/analysis_modes/fr_mode.py +81 -0
- canns/pipeline/asa_gui/analysis_modes/frm_mode.py +92 -0
- canns/pipeline/asa_gui/analysis_modes/gridscore_mode.py +123 -0
- canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +199 -0
- canns/pipeline/asa_gui/analysis_modes/tda_mode.py +112 -0
- canns/pipeline/asa_gui/app.py +29 -0
- canns/pipeline/asa_gui/controllers/__init__.py +6 -0
- canns/pipeline/asa_gui/controllers/analysis_controller.py +59 -0
- canns/pipeline/asa_gui/controllers/preprocess_controller.py +89 -0
- canns/pipeline/asa_gui/core/__init__.py +15 -0
- canns/pipeline/asa_gui/core/cache.py +14 -0
- canns/pipeline/asa_gui/core/runner.py +1936 -0
- canns/pipeline/asa_gui/core/state.py +324 -0
- canns/pipeline/asa_gui/core/worker.py +260 -0
- canns/pipeline/asa_gui/main_window.py +184 -0
- canns/pipeline/asa_gui/models/__init__.py +7 -0
- canns/pipeline/asa_gui/models/config.py +14 -0
- canns/pipeline/asa_gui/models/job.py +31 -0
- canns/pipeline/asa_gui/models/presets.py +21 -0
- canns/pipeline/asa_gui/resources/__init__.py +16 -0
- canns/pipeline/asa_gui/resources/dark.qss +167 -0
- canns/pipeline/asa_gui/resources/light.qss +163 -0
- canns/pipeline/asa_gui/resources/styles.qss +130 -0
- canns/pipeline/asa_gui/utils/__init__.py +1 -0
- canns/pipeline/asa_gui/utils/formatters.py +15 -0
- canns/pipeline/asa_gui/utils/io_adapters.py +40 -0
- canns/pipeline/asa_gui/utils/validators.py +41 -0
- canns/pipeline/asa_gui/views/__init__.py +1 -0
- canns/pipeline/asa_gui/views/help_content.py +171 -0
- canns/pipeline/asa_gui/views/pages/__init__.py +6 -0
- canns/pipeline/asa_gui/views/pages/analysis_page.py +565 -0
- canns/pipeline/asa_gui/views/pages/preprocess_page.py +492 -0
- canns/pipeline/asa_gui/views/panels/__init__.py +1 -0
- canns/pipeline/asa_gui/views/widgets/__init__.py +21 -0
- canns/pipeline/asa_gui/views/widgets/artifacts_tab.py +44 -0
- canns/pipeline/asa_gui/views/widgets/drop_zone.py +80 -0
- canns/pipeline/asa_gui/views/widgets/file_list.py +27 -0
- canns/pipeline/asa_gui/views/widgets/gridscore_tab.py +308 -0
- canns/pipeline/asa_gui/views/widgets/help_dialog.py +27 -0
- canns/pipeline/asa_gui/views/widgets/image_tab.py +50 -0
- canns/pipeline/asa_gui/views/widgets/image_viewer.py +97 -0
- canns/pipeline/asa_gui/views/widgets/log_box.py +16 -0
- canns/pipeline/asa_gui/views/widgets/pathcompare_tab.py +200 -0
- canns/pipeline/asa_gui/views/widgets/popup_combo.py +25 -0
- canns/pipeline/gallery/__init__.py +15 -5
- canns/pipeline/gallery/__main__.py +11 -0
- canns/pipeline/gallery/app.py +705 -0
- canns/pipeline/gallery/runner.py +790 -0
- canns/pipeline/gallery/state.py +51 -0
- canns/pipeline/gallery/styles.tcss +123 -0
- canns/pipeline/launcher.py +81 -0
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/METADATA +11 -1
- canns-0.14.0.dist-info/RECORD +163 -0
- canns-0.14.0.dist-info/entry_points.txt +5 -0
- canns/pipeline/_base.py +0 -50
- canns-0.13.1.dist-info/RECORD +0 -89
- canns-0.13.1.dist-info/entry_points.txt +0 -3
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/WHEEL +0 -0
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
"""Grid cell visualization utilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from matplotlib import patches
|
|
9
|
+
from matplotlib import pyplot as plt
|
|
10
|
+
|
|
11
|
+
from canns.analyzer.visualization.core.config import PlotConfig, PlotConfigs, finalize_figure
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _ensure_plot_config(
|
|
15
|
+
config: PlotConfig | None,
|
|
16
|
+
factory,
|
|
17
|
+
*,
|
|
18
|
+
kwargs: dict[str, Any] | None = None,
|
|
19
|
+
**defaults: Any,
|
|
20
|
+
) -> PlotConfig:
|
|
21
|
+
if config is None:
|
|
22
|
+
defaults.update({"kwargs": kwargs or {}})
|
|
23
|
+
return factory(**defaults)
|
|
24
|
+
|
|
25
|
+
if kwargs:
|
|
26
|
+
config_kwargs = config.kwargs or {}
|
|
27
|
+
config_kwargs.update(kwargs)
|
|
28
|
+
config.kwargs = config_kwargs
|
|
29
|
+
return config
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def plot_autocorrelogram(
|
|
33
|
+
autocorr: np.ndarray,
|
|
34
|
+
config: PlotConfig | None = None,
|
|
35
|
+
*,
|
|
36
|
+
gridness_score: float | None = None,
|
|
37
|
+
center_radius: float | None = None,
|
|
38
|
+
peak_locations: np.ndarray | None = None,
|
|
39
|
+
title: str = "Spatial Autocorrelation",
|
|
40
|
+
xlabel: str = "X Lag (bins)",
|
|
41
|
+
ylabel: str = "Y Lag (bins)",
|
|
42
|
+
figsize: tuple[int, int] = (6, 6),
|
|
43
|
+
save_path: str | None = None,
|
|
44
|
+
show: bool = True,
|
|
45
|
+
ax: plt.Axes | None = None,
|
|
46
|
+
**kwargs: Any,
|
|
47
|
+
) -> tuple[plt.Figure, plt.Axes]:
|
|
48
|
+
"""Plot 2D autocorrelogram with optional annotations."""
|
|
49
|
+
config = _ensure_plot_config(
|
|
50
|
+
config,
|
|
51
|
+
PlotConfigs.grid_autocorrelation,
|
|
52
|
+
title=title,
|
|
53
|
+
xlabel=xlabel,
|
|
54
|
+
ylabel=ylabel,
|
|
55
|
+
figsize=figsize,
|
|
56
|
+
save_path=save_path,
|
|
57
|
+
show=show,
|
|
58
|
+
kwargs=kwargs,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
created_fig = False
|
|
62
|
+
if ax is None:
|
|
63
|
+
fig, ax = plt.subplots(1, 1, figsize=config.figsize)
|
|
64
|
+
created_fig = True
|
|
65
|
+
else:
|
|
66
|
+
fig = ax.figure
|
|
67
|
+
|
|
68
|
+
# Plot autocorrelogram
|
|
69
|
+
plot_kwargs = config.to_matplotlib_kwargs()
|
|
70
|
+
plot_kwargs.setdefault("origin", "lower")
|
|
71
|
+
plot_kwargs.setdefault("interpolation", "bilinear")
|
|
72
|
+
im = ax.imshow(autocorr, **plot_kwargs)
|
|
73
|
+
fig.colorbar(im, ax=ax, label="Correlation")
|
|
74
|
+
|
|
75
|
+
# Add center circle if radius provided
|
|
76
|
+
if center_radius is not None:
|
|
77
|
+
center = np.array(autocorr.shape) / 2
|
|
78
|
+
circle = patches.Circle(
|
|
79
|
+
(center[1], center[0]),
|
|
80
|
+
center_radius,
|
|
81
|
+
fill=False,
|
|
82
|
+
edgecolor="red",
|
|
83
|
+
linewidth=2,
|
|
84
|
+
linestyle="--",
|
|
85
|
+
label="Center field",
|
|
86
|
+
)
|
|
87
|
+
ax.add_patch(circle)
|
|
88
|
+
|
|
89
|
+
# Mark peak locations if provided
|
|
90
|
+
if peak_locations is not None:
|
|
91
|
+
ax.plot(
|
|
92
|
+
peak_locations[:, 0],
|
|
93
|
+
peak_locations[:, 1],
|
|
94
|
+
"r+",
|
|
95
|
+
markersize=10,
|
|
96
|
+
markeredgewidth=2,
|
|
97
|
+
label="Grid peaks",
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Add gridness score to title
|
|
101
|
+
title_text = config.title
|
|
102
|
+
if gridness_score is not None:
|
|
103
|
+
title_text = f"{config.title} (Gridness: {gridness_score:.3f})"
|
|
104
|
+
|
|
105
|
+
ax.set_title(title_text)
|
|
106
|
+
ax.set_xlabel(config.xlabel)
|
|
107
|
+
ax.set_ylabel(config.ylabel)
|
|
108
|
+
|
|
109
|
+
if center_radius is not None or peak_locations is not None:
|
|
110
|
+
ax.legend(loc="upper right")
|
|
111
|
+
|
|
112
|
+
if created_fig:
|
|
113
|
+
fig.tight_layout()
|
|
114
|
+
finalize_figure(fig, config, rasterize_artists=[im] if config.rasterized else None)
|
|
115
|
+
|
|
116
|
+
return fig, ax
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def plot_gridness_analysis(
|
|
120
|
+
rate_map: np.ndarray,
|
|
121
|
+
autocorr: np.ndarray,
|
|
122
|
+
result,
|
|
123
|
+
config: PlotConfig | None = None,
|
|
124
|
+
*,
|
|
125
|
+
title: str = "Grid Cell Analysis",
|
|
126
|
+
figsize: tuple[int, int] = (15, 5),
|
|
127
|
+
save_path: str | None = None,
|
|
128
|
+
show: bool = True,
|
|
129
|
+
) -> plt.Figure:
|
|
130
|
+
"""Comprehensive grid analysis plot with rate map, autocorr, and statistics."""
|
|
131
|
+
config = _ensure_plot_config(
|
|
132
|
+
config,
|
|
133
|
+
PlotConfig.for_static_plot,
|
|
134
|
+
title=title,
|
|
135
|
+
figsize=figsize,
|
|
136
|
+
save_path=save_path,
|
|
137
|
+
show=show,
|
|
138
|
+
kwargs={},
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
fig, axes = plt.subplots(1, 3, figsize=config.figsize)
|
|
142
|
+
|
|
143
|
+
# Plot 1: Rate map
|
|
144
|
+
im1 = axes[0].imshow(rate_map, cmap="hot", origin="lower")
|
|
145
|
+
axes[0].set_title("Firing Rate Map")
|
|
146
|
+
axes[0].set_xlabel("X (bins)")
|
|
147
|
+
axes[0].set_ylabel("Y (bins)")
|
|
148
|
+
plt.colorbar(im1, ax=axes[0], label="Rate (Hz)")
|
|
149
|
+
|
|
150
|
+
# Plot 2: Autocorrelogram with annotations
|
|
151
|
+
plot_autocorrelogram(
|
|
152
|
+
autocorr,
|
|
153
|
+
gridness_score=result.score,
|
|
154
|
+
center_radius=result.center_radius,
|
|
155
|
+
peak_locations=result.peak_locations,
|
|
156
|
+
title="Autocorrelogram",
|
|
157
|
+
ax=axes[1],
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Plot 3: Grid statistics
|
|
161
|
+
axes[2].axis("off")
|
|
162
|
+
stats_text = f"""
|
|
163
|
+
Grid Cell Analysis
|
|
164
|
+
|
|
165
|
+
Gridness Score: {result.score:.3f}
|
|
166
|
+
Center Radius: {result.center_radius:.1f} bins
|
|
167
|
+
Optimal Radius: {result.optimal_radius:.1f} bins
|
|
168
|
+
|
|
169
|
+
Grid Spacing (bins):
|
|
170
|
+
{result.spacing[0]:.2f}
|
|
171
|
+
{result.spacing[1]:.2f}
|
|
172
|
+
{result.spacing[2]:.2f}
|
|
173
|
+
|
|
174
|
+
Grid Orientation (°):
|
|
175
|
+
{result.orientation[0]:.1f}
|
|
176
|
+
{result.orientation[1]:.1f}
|
|
177
|
+
{result.orientation[2]:.1f}
|
|
178
|
+
|
|
179
|
+
Ellipse Parameters:
|
|
180
|
+
Center: ({result.ellipse[0]:.1f}, {result.ellipse[1]:.1f})
|
|
181
|
+
Radii: ({result.ellipse[2]:.1f}, {result.ellipse[3]:.1f})
|
|
182
|
+
Angle: {result.ellipse_theta_deg:.1f}°
|
|
183
|
+
"""
|
|
184
|
+
axes[2].text(0.1, 0.5, stats_text, fontsize=10, verticalalignment="center", family="monospace")
|
|
185
|
+
axes[2].set_title("Grid Statistics")
|
|
186
|
+
|
|
187
|
+
fig.suptitle(config.title, fontsize=14, fontweight="bold")
|
|
188
|
+
fig.tight_layout()
|
|
189
|
+
finalize_figure(fig, config)
|
|
190
|
+
return fig
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def plot_rate_map(
|
|
194
|
+
rate_map: np.ndarray,
|
|
195
|
+
config: PlotConfig | None = None,
|
|
196
|
+
*,
|
|
197
|
+
title: str = "Firing Field (Rate Map)",
|
|
198
|
+
xlabel: str = "X Position (bins)",
|
|
199
|
+
ylabel: str = "Y Position (bins)",
|
|
200
|
+
figsize: tuple[int, int] = (6, 6),
|
|
201
|
+
colorbar: bool = True,
|
|
202
|
+
save_path: str | None = None,
|
|
203
|
+
show: bool = True,
|
|
204
|
+
ax: plt.Axes | None = None,
|
|
205
|
+
**kwargs: Any,
|
|
206
|
+
) -> tuple[plt.Figure, plt.Axes]:
|
|
207
|
+
"""Plot 2D spatial firing rate map."""
|
|
208
|
+
config = _ensure_plot_config(
|
|
209
|
+
config,
|
|
210
|
+
PlotConfigs.firing_field_heatmap,
|
|
211
|
+
title=title,
|
|
212
|
+
xlabel=xlabel,
|
|
213
|
+
ylabel=ylabel,
|
|
214
|
+
figsize=figsize,
|
|
215
|
+
save_path=save_path,
|
|
216
|
+
show=show,
|
|
217
|
+
kwargs=kwargs,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
created_fig = False
|
|
221
|
+
if ax is None:
|
|
222
|
+
fig, ax = plt.subplots(1, 1, figsize=config.figsize)
|
|
223
|
+
created_fig = True
|
|
224
|
+
else:
|
|
225
|
+
fig = ax.figure
|
|
226
|
+
|
|
227
|
+
plot_kwargs = config.to_matplotlib_kwargs()
|
|
228
|
+
plot_kwargs.setdefault("origin", "lower")
|
|
229
|
+
plot_kwargs.setdefault("interpolation", "bilinear")
|
|
230
|
+
im = ax.imshow(rate_map, **plot_kwargs)
|
|
231
|
+
|
|
232
|
+
if colorbar:
|
|
233
|
+
fig.colorbar(im, ax=ax, label="Firing Rate (Hz)")
|
|
234
|
+
|
|
235
|
+
ax.set_title(config.title)
|
|
236
|
+
ax.set_xlabel(config.xlabel)
|
|
237
|
+
ax.set_ylabel(config.ylabel)
|
|
238
|
+
|
|
239
|
+
if created_fig:
|
|
240
|
+
fig.tight_layout()
|
|
241
|
+
finalize_figure(fig, config, rasterize_artists=[im] if config.rasterized else None)
|
|
242
|
+
|
|
243
|
+
return fig, ax
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def plot_grid_score_histogram(
|
|
247
|
+
scores: np.ndarray,
|
|
248
|
+
config: PlotConfig | None = None,
|
|
249
|
+
*,
|
|
250
|
+
bins: int = 30,
|
|
251
|
+
title: str = "Grid Score Distribution",
|
|
252
|
+
xlabel: str = "Grid Score",
|
|
253
|
+
ylabel: str = "Count",
|
|
254
|
+
figsize: tuple[int, int] = (6, 4),
|
|
255
|
+
save_path: str | None = None,
|
|
256
|
+
show: bool = True,
|
|
257
|
+
ax: plt.Axes | None = None,
|
|
258
|
+
**kwargs: Any,
|
|
259
|
+
) -> tuple[plt.Figure, plt.Axes]:
|
|
260
|
+
"""Plot histogram of gridness scores."""
|
|
261
|
+
config = _ensure_plot_config(
|
|
262
|
+
config,
|
|
263
|
+
PlotConfig.for_static_plot,
|
|
264
|
+
title=title,
|
|
265
|
+
xlabel=xlabel,
|
|
266
|
+
ylabel=ylabel,
|
|
267
|
+
figsize=figsize,
|
|
268
|
+
save_path=save_path,
|
|
269
|
+
show=show,
|
|
270
|
+
kwargs=kwargs,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
scores = np.asarray(scores, dtype=float)
|
|
274
|
+
scores = scores[np.isfinite(scores)]
|
|
275
|
+
|
|
276
|
+
created_fig = False
|
|
277
|
+
if ax is None:
|
|
278
|
+
fig, ax = plt.subplots(1, 1, figsize=config.figsize)
|
|
279
|
+
created_fig = True
|
|
280
|
+
else:
|
|
281
|
+
fig = ax.figure
|
|
282
|
+
|
|
283
|
+
ax.hist(scores, bins=bins, **config.to_matplotlib_kwargs())
|
|
284
|
+
ax.set_title(config.title)
|
|
285
|
+
ax.set_xlabel(config.xlabel)
|
|
286
|
+
ax.set_ylabel(config.ylabel)
|
|
287
|
+
|
|
288
|
+
if created_fig:
|
|
289
|
+
fig.tight_layout()
|
|
290
|
+
finalize_figure(fig, config)
|
|
291
|
+
|
|
292
|
+
return fig, ax
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Head Direction Cell Visualization
|
|
3
|
+
|
|
4
|
+
Plotting functions for head direction cell analysis.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def plot_polar_tuning(
|
|
12
|
+
angles: np.ndarray,
|
|
13
|
+
rates: np.ndarray,
|
|
14
|
+
preferred_direction: float | None = None,
|
|
15
|
+
mvl: float | None = None,
|
|
16
|
+
title: str = "Directional Tuning",
|
|
17
|
+
ax: plt.Axes | None = None,
|
|
18
|
+
) -> plt.Axes:
|
|
19
|
+
"""
|
|
20
|
+
Plot directional tuning curve in polar coordinates.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
angles : np.ndarray
|
|
25
|
+
Angular bins in radians
|
|
26
|
+
rates : np.ndarray
|
|
27
|
+
Firing rates for each bin
|
|
28
|
+
preferred_direction : float, optional
|
|
29
|
+
Preferred direction to mark (radians)
|
|
30
|
+
mvl : float, optional
|
|
31
|
+
Mean Vector Length to display
|
|
32
|
+
title : str, optional
|
|
33
|
+
Plot title
|
|
34
|
+
ax : plt.Axes, optional
|
|
35
|
+
Polar axes to plot on. If None, creates new figure.
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
ax : plt.Axes
|
|
40
|
+
The axes object
|
|
41
|
+
"""
|
|
42
|
+
if ax is None:
|
|
43
|
+
fig = plt.figure(figsize=(6, 6))
|
|
44
|
+
ax = fig.add_subplot(111, projection="polar")
|
|
45
|
+
|
|
46
|
+
# Plot tuning curve
|
|
47
|
+
ax.plot(angles, rates, "b-", linewidth=2, label="Tuning curve")
|
|
48
|
+
ax.fill_between(angles, 0, rates, alpha=0.3)
|
|
49
|
+
|
|
50
|
+
# Mark preferred direction
|
|
51
|
+
if preferred_direction is not None:
|
|
52
|
+
max_rate = np.max(rates)
|
|
53
|
+
ax.plot(
|
|
54
|
+
[preferred_direction, preferred_direction],
|
|
55
|
+
[0, max_rate],
|
|
56
|
+
"r--",
|
|
57
|
+
linewidth=2,
|
|
58
|
+
label=f"Preferred: {np.rad2deg(preferred_direction):.1f}°",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Add MVL to title
|
|
62
|
+
if mvl is not None:
|
|
63
|
+
title += f" (MVL: {mvl:.3f})"
|
|
64
|
+
|
|
65
|
+
ax.set_title(title, pad=20)
|
|
66
|
+
ax.set_theta_zero_location("E") # 0° to the right
|
|
67
|
+
ax.set_theta_direction(1) # Counterclockwise
|
|
68
|
+
ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.1))
|
|
69
|
+
|
|
70
|
+
return ax
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def plot_temporal_autocorr(
|
|
74
|
+
lags: np.ndarray,
|
|
75
|
+
acorr: np.ndarray,
|
|
76
|
+
title: str = "Temporal Autocorrelation",
|
|
77
|
+
ax: plt.Axes | None = None,
|
|
78
|
+
) -> plt.Axes:
|
|
79
|
+
"""
|
|
80
|
+
Plot temporal autocorrelation as bar plot.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
lags : np.ndarray
|
|
85
|
+
Time lags (ms or bins)
|
|
86
|
+
acorr : np.ndarray
|
|
87
|
+
Autocorrelation values
|
|
88
|
+
title : str, optional
|
|
89
|
+
Plot title
|
|
90
|
+
ax : plt.Axes, optional
|
|
91
|
+
Axes to plot on
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
ax : plt.Axes
|
|
96
|
+
The axes object
|
|
97
|
+
"""
|
|
98
|
+
if ax is None:
|
|
99
|
+
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
|
|
100
|
+
|
|
101
|
+
ax.bar(lags, acorr, width=np.diff(lags)[0] * 0.8, color="blue", alpha=0.7)
|
|
102
|
+
ax.axhline(y=0, color="k", linestyle="--", linewidth=0.5)
|
|
103
|
+
|
|
104
|
+
ax.set_title(title)
|
|
105
|
+
ax.set_xlabel("Time Lag (ms)")
|
|
106
|
+
ax.set_ylabel("Autocorrelation")
|
|
107
|
+
ax.grid(True, alpha=0.3)
|
|
108
|
+
|
|
109
|
+
return ax
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def plot_hd_analysis(
|
|
113
|
+
result,
|
|
114
|
+
time_stamps: np.ndarray | None = None,
|
|
115
|
+
head_directions: np.ndarray | None = None,
|
|
116
|
+
spike_times: np.ndarray | None = None,
|
|
117
|
+
figsize: tuple = (15, 5),
|
|
118
|
+
) -> plt.Figure:
|
|
119
|
+
"""
|
|
120
|
+
Comprehensive head direction analysis plot.
|
|
121
|
+
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
result : HDCellResult
|
|
125
|
+
Results from HeadDirectionAnalyzer
|
|
126
|
+
time_stamps : np.ndarray, optional
|
|
127
|
+
Time stamps for plotting trajectory
|
|
128
|
+
head_directions : np.ndarray, optional
|
|
129
|
+
Head direction time series
|
|
130
|
+
spike_times : np.ndarray, optional
|
|
131
|
+
Spike times
|
|
132
|
+
figsize : tuple, optional
|
|
133
|
+
Figure size
|
|
134
|
+
|
|
135
|
+
Returns
|
|
136
|
+
-------
|
|
137
|
+
fig : plt.Figure
|
|
138
|
+
The figure object
|
|
139
|
+
"""
|
|
140
|
+
fig = plt.figure(figsize=figsize)
|
|
141
|
+
|
|
142
|
+
# Plot 1: Polar tuning curve
|
|
143
|
+
ax1 = fig.add_subplot(131, projection="polar")
|
|
144
|
+
bin_centers, firing_rates = result.tuning_curve
|
|
145
|
+
plot_polar_tuning(
|
|
146
|
+
bin_centers,
|
|
147
|
+
firing_rates,
|
|
148
|
+
preferred_direction=result.preferred_direction,
|
|
149
|
+
mvl=result.mvl_hd,
|
|
150
|
+
ax=ax1,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Plot 2: Linear tuning curve
|
|
154
|
+
ax2 = fig.add_subplot(132)
|
|
155
|
+
ax2.plot(np.rad2deg(bin_centers), firing_rates, "b-", linewidth=2)
|
|
156
|
+
ax2.fill_between(np.rad2deg(bin_centers), 0, firing_rates, alpha=0.3)
|
|
157
|
+
if result.preferred_direction is not None:
|
|
158
|
+
ax2.axvline(
|
|
159
|
+
np.rad2deg(result.preferred_direction),
|
|
160
|
+
color="r",
|
|
161
|
+
linestyle="--",
|
|
162
|
+
linewidth=2,
|
|
163
|
+
label="Preferred direction",
|
|
164
|
+
)
|
|
165
|
+
ax2.set_xlabel("Head Direction (°)")
|
|
166
|
+
ax2.set_ylabel("Firing Rate (Hz)")
|
|
167
|
+
ax2.set_title("Linear Tuning Curve")
|
|
168
|
+
ax2.grid(True, alpha=0.3)
|
|
169
|
+
ax2.legend()
|
|
170
|
+
|
|
171
|
+
# Plot 3: Statistics and trajectory (if provided)
|
|
172
|
+
ax3 = fig.add_subplot(133)
|
|
173
|
+
if time_stamps is not None and head_directions is not None:
|
|
174
|
+
ax3.plot(time_stamps, np.rad2deg(head_directions), "k-", linewidth=0.5, alpha=0.5)
|
|
175
|
+
if spike_times is not None:
|
|
176
|
+
spike_hd = np.interp(spike_times, time_stamps, head_directions)
|
|
177
|
+
ax3.plot(spike_times, np.rad2deg(spike_hd), "r.", markersize=3)
|
|
178
|
+
ax3.set_xlabel("Time (s)")
|
|
179
|
+
ax3.set_ylabel("Head Direction (°)")
|
|
180
|
+
ax3.set_title("HD Trajectory with Spikes")
|
|
181
|
+
ax3.set_ylim([-180, 180])
|
|
182
|
+
else:
|
|
183
|
+
ax3.axis("off")
|
|
184
|
+
stats_text = f"""
|
|
185
|
+
Head Direction Cell Analysis
|
|
186
|
+
|
|
187
|
+
Classification: {"HD Cell" if result.is_hd else "Non-HD Cell"}
|
|
188
|
+
MVL (HD): {result.mvl_hd:.3f}
|
|
189
|
+
{"MVL (Theta): " + f"{result.mvl_theta:.3f}" if result.mvl_theta else ""}
|
|
190
|
+
Preferred Direction: {np.rad2deg(result.preferred_direction):.1f}°
|
|
191
|
+
Rayleigh p-value: {result.rayleigh_p:.6f}
|
|
192
|
+
|
|
193
|
+
Peak Firing Rate: {np.max(firing_rates):.2f} Hz
|
|
194
|
+
Mean Firing Rate: {np.mean(firing_rates):.2f} Hz
|
|
195
|
+
"""
|
|
196
|
+
ax3.text(0.1, 0.5, stats_text, fontsize=10, verticalalignment="center", family="monospace")
|
|
197
|
+
ax3.set_title("HD Statistics")
|
|
198
|
+
|
|
199
|
+
plt.tight_layout()
|
|
200
|
+
return fig
|
|
@@ -280,7 +280,7 @@ class PlotConfigs:
|
|
|
280
280
|
return PlotConfig.for_static_plot(**defaults)
|
|
281
281
|
|
|
282
282
|
@staticmethod
|
|
283
|
-
def
|
|
283
|
+
def cohospace_trajectory_2d(**kwargs: Any) -> PlotConfig:
|
|
284
284
|
defaults: dict[str, Any] = {
|
|
285
285
|
"title": "CohoSpace trajectory",
|
|
286
286
|
"xlabel": "theta1 (deg)",
|
|
@@ -291,7 +291,18 @@ class PlotConfigs:
|
|
|
291
291
|
return PlotConfig.for_static_plot(**defaults)
|
|
292
292
|
|
|
293
293
|
@staticmethod
|
|
294
|
-
def
|
|
294
|
+
def cohospace_trajectory_1d(**kwargs: Any) -> PlotConfig:
|
|
295
|
+
defaults: dict[str, Any] = {
|
|
296
|
+
"title": "CohoSpace trajectory (1D)",
|
|
297
|
+
"xlabel": "cos(theta)",
|
|
298
|
+
"ylabel": "sin(theta)",
|
|
299
|
+
"figsize": (6, 6),
|
|
300
|
+
}
|
|
301
|
+
defaults.update(kwargs)
|
|
302
|
+
return PlotConfig.for_static_plot(**defaults)
|
|
303
|
+
|
|
304
|
+
@staticmethod
|
|
305
|
+
def cohospace_neuron_2d(**kwargs: Any) -> PlotConfig:
|
|
295
306
|
defaults: dict[str, Any] = {
|
|
296
307
|
"title": "Neuron activity on coho-space",
|
|
297
308
|
"xlabel": "Theta 1 (deg)",
|
|
@@ -302,7 +313,18 @@ class PlotConfigs:
|
|
|
302
313
|
return PlotConfig.for_static_plot(**defaults)
|
|
303
314
|
|
|
304
315
|
@staticmethod
|
|
305
|
-
def
|
|
316
|
+
def cohospace_neuron_1d(**kwargs: Any) -> PlotConfig:
|
|
317
|
+
defaults: dict[str, Any] = {
|
|
318
|
+
"title": "Neuron activity on coho-space (1D)",
|
|
319
|
+
"xlabel": "cos(theta)",
|
|
320
|
+
"ylabel": "sin(theta)",
|
|
321
|
+
"figsize": (6, 6),
|
|
322
|
+
}
|
|
323
|
+
defaults.update(kwargs)
|
|
324
|
+
return PlotConfig.for_static_plot(**defaults)
|
|
325
|
+
|
|
326
|
+
@staticmethod
|
|
327
|
+
def cohospace_population_2d(**kwargs: Any) -> PlotConfig:
|
|
306
328
|
defaults: dict[str, Any] = {
|
|
307
329
|
"title": "Population activity on coho-space",
|
|
308
330
|
"xlabel": "Theta 1 (deg)",
|
|
@@ -312,6 +334,17 @@ class PlotConfigs:
|
|
|
312
334
|
defaults.update(kwargs)
|
|
313
335
|
return PlotConfig.for_static_plot(**defaults)
|
|
314
336
|
|
|
337
|
+
@staticmethod
|
|
338
|
+
def cohospace_population_1d(**kwargs: Any) -> PlotConfig:
|
|
339
|
+
defaults: dict[str, Any] = {
|
|
340
|
+
"title": "Population activity on coho-space (1D)",
|
|
341
|
+
"xlabel": "cos(theta)",
|
|
342
|
+
"ylabel": "sin(theta)",
|
|
343
|
+
"figsize": (6, 6),
|
|
344
|
+
}
|
|
345
|
+
defaults.update(kwargs)
|
|
346
|
+
return PlotConfig.for_static_plot(**defaults)
|
|
347
|
+
|
|
315
348
|
@staticmethod
|
|
316
349
|
def fr_heatmap(**kwargs: Any) -> PlotConfig:
|
|
317
350
|
defaults: dict[str, Any] = {
|
|
@@ -337,7 +370,7 @@ class PlotConfigs:
|
|
|
337
370
|
return PlotConfig.for_static_plot(**defaults)
|
|
338
371
|
|
|
339
372
|
@staticmethod
|
|
340
|
-
def
|
|
373
|
+
def path_compare_2d(**kwargs: Any) -> PlotConfig:
|
|
341
374
|
defaults: dict[str, Any] = {
|
|
342
375
|
"title": "Path Compare",
|
|
343
376
|
"figsize": (12, 5),
|
|
@@ -345,6 +378,15 @@ class PlotConfigs:
|
|
|
345
378
|
defaults.update(kwargs)
|
|
346
379
|
return PlotConfig.for_static_plot(**defaults)
|
|
347
380
|
|
|
381
|
+
@staticmethod
|
|
382
|
+
def path_compare_1d(**kwargs: Any) -> PlotConfig:
|
|
383
|
+
defaults: dict[str, Any] = {
|
|
384
|
+
"title": "Path Compare (1D)",
|
|
385
|
+
"figsize": (12, 5),
|
|
386
|
+
}
|
|
387
|
+
defaults.update(kwargs)
|
|
388
|
+
return PlotConfig.for_static_plot(**defaults)
|
|
389
|
+
|
|
348
390
|
@staticmethod
|
|
349
391
|
def raster_plot(mode: str = "block", **kwargs: Any) -> PlotConfig:
|
|
350
392
|
defaults: dict[str, Any] = {
|
canns/data/__init__.py
CHANGED
|
@@ -16,11 +16,13 @@ from .datasets import (
|
|
|
16
16
|
get_data_dir,
|
|
17
17
|
get_dataset_path,
|
|
18
18
|
get_huggingface_upload_guide,
|
|
19
|
+
get_left_right_data_session,
|
|
20
|
+
get_left_right_npz,
|
|
19
21
|
list_datasets,
|
|
20
22
|
load,
|
|
21
23
|
quick_setup,
|
|
22
24
|
)
|
|
23
|
-
from .loaders import load_grid_data, load_roi_data
|
|
25
|
+
from .loaders import load_grid_data, load_left_right_npz, load_roi_data
|
|
24
26
|
|
|
25
27
|
__all__ = [
|
|
26
28
|
# Dataset registry and management
|
|
@@ -31,6 +33,8 @@ __all__ = [
|
|
|
31
33
|
"list_datasets",
|
|
32
34
|
"download_dataset",
|
|
33
35
|
"get_dataset_path",
|
|
36
|
+
"get_left_right_data_session",
|
|
37
|
+
"get_left_right_npz",
|
|
34
38
|
"quick_setup",
|
|
35
39
|
"get_huggingface_upload_guide",
|
|
36
40
|
# Generic loading
|
|
@@ -38,4 +42,5 @@ __all__ = [
|
|
|
38
42
|
# Specialized loaders
|
|
39
43
|
"load_roi_data",
|
|
40
44
|
"load_grid_data",
|
|
45
|
+
"load_left_right_npz",
|
|
41
46
|
]
|