canns 0.14.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. canns/analyzer/data/asa/__init__.py +77 -21
  2. canns/analyzer/data/asa/coho.py +97 -0
  3. canns/analyzer/data/asa/cohomap.py +408 -0
  4. canns/analyzer/data/asa/cohomap_scatter.py +10 -0
  5. canns/analyzer/data/asa/cohomap_vectors.py +311 -0
  6. canns/analyzer/data/asa/cohospace.py +173 -1153
  7. canns/analyzer/data/asa/cohospace_phase_centers.py +137 -0
  8. canns/analyzer/data/asa/cohospace_scatter.py +1220 -0
  9. canns/analyzer/data/asa/embedding.py +3 -4
  10. canns/analyzer/data/asa/plotting.py +4 -4
  11. canns/analyzer/data/cell_classification/__init__.py +10 -0
  12. canns/analyzer/data/cell_classification/core/__init__.py +4 -0
  13. canns/analyzer/data/cell_classification/core/btn.py +272 -0
  14. canns/analyzer/data/cell_classification/visualization/__init__.py +3 -0
  15. canns/analyzer/data/cell_classification/visualization/btn_plots.py +258 -0
  16. canns/analyzer/visualization/__init__.py +2 -0
  17. canns/analyzer/visualization/core/config.py +20 -0
  18. canns/analyzer/visualization/theta_sweep_plots.py +142 -0
  19. canns/pipeline/asa/runner.py +19 -19
  20. canns/pipeline/asa_gui/__init__.py +5 -3
  21. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +32 -4
  22. canns/pipeline/asa_gui/core/runner.py +23 -23
  23. canns/pipeline/asa_gui/views/pages/preprocess_page.py +250 -8
  24. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/METADATA +2 -1
  25. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/RECORD +28 -20
  26. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/WHEEL +0 -0
  27. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/entry_points.txt +0 -0
  28. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,15 +1,49 @@
1
1
  from __future__ import annotations
2
2
 
3
- # Coho-space analysis + visualization
3
+ from .cohomap import (
4
+ cohomap,
5
+ cohomap_upgrade,
6
+ ecohomap,
7
+ fit_cohomap_stripes,
8
+ fit_cohomap_stripes_upgrade,
9
+ plot_cohomap,
10
+ plot_cohomap_upgrade,
11
+ plot_ecohomap,
12
+ )
13
+ from .cohomap_scatter import plot_cohomap_scatter, plot_cohomap_scatter_multi
14
+ from .cohomap_vectors import (
15
+ cohomap_vectors,
16
+ plot_cohomap_stripes,
17
+ plot_cohomap_vectors,
18
+ )
4
19
  from .cohospace import (
5
- compute_cohoscore_1d,
6
- compute_cohoscore_2d,
7
- plot_cohospace_neuron_1d,
8
- plot_cohospace_neuron_2d,
9
- plot_cohospace_population_1d,
10
- plot_cohospace_population_2d,
11
- plot_cohospace_trajectory_1d,
12
- plot_cohospace_trajectory_2d,
20
+ cohospace,
21
+ cohospace_upgrade,
22
+ ecohospace,
23
+ plot_cohospace,
24
+ plot_cohospace_skewed,
25
+ plot_cohospace_upgrade,
26
+ plot_cohospace_upgrade_skewed,
27
+ plot_ecohospace,
28
+ plot_ecohospace_skewed,
29
+ )
30
+ from .cohospace_phase_centers import (
31
+ cohospace_phase_centers,
32
+ plot_cohospace_phase_centers,
33
+ )
34
+
35
+ # Coho-space (scatter) analysis + visualization
36
+ from .cohospace_scatter import (
37
+ compute_cohoscore_scatter_1d,
38
+ compute_cohoscore_scatter_2d,
39
+ plot_cohospace_scatter_neuron_1d,
40
+ plot_cohospace_scatter_neuron_2d,
41
+ plot_cohospace_scatter_neuron_skewed,
42
+ plot_cohospace_scatter_population_1d,
43
+ plot_cohospace_scatter_population_2d,
44
+ plot_cohospace_scatter_population_skewed,
45
+ plot_cohospace_scatter_trajectory_1d,
46
+ plot_cohospace_scatter_trajectory_2d,
13
47
  )
14
48
  from .config import (
15
49
  CANN2DError,
@@ -47,8 +81,6 @@ from .path import (
47
81
  from .plotting import (
48
82
  plot_2d_bump_on_manifold,
49
83
  plot_3d_bump_on_torus,
50
- plot_cohomap,
51
- plot_cohomap_multi,
52
84
  plot_path_compare_1d,
53
85
  plot_path_compare_2d,
54
86
  plot_projection,
@@ -72,10 +104,32 @@ __all__ = [
72
104
  "plot_projection",
73
105
  "plot_path_compare_1d",
74
106
  "plot_path_compare_2d",
75
- "plot_cohomap",
76
- "plot_cohomap_multi",
107
+ "plot_cohomap_scatter",
108
+ "plot_cohomap_scatter_multi",
77
109
  "plot_3d_bump_on_torus",
78
110
  "plot_2d_bump_on_manifold",
111
+ "cohomap",
112
+ "cohomap_upgrade",
113
+ "fit_cohomap_stripes",
114
+ "fit_cohomap_stripes_upgrade",
115
+ "plot_cohomap",
116
+ "plot_cohomap_upgrade",
117
+ "cohospace",
118
+ "cohospace_upgrade",
119
+ "plot_cohospace",
120
+ "plot_cohospace_skewed",
121
+ "plot_cohospace_upgrade",
122
+ "plot_cohospace_upgrade_skewed",
123
+ "ecohomap",
124
+ "ecohospace",
125
+ "plot_ecohomap",
126
+ "plot_ecohospace",
127
+ "plot_ecohospace_skewed",
128
+ "cohomap_vectors",
129
+ "plot_cohomap_stripes",
130
+ "plot_cohomap_vectors",
131
+ "cohospace_phase_centers",
132
+ "plot_cohospace_phase_centers",
79
133
  "BumpFitsConfig",
80
134
  "CANN1DPlotConfig",
81
135
  "create_1d_bump_animation",
@@ -85,14 +139,16 @@ __all__ = [
85
139
  "FRMResult",
86
140
  "compute_frm",
87
141
  "plot_frm",
88
- "plot_cohospace_trajectory_1d",
89
- "plot_cohospace_trajectory_2d",
90
- "plot_cohospace_neuron_1d",
91
- "plot_cohospace_neuron_2d",
92
- "plot_cohospace_population_1d",
93
- "plot_cohospace_population_2d",
94
- "compute_cohoscore_1d",
95
- "compute_cohoscore_2d",
142
+ "plot_cohospace_scatter_trajectory_1d",
143
+ "plot_cohospace_scatter_trajectory_2d",
144
+ "plot_cohospace_scatter_neuron_1d",
145
+ "plot_cohospace_scatter_neuron_2d",
146
+ "plot_cohospace_scatter_population_1d",
147
+ "plot_cohospace_scatter_population_2d",
148
+ "plot_cohospace_scatter_neuron_skewed",
149
+ "plot_cohospace_scatter_population_skewed",
150
+ "compute_cohoscore_scatter_1d",
151
+ "compute_cohoscore_scatter_2d",
96
152
  "align_coords_to_position_1d",
97
153
  "align_coords_to_position_2d",
98
154
  "apply_angle_scale",
@@ -0,0 +1,97 @@
1
+ """Shared helpers for CohoMap/CohoSpace analysis."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ from scipy.ndimage import gaussian_filter
10
+
11
+ from ...visualization.core import PlotConfig
12
+ from .path import find_coords_matrix, find_times_box
13
+
14
+
15
+ def _ensure_plot_config(
16
+ config: PlotConfig | None,
17
+ factory,
18
+ *args,
19
+ **defaults,
20
+ ) -> PlotConfig:
21
+ if config is None:
22
+ return factory(*args, **defaults)
23
+ return config
24
+
25
+
26
+ def _ensure_parent_dir(save_path: str | None) -> None:
27
+ if save_path:
28
+ parent = os.path.dirname(save_path)
29
+ if parent:
30
+ os.makedirs(parent, exist_ok=True)
31
+
32
+
33
+ def _circmean(x: np.ndarray) -> float:
34
+ return float(np.arctan2(np.mean(np.sin(x)), np.mean(np.cos(x))))
35
+
36
+
37
+ def _smooth_circular_map(
38
+ mtot: np.ndarray,
39
+ smooth_sigma: float,
40
+ *,
41
+ fill_nan: bool = False,
42
+ fill_sigma: float | None = None,
43
+ fill_min_weight: float = 1e-3,
44
+ ) -> np.ndarray:
45
+ mtot = np.asarray(mtot, dtype=float)
46
+ nans = np.isnan(mtot)
47
+ mask = (~nans).astype(float)
48
+ sintot = np.sin(mtot)
49
+ costot = np.cos(mtot)
50
+ sintot[nans] = 0.0
51
+ costot[nans] = 0.0
52
+
53
+ if fill_nan:
54
+ if fill_sigma is None:
55
+ fill_sigma = smooth_sigma if smooth_sigma and smooth_sigma > 0 else 1.0
56
+ weight = gaussian_filter(mask, fill_sigma)
57
+ sintot = gaussian_filter(sintot * mask, fill_sigma)
58
+ costot = gaussian_filter(costot * mask, fill_sigma)
59
+ min_weight = max(float(fill_min_weight), 0.0)
60
+ valid = weight > min_weight
61
+ sintot = np.divide(sintot, weight, out=np.zeros_like(sintot), where=valid)
62
+ costot = np.divide(costot, weight, out=np.zeros_like(costot), where=valid)
63
+ mtot = np.arctan2(sintot, costot)
64
+ if fill_min_weight > 0:
65
+ mtot[~valid] = np.nan
66
+ return mtot
67
+
68
+ if smooth_sigma and smooth_sigma > 0:
69
+ sintot = gaussian_filter(sintot, smooth_sigma)
70
+ costot = gaussian_filter(costot, smooth_sigma)
71
+ mtot = np.arctan2(sintot, costot)
72
+ mtot[nans] = np.nan
73
+ return mtot
74
+
75
+
76
+ def _extract_coords_and_times(
77
+ decoding_result: dict[str, Any],
78
+ coords_key: str | None = None,
79
+ ) -> tuple[np.ndarray, np.ndarray | None]:
80
+ if coords_key is not None:
81
+ if coords_key not in decoding_result:
82
+ raise KeyError(f"coords_key '{coords_key}' not found in decoding_result.")
83
+ coords = np.asarray(decoding_result[coords_key])
84
+ elif "coordsbox" in decoding_result:
85
+ coords = np.asarray(decoding_result["coordsbox"])
86
+ else:
87
+ coords, _ = find_coords_matrix(decoding_result)
88
+
89
+ times_box, _ = find_times_box(decoding_result)
90
+ return coords, times_box
91
+
92
+
93
+ def _phase_map_valid_fraction(phase_map: np.ndarray) -> float:
94
+ valid = np.isfinite(phase_map)
95
+ if valid.size == 0:
96
+ return 0.0
97
+ return float(np.mean(valid))
@@ -0,0 +1,408 @@
1
+ """CohoMap (Ecoho-style) computation and plotting."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from scipy.ndimage import rotate
10
+ from scipy.optimize import minimize
11
+ from scipy.stats import binned_statistic_2d
12
+
13
+ from ...visualization.core import PlotConfig, finalize_figure
14
+ from .coho import (
15
+ _circmean,
16
+ _ensure_parent_dir,
17
+ _ensure_plot_config,
18
+ _extract_coords_and_times,
19
+ _phase_map_valid_fraction,
20
+ _smooth_circular_map,
21
+ )
22
+ from .path import parse_times_box_to_indices
23
+
24
+
25
+ def _select_phase_sign(
26
+ phase_map: np.ndarray,
27
+ params: np.ndarray,
28
+ *,
29
+ grid_size: int,
30
+ trim: int,
31
+ ) -> int:
32
+ mtot = np.asarray(phase_map)
33
+ expected = grid_size - 1 - 2 * trim
34
+ if mtot.ndim != 2 or mtot.shape[0] != expected or mtot.shape[1] != expected:
35
+ raise ValueError("phase_map shape does not match grid_size/trim for alignment")
36
+ nnans = ~np.isnan(mtot)
37
+ if not np.any(nnans):
38
+ return 1
39
+ x, _ = np.meshgrid(
40
+ np.linspace(0, 3 * np.pi, grid_size - 1),
41
+ np.linspace(0, 3 * np.pi, grid_size - 1),
42
+ )
43
+ x1 = rotate(x, params[0] * 360.0 / (2 * np.pi), reshape=False)
44
+ base = params[2] * x1[trim:-trim, trim:-trim] + params[1]
45
+ mtot_vals = (mtot[nnans]) % (2 * np.pi)
46
+ pm1 = (base % (2 * np.pi))[nnans] - mtot_vals
47
+ pm2 = ((2 * np.pi - base) % (2 * np.pi))[nnans] - mtot_vals
48
+ if np.sum(np.abs(pm1)) > np.sum(np.abs(pm2)):
49
+ return -1
50
+ return 1
51
+
52
+
53
+ def _rot_coord(
54
+ params1: np.ndarray,
55
+ params2: np.ndarray,
56
+ c1: np.ndarray,
57
+ c2: np.ndarray,
58
+ p: tuple[int, int],
59
+ ) -> np.ndarray:
60
+ if abs(np.cos(params1[0])) < abs(np.cos(params2[0])):
61
+ cc1 = c2.copy()
62
+ cc2 = c1.copy()
63
+ y = params1.copy()
64
+ x = params2.copy()
65
+ p = (p[1], p[0])
66
+ else:
67
+ cc1 = c1.copy()
68
+ cc2 = c2.copy()
69
+ x = params1.copy()
70
+ y = params2.copy()
71
+
72
+ if p[1] == -1:
73
+ cc2 = 2 * np.pi - cc2
74
+ if p[0] == -1:
75
+ cc1 = 2 * np.pi - cc1
76
+
77
+ alpha = y[0] - x[0]
78
+ if (alpha < 0) and (abs(alpha) > np.pi / 2):
79
+ cctmp = cc2.copy()
80
+ cc2 = cc1.copy()
81
+ cc1 = cctmp
82
+
83
+ if (alpha < 0) and (abs(alpha) < np.pi / 2):
84
+ cc1 = 2 * np.pi - cc1 + (np.pi / 3) * cc2
85
+ elif abs(alpha) > np.pi / 2:
86
+ cc2 = cc2 + (np.pi / 3) * cc1
87
+
88
+ return np.stack([cc1, cc2], axis=1) % (2 * np.pi)
89
+
90
+
91
+ def _toroidal_align_coords(
92
+ coords: np.ndarray,
93
+ phase_map1: np.ndarray,
94
+ phase_map2: np.ndarray,
95
+ *,
96
+ trim: int,
97
+ grid_size: int | None,
98
+ ) -> tuple[np.ndarray, float, float]:
99
+ if phase_map1.shape != phase_map2.shape:
100
+ raise ValueError("phase_map shapes do not match for alignment")
101
+ coords = np.asarray(coords)
102
+ if coords.ndim != 2 or coords.shape[1] < 2:
103
+ raise ValueError(f"coords must be (N,2+) array, got {coords.shape}")
104
+ if grid_size is None:
105
+ grid_size = int(phase_map1.shape[0]) + 2 * trim + 1
106
+ p1, f1 = fit_cohomap_stripes(phase_map1, grid_size=grid_size, trim=trim)
107
+ p2, f2 = fit_cohomap_stripes(phase_map2, grid_size=grid_size, trim=trim)
108
+ s1 = _select_phase_sign(phase_map1, p1, grid_size=grid_size, trim=trim)
109
+ s2 = _select_phase_sign(phase_map2, p2, grid_size=grid_size, trim=trim)
110
+ aligned = _rot_coord(p1, p2, coords[:, 0], coords[:, 1], (s1, s2))
111
+ return aligned, float(f1), float(f2)
112
+
113
+
114
+ def cohomap(
115
+ decoding_result: dict[str, Any],
116
+ position_data: dict[str, Any],
117
+ *,
118
+ coords_key: str | None = None,
119
+ bins: int = 101,
120
+ margin_frac: float = 0.0025,
121
+ smooth_sigma: float = 1.0,
122
+ fill_nan: bool = True,
123
+ fill_sigma: float | None = None,
124
+ fill_min_weight: float = 1e-3,
125
+ align_torus: bool = True,
126
+ align_trim: int = 25,
127
+ align_grid_size: int | None = None,
128
+ align_min_valid_frac: float | None = None,
129
+ align_max_fit_error: float | None = None,
130
+ ) -> dict[str, Any]:
131
+ """
132
+ Compute EcohoMap phase maps using circular-mean binning.
133
+
134
+ This mirrors GridCellTorus get_ang_hist: bin spatial positions and compute the
135
+ circular mean of each decoded angle within spatial bins, then smooth in sin/cos
136
+ space. Optional toroidal alignment follows the GridCellTorus stripe fit + rotation.
137
+ You can gate alignment by valid fraction or fit error thresholds.
138
+ """
139
+ coords, times_box = _extract_coords_and_times(decoding_result, coords_key)
140
+ if coords.ndim != 2 or coords.shape[1] < 2:
141
+ raise ValueError(f"coords must be (N,2+) array, got {coords.shape}")
142
+
143
+ xx = np.asarray(position_data["x"])
144
+ yy = np.asarray(position_data["y"])
145
+
146
+ if times_box is not None:
147
+ if "t" in position_data:
148
+ idx, _ = parse_times_box_to_indices(times_box, np.asarray(position_data["t"]))
149
+ xx = xx[idx]
150
+ yy = yy[idx]
151
+ else:
152
+ idx = np.asarray(times_box).astype(int)
153
+ xx = xx[idx]
154
+ yy = yy[idx]
155
+
156
+ if len(xx) != coords.shape[0]:
157
+ raise ValueError(
158
+ "Length mismatch: coords length does not match position length after times_box."
159
+ )
160
+
161
+ x_min, x_max = float(np.min(xx)), float(np.max(xx))
162
+ y_min, y_max = float(np.min(yy)), float(np.max(yy))
163
+ x_pad = (x_max - x_min) * margin_frac
164
+ y_pad = (y_max - y_min) * margin_frac
165
+
166
+ binsx = np.linspace(x_min + x_pad, x_max - x_pad, bins)
167
+ binsy = np.linspace(y_min + y_pad, y_max - y_pad, bins)
168
+
169
+ def _angle_hist(values: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
170
+ nnans = ~np.isnan(values)
171
+ mtot, x_edge, y_edge, _ = binned_statistic_2d(
172
+ xx[nnans],
173
+ yy[nnans],
174
+ values[nnans],
175
+ statistic=_circmean,
176
+ bins=(binsx, binsy),
177
+ range=None,
178
+ expand_binnumbers=True,
179
+ )
180
+ mtot = _smooth_circular_map(
181
+ mtot,
182
+ smooth_sigma,
183
+ fill_nan=fill_nan,
184
+ fill_sigma=fill_sigma,
185
+ fill_min_weight=fill_min_weight,
186
+ )
187
+ return mtot, x_edge, y_edge
188
+
189
+ coords_use = np.asarray(coords, float)
190
+ m1_raw, x_edge, y_edge = _angle_hist(coords_use[:, 0])
191
+ m2_raw, _, _ = _angle_hist(coords_use[:, 1])
192
+ aligned = False
193
+ align_error = None
194
+ align_valid_frac1 = None
195
+ align_valid_frac2 = None
196
+ align_fit_error1 = None
197
+ align_fit_error2 = None
198
+
199
+ if align_torus:
200
+ try:
201
+ align_valid_frac1 = _phase_map_valid_fraction(m1_raw)
202
+ align_valid_frac2 = _phase_map_valid_fraction(m2_raw)
203
+ min_valid = min(align_valid_frac1, align_valid_frac2)
204
+
205
+ if align_min_valid_frac is not None and min_valid < align_min_valid_frac:
206
+ align_error = (
207
+ f"valid fraction too low ({min_valid:.3f} < {align_min_valid_frac:.3f})"
208
+ )
209
+ else:
210
+ coords_aligned, f1, f2 = _toroidal_align_coords(
211
+ coords_use[:, :2],
212
+ m1_raw,
213
+ m2_raw,
214
+ trim=align_trim,
215
+ grid_size=align_grid_size,
216
+ )
217
+ align_fit_error1 = f1
218
+ align_fit_error2 = f2
219
+ if align_max_fit_error is not None and (
220
+ f1 > align_max_fit_error or f2 > align_max_fit_error
221
+ ):
222
+ align_error = (
223
+ f"fit error too high ({f1:.4f}, {f2:.4f} > {align_max_fit_error:.4f})"
224
+ )
225
+ else:
226
+ coords_use = coords_use.copy()
227
+ coords_use[:, :2] = coords_aligned
228
+ aligned = True
229
+ except Exception as exc:
230
+ align_error = str(exc)
231
+
232
+ if aligned:
233
+ m1, x_edge, y_edge = _angle_hist(coords_use[:, 0])
234
+ m2, _, _ = _angle_hist(coords_use[:, 1])
235
+ else:
236
+ m1, m2 = m1_raw, m2_raw
237
+
238
+ return {
239
+ "phase_map1": m1,
240
+ "phase_map2": m2,
241
+ "phase_map1_raw": m1_raw,
242
+ "phase_map2_raw": m2_raw,
243
+ "x_edge": x_edge,
244
+ "y_edge": y_edge,
245
+ "bins": bins,
246
+ "margin_frac": margin_frac,
247
+ "smooth_sigma": smooth_sigma,
248
+ "fill_nan": fill_nan,
249
+ "fill_sigma": fill_sigma,
250
+ "fill_min_weight": fill_min_weight,
251
+ "aligned": aligned,
252
+ "align_error": align_error,
253
+ "align_min_valid_frac": align_min_valid_frac,
254
+ "align_max_fit_error": align_max_fit_error,
255
+ "align_valid_frac1": align_valid_frac1,
256
+ "align_valid_frac2": align_valid_frac2,
257
+ "align_fit_error1": align_fit_error1,
258
+ "align_fit_error2": align_fit_error2,
259
+ }
260
+
261
+
262
+ def fit_cohomap_stripes(
263
+ phase_map: np.ndarray,
264
+ *,
265
+ grid_size: int | None = 151,
266
+ trim: int = 25,
267
+ angle_grid: int = 10,
268
+ phase_grid: int = 10,
269
+ spacing_grid: int = 10,
270
+ spacing_range: tuple[float, float] = (1.0, 6.0),
271
+ ) -> tuple[np.ndarray, float]:
272
+ """
273
+ Fit a cosine stripe model to a phase map, mirroring GridCellTorus fit_sine_wave.
274
+ """
275
+ mtot = np.asarray(phase_map)
276
+ if mtot.ndim != 2:
277
+ raise ValueError(f"phase_map must be 2D, got {mtot.shape}")
278
+
279
+ if grid_size is None:
280
+ grid_size = mtot.shape[0] + 2 * trim + 1
281
+
282
+ expected = grid_size - 1 - 2 * trim
283
+ if expected != mtot.shape[0]:
284
+ raise ValueError(
285
+ f"grid_size/trim incompatible with phase_map shape: "
286
+ f"expected {expected} but got {mtot.shape[0]}"
287
+ )
288
+
289
+ numangsint = grid_size
290
+ x, _ = np.meshgrid(
291
+ np.linspace(0, 3 * np.pi, numangsint - 1),
292
+ np.linspace(0, 3 * np.pi, numangsint - 1),
293
+ )
294
+ nnans = ~np.isnan(mtot)
295
+
296
+ def cos_wave(p: np.ndarray) -> float:
297
+ x1 = rotate(x, p[0] * 360.0 / (2 * np.pi), reshape=False)
298
+ model = np.cos(p[2] * x1[trim:-trim, trim:-trim] + p[1])
299
+ return float(np.mean(np.square(model[nnans] - np.cos(mtot[nnans]))))
300
+
301
+ angle_space = np.linspace(0, np.pi, angle_grid)
302
+ phase_space = np.linspace(0, 2 * np.pi, phase_grid)
303
+ spacing_space = np.linspace(spacing_range[0], spacing_range[1], spacing_grid)
304
+
305
+ grid = np.zeros((angle_grid, phase_grid, spacing_grid))
306
+ for i, ang in enumerate(angle_space):
307
+ for j, ph in enumerate(phase_space):
308
+ for k, sp in enumerate(spacing_space):
309
+ grid[i, j, k] = cos_wave(np.array([ang, ph, sp]))
310
+
311
+ p_ind = np.unravel_index(np.argmin(grid), grid.shape)
312
+ p0 = np.array([angle_space[p_ind[0]], phase_space[p_ind[1]], spacing_space[p_ind[2]]])
313
+ res = minimize(cos_wave, p0, method="SLSQP", options={"disp": False})
314
+ return res["x"], float(res["fun"])
315
+
316
+
317
+ def plot_cohomap(
318
+ cohomap_result: dict[str, Any],
319
+ *,
320
+ config: PlotConfig | None = None,
321
+ save_path: str | None = None,
322
+ show: bool = False,
323
+ figsize: tuple[int, int] = (10, 4),
324
+ cmap: str = "viridis",
325
+ mode: str = "cos",
326
+ ) -> plt.Figure:
327
+ """
328
+ Plot EcohoMap phase maps (two panels: phase_map1/phase_map2).
329
+
330
+ mode:
331
+ "phase" to show raw phase (radians),
332
+ "cos" or "sin" to show cosine/sine of phase like GridCellTorus.
333
+ """
334
+ config = _ensure_plot_config(
335
+ config,
336
+ PlotConfig.for_static_plot,
337
+ title="EcohoMap",
338
+ xlabel="",
339
+ ylabel="",
340
+ figsize=figsize,
341
+ save_path=save_path,
342
+ show=show,
343
+ )
344
+
345
+ m1 = cohomap_result["phase_map1"]
346
+ m2 = cohomap_result["phase_map2"]
347
+ x_edge = cohomap_result["x_edge"]
348
+ y_edge = cohomap_result["y_edge"]
349
+
350
+ fig, ax = plt.subplots(1, 2, figsize=config.figsize)
351
+ for i, (mtot, title) in enumerate(((m1, "Phase Map 1"), (m2, "Phase Map 2"))):
352
+ if mode == "phase":
353
+ plot_map = mtot
354
+ cbar_label = "Phase (rad)"
355
+ vmin, vmax = -np.pi, np.pi
356
+ elif mode == "cos":
357
+ plot_map = np.cos(mtot)
358
+ cbar_label = "cos(phase)"
359
+ vmin, vmax = -1.0, 1.0
360
+ elif mode == "sin":
361
+ plot_map = np.sin(mtot)
362
+ cbar_label = "sin(phase)"
363
+ vmin, vmax = -1.0, 1.0
364
+ else:
365
+ raise ValueError(f"Unknown mode '{mode}'. Use 'phase', 'cos', or 'sin'.")
366
+ im = ax[i].imshow(
367
+ plot_map,
368
+ origin="lower",
369
+ extent=[x_edge[0], x_edge[-1], y_edge[0], y_edge[-1]],
370
+ cmap=cmap,
371
+ vmin=vmin,
372
+ vmax=vmax,
373
+ )
374
+ ax[i].set_title(title, fontsize=10)
375
+ ax[i].set_aspect("equal", "box")
376
+ ax[i].set_xticks([])
377
+ ax[i].set_yticks([])
378
+ plt.colorbar(im, ax=ax[i], fraction=0.046, pad=0.04, label=cbar_label)
379
+
380
+ fig.tight_layout()
381
+ _ensure_parent_dir(config.save_path)
382
+ finalize_figure(fig, config)
383
+ return fig
384
+
385
+
386
+ def cohomap_upgrade(*args, **kwargs) -> dict[str, Any]:
387
+ """Legacy alias for EcohoMap (formerly cohomap_upgrade)."""
388
+ return cohomap(*args, **kwargs)
389
+
390
+
391
+ def ecohomap(*args, **kwargs) -> dict[str, Any]:
392
+ """Alias for EcohoMap (GridCellTorus-style)."""
393
+ return cohomap(*args, **kwargs)
394
+
395
+
396
+ def fit_cohomap_stripes_upgrade(*args, **kwargs) -> tuple[np.ndarray, float]:
397
+ """Legacy alias for EcohoMap stripe fitting."""
398
+ return fit_cohomap_stripes(*args, **kwargs)
399
+
400
+
401
+ def plot_cohomap_upgrade(*args, **kwargs) -> plt.Figure:
402
+ """Legacy alias for EcohoMap plotting (formerly plot_cohomap_upgrade)."""
403
+ return plot_cohomap(*args, **kwargs)
404
+
405
+
406
+ def plot_ecohomap(*args, **kwargs) -> plt.Figure:
407
+ """Alias for EcohoMap plotting."""
408
+ return plot_cohomap(*args, **kwargs)
@@ -0,0 +1,10 @@
1
+ """Scatter-style CohoMap plotting helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from .plotting import plot_cohomap_scatter, plot_cohomap_scatter_multi
6
+
7
+ __all__ = [
8
+ "plot_cohomap_scatter",
9
+ "plot_cohomap_scatter_multi",
10
+ ]