canns 0.13.1__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. canns/analyzer/data/__init__.py +5 -1
  2. canns/analyzer/data/asa/__init__.py +27 -12
  3. canns/analyzer/data/asa/cohospace.py +336 -10
  4. canns/analyzer/data/asa/config.py +3 -0
  5. canns/analyzer/data/asa/embedding.py +48 -45
  6. canns/analyzer/data/asa/path.py +104 -2
  7. canns/analyzer/data/asa/plotting.py +88 -19
  8. canns/analyzer/data/asa/tda.py +11 -4
  9. canns/analyzer/data/cell_classification/__init__.py +97 -0
  10. canns/analyzer/data/cell_classification/core/__init__.py +26 -0
  11. canns/analyzer/data/cell_classification/core/grid_cells.py +633 -0
  12. canns/analyzer/data/cell_classification/core/grid_modules_leiden.py +288 -0
  13. canns/analyzer/data/cell_classification/core/head_direction.py +347 -0
  14. canns/analyzer/data/cell_classification/core/spatial_analysis.py +431 -0
  15. canns/analyzer/data/cell_classification/io/__init__.py +5 -0
  16. canns/analyzer/data/cell_classification/io/matlab_loader.py +417 -0
  17. canns/analyzer/data/cell_classification/utils/__init__.py +39 -0
  18. canns/analyzer/data/cell_classification/utils/circular_stats.py +383 -0
  19. canns/analyzer/data/cell_classification/utils/correlation.py +318 -0
  20. canns/analyzer/data/cell_classification/utils/geometry.py +442 -0
  21. canns/analyzer/data/cell_classification/utils/image_processing.py +416 -0
  22. canns/analyzer/data/cell_classification/visualization/__init__.py +19 -0
  23. canns/analyzer/data/cell_classification/visualization/grid_plots.py +292 -0
  24. canns/analyzer/data/cell_classification/visualization/hd_plots.py +200 -0
  25. canns/analyzer/metrics/__init__.py +2 -1
  26. canns/analyzer/visualization/core/config.py +46 -4
  27. canns/data/__init__.py +6 -1
  28. canns/data/datasets.py +154 -1
  29. canns/data/loaders.py +37 -0
  30. canns/pipeline/__init__.py +13 -9
  31. canns/pipeline/__main__.py +6 -0
  32. canns/pipeline/asa/runner.py +105 -41
  33. canns/pipeline/asa_gui/__init__.py +68 -0
  34. canns/pipeline/asa_gui/__main__.py +6 -0
  35. canns/pipeline/asa_gui/analysis_modes/__init__.py +42 -0
  36. canns/pipeline/asa_gui/analysis_modes/base.py +39 -0
  37. canns/pipeline/asa_gui/analysis_modes/batch_mode.py +21 -0
  38. canns/pipeline/asa_gui/analysis_modes/cohomap_mode.py +56 -0
  39. canns/pipeline/asa_gui/analysis_modes/cohospace_mode.py +194 -0
  40. canns/pipeline/asa_gui/analysis_modes/decode_mode.py +52 -0
  41. canns/pipeline/asa_gui/analysis_modes/fr_mode.py +81 -0
  42. canns/pipeline/asa_gui/analysis_modes/frm_mode.py +92 -0
  43. canns/pipeline/asa_gui/analysis_modes/gridscore_mode.py +123 -0
  44. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +199 -0
  45. canns/pipeline/asa_gui/analysis_modes/tda_mode.py +112 -0
  46. canns/pipeline/asa_gui/app.py +29 -0
  47. canns/pipeline/asa_gui/controllers/__init__.py +6 -0
  48. canns/pipeline/asa_gui/controllers/analysis_controller.py +59 -0
  49. canns/pipeline/asa_gui/controllers/preprocess_controller.py +89 -0
  50. canns/pipeline/asa_gui/core/__init__.py +15 -0
  51. canns/pipeline/asa_gui/core/cache.py +14 -0
  52. canns/pipeline/asa_gui/core/runner.py +1936 -0
  53. canns/pipeline/asa_gui/core/state.py +324 -0
  54. canns/pipeline/asa_gui/core/worker.py +260 -0
  55. canns/pipeline/asa_gui/main_window.py +184 -0
  56. canns/pipeline/asa_gui/models/__init__.py +7 -0
  57. canns/pipeline/asa_gui/models/config.py +14 -0
  58. canns/pipeline/asa_gui/models/job.py +31 -0
  59. canns/pipeline/asa_gui/models/presets.py +21 -0
  60. canns/pipeline/asa_gui/resources/__init__.py +16 -0
  61. canns/pipeline/asa_gui/resources/dark.qss +167 -0
  62. canns/pipeline/asa_gui/resources/light.qss +163 -0
  63. canns/pipeline/asa_gui/resources/styles.qss +130 -0
  64. canns/pipeline/asa_gui/utils/__init__.py +1 -0
  65. canns/pipeline/asa_gui/utils/formatters.py +15 -0
  66. canns/pipeline/asa_gui/utils/io_adapters.py +40 -0
  67. canns/pipeline/asa_gui/utils/validators.py +41 -0
  68. canns/pipeline/asa_gui/views/__init__.py +1 -0
  69. canns/pipeline/asa_gui/views/help_content.py +171 -0
  70. canns/pipeline/asa_gui/views/pages/__init__.py +6 -0
  71. canns/pipeline/asa_gui/views/pages/analysis_page.py +565 -0
  72. canns/pipeline/asa_gui/views/pages/preprocess_page.py +492 -0
  73. canns/pipeline/asa_gui/views/panels/__init__.py +1 -0
  74. canns/pipeline/asa_gui/views/widgets/__init__.py +21 -0
  75. canns/pipeline/asa_gui/views/widgets/artifacts_tab.py +44 -0
  76. canns/pipeline/asa_gui/views/widgets/drop_zone.py +80 -0
  77. canns/pipeline/asa_gui/views/widgets/file_list.py +27 -0
  78. canns/pipeline/asa_gui/views/widgets/gridscore_tab.py +308 -0
  79. canns/pipeline/asa_gui/views/widgets/help_dialog.py +27 -0
  80. canns/pipeline/asa_gui/views/widgets/image_tab.py +50 -0
  81. canns/pipeline/asa_gui/views/widgets/image_viewer.py +97 -0
  82. canns/pipeline/asa_gui/views/widgets/log_box.py +16 -0
  83. canns/pipeline/asa_gui/views/widgets/pathcompare_tab.py +200 -0
  84. canns/pipeline/asa_gui/views/widgets/popup_combo.py +25 -0
  85. canns/pipeline/gallery/__init__.py +15 -5
  86. canns/pipeline/gallery/__main__.py +11 -0
  87. canns/pipeline/gallery/app.py +705 -0
  88. canns/pipeline/gallery/runner.py +790 -0
  89. canns/pipeline/gallery/state.py +51 -0
  90. canns/pipeline/gallery/styles.tcss +123 -0
  91. canns/pipeline/launcher.py +81 -0
  92. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/METADATA +11 -1
  93. canns-0.14.0.dist-info/RECORD +163 -0
  94. canns-0.14.0.dist-info/entry_points.txt +5 -0
  95. canns/pipeline/_base.py +0 -50
  96. canns-0.13.1.dist-info/RECORD +0 -89
  97. canns-0.13.1.dist-info/entry_points.txt +0 -3
  98. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/WHEEL +0 -0
  99. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,292 @@
1
+ """Grid cell visualization utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ from matplotlib import patches
9
+ from matplotlib import pyplot as plt
10
+
11
+ from canns.analyzer.visualization.core.config import PlotConfig, PlotConfigs, finalize_figure
12
+
13
+
14
+ def _ensure_plot_config(
15
+ config: PlotConfig | None,
16
+ factory,
17
+ *,
18
+ kwargs: dict[str, Any] | None = None,
19
+ **defaults: Any,
20
+ ) -> PlotConfig:
21
+ if config is None:
22
+ defaults.update({"kwargs": kwargs or {}})
23
+ return factory(**defaults)
24
+
25
+ if kwargs:
26
+ config_kwargs = config.kwargs or {}
27
+ config_kwargs.update(kwargs)
28
+ config.kwargs = config_kwargs
29
+ return config
30
+
31
+
32
+ def plot_autocorrelogram(
33
+ autocorr: np.ndarray,
34
+ config: PlotConfig | None = None,
35
+ *,
36
+ gridness_score: float | None = None,
37
+ center_radius: float | None = None,
38
+ peak_locations: np.ndarray | None = None,
39
+ title: str = "Spatial Autocorrelation",
40
+ xlabel: str = "X Lag (bins)",
41
+ ylabel: str = "Y Lag (bins)",
42
+ figsize: tuple[int, int] = (6, 6),
43
+ save_path: str | None = None,
44
+ show: bool = True,
45
+ ax: plt.Axes | None = None,
46
+ **kwargs: Any,
47
+ ) -> tuple[plt.Figure, plt.Axes]:
48
+ """Plot 2D autocorrelogram with optional annotations."""
49
+ config = _ensure_plot_config(
50
+ config,
51
+ PlotConfigs.grid_autocorrelation,
52
+ title=title,
53
+ xlabel=xlabel,
54
+ ylabel=ylabel,
55
+ figsize=figsize,
56
+ save_path=save_path,
57
+ show=show,
58
+ kwargs=kwargs,
59
+ )
60
+
61
+ created_fig = False
62
+ if ax is None:
63
+ fig, ax = plt.subplots(1, 1, figsize=config.figsize)
64
+ created_fig = True
65
+ else:
66
+ fig = ax.figure
67
+
68
+ # Plot autocorrelogram
69
+ plot_kwargs = config.to_matplotlib_kwargs()
70
+ plot_kwargs.setdefault("origin", "lower")
71
+ plot_kwargs.setdefault("interpolation", "bilinear")
72
+ im = ax.imshow(autocorr, **plot_kwargs)
73
+ fig.colorbar(im, ax=ax, label="Correlation")
74
+
75
+ # Add center circle if radius provided
76
+ if center_radius is not None:
77
+ center = np.array(autocorr.shape) / 2
78
+ circle = patches.Circle(
79
+ (center[1], center[0]),
80
+ center_radius,
81
+ fill=False,
82
+ edgecolor="red",
83
+ linewidth=2,
84
+ linestyle="--",
85
+ label="Center field",
86
+ )
87
+ ax.add_patch(circle)
88
+
89
+ # Mark peak locations if provided
90
+ if peak_locations is not None:
91
+ ax.plot(
92
+ peak_locations[:, 0],
93
+ peak_locations[:, 1],
94
+ "r+",
95
+ markersize=10,
96
+ markeredgewidth=2,
97
+ label="Grid peaks",
98
+ )
99
+
100
+ # Add gridness score to title
101
+ title_text = config.title
102
+ if gridness_score is not None:
103
+ title_text = f"{config.title} (Gridness: {gridness_score:.3f})"
104
+
105
+ ax.set_title(title_text)
106
+ ax.set_xlabel(config.xlabel)
107
+ ax.set_ylabel(config.ylabel)
108
+
109
+ if center_radius is not None or peak_locations is not None:
110
+ ax.legend(loc="upper right")
111
+
112
+ if created_fig:
113
+ fig.tight_layout()
114
+ finalize_figure(fig, config, rasterize_artists=[im] if config.rasterized else None)
115
+
116
+ return fig, ax
117
+
118
+
119
+ def plot_gridness_analysis(
120
+ rate_map: np.ndarray,
121
+ autocorr: np.ndarray,
122
+ result,
123
+ config: PlotConfig | None = None,
124
+ *,
125
+ title: str = "Grid Cell Analysis",
126
+ figsize: tuple[int, int] = (15, 5),
127
+ save_path: str | None = None,
128
+ show: bool = True,
129
+ ) -> plt.Figure:
130
+ """Comprehensive grid analysis plot with rate map, autocorr, and statistics."""
131
+ config = _ensure_plot_config(
132
+ config,
133
+ PlotConfig.for_static_plot,
134
+ title=title,
135
+ figsize=figsize,
136
+ save_path=save_path,
137
+ show=show,
138
+ kwargs={},
139
+ )
140
+
141
+ fig, axes = plt.subplots(1, 3, figsize=config.figsize)
142
+
143
+ # Plot 1: Rate map
144
+ im1 = axes[0].imshow(rate_map, cmap="hot", origin="lower")
145
+ axes[0].set_title("Firing Rate Map")
146
+ axes[0].set_xlabel("X (bins)")
147
+ axes[0].set_ylabel("Y (bins)")
148
+ plt.colorbar(im1, ax=axes[0], label="Rate (Hz)")
149
+
150
+ # Plot 2: Autocorrelogram with annotations
151
+ plot_autocorrelogram(
152
+ autocorr,
153
+ gridness_score=result.score,
154
+ center_radius=result.center_radius,
155
+ peak_locations=result.peak_locations,
156
+ title="Autocorrelogram",
157
+ ax=axes[1],
158
+ )
159
+
160
+ # Plot 3: Grid statistics
161
+ axes[2].axis("off")
162
+ stats_text = f"""
163
+ Grid Cell Analysis
164
+
165
+ Gridness Score: {result.score:.3f}
166
+ Center Radius: {result.center_radius:.1f} bins
167
+ Optimal Radius: {result.optimal_radius:.1f} bins
168
+
169
+ Grid Spacing (bins):
170
+ {result.spacing[0]:.2f}
171
+ {result.spacing[1]:.2f}
172
+ {result.spacing[2]:.2f}
173
+
174
+ Grid Orientation (°):
175
+ {result.orientation[0]:.1f}
176
+ {result.orientation[1]:.1f}
177
+ {result.orientation[2]:.1f}
178
+
179
+ Ellipse Parameters:
180
+ Center: ({result.ellipse[0]:.1f}, {result.ellipse[1]:.1f})
181
+ Radii: ({result.ellipse[2]:.1f}, {result.ellipse[3]:.1f})
182
+ Angle: {result.ellipse_theta_deg:.1f}°
183
+ """
184
+ axes[2].text(0.1, 0.5, stats_text, fontsize=10, verticalalignment="center", family="monospace")
185
+ axes[2].set_title("Grid Statistics")
186
+
187
+ fig.suptitle(config.title, fontsize=14, fontweight="bold")
188
+ fig.tight_layout()
189
+ finalize_figure(fig, config)
190
+ return fig
191
+
192
+
193
+ def plot_rate_map(
194
+ rate_map: np.ndarray,
195
+ config: PlotConfig | None = None,
196
+ *,
197
+ title: str = "Firing Field (Rate Map)",
198
+ xlabel: str = "X Position (bins)",
199
+ ylabel: str = "Y Position (bins)",
200
+ figsize: tuple[int, int] = (6, 6),
201
+ colorbar: bool = True,
202
+ save_path: str | None = None,
203
+ show: bool = True,
204
+ ax: plt.Axes | None = None,
205
+ **kwargs: Any,
206
+ ) -> tuple[plt.Figure, plt.Axes]:
207
+ """Plot 2D spatial firing rate map."""
208
+ config = _ensure_plot_config(
209
+ config,
210
+ PlotConfigs.firing_field_heatmap,
211
+ title=title,
212
+ xlabel=xlabel,
213
+ ylabel=ylabel,
214
+ figsize=figsize,
215
+ save_path=save_path,
216
+ show=show,
217
+ kwargs=kwargs,
218
+ )
219
+
220
+ created_fig = False
221
+ if ax is None:
222
+ fig, ax = plt.subplots(1, 1, figsize=config.figsize)
223
+ created_fig = True
224
+ else:
225
+ fig = ax.figure
226
+
227
+ plot_kwargs = config.to_matplotlib_kwargs()
228
+ plot_kwargs.setdefault("origin", "lower")
229
+ plot_kwargs.setdefault("interpolation", "bilinear")
230
+ im = ax.imshow(rate_map, **plot_kwargs)
231
+
232
+ if colorbar:
233
+ fig.colorbar(im, ax=ax, label="Firing Rate (Hz)")
234
+
235
+ ax.set_title(config.title)
236
+ ax.set_xlabel(config.xlabel)
237
+ ax.set_ylabel(config.ylabel)
238
+
239
+ if created_fig:
240
+ fig.tight_layout()
241
+ finalize_figure(fig, config, rasterize_artists=[im] if config.rasterized else None)
242
+
243
+ return fig, ax
244
+
245
+
246
+ def plot_grid_score_histogram(
247
+ scores: np.ndarray,
248
+ config: PlotConfig | None = None,
249
+ *,
250
+ bins: int = 30,
251
+ title: str = "Grid Score Distribution",
252
+ xlabel: str = "Grid Score",
253
+ ylabel: str = "Count",
254
+ figsize: tuple[int, int] = (6, 4),
255
+ save_path: str | None = None,
256
+ show: bool = True,
257
+ ax: plt.Axes | None = None,
258
+ **kwargs: Any,
259
+ ) -> tuple[plt.Figure, plt.Axes]:
260
+ """Plot histogram of gridness scores."""
261
+ config = _ensure_plot_config(
262
+ config,
263
+ PlotConfig.for_static_plot,
264
+ title=title,
265
+ xlabel=xlabel,
266
+ ylabel=ylabel,
267
+ figsize=figsize,
268
+ save_path=save_path,
269
+ show=show,
270
+ kwargs=kwargs,
271
+ )
272
+
273
+ scores = np.asarray(scores, dtype=float)
274
+ scores = scores[np.isfinite(scores)]
275
+
276
+ created_fig = False
277
+ if ax is None:
278
+ fig, ax = plt.subplots(1, 1, figsize=config.figsize)
279
+ created_fig = True
280
+ else:
281
+ fig = ax.figure
282
+
283
+ ax.hist(scores, bins=bins, **config.to_matplotlib_kwargs())
284
+ ax.set_title(config.title)
285
+ ax.set_xlabel(config.xlabel)
286
+ ax.set_ylabel(config.ylabel)
287
+
288
+ if created_fig:
289
+ fig.tight_layout()
290
+ finalize_figure(fig, config)
291
+
292
+ return fig, ax
@@ -0,0 +1,200 @@
1
+ """
2
+ Head Direction Cell Visualization
3
+
4
+ Plotting functions for head direction cell analysis.
5
+ """
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+
10
+
11
+ def plot_polar_tuning(
12
+ angles: np.ndarray,
13
+ rates: np.ndarray,
14
+ preferred_direction: float | None = None,
15
+ mvl: float | None = None,
16
+ title: str = "Directional Tuning",
17
+ ax: plt.Axes | None = None,
18
+ ) -> plt.Axes:
19
+ """
20
+ Plot directional tuning curve in polar coordinates.
21
+
22
+ Parameters
23
+ ----------
24
+ angles : np.ndarray
25
+ Angular bins in radians
26
+ rates : np.ndarray
27
+ Firing rates for each bin
28
+ preferred_direction : float, optional
29
+ Preferred direction to mark (radians)
30
+ mvl : float, optional
31
+ Mean Vector Length to display
32
+ title : str, optional
33
+ Plot title
34
+ ax : plt.Axes, optional
35
+ Polar axes to plot on. If None, creates new figure.
36
+
37
+ Returns
38
+ -------
39
+ ax : plt.Axes
40
+ The axes object
41
+ """
42
+ if ax is None:
43
+ fig = plt.figure(figsize=(6, 6))
44
+ ax = fig.add_subplot(111, projection="polar")
45
+
46
+ # Plot tuning curve
47
+ ax.plot(angles, rates, "b-", linewidth=2, label="Tuning curve")
48
+ ax.fill_between(angles, 0, rates, alpha=0.3)
49
+
50
+ # Mark preferred direction
51
+ if preferred_direction is not None:
52
+ max_rate = np.max(rates)
53
+ ax.plot(
54
+ [preferred_direction, preferred_direction],
55
+ [0, max_rate],
56
+ "r--",
57
+ linewidth=2,
58
+ label=f"Preferred: {np.rad2deg(preferred_direction):.1f}°",
59
+ )
60
+
61
+ # Add MVL to title
62
+ if mvl is not None:
63
+ title += f" (MVL: {mvl:.3f})"
64
+
65
+ ax.set_title(title, pad=20)
66
+ ax.set_theta_zero_location("E") # 0° to the right
67
+ ax.set_theta_direction(1) # Counterclockwise
68
+ ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.1))
69
+
70
+ return ax
71
+
72
+
73
+ def plot_temporal_autocorr(
74
+ lags: np.ndarray,
75
+ acorr: np.ndarray,
76
+ title: str = "Temporal Autocorrelation",
77
+ ax: plt.Axes | None = None,
78
+ ) -> plt.Axes:
79
+ """
80
+ Plot temporal autocorrelation as bar plot.
81
+
82
+ Parameters
83
+ ----------
84
+ lags : np.ndarray
85
+ Time lags (ms or bins)
86
+ acorr : np.ndarray
87
+ Autocorrelation values
88
+ title : str, optional
89
+ Plot title
90
+ ax : plt.Axes, optional
91
+ Axes to plot on
92
+
93
+ Returns
94
+ -------
95
+ ax : plt.Axes
96
+ The axes object
97
+ """
98
+ if ax is None:
99
+ fig, ax = plt.subplots(1, 1, figsize=(8, 4))
100
+
101
+ ax.bar(lags, acorr, width=np.diff(lags)[0] * 0.8, color="blue", alpha=0.7)
102
+ ax.axhline(y=0, color="k", linestyle="--", linewidth=0.5)
103
+
104
+ ax.set_title(title)
105
+ ax.set_xlabel("Time Lag (ms)")
106
+ ax.set_ylabel("Autocorrelation")
107
+ ax.grid(True, alpha=0.3)
108
+
109
+ return ax
110
+
111
+
112
+ def plot_hd_analysis(
113
+ result,
114
+ time_stamps: np.ndarray | None = None,
115
+ head_directions: np.ndarray | None = None,
116
+ spike_times: np.ndarray | None = None,
117
+ figsize: tuple = (15, 5),
118
+ ) -> plt.Figure:
119
+ """
120
+ Comprehensive head direction analysis plot.
121
+
122
+ Parameters
123
+ ----------
124
+ result : HDCellResult
125
+ Results from HeadDirectionAnalyzer
126
+ time_stamps : np.ndarray, optional
127
+ Time stamps for plotting trajectory
128
+ head_directions : np.ndarray, optional
129
+ Head direction time series
130
+ spike_times : np.ndarray, optional
131
+ Spike times
132
+ figsize : tuple, optional
133
+ Figure size
134
+
135
+ Returns
136
+ -------
137
+ fig : plt.Figure
138
+ The figure object
139
+ """
140
+ fig = plt.figure(figsize=figsize)
141
+
142
+ # Plot 1: Polar tuning curve
143
+ ax1 = fig.add_subplot(131, projection="polar")
144
+ bin_centers, firing_rates = result.tuning_curve
145
+ plot_polar_tuning(
146
+ bin_centers,
147
+ firing_rates,
148
+ preferred_direction=result.preferred_direction,
149
+ mvl=result.mvl_hd,
150
+ ax=ax1,
151
+ )
152
+
153
+ # Plot 2: Linear tuning curve
154
+ ax2 = fig.add_subplot(132)
155
+ ax2.plot(np.rad2deg(bin_centers), firing_rates, "b-", linewidth=2)
156
+ ax2.fill_between(np.rad2deg(bin_centers), 0, firing_rates, alpha=0.3)
157
+ if result.preferred_direction is not None:
158
+ ax2.axvline(
159
+ np.rad2deg(result.preferred_direction),
160
+ color="r",
161
+ linestyle="--",
162
+ linewidth=2,
163
+ label="Preferred direction",
164
+ )
165
+ ax2.set_xlabel("Head Direction (°)")
166
+ ax2.set_ylabel("Firing Rate (Hz)")
167
+ ax2.set_title("Linear Tuning Curve")
168
+ ax2.grid(True, alpha=0.3)
169
+ ax2.legend()
170
+
171
+ # Plot 3: Statistics and trajectory (if provided)
172
+ ax3 = fig.add_subplot(133)
173
+ if time_stamps is not None and head_directions is not None:
174
+ ax3.plot(time_stamps, np.rad2deg(head_directions), "k-", linewidth=0.5, alpha=0.5)
175
+ if spike_times is not None:
176
+ spike_hd = np.interp(spike_times, time_stamps, head_directions)
177
+ ax3.plot(spike_times, np.rad2deg(spike_hd), "r.", markersize=3)
178
+ ax3.set_xlabel("Time (s)")
179
+ ax3.set_ylabel("Head Direction (°)")
180
+ ax3.set_title("HD Trajectory with Spikes")
181
+ ax3.set_ylim([-180, 180])
182
+ else:
183
+ ax3.axis("off")
184
+ stats_text = f"""
185
+ Head Direction Cell Analysis
186
+
187
+ Classification: {"HD Cell" if result.is_hd else "Non-HD Cell"}
188
+ MVL (HD): {result.mvl_hd:.3f}
189
+ {"MVL (Theta): " + f"{result.mvl_theta:.3f}" if result.mvl_theta else ""}
190
+ Preferred Direction: {np.rad2deg(result.preferred_direction):.1f}°
191
+ Rayleigh p-value: {result.rayleigh_p:.6f}
192
+
193
+ Peak Firing Rate: {np.max(firing_rates):.2f} Hz
194
+ Mean Firing Rate: {np.mean(firing_rates):.2f} Hz
195
+ """
196
+ ax3.text(0.1, 0.5, stats_text, fontsize=10, verticalalignment="center", family="monospace")
197
+ ax3.set_title("HD Statistics")
198
+
199
+ plt.tight_layout()
200
+ return fig
@@ -1,8 +1,9 @@
1
1
  """Model metrics computation utilities."""
2
2
 
3
- from . import spatial_metrics, utils
3
+ from . import spatial_metrics, systematic_ratemap, utils
4
4
 
5
5
  __all__ = [
6
6
  "spatial_metrics",
7
+ "systematic_ratemap",
7
8
  "utils",
8
9
  ]
@@ -280,7 +280,7 @@ class PlotConfigs:
280
280
  return PlotConfig.for_static_plot(**defaults)
281
281
 
282
282
  @staticmethod
283
- def cohospace_trajectory(**kwargs: Any) -> PlotConfig:
283
+ def cohospace_trajectory_2d(**kwargs: Any) -> PlotConfig:
284
284
  defaults: dict[str, Any] = {
285
285
  "title": "CohoSpace trajectory",
286
286
  "xlabel": "theta1 (deg)",
@@ -291,7 +291,18 @@ class PlotConfigs:
291
291
  return PlotConfig.for_static_plot(**defaults)
292
292
 
293
293
  @staticmethod
294
- def cohospace_neuron(**kwargs: Any) -> PlotConfig:
294
+ def cohospace_trajectory_1d(**kwargs: Any) -> PlotConfig:
295
+ defaults: dict[str, Any] = {
296
+ "title": "CohoSpace trajectory (1D)",
297
+ "xlabel": "cos(theta)",
298
+ "ylabel": "sin(theta)",
299
+ "figsize": (6, 6),
300
+ }
301
+ defaults.update(kwargs)
302
+ return PlotConfig.for_static_plot(**defaults)
303
+
304
+ @staticmethod
305
+ def cohospace_neuron_2d(**kwargs: Any) -> PlotConfig:
295
306
  defaults: dict[str, Any] = {
296
307
  "title": "Neuron activity on coho-space",
297
308
  "xlabel": "Theta 1 (deg)",
@@ -302,7 +313,18 @@ class PlotConfigs:
302
313
  return PlotConfig.for_static_plot(**defaults)
303
314
 
304
315
  @staticmethod
305
- def cohospace_population(**kwargs: Any) -> PlotConfig:
316
+ def cohospace_neuron_1d(**kwargs: Any) -> PlotConfig:
317
+ defaults: dict[str, Any] = {
318
+ "title": "Neuron activity on coho-space (1D)",
319
+ "xlabel": "cos(theta)",
320
+ "ylabel": "sin(theta)",
321
+ "figsize": (6, 6),
322
+ }
323
+ defaults.update(kwargs)
324
+ return PlotConfig.for_static_plot(**defaults)
325
+
326
+ @staticmethod
327
+ def cohospace_population_2d(**kwargs: Any) -> PlotConfig:
306
328
  defaults: dict[str, Any] = {
307
329
  "title": "Population activity on coho-space",
308
330
  "xlabel": "Theta 1 (deg)",
@@ -312,6 +334,17 @@ class PlotConfigs:
312
334
  defaults.update(kwargs)
313
335
  return PlotConfig.for_static_plot(**defaults)
314
336
 
337
+ @staticmethod
338
+ def cohospace_population_1d(**kwargs: Any) -> PlotConfig:
339
+ defaults: dict[str, Any] = {
340
+ "title": "Population activity on coho-space (1D)",
341
+ "xlabel": "cos(theta)",
342
+ "ylabel": "sin(theta)",
343
+ "figsize": (6, 6),
344
+ }
345
+ defaults.update(kwargs)
346
+ return PlotConfig.for_static_plot(**defaults)
347
+
315
348
  @staticmethod
316
349
  def fr_heatmap(**kwargs: Any) -> PlotConfig:
317
350
  defaults: dict[str, Any] = {
@@ -337,7 +370,7 @@ class PlotConfigs:
337
370
  return PlotConfig.for_static_plot(**defaults)
338
371
 
339
372
  @staticmethod
340
- def path_compare(**kwargs: Any) -> PlotConfig:
373
+ def path_compare_2d(**kwargs: Any) -> PlotConfig:
341
374
  defaults: dict[str, Any] = {
342
375
  "title": "Path Compare",
343
376
  "figsize": (12, 5),
@@ -345,6 +378,15 @@ class PlotConfigs:
345
378
  defaults.update(kwargs)
346
379
  return PlotConfig.for_static_plot(**defaults)
347
380
 
381
+ @staticmethod
382
+ def path_compare_1d(**kwargs: Any) -> PlotConfig:
383
+ defaults: dict[str, Any] = {
384
+ "title": "Path Compare (1D)",
385
+ "figsize": (12, 5),
386
+ }
387
+ defaults.update(kwargs)
388
+ return PlotConfig.for_static_plot(**defaults)
389
+
348
390
  @staticmethod
349
391
  def raster_plot(mode: str = "block", **kwargs: Any) -> PlotConfig:
350
392
  defaults: dict[str, Any] = {
canns/data/__init__.py CHANGED
@@ -16,11 +16,13 @@ from .datasets import (
16
16
  get_data_dir,
17
17
  get_dataset_path,
18
18
  get_huggingface_upload_guide,
19
+ get_left_right_data_session,
20
+ get_left_right_npz,
19
21
  list_datasets,
20
22
  load,
21
23
  quick_setup,
22
24
  )
23
- from .loaders import load_grid_data, load_roi_data
25
+ from .loaders import load_grid_data, load_left_right_npz, load_roi_data
24
26
 
25
27
  __all__ = [
26
28
  # Dataset registry and management
@@ -31,6 +33,8 @@ __all__ = [
31
33
  "list_datasets",
32
34
  "download_dataset",
33
35
  "get_dataset_path",
36
+ "get_left_right_data_session",
37
+ "get_left_right_npz",
34
38
  "quick_setup",
35
39
  "get_huggingface_upload_guide",
36
40
  # Generic loading
@@ -38,4 +42,5 @@ __all__ = [
38
42
  # Specialized loaders
39
43
  "load_roi_data",
40
44
  "load_grid_data",
45
+ "load_left_right_npz",
41
46
  ]