canns 0.14.3__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. canns/analyzer/data/asa/__init__.py +56 -21
  2. canns/analyzer/data/asa/coho.py +21 -0
  3. canns/analyzer/data/asa/cohomap.py +453 -0
  4. canns/analyzer/data/asa/cohomap_vectors.py +365 -0
  5. canns/analyzer/data/asa/cohospace.py +155 -1165
  6. canns/analyzer/data/asa/cohospace_phase_centers.py +119 -0
  7. canns/analyzer/data/asa/cohospace_scatter.py +1115 -0
  8. canns/analyzer/data/asa/embedding.py +5 -7
  9. canns/analyzer/data/asa/fr.py +1 -8
  10. canns/analyzer/data/asa/path.py +70 -0
  11. canns/analyzer/data/asa/plotting.py +5 -30
  12. canns/analyzer/data/asa/utils.py +160 -0
  13. canns/analyzer/data/cell_classification/__init__.py +10 -0
  14. canns/analyzer/data/cell_classification/core/__init__.py +4 -0
  15. canns/analyzer/data/cell_classification/core/btn.py +272 -0
  16. canns/analyzer/data/cell_classification/visualization/__init__.py +3 -0
  17. canns/analyzer/data/cell_classification/visualization/btn_plots.py +241 -0
  18. canns/analyzer/visualization/__init__.py +2 -0
  19. canns/analyzer/visualization/core/config.py +20 -0
  20. canns/analyzer/visualization/theta_sweep_plots.py +142 -0
  21. canns/pipeline/asa/runner.py +19 -19
  22. canns/pipeline/asa_gui/__init__.py +5 -3
  23. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +3 -1
  24. canns/pipeline/asa_gui/core/runner.py +23 -23
  25. canns/pipeline/asa_gui/views/pages/preprocess_page.py +7 -12
  26. {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/METADATA +1 -1
  27. {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/RECORD +30 -23
  28. canns/analyzer/data/asa/filters.py +0 -208
  29. {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/WHEEL +0 -0
  30. {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/entry_points.txt +0 -0
  31. {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,15 +1,37 @@
1
1
  from __future__ import annotations
2
2
 
3
- # Coho-space analysis + visualization
3
+ from .cohomap import (
4
+ cohomap,
5
+ fit_cohomap_stripes,
6
+ plot_cohomap,
7
+ )
8
+ from .cohomap_vectors import (
9
+ cohomap_vectors,
10
+ plot_cohomap_stripes,
11
+ plot_cohomap_vectors,
12
+ )
4
13
  from .cohospace import (
5
- compute_cohoscore_1d,
6
- compute_cohoscore_2d,
7
- plot_cohospace_neuron_1d,
8
- plot_cohospace_neuron_2d,
9
- plot_cohospace_population_1d,
10
- plot_cohospace_population_2d,
11
- plot_cohospace_trajectory_1d,
12
- plot_cohospace_trajectory_2d,
14
+ cohospace,
15
+ plot_cohospace,
16
+ plot_cohospace_skewed,
17
+ )
18
+ from .cohospace_phase_centers import (
19
+ cohospace_phase_centers,
20
+ plot_cohospace_phase_centers,
21
+ )
22
+
23
+ # Coho-space (scatter) analysis + visualization
24
+ from .cohospace_scatter import (
25
+ compute_cohoscore_scatter_1d,
26
+ compute_cohoscore_scatter_2d,
27
+ plot_cohospace_scatter_neuron_1d,
28
+ plot_cohospace_scatter_neuron_2d,
29
+ plot_cohospace_scatter_neuron_skewed,
30
+ plot_cohospace_scatter_population_1d,
31
+ plot_cohospace_scatter_population_2d,
32
+ plot_cohospace_scatter_population_skewed,
33
+ plot_cohospace_scatter_trajectory_1d,
34
+ plot_cohospace_scatter_trajectory_2d,
13
35
  )
14
36
  from .config import (
15
37
  CANN2DError,
@@ -47,8 +69,8 @@ from .path import (
47
69
  from .plotting import (
48
70
  plot_2d_bump_on_manifold,
49
71
  plot_3d_bump_on_torus,
50
- plot_cohomap,
51
- plot_cohomap_multi,
72
+ plot_cohomap_scatter,
73
+ plot_cohomap_scatter_multi,
52
74
  plot_path_compare_1d,
53
75
  plot_path_compare_2d,
54
76
  plot_projection,
@@ -72,10 +94,21 @@ __all__ = [
72
94
  "plot_projection",
73
95
  "plot_path_compare_1d",
74
96
  "plot_path_compare_2d",
75
- "plot_cohomap",
76
- "plot_cohomap_multi",
97
+ "plot_cohomap_scatter",
98
+ "plot_cohomap_scatter_multi",
77
99
  "plot_3d_bump_on_torus",
78
100
  "plot_2d_bump_on_manifold",
101
+ "cohomap",
102
+ "fit_cohomap_stripes",
103
+ "plot_cohomap",
104
+ "cohospace",
105
+ "plot_cohospace",
106
+ "plot_cohospace_skewed",
107
+ "cohomap_vectors",
108
+ "plot_cohomap_stripes",
109
+ "plot_cohomap_vectors",
110
+ "cohospace_phase_centers",
111
+ "plot_cohospace_phase_centers",
79
112
  "BumpFitsConfig",
80
113
  "CANN1DPlotConfig",
81
114
  "create_1d_bump_animation",
@@ -85,14 +118,16 @@ __all__ = [
85
118
  "FRMResult",
86
119
  "compute_frm",
87
120
  "plot_frm",
88
- "plot_cohospace_trajectory_1d",
89
- "plot_cohospace_trajectory_2d",
90
- "plot_cohospace_neuron_1d",
91
- "plot_cohospace_neuron_2d",
92
- "plot_cohospace_population_1d",
93
- "plot_cohospace_population_2d",
94
- "compute_cohoscore_1d",
95
- "compute_cohoscore_2d",
121
+ "plot_cohospace_scatter_trajectory_1d",
122
+ "plot_cohospace_scatter_trajectory_2d",
123
+ "plot_cohospace_scatter_neuron_1d",
124
+ "plot_cohospace_scatter_neuron_2d",
125
+ "plot_cohospace_scatter_population_1d",
126
+ "plot_cohospace_scatter_population_2d",
127
+ "plot_cohospace_scatter_neuron_skewed",
128
+ "plot_cohospace_scatter_population_skewed",
129
+ "compute_cohoscore_scatter_1d",
130
+ "compute_cohoscore_scatter_2d",
96
131
  "align_coords_to_position_1d",
97
132
  "align_coords_to_position_2d",
98
133
  "apply_angle_scale",
@@ -0,0 +1,21 @@
1
+ """Shared helpers for CohoMap/CohoSpace analysis."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from .utils import (
6
+ _circmean,
7
+ _ensure_parent_dir,
8
+ _ensure_plot_config,
9
+ _extract_coords_and_times,
10
+ _phase_map_valid_fraction,
11
+ _smooth_circular_map,
12
+ )
13
+
14
+ __all__ = [
15
+ "_ensure_plot_config",
16
+ "_ensure_parent_dir",
17
+ "_circmean",
18
+ "_smooth_circular_map",
19
+ "_extract_coords_and_times",
20
+ "_phase_map_valid_fraction",
21
+ ]
@@ -0,0 +1,453 @@
1
+ """CohoMap (Ecoho-style) computation and plotting."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from scipy.ndimage import rotate
10
+ from scipy.optimize import minimize
11
+ from scipy.stats import binned_statistic_2d
12
+
13
+ from ...visualization.core import PlotConfig, finalize_figure
14
+ from .coho import (
15
+ _circmean,
16
+ _ensure_parent_dir,
17
+ _ensure_plot_config,
18
+ _extract_coords_and_times,
19
+ _phase_map_valid_fraction,
20
+ _smooth_circular_map,
21
+ )
22
+ from .path import parse_times_box_to_indices
23
+
24
+
25
+ def _select_phase_sign(
26
+ phase_map: np.ndarray,
27
+ params: np.ndarray,
28
+ *,
29
+ grid_size: int,
30
+ trim: int,
31
+ ) -> int:
32
+ mtot = np.asarray(phase_map)
33
+ expected = grid_size - 1 - 2 * trim
34
+ if mtot.ndim != 2 or mtot.shape[0] != expected or mtot.shape[1] != expected:
35
+ raise ValueError("phase_map shape does not match grid_size/trim for alignment")
36
+ nnans = ~np.isnan(mtot)
37
+ if not np.any(nnans):
38
+ return 1
39
+ x, _ = np.meshgrid(
40
+ np.linspace(0, 3 * np.pi, grid_size - 1),
41
+ np.linspace(0, 3 * np.pi, grid_size - 1),
42
+ )
43
+ x1 = rotate(x, params[0] * 360.0 / (2 * np.pi), reshape=False)
44
+ base = params[2] * x1[trim:-trim, trim:-trim] + params[1]
45
+ mtot_vals = (mtot[nnans]) % (2 * np.pi)
46
+ pm1 = (base % (2 * np.pi))[nnans] - mtot_vals
47
+ pm2 = ((2 * np.pi - base) % (2 * np.pi))[nnans] - mtot_vals
48
+ if np.sum(np.abs(pm1)) > np.sum(np.abs(pm2)):
49
+ return -1
50
+ return 1
51
+
52
+
53
+ def _rot_coord(
54
+ params1: np.ndarray,
55
+ params2: np.ndarray,
56
+ c1: np.ndarray,
57
+ c2: np.ndarray,
58
+ p: tuple[int, int],
59
+ ) -> np.ndarray:
60
+ """Transform and align decoded coordinates based on stripe orientation.
61
+
62
+ This function rotates and aligns coordinates to match the dominant stripe
63
+ orientation in the CohoMap. It handles the torus geometry by choosing the
64
+ appropriate coordinate system and applying reflections/rotations.
65
+
66
+ Parameters
67
+ ----------
68
+ params1 : np.ndarray
69
+ Stripe fit parameters for first dimension (includes orientation angle).
70
+ params2 : np.ndarray
71
+ Stripe fit parameters for second dimension (includes orientation angle).
72
+ c1 : np.ndarray
73
+ Decoded coordinates for first dimension.
74
+ c2 : np.ndarray
75
+ Decoded coordinates for second dimension.
76
+ p : tuple[int, int]
77
+ Direction indicators (1 or -1) for each dimension, controlling reflections.
78
+
79
+ Returns
80
+ -------
81
+ np.ndarray
82
+ Aligned coordinates of shape (N, 2), wrapped to [0, 2π).
83
+
84
+ Notes
85
+ -----
86
+ The function selects the coordinate system based on which dimension has
87
+ stronger horizontal alignment (larger |cos(angle)|), then applies geometric
88
+ transformations to align the coordinates with the stripe pattern.
89
+ """
90
+ if abs(np.cos(params1[0])) < abs(np.cos(params2[0])):
91
+ cc1 = c2.copy()
92
+ cc2 = c1.copy()
93
+ y = params1.copy()
94
+ x = params2.copy()
95
+ p = (p[1], p[0])
96
+ else:
97
+ cc1 = c1.copy()
98
+ cc2 = c2.copy()
99
+ x = params1.copy()
100
+ y = params2.copy()
101
+
102
+ if p[1] == -1:
103
+ cc2 = 2 * np.pi - cc2
104
+ if p[0] == -1:
105
+ cc1 = 2 * np.pi - cc1
106
+
107
+ alpha = y[0] - x[0]
108
+ if (alpha < 0) and (abs(alpha) > np.pi / 2):
109
+ cctmp = cc2.copy()
110
+ cc2 = cc1.copy()
111
+ cc1 = cctmp
112
+
113
+ if (alpha < 0) and (abs(alpha) < np.pi / 2):
114
+ cc1 = 2 * np.pi - cc1 + (np.pi / 3) * cc2
115
+ elif abs(alpha) > np.pi / 2:
116
+ cc2 = cc2 + (np.pi / 3) * cc1
117
+
118
+ return np.stack([cc1, cc2], axis=1) % (2 * np.pi)
119
+
120
+
121
+ def _toroidal_align_coords(
122
+ coords: np.ndarray,
123
+ phase_map1: np.ndarray,
124
+ phase_map2: np.ndarray,
125
+ *,
126
+ trim: int,
127
+ grid_size: int | None,
128
+ ) -> tuple[np.ndarray, float, float]:
129
+ """Align decoded coordinates to CohoMap stripe patterns.
130
+
131
+ This function fits stripe patterns to both phase maps, determines the
132
+ optimal orientation and phase signs, then transforms the decoded coordinates
133
+ to align with the detected stripe structure.
134
+
135
+ Parameters
136
+ ----------
137
+ coords : np.ndarray
138
+ Decoded coordinates of shape (N, 2+). Only first two columns are used.
139
+ phase_map1 : np.ndarray
140
+ Phase map for first dimension (2D grid).
141
+ phase_map2 : np.ndarray
142
+ Phase map for second dimension (2D grid), must match phase_map1 shape.
143
+ trim : int
144
+ Number of edge bins to trim when fitting stripes.
145
+ grid_size : int, optional
146
+ Grid size for stripe fitting. If None, inferred from phase_map shape.
147
+
148
+ Returns
149
+ -------
150
+ aligned : np.ndarray
151
+ Aligned coordinates of shape (N, 2), wrapped to [0, 2π).
152
+ f1 : float
153
+ Fit quality score for first dimension (higher is better).
154
+ f2 : float
155
+ Fit quality score for second dimension (higher is better).
156
+
157
+ Raises
158
+ ------
159
+ ValueError
160
+ If phase_map shapes don't match or coords has wrong shape.
161
+
162
+ Notes
163
+ -----
164
+ The alignment process:
165
+ 1. Fits stripe patterns to both phase maps using FFT-based detection
166
+ 2. Determines phase signs (±1) for each dimension
167
+ 3. Applies geometric transformations via `_rot_coord()` to align coordinates
168
+ """
169
+ if phase_map1.shape != phase_map2.shape:
170
+ raise ValueError("phase_map shapes do not match for alignment")
171
+ coords = np.asarray(coords)
172
+ if coords.ndim != 2 or coords.shape[1] < 2:
173
+ raise ValueError(f"coords must be (N,2+) array, got {coords.shape}")
174
+ if grid_size is None:
175
+ grid_size = int(phase_map1.shape[0]) + 2 * trim + 1
176
+ p1, f1 = fit_cohomap_stripes(phase_map1, grid_size=grid_size, trim=trim)
177
+ p2, f2 = fit_cohomap_stripes(phase_map2, grid_size=grid_size, trim=trim)
178
+ s1 = _select_phase_sign(phase_map1, p1, grid_size=grid_size, trim=trim)
179
+ s2 = _select_phase_sign(phase_map2, p2, grid_size=grid_size, trim=trim)
180
+ aligned = _rot_coord(p1, p2, coords[:, 0], coords[:, 1], (s1, s2))
181
+ return aligned, float(f1), float(f2)
182
+
183
+
184
+ def cohomap(
185
+ decoding_result: dict[str, Any],
186
+ position_data: dict[str, Any],
187
+ *,
188
+ coords_key: str | None = None,
189
+ bins: int = 101,
190
+ margin_frac: float = 0.0025,
191
+ smooth_sigma: float = 1.0,
192
+ fill_nan: bool = True,
193
+ fill_sigma: float | None = None,
194
+ fill_min_weight: float = 1e-3,
195
+ align_torus: bool = True,
196
+ align_trim: int = 25,
197
+ align_grid_size: int | None = None,
198
+ align_min_valid_frac: float | None = None,
199
+ align_max_fit_error: float | None = None,
200
+ ) -> dict[str, Any]:
201
+ """
202
+ Compute EcohoMap phase maps using circular-mean binning.
203
+
204
+ This mirrors GridCellTorus get_ang_hist: bin spatial positions and compute the
205
+ circular mean of each decoded angle within spatial bins, then smooth in sin/cos
206
+ space. Optional toroidal alignment follows the GridCellTorus stripe fit + rotation.
207
+ You can gate alignment by valid fraction or fit error thresholds.
208
+ """
209
+ coords, times_box = _extract_coords_and_times(decoding_result, coords_key)
210
+ if coords.ndim != 2 or coords.shape[1] < 2:
211
+ raise ValueError(f"coords must be (N,2+) array, got {coords.shape}")
212
+
213
+ xx = np.asarray(position_data["x"])
214
+ yy = np.asarray(position_data["y"])
215
+
216
+ if times_box is not None:
217
+ if "t" in position_data:
218
+ idx, _ = parse_times_box_to_indices(times_box, np.asarray(position_data["t"]))
219
+ xx = xx[idx]
220
+ yy = yy[idx]
221
+ else:
222
+ idx = np.asarray(times_box).astype(int)
223
+ xx = xx[idx]
224
+ yy = yy[idx]
225
+
226
+ if len(xx) != coords.shape[0]:
227
+ raise ValueError(
228
+ "Length mismatch: coords length does not match position length after times_box."
229
+ )
230
+
231
+ x_min, x_max = float(np.min(xx)), float(np.max(xx))
232
+ y_min, y_max = float(np.min(yy)), float(np.max(yy))
233
+ x_pad = (x_max - x_min) * margin_frac
234
+ y_pad = (y_max - y_min) * margin_frac
235
+
236
+ binsx = np.linspace(x_min + x_pad, x_max - x_pad, bins)
237
+ binsy = np.linspace(y_min + y_pad, y_max - y_pad, bins)
238
+
239
+ def _angle_hist(values: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
240
+ nnans = ~np.isnan(values)
241
+ mtot, x_edge, y_edge, _ = binned_statistic_2d(
242
+ xx[nnans],
243
+ yy[nnans],
244
+ values[nnans],
245
+ statistic=_circmean,
246
+ bins=(binsx, binsy),
247
+ range=None,
248
+ expand_binnumbers=True,
249
+ )
250
+ mtot = _smooth_circular_map(
251
+ mtot,
252
+ smooth_sigma,
253
+ fill_nan=fill_nan,
254
+ fill_sigma=fill_sigma,
255
+ fill_min_weight=fill_min_weight,
256
+ )
257
+ return mtot, x_edge, y_edge
258
+
259
+ coords_use = np.asarray(coords, float)
260
+ m1_raw, x_edge, y_edge = _angle_hist(coords_use[:, 0])
261
+ m2_raw, _, _ = _angle_hist(coords_use[:, 1])
262
+ aligned = False
263
+ align_error = None
264
+ align_valid_frac1 = None
265
+ align_valid_frac2 = None
266
+ align_fit_error1 = None
267
+ align_fit_error2 = None
268
+
269
+ if align_torus:
270
+ try:
271
+ align_valid_frac1 = _phase_map_valid_fraction(m1_raw)
272
+ align_valid_frac2 = _phase_map_valid_fraction(m2_raw)
273
+ min_valid = min(align_valid_frac1, align_valid_frac2)
274
+
275
+ if align_min_valid_frac is not None and min_valid < align_min_valid_frac:
276
+ align_error = (
277
+ f"valid fraction too low ({min_valid:.3f} < {align_min_valid_frac:.3f})"
278
+ )
279
+ else:
280
+ coords_aligned, f1, f2 = _toroidal_align_coords(
281
+ coords_use[:, :2],
282
+ m1_raw,
283
+ m2_raw,
284
+ trim=align_trim,
285
+ grid_size=align_grid_size,
286
+ )
287
+ align_fit_error1 = f1
288
+ align_fit_error2 = f2
289
+ if align_max_fit_error is not None and (
290
+ f1 > align_max_fit_error or f2 > align_max_fit_error
291
+ ):
292
+ align_error = (
293
+ f"fit error too high ({f1:.4f}, {f2:.4f} > {align_max_fit_error:.4f})"
294
+ )
295
+ else:
296
+ coords_use = coords_use.copy()
297
+ coords_use[:, :2] = coords_aligned
298
+ aligned = True
299
+ except Exception as exc:
300
+ align_error = str(exc)
301
+
302
+ if aligned:
303
+ m1, x_edge, y_edge = _angle_hist(coords_use[:, 0])
304
+ m2, _, _ = _angle_hist(coords_use[:, 1])
305
+ else:
306
+ m1, m2 = m1_raw, m2_raw
307
+
308
+ return {
309
+ "phase_map1": m1,
310
+ "phase_map2": m2,
311
+ "phase_map1_raw": m1_raw,
312
+ "phase_map2_raw": m2_raw,
313
+ "x_edge": x_edge,
314
+ "y_edge": y_edge,
315
+ "bins": bins,
316
+ "margin_frac": margin_frac,
317
+ "smooth_sigma": smooth_sigma,
318
+ "fill_nan": fill_nan,
319
+ "fill_sigma": fill_sigma,
320
+ "fill_min_weight": fill_min_weight,
321
+ "aligned": aligned,
322
+ "align_error": align_error,
323
+ "align_min_valid_frac": align_min_valid_frac,
324
+ "align_max_fit_error": align_max_fit_error,
325
+ "align_valid_frac1": align_valid_frac1,
326
+ "align_valid_frac2": align_valid_frac2,
327
+ "align_fit_error1": align_fit_error1,
328
+ "align_fit_error2": align_fit_error2,
329
+ }
330
+
331
+
332
+ def fit_cohomap_stripes(
333
+ phase_map: np.ndarray,
334
+ *,
335
+ grid_size: int | None = 151,
336
+ trim: int = 25,
337
+ angle_grid: int = 10,
338
+ phase_grid: int = 10,
339
+ spacing_grid: int = 10,
340
+ spacing_range: tuple[float, float] = (1.0, 6.0),
341
+ ) -> tuple[np.ndarray, float]:
342
+ """
343
+ Fit a cosine stripe model to a phase map, mirroring GridCellTorus fit_sine_wave.
344
+ """
345
+ mtot = np.asarray(phase_map)
346
+ if mtot.ndim != 2:
347
+ raise ValueError(f"phase_map must be 2D, got {mtot.shape}")
348
+
349
+ if grid_size is None:
350
+ grid_size = mtot.shape[0] + 2 * trim + 1
351
+
352
+ expected = grid_size - 1 - 2 * trim
353
+ if expected != mtot.shape[0]:
354
+ raise ValueError(
355
+ f"grid_size/trim incompatible with phase_map shape: "
356
+ f"expected {expected} but got {mtot.shape[0]}"
357
+ )
358
+
359
+ numangsint = grid_size
360
+ x, _ = np.meshgrid(
361
+ np.linspace(0, 3 * np.pi, numangsint - 1),
362
+ np.linspace(0, 3 * np.pi, numangsint - 1),
363
+ )
364
+ nnans = ~np.isnan(mtot)
365
+
366
+ def cos_wave(p: np.ndarray) -> float:
367
+ x1 = rotate(x, p[0] * 360.0 / (2 * np.pi), reshape=False)
368
+ model = np.cos(p[2] * x1[trim:-trim, trim:-trim] + p[1])
369
+ return float(np.mean(np.square(model[nnans] - np.cos(mtot[nnans]))))
370
+
371
+ angle_space = np.linspace(0, np.pi, angle_grid)
372
+ phase_space = np.linspace(0, 2 * np.pi, phase_grid)
373
+ spacing_space = np.linspace(spacing_range[0], spacing_range[1], spacing_grid)
374
+
375
+ grid = np.zeros((angle_grid, phase_grid, spacing_grid))
376
+ for i, ang in enumerate(angle_space):
377
+ for j, ph in enumerate(phase_space):
378
+ for k, sp in enumerate(spacing_space):
379
+ grid[i, j, k] = cos_wave(np.array([ang, ph, sp]))
380
+
381
+ p_ind = np.unravel_index(np.argmin(grid), grid.shape)
382
+ p0 = np.array([angle_space[p_ind[0]], phase_space[p_ind[1]], spacing_space[p_ind[2]]])
383
+ res = minimize(cos_wave, p0, method="SLSQP", options={"disp": False})
384
+ return res["x"], float(res["fun"])
385
+
386
+
387
+ def plot_cohomap(
388
+ cohomap_result: dict[str, Any],
389
+ *,
390
+ config: PlotConfig | None = None,
391
+ save_path: str | None = None,
392
+ show: bool = False,
393
+ figsize: tuple[int, int] = (10, 4),
394
+ cmap: str = "viridis",
395
+ mode: str = "cos",
396
+ ) -> plt.Figure:
397
+ """
398
+ Plot EcohoMap phase maps (two panels: phase_map1/phase_map2).
399
+
400
+ mode:
401
+ "phase" to show raw phase (radians),
402
+ "cos" or "sin" to show cosine/sine of phase like GridCellTorus.
403
+ """
404
+ config = _ensure_plot_config(
405
+ config,
406
+ PlotConfig.for_static_plot,
407
+ title="EcohoMap",
408
+ xlabel="",
409
+ ylabel="",
410
+ figsize=figsize,
411
+ save_path=save_path,
412
+ show=show,
413
+ )
414
+
415
+ m1 = cohomap_result["phase_map1"]
416
+ m2 = cohomap_result["phase_map2"]
417
+ x_edge = cohomap_result["x_edge"]
418
+ y_edge = cohomap_result["y_edge"]
419
+
420
+ fig, ax = plt.subplots(1, 2, figsize=config.figsize)
421
+ for i, (mtot, title) in enumerate(((m1, "Phase Map 1"), (m2, "Phase Map 2"))):
422
+ if mode == "phase":
423
+ plot_map = mtot
424
+ cbar_label = "Phase (rad)"
425
+ vmin, vmax = -np.pi, np.pi
426
+ elif mode == "cos":
427
+ plot_map = np.cos(mtot)
428
+ cbar_label = "cos(phase)"
429
+ vmin, vmax = -1.0, 1.0
430
+ elif mode == "sin":
431
+ plot_map = np.sin(mtot)
432
+ cbar_label = "sin(phase)"
433
+ vmin, vmax = -1.0, 1.0
434
+ else:
435
+ raise ValueError(f"Unknown mode '{mode}'. Use 'phase', 'cos', or 'sin'.")
436
+ im = ax[i].imshow(
437
+ plot_map,
438
+ origin="lower",
439
+ extent=[x_edge[0], x_edge[-1], y_edge[0], y_edge[-1]],
440
+ cmap=cmap,
441
+ vmin=vmin,
442
+ vmax=vmax,
443
+ )
444
+ ax[i].set_title(title, fontsize=10)
445
+ ax[i].set_aspect("equal", "box")
446
+ ax[i].set_xticks([])
447
+ ax[i].set_yticks([])
448
+ plt.colorbar(im, ax=ax[i], fraction=0.046, pad=0.04, label=cbar_label)
449
+
450
+ fig.tight_layout()
451
+ _ensure_parent_dir(config.save_path)
452
+ finalize_figure(fig, config)
453
+ return fig