canns 0.14.3__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/asa/__init__.py +56 -21
- canns/analyzer/data/asa/coho.py +21 -0
- canns/analyzer/data/asa/cohomap.py +453 -0
- canns/analyzer/data/asa/cohomap_vectors.py +365 -0
- canns/analyzer/data/asa/cohospace.py +155 -1165
- canns/analyzer/data/asa/cohospace_phase_centers.py +119 -0
- canns/analyzer/data/asa/cohospace_scatter.py +1115 -0
- canns/analyzer/data/asa/embedding.py +5 -7
- canns/analyzer/data/asa/fr.py +1 -8
- canns/analyzer/data/asa/path.py +70 -0
- canns/analyzer/data/asa/plotting.py +5 -30
- canns/analyzer/data/asa/utils.py +160 -0
- canns/analyzer/data/cell_classification/__init__.py +10 -0
- canns/analyzer/data/cell_classification/core/__init__.py +4 -0
- canns/analyzer/data/cell_classification/core/btn.py +272 -0
- canns/analyzer/data/cell_classification/visualization/__init__.py +3 -0
- canns/analyzer/data/cell_classification/visualization/btn_plots.py +241 -0
- canns/analyzer/visualization/__init__.py +2 -0
- canns/analyzer/visualization/core/config.py +20 -0
- canns/analyzer/visualization/theta_sweep_plots.py +142 -0
- canns/pipeline/asa/runner.py +19 -19
- canns/pipeline/asa_gui/__init__.py +5 -3
- canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +3 -1
- canns/pipeline/asa_gui/core/runner.py +23 -23
- canns/pipeline/asa_gui/views/pages/preprocess_page.py +7 -12
- {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/METADATA +1 -1
- {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/RECORD +30 -23
- canns/analyzer/data/asa/filters.py +0 -208
- {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/WHEEL +0 -0
- {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/entry_points.txt +0 -0
- {canns-0.14.3.dist-info → canns-0.15.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,15 +1,37 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
from .cohomap import (
|
|
4
|
+
cohomap,
|
|
5
|
+
fit_cohomap_stripes,
|
|
6
|
+
plot_cohomap,
|
|
7
|
+
)
|
|
8
|
+
from .cohomap_vectors import (
|
|
9
|
+
cohomap_vectors,
|
|
10
|
+
plot_cohomap_stripes,
|
|
11
|
+
plot_cohomap_vectors,
|
|
12
|
+
)
|
|
4
13
|
from .cohospace import (
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
14
|
+
cohospace,
|
|
15
|
+
plot_cohospace,
|
|
16
|
+
plot_cohospace_skewed,
|
|
17
|
+
)
|
|
18
|
+
from .cohospace_phase_centers import (
|
|
19
|
+
cohospace_phase_centers,
|
|
20
|
+
plot_cohospace_phase_centers,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# Coho-space (scatter) analysis + visualization
|
|
24
|
+
from .cohospace_scatter import (
|
|
25
|
+
compute_cohoscore_scatter_1d,
|
|
26
|
+
compute_cohoscore_scatter_2d,
|
|
27
|
+
plot_cohospace_scatter_neuron_1d,
|
|
28
|
+
plot_cohospace_scatter_neuron_2d,
|
|
29
|
+
plot_cohospace_scatter_neuron_skewed,
|
|
30
|
+
plot_cohospace_scatter_population_1d,
|
|
31
|
+
plot_cohospace_scatter_population_2d,
|
|
32
|
+
plot_cohospace_scatter_population_skewed,
|
|
33
|
+
plot_cohospace_scatter_trajectory_1d,
|
|
34
|
+
plot_cohospace_scatter_trajectory_2d,
|
|
13
35
|
)
|
|
14
36
|
from .config import (
|
|
15
37
|
CANN2DError,
|
|
@@ -47,8 +69,8 @@ from .path import (
|
|
|
47
69
|
from .plotting import (
|
|
48
70
|
plot_2d_bump_on_manifold,
|
|
49
71
|
plot_3d_bump_on_torus,
|
|
50
|
-
|
|
51
|
-
|
|
72
|
+
plot_cohomap_scatter,
|
|
73
|
+
plot_cohomap_scatter_multi,
|
|
52
74
|
plot_path_compare_1d,
|
|
53
75
|
plot_path_compare_2d,
|
|
54
76
|
plot_projection,
|
|
@@ -72,10 +94,21 @@ __all__ = [
|
|
|
72
94
|
"plot_projection",
|
|
73
95
|
"plot_path_compare_1d",
|
|
74
96
|
"plot_path_compare_2d",
|
|
75
|
-
"
|
|
76
|
-
"
|
|
97
|
+
"plot_cohomap_scatter",
|
|
98
|
+
"plot_cohomap_scatter_multi",
|
|
77
99
|
"plot_3d_bump_on_torus",
|
|
78
100
|
"plot_2d_bump_on_manifold",
|
|
101
|
+
"cohomap",
|
|
102
|
+
"fit_cohomap_stripes",
|
|
103
|
+
"plot_cohomap",
|
|
104
|
+
"cohospace",
|
|
105
|
+
"plot_cohospace",
|
|
106
|
+
"plot_cohospace_skewed",
|
|
107
|
+
"cohomap_vectors",
|
|
108
|
+
"plot_cohomap_stripes",
|
|
109
|
+
"plot_cohomap_vectors",
|
|
110
|
+
"cohospace_phase_centers",
|
|
111
|
+
"plot_cohospace_phase_centers",
|
|
79
112
|
"BumpFitsConfig",
|
|
80
113
|
"CANN1DPlotConfig",
|
|
81
114
|
"create_1d_bump_animation",
|
|
@@ -85,14 +118,16 @@ __all__ = [
|
|
|
85
118
|
"FRMResult",
|
|
86
119
|
"compute_frm",
|
|
87
120
|
"plot_frm",
|
|
88
|
-
"
|
|
89
|
-
"
|
|
90
|
-
"
|
|
91
|
-
"
|
|
92
|
-
"
|
|
93
|
-
"
|
|
94
|
-
"
|
|
95
|
-
"
|
|
121
|
+
"plot_cohospace_scatter_trajectory_1d",
|
|
122
|
+
"plot_cohospace_scatter_trajectory_2d",
|
|
123
|
+
"plot_cohospace_scatter_neuron_1d",
|
|
124
|
+
"plot_cohospace_scatter_neuron_2d",
|
|
125
|
+
"plot_cohospace_scatter_population_1d",
|
|
126
|
+
"plot_cohospace_scatter_population_2d",
|
|
127
|
+
"plot_cohospace_scatter_neuron_skewed",
|
|
128
|
+
"plot_cohospace_scatter_population_skewed",
|
|
129
|
+
"compute_cohoscore_scatter_1d",
|
|
130
|
+
"compute_cohoscore_scatter_2d",
|
|
96
131
|
"align_coords_to_position_1d",
|
|
97
132
|
"align_coords_to_position_2d",
|
|
98
133
|
"apply_angle_scale",
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Shared helpers for CohoMap/CohoSpace analysis."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from .utils import (
|
|
6
|
+
_circmean,
|
|
7
|
+
_ensure_parent_dir,
|
|
8
|
+
_ensure_plot_config,
|
|
9
|
+
_extract_coords_and_times,
|
|
10
|
+
_phase_map_valid_fraction,
|
|
11
|
+
_smooth_circular_map,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"_ensure_plot_config",
|
|
16
|
+
"_ensure_parent_dir",
|
|
17
|
+
"_circmean",
|
|
18
|
+
"_smooth_circular_map",
|
|
19
|
+
"_extract_coords_and_times",
|
|
20
|
+
"_phase_map_valid_fraction",
|
|
21
|
+
]
|
|
@@ -0,0 +1,453 @@
|
|
|
1
|
+
"""CohoMap (Ecoho-style) computation and plotting."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import numpy as np
|
|
9
|
+
from scipy.ndimage import rotate
|
|
10
|
+
from scipy.optimize import minimize
|
|
11
|
+
from scipy.stats import binned_statistic_2d
|
|
12
|
+
|
|
13
|
+
from ...visualization.core import PlotConfig, finalize_figure
|
|
14
|
+
from .coho import (
|
|
15
|
+
_circmean,
|
|
16
|
+
_ensure_parent_dir,
|
|
17
|
+
_ensure_plot_config,
|
|
18
|
+
_extract_coords_and_times,
|
|
19
|
+
_phase_map_valid_fraction,
|
|
20
|
+
_smooth_circular_map,
|
|
21
|
+
)
|
|
22
|
+
from .path import parse_times_box_to_indices
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _select_phase_sign(
|
|
26
|
+
phase_map: np.ndarray,
|
|
27
|
+
params: np.ndarray,
|
|
28
|
+
*,
|
|
29
|
+
grid_size: int,
|
|
30
|
+
trim: int,
|
|
31
|
+
) -> int:
|
|
32
|
+
mtot = np.asarray(phase_map)
|
|
33
|
+
expected = grid_size - 1 - 2 * trim
|
|
34
|
+
if mtot.ndim != 2 or mtot.shape[0] != expected or mtot.shape[1] != expected:
|
|
35
|
+
raise ValueError("phase_map shape does not match grid_size/trim for alignment")
|
|
36
|
+
nnans = ~np.isnan(mtot)
|
|
37
|
+
if not np.any(nnans):
|
|
38
|
+
return 1
|
|
39
|
+
x, _ = np.meshgrid(
|
|
40
|
+
np.linspace(0, 3 * np.pi, grid_size - 1),
|
|
41
|
+
np.linspace(0, 3 * np.pi, grid_size - 1),
|
|
42
|
+
)
|
|
43
|
+
x1 = rotate(x, params[0] * 360.0 / (2 * np.pi), reshape=False)
|
|
44
|
+
base = params[2] * x1[trim:-trim, trim:-trim] + params[1]
|
|
45
|
+
mtot_vals = (mtot[nnans]) % (2 * np.pi)
|
|
46
|
+
pm1 = (base % (2 * np.pi))[nnans] - mtot_vals
|
|
47
|
+
pm2 = ((2 * np.pi - base) % (2 * np.pi))[nnans] - mtot_vals
|
|
48
|
+
if np.sum(np.abs(pm1)) > np.sum(np.abs(pm2)):
|
|
49
|
+
return -1
|
|
50
|
+
return 1
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _rot_coord(
|
|
54
|
+
params1: np.ndarray,
|
|
55
|
+
params2: np.ndarray,
|
|
56
|
+
c1: np.ndarray,
|
|
57
|
+
c2: np.ndarray,
|
|
58
|
+
p: tuple[int, int],
|
|
59
|
+
) -> np.ndarray:
|
|
60
|
+
"""Transform and align decoded coordinates based on stripe orientation.
|
|
61
|
+
|
|
62
|
+
This function rotates and aligns coordinates to match the dominant stripe
|
|
63
|
+
orientation in the CohoMap. It handles the torus geometry by choosing the
|
|
64
|
+
appropriate coordinate system and applying reflections/rotations.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
params1 : np.ndarray
|
|
69
|
+
Stripe fit parameters for first dimension (includes orientation angle).
|
|
70
|
+
params2 : np.ndarray
|
|
71
|
+
Stripe fit parameters for second dimension (includes orientation angle).
|
|
72
|
+
c1 : np.ndarray
|
|
73
|
+
Decoded coordinates for first dimension.
|
|
74
|
+
c2 : np.ndarray
|
|
75
|
+
Decoded coordinates for second dimension.
|
|
76
|
+
p : tuple[int, int]
|
|
77
|
+
Direction indicators (1 or -1) for each dimension, controlling reflections.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
np.ndarray
|
|
82
|
+
Aligned coordinates of shape (N, 2), wrapped to [0, 2π).
|
|
83
|
+
|
|
84
|
+
Notes
|
|
85
|
+
-----
|
|
86
|
+
The function selects the coordinate system based on which dimension has
|
|
87
|
+
stronger horizontal alignment (larger |cos(angle)|), then applies geometric
|
|
88
|
+
transformations to align the coordinates with the stripe pattern.
|
|
89
|
+
"""
|
|
90
|
+
if abs(np.cos(params1[0])) < abs(np.cos(params2[0])):
|
|
91
|
+
cc1 = c2.copy()
|
|
92
|
+
cc2 = c1.copy()
|
|
93
|
+
y = params1.copy()
|
|
94
|
+
x = params2.copy()
|
|
95
|
+
p = (p[1], p[0])
|
|
96
|
+
else:
|
|
97
|
+
cc1 = c1.copy()
|
|
98
|
+
cc2 = c2.copy()
|
|
99
|
+
x = params1.copy()
|
|
100
|
+
y = params2.copy()
|
|
101
|
+
|
|
102
|
+
if p[1] == -1:
|
|
103
|
+
cc2 = 2 * np.pi - cc2
|
|
104
|
+
if p[0] == -1:
|
|
105
|
+
cc1 = 2 * np.pi - cc1
|
|
106
|
+
|
|
107
|
+
alpha = y[0] - x[0]
|
|
108
|
+
if (alpha < 0) and (abs(alpha) > np.pi / 2):
|
|
109
|
+
cctmp = cc2.copy()
|
|
110
|
+
cc2 = cc1.copy()
|
|
111
|
+
cc1 = cctmp
|
|
112
|
+
|
|
113
|
+
if (alpha < 0) and (abs(alpha) < np.pi / 2):
|
|
114
|
+
cc1 = 2 * np.pi - cc1 + (np.pi / 3) * cc2
|
|
115
|
+
elif abs(alpha) > np.pi / 2:
|
|
116
|
+
cc2 = cc2 + (np.pi / 3) * cc1
|
|
117
|
+
|
|
118
|
+
return np.stack([cc1, cc2], axis=1) % (2 * np.pi)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _toroidal_align_coords(
|
|
122
|
+
coords: np.ndarray,
|
|
123
|
+
phase_map1: np.ndarray,
|
|
124
|
+
phase_map2: np.ndarray,
|
|
125
|
+
*,
|
|
126
|
+
trim: int,
|
|
127
|
+
grid_size: int | None,
|
|
128
|
+
) -> tuple[np.ndarray, float, float]:
|
|
129
|
+
"""Align decoded coordinates to CohoMap stripe patterns.
|
|
130
|
+
|
|
131
|
+
This function fits stripe patterns to both phase maps, determines the
|
|
132
|
+
optimal orientation and phase signs, then transforms the decoded coordinates
|
|
133
|
+
to align with the detected stripe structure.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
coords : np.ndarray
|
|
138
|
+
Decoded coordinates of shape (N, 2+). Only first two columns are used.
|
|
139
|
+
phase_map1 : np.ndarray
|
|
140
|
+
Phase map for first dimension (2D grid).
|
|
141
|
+
phase_map2 : np.ndarray
|
|
142
|
+
Phase map for second dimension (2D grid), must match phase_map1 shape.
|
|
143
|
+
trim : int
|
|
144
|
+
Number of edge bins to trim when fitting stripes.
|
|
145
|
+
grid_size : int, optional
|
|
146
|
+
Grid size for stripe fitting. If None, inferred from phase_map shape.
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
aligned : np.ndarray
|
|
151
|
+
Aligned coordinates of shape (N, 2), wrapped to [0, 2π).
|
|
152
|
+
f1 : float
|
|
153
|
+
Fit quality score for first dimension (higher is better).
|
|
154
|
+
f2 : float
|
|
155
|
+
Fit quality score for second dimension (higher is better).
|
|
156
|
+
|
|
157
|
+
Raises
|
|
158
|
+
------
|
|
159
|
+
ValueError
|
|
160
|
+
If phase_map shapes don't match or coords has wrong shape.
|
|
161
|
+
|
|
162
|
+
Notes
|
|
163
|
+
-----
|
|
164
|
+
The alignment process:
|
|
165
|
+
1. Fits stripe patterns to both phase maps using FFT-based detection
|
|
166
|
+
2. Determines phase signs (±1) for each dimension
|
|
167
|
+
3. Applies geometric transformations via `_rot_coord()` to align coordinates
|
|
168
|
+
"""
|
|
169
|
+
if phase_map1.shape != phase_map2.shape:
|
|
170
|
+
raise ValueError("phase_map shapes do not match for alignment")
|
|
171
|
+
coords = np.asarray(coords)
|
|
172
|
+
if coords.ndim != 2 or coords.shape[1] < 2:
|
|
173
|
+
raise ValueError(f"coords must be (N,2+) array, got {coords.shape}")
|
|
174
|
+
if grid_size is None:
|
|
175
|
+
grid_size = int(phase_map1.shape[0]) + 2 * trim + 1
|
|
176
|
+
p1, f1 = fit_cohomap_stripes(phase_map1, grid_size=grid_size, trim=trim)
|
|
177
|
+
p2, f2 = fit_cohomap_stripes(phase_map2, grid_size=grid_size, trim=trim)
|
|
178
|
+
s1 = _select_phase_sign(phase_map1, p1, grid_size=grid_size, trim=trim)
|
|
179
|
+
s2 = _select_phase_sign(phase_map2, p2, grid_size=grid_size, trim=trim)
|
|
180
|
+
aligned = _rot_coord(p1, p2, coords[:, 0], coords[:, 1], (s1, s2))
|
|
181
|
+
return aligned, float(f1), float(f2)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def cohomap(
|
|
185
|
+
decoding_result: dict[str, Any],
|
|
186
|
+
position_data: dict[str, Any],
|
|
187
|
+
*,
|
|
188
|
+
coords_key: str | None = None,
|
|
189
|
+
bins: int = 101,
|
|
190
|
+
margin_frac: float = 0.0025,
|
|
191
|
+
smooth_sigma: float = 1.0,
|
|
192
|
+
fill_nan: bool = True,
|
|
193
|
+
fill_sigma: float | None = None,
|
|
194
|
+
fill_min_weight: float = 1e-3,
|
|
195
|
+
align_torus: bool = True,
|
|
196
|
+
align_trim: int = 25,
|
|
197
|
+
align_grid_size: int | None = None,
|
|
198
|
+
align_min_valid_frac: float | None = None,
|
|
199
|
+
align_max_fit_error: float | None = None,
|
|
200
|
+
) -> dict[str, Any]:
|
|
201
|
+
"""
|
|
202
|
+
Compute EcohoMap phase maps using circular-mean binning.
|
|
203
|
+
|
|
204
|
+
This mirrors GridCellTorus get_ang_hist: bin spatial positions and compute the
|
|
205
|
+
circular mean of each decoded angle within spatial bins, then smooth in sin/cos
|
|
206
|
+
space. Optional toroidal alignment follows the GridCellTorus stripe fit + rotation.
|
|
207
|
+
You can gate alignment by valid fraction or fit error thresholds.
|
|
208
|
+
"""
|
|
209
|
+
coords, times_box = _extract_coords_and_times(decoding_result, coords_key)
|
|
210
|
+
if coords.ndim != 2 or coords.shape[1] < 2:
|
|
211
|
+
raise ValueError(f"coords must be (N,2+) array, got {coords.shape}")
|
|
212
|
+
|
|
213
|
+
xx = np.asarray(position_data["x"])
|
|
214
|
+
yy = np.asarray(position_data["y"])
|
|
215
|
+
|
|
216
|
+
if times_box is not None:
|
|
217
|
+
if "t" in position_data:
|
|
218
|
+
idx, _ = parse_times_box_to_indices(times_box, np.asarray(position_data["t"]))
|
|
219
|
+
xx = xx[idx]
|
|
220
|
+
yy = yy[idx]
|
|
221
|
+
else:
|
|
222
|
+
idx = np.asarray(times_box).astype(int)
|
|
223
|
+
xx = xx[idx]
|
|
224
|
+
yy = yy[idx]
|
|
225
|
+
|
|
226
|
+
if len(xx) != coords.shape[0]:
|
|
227
|
+
raise ValueError(
|
|
228
|
+
"Length mismatch: coords length does not match position length after times_box."
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
x_min, x_max = float(np.min(xx)), float(np.max(xx))
|
|
232
|
+
y_min, y_max = float(np.min(yy)), float(np.max(yy))
|
|
233
|
+
x_pad = (x_max - x_min) * margin_frac
|
|
234
|
+
y_pad = (y_max - y_min) * margin_frac
|
|
235
|
+
|
|
236
|
+
binsx = np.linspace(x_min + x_pad, x_max - x_pad, bins)
|
|
237
|
+
binsy = np.linspace(y_min + y_pad, y_max - y_pad, bins)
|
|
238
|
+
|
|
239
|
+
def _angle_hist(values: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
240
|
+
nnans = ~np.isnan(values)
|
|
241
|
+
mtot, x_edge, y_edge, _ = binned_statistic_2d(
|
|
242
|
+
xx[nnans],
|
|
243
|
+
yy[nnans],
|
|
244
|
+
values[nnans],
|
|
245
|
+
statistic=_circmean,
|
|
246
|
+
bins=(binsx, binsy),
|
|
247
|
+
range=None,
|
|
248
|
+
expand_binnumbers=True,
|
|
249
|
+
)
|
|
250
|
+
mtot = _smooth_circular_map(
|
|
251
|
+
mtot,
|
|
252
|
+
smooth_sigma,
|
|
253
|
+
fill_nan=fill_nan,
|
|
254
|
+
fill_sigma=fill_sigma,
|
|
255
|
+
fill_min_weight=fill_min_weight,
|
|
256
|
+
)
|
|
257
|
+
return mtot, x_edge, y_edge
|
|
258
|
+
|
|
259
|
+
coords_use = np.asarray(coords, float)
|
|
260
|
+
m1_raw, x_edge, y_edge = _angle_hist(coords_use[:, 0])
|
|
261
|
+
m2_raw, _, _ = _angle_hist(coords_use[:, 1])
|
|
262
|
+
aligned = False
|
|
263
|
+
align_error = None
|
|
264
|
+
align_valid_frac1 = None
|
|
265
|
+
align_valid_frac2 = None
|
|
266
|
+
align_fit_error1 = None
|
|
267
|
+
align_fit_error2 = None
|
|
268
|
+
|
|
269
|
+
if align_torus:
|
|
270
|
+
try:
|
|
271
|
+
align_valid_frac1 = _phase_map_valid_fraction(m1_raw)
|
|
272
|
+
align_valid_frac2 = _phase_map_valid_fraction(m2_raw)
|
|
273
|
+
min_valid = min(align_valid_frac1, align_valid_frac2)
|
|
274
|
+
|
|
275
|
+
if align_min_valid_frac is not None and min_valid < align_min_valid_frac:
|
|
276
|
+
align_error = (
|
|
277
|
+
f"valid fraction too low ({min_valid:.3f} < {align_min_valid_frac:.3f})"
|
|
278
|
+
)
|
|
279
|
+
else:
|
|
280
|
+
coords_aligned, f1, f2 = _toroidal_align_coords(
|
|
281
|
+
coords_use[:, :2],
|
|
282
|
+
m1_raw,
|
|
283
|
+
m2_raw,
|
|
284
|
+
trim=align_trim,
|
|
285
|
+
grid_size=align_grid_size,
|
|
286
|
+
)
|
|
287
|
+
align_fit_error1 = f1
|
|
288
|
+
align_fit_error2 = f2
|
|
289
|
+
if align_max_fit_error is not None and (
|
|
290
|
+
f1 > align_max_fit_error or f2 > align_max_fit_error
|
|
291
|
+
):
|
|
292
|
+
align_error = (
|
|
293
|
+
f"fit error too high ({f1:.4f}, {f2:.4f} > {align_max_fit_error:.4f})"
|
|
294
|
+
)
|
|
295
|
+
else:
|
|
296
|
+
coords_use = coords_use.copy()
|
|
297
|
+
coords_use[:, :2] = coords_aligned
|
|
298
|
+
aligned = True
|
|
299
|
+
except Exception as exc:
|
|
300
|
+
align_error = str(exc)
|
|
301
|
+
|
|
302
|
+
if aligned:
|
|
303
|
+
m1, x_edge, y_edge = _angle_hist(coords_use[:, 0])
|
|
304
|
+
m2, _, _ = _angle_hist(coords_use[:, 1])
|
|
305
|
+
else:
|
|
306
|
+
m1, m2 = m1_raw, m2_raw
|
|
307
|
+
|
|
308
|
+
return {
|
|
309
|
+
"phase_map1": m1,
|
|
310
|
+
"phase_map2": m2,
|
|
311
|
+
"phase_map1_raw": m1_raw,
|
|
312
|
+
"phase_map2_raw": m2_raw,
|
|
313
|
+
"x_edge": x_edge,
|
|
314
|
+
"y_edge": y_edge,
|
|
315
|
+
"bins": bins,
|
|
316
|
+
"margin_frac": margin_frac,
|
|
317
|
+
"smooth_sigma": smooth_sigma,
|
|
318
|
+
"fill_nan": fill_nan,
|
|
319
|
+
"fill_sigma": fill_sigma,
|
|
320
|
+
"fill_min_weight": fill_min_weight,
|
|
321
|
+
"aligned": aligned,
|
|
322
|
+
"align_error": align_error,
|
|
323
|
+
"align_min_valid_frac": align_min_valid_frac,
|
|
324
|
+
"align_max_fit_error": align_max_fit_error,
|
|
325
|
+
"align_valid_frac1": align_valid_frac1,
|
|
326
|
+
"align_valid_frac2": align_valid_frac2,
|
|
327
|
+
"align_fit_error1": align_fit_error1,
|
|
328
|
+
"align_fit_error2": align_fit_error2,
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def fit_cohomap_stripes(
|
|
333
|
+
phase_map: np.ndarray,
|
|
334
|
+
*,
|
|
335
|
+
grid_size: int | None = 151,
|
|
336
|
+
trim: int = 25,
|
|
337
|
+
angle_grid: int = 10,
|
|
338
|
+
phase_grid: int = 10,
|
|
339
|
+
spacing_grid: int = 10,
|
|
340
|
+
spacing_range: tuple[float, float] = (1.0, 6.0),
|
|
341
|
+
) -> tuple[np.ndarray, float]:
|
|
342
|
+
"""
|
|
343
|
+
Fit a cosine stripe model to a phase map, mirroring GridCellTorus fit_sine_wave.
|
|
344
|
+
"""
|
|
345
|
+
mtot = np.asarray(phase_map)
|
|
346
|
+
if mtot.ndim != 2:
|
|
347
|
+
raise ValueError(f"phase_map must be 2D, got {mtot.shape}")
|
|
348
|
+
|
|
349
|
+
if grid_size is None:
|
|
350
|
+
grid_size = mtot.shape[0] + 2 * trim + 1
|
|
351
|
+
|
|
352
|
+
expected = grid_size - 1 - 2 * trim
|
|
353
|
+
if expected != mtot.shape[0]:
|
|
354
|
+
raise ValueError(
|
|
355
|
+
f"grid_size/trim incompatible with phase_map shape: "
|
|
356
|
+
f"expected {expected} but got {mtot.shape[0]}"
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
numangsint = grid_size
|
|
360
|
+
x, _ = np.meshgrid(
|
|
361
|
+
np.linspace(0, 3 * np.pi, numangsint - 1),
|
|
362
|
+
np.linspace(0, 3 * np.pi, numangsint - 1),
|
|
363
|
+
)
|
|
364
|
+
nnans = ~np.isnan(mtot)
|
|
365
|
+
|
|
366
|
+
def cos_wave(p: np.ndarray) -> float:
|
|
367
|
+
x1 = rotate(x, p[0] * 360.0 / (2 * np.pi), reshape=False)
|
|
368
|
+
model = np.cos(p[2] * x1[trim:-trim, trim:-trim] + p[1])
|
|
369
|
+
return float(np.mean(np.square(model[nnans] - np.cos(mtot[nnans]))))
|
|
370
|
+
|
|
371
|
+
angle_space = np.linspace(0, np.pi, angle_grid)
|
|
372
|
+
phase_space = np.linspace(0, 2 * np.pi, phase_grid)
|
|
373
|
+
spacing_space = np.linspace(spacing_range[0], spacing_range[1], spacing_grid)
|
|
374
|
+
|
|
375
|
+
grid = np.zeros((angle_grid, phase_grid, spacing_grid))
|
|
376
|
+
for i, ang in enumerate(angle_space):
|
|
377
|
+
for j, ph in enumerate(phase_space):
|
|
378
|
+
for k, sp in enumerate(spacing_space):
|
|
379
|
+
grid[i, j, k] = cos_wave(np.array([ang, ph, sp]))
|
|
380
|
+
|
|
381
|
+
p_ind = np.unravel_index(np.argmin(grid), grid.shape)
|
|
382
|
+
p0 = np.array([angle_space[p_ind[0]], phase_space[p_ind[1]], spacing_space[p_ind[2]]])
|
|
383
|
+
res = minimize(cos_wave, p0, method="SLSQP", options={"disp": False})
|
|
384
|
+
return res["x"], float(res["fun"])
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def plot_cohomap(
|
|
388
|
+
cohomap_result: dict[str, Any],
|
|
389
|
+
*,
|
|
390
|
+
config: PlotConfig | None = None,
|
|
391
|
+
save_path: str | None = None,
|
|
392
|
+
show: bool = False,
|
|
393
|
+
figsize: tuple[int, int] = (10, 4),
|
|
394
|
+
cmap: str = "viridis",
|
|
395
|
+
mode: str = "cos",
|
|
396
|
+
) -> plt.Figure:
|
|
397
|
+
"""
|
|
398
|
+
Plot EcohoMap phase maps (two panels: phase_map1/phase_map2).
|
|
399
|
+
|
|
400
|
+
mode:
|
|
401
|
+
"phase" to show raw phase (radians),
|
|
402
|
+
"cos" or "sin" to show cosine/sine of phase like GridCellTorus.
|
|
403
|
+
"""
|
|
404
|
+
config = _ensure_plot_config(
|
|
405
|
+
config,
|
|
406
|
+
PlotConfig.for_static_plot,
|
|
407
|
+
title="EcohoMap",
|
|
408
|
+
xlabel="",
|
|
409
|
+
ylabel="",
|
|
410
|
+
figsize=figsize,
|
|
411
|
+
save_path=save_path,
|
|
412
|
+
show=show,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
m1 = cohomap_result["phase_map1"]
|
|
416
|
+
m2 = cohomap_result["phase_map2"]
|
|
417
|
+
x_edge = cohomap_result["x_edge"]
|
|
418
|
+
y_edge = cohomap_result["y_edge"]
|
|
419
|
+
|
|
420
|
+
fig, ax = plt.subplots(1, 2, figsize=config.figsize)
|
|
421
|
+
for i, (mtot, title) in enumerate(((m1, "Phase Map 1"), (m2, "Phase Map 2"))):
|
|
422
|
+
if mode == "phase":
|
|
423
|
+
plot_map = mtot
|
|
424
|
+
cbar_label = "Phase (rad)"
|
|
425
|
+
vmin, vmax = -np.pi, np.pi
|
|
426
|
+
elif mode == "cos":
|
|
427
|
+
plot_map = np.cos(mtot)
|
|
428
|
+
cbar_label = "cos(phase)"
|
|
429
|
+
vmin, vmax = -1.0, 1.0
|
|
430
|
+
elif mode == "sin":
|
|
431
|
+
plot_map = np.sin(mtot)
|
|
432
|
+
cbar_label = "sin(phase)"
|
|
433
|
+
vmin, vmax = -1.0, 1.0
|
|
434
|
+
else:
|
|
435
|
+
raise ValueError(f"Unknown mode '{mode}'. Use 'phase', 'cos', or 'sin'.")
|
|
436
|
+
im = ax[i].imshow(
|
|
437
|
+
plot_map,
|
|
438
|
+
origin="lower",
|
|
439
|
+
extent=[x_edge[0], x_edge[-1], y_edge[0], y_edge[-1]],
|
|
440
|
+
cmap=cmap,
|
|
441
|
+
vmin=vmin,
|
|
442
|
+
vmax=vmax,
|
|
443
|
+
)
|
|
444
|
+
ax[i].set_title(title, fontsize=10)
|
|
445
|
+
ax[i].set_aspect("equal", "box")
|
|
446
|
+
ax[i].set_xticks([])
|
|
447
|
+
ax[i].set_yticks([])
|
|
448
|
+
plt.colorbar(im, ax=ax[i], fraction=0.046, pad=0.04, label=cbar_label)
|
|
449
|
+
|
|
450
|
+
fig.tight_layout()
|
|
451
|
+
_ensure_parent_dir(config.save_path)
|
|
452
|
+
finalize_figure(fig, config)
|
|
453
|
+
return fig
|