canns 0.14.3__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/asa/__init__.py +77 -21
- canns/analyzer/data/asa/coho.py +97 -0
- canns/analyzer/data/asa/cohomap.py +408 -0
- canns/analyzer/data/asa/cohomap_scatter.py +10 -0
- canns/analyzer/data/asa/cohomap_vectors.py +311 -0
- canns/analyzer/data/asa/cohospace.py +173 -1153
- canns/analyzer/data/asa/cohospace_phase_centers.py +137 -0
- canns/analyzer/data/asa/cohospace_scatter.py +1220 -0
- canns/analyzer/data/asa/embedding.py +3 -4
- canns/analyzer/data/asa/plotting.py +4 -4
- canns/analyzer/data/cell_classification/__init__.py +10 -0
- canns/analyzer/data/cell_classification/core/__init__.py +4 -0
- canns/analyzer/data/cell_classification/core/btn.py +272 -0
- canns/analyzer/data/cell_classification/visualization/__init__.py +3 -0
- canns/analyzer/data/cell_classification/visualization/btn_plots.py +258 -0
- canns/analyzer/visualization/__init__.py +2 -0
- canns/analyzer/visualization/core/config.py +20 -0
- canns/analyzer/visualization/theta_sweep_plots.py +142 -0
- canns/pipeline/asa/runner.py +19 -19
- canns/pipeline/asa_gui/__init__.py +5 -3
- canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +3 -1
- canns/pipeline/asa_gui/core/runner.py +23 -23
- canns/pipeline/asa_gui/views/pages/preprocess_page.py +7 -12
- {canns-0.14.3.dist-info → canns-0.15.0.dist-info}/METADATA +1 -1
- {canns-0.14.3.dist-info → canns-0.15.0.dist-info}/RECORD +28 -20
- {canns-0.14.3.dist-info → canns-0.15.0.dist-info}/WHEEL +0 -0
- {canns-0.14.3.dist-info → canns-0.15.0.dist-info}/entry_points.txt +0 -0
- {canns-0.14.3.dist-info → canns-0.15.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,15 +1,49 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
from .cohomap import (
|
|
4
|
+
cohomap,
|
|
5
|
+
cohomap_upgrade,
|
|
6
|
+
ecohomap,
|
|
7
|
+
fit_cohomap_stripes,
|
|
8
|
+
fit_cohomap_stripes_upgrade,
|
|
9
|
+
plot_cohomap,
|
|
10
|
+
plot_cohomap_upgrade,
|
|
11
|
+
plot_ecohomap,
|
|
12
|
+
)
|
|
13
|
+
from .cohomap_scatter import plot_cohomap_scatter, plot_cohomap_scatter_multi
|
|
14
|
+
from .cohomap_vectors import (
|
|
15
|
+
cohomap_vectors,
|
|
16
|
+
plot_cohomap_stripes,
|
|
17
|
+
plot_cohomap_vectors,
|
|
18
|
+
)
|
|
4
19
|
from .cohospace import (
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
20
|
+
cohospace,
|
|
21
|
+
cohospace_upgrade,
|
|
22
|
+
ecohospace,
|
|
23
|
+
plot_cohospace,
|
|
24
|
+
plot_cohospace_skewed,
|
|
25
|
+
plot_cohospace_upgrade,
|
|
26
|
+
plot_cohospace_upgrade_skewed,
|
|
27
|
+
plot_ecohospace,
|
|
28
|
+
plot_ecohospace_skewed,
|
|
29
|
+
)
|
|
30
|
+
from .cohospace_phase_centers import (
|
|
31
|
+
cohospace_phase_centers,
|
|
32
|
+
plot_cohospace_phase_centers,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Coho-space (scatter) analysis + visualization
|
|
36
|
+
from .cohospace_scatter import (
|
|
37
|
+
compute_cohoscore_scatter_1d,
|
|
38
|
+
compute_cohoscore_scatter_2d,
|
|
39
|
+
plot_cohospace_scatter_neuron_1d,
|
|
40
|
+
plot_cohospace_scatter_neuron_2d,
|
|
41
|
+
plot_cohospace_scatter_neuron_skewed,
|
|
42
|
+
plot_cohospace_scatter_population_1d,
|
|
43
|
+
plot_cohospace_scatter_population_2d,
|
|
44
|
+
plot_cohospace_scatter_population_skewed,
|
|
45
|
+
plot_cohospace_scatter_trajectory_1d,
|
|
46
|
+
plot_cohospace_scatter_trajectory_2d,
|
|
13
47
|
)
|
|
14
48
|
from .config import (
|
|
15
49
|
CANN2DError,
|
|
@@ -47,8 +81,6 @@ from .path import (
|
|
|
47
81
|
from .plotting import (
|
|
48
82
|
plot_2d_bump_on_manifold,
|
|
49
83
|
plot_3d_bump_on_torus,
|
|
50
|
-
plot_cohomap,
|
|
51
|
-
plot_cohomap_multi,
|
|
52
84
|
plot_path_compare_1d,
|
|
53
85
|
plot_path_compare_2d,
|
|
54
86
|
plot_projection,
|
|
@@ -72,10 +104,32 @@ __all__ = [
|
|
|
72
104
|
"plot_projection",
|
|
73
105
|
"plot_path_compare_1d",
|
|
74
106
|
"plot_path_compare_2d",
|
|
75
|
-
"
|
|
76
|
-
"
|
|
107
|
+
"plot_cohomap_scatter",
|
|
108
|
+
"plot_cohomap_scatter_multi",
|
|
77
109
|
"plot_3d_bump_on_torus",
|
|
78
110
|
"plot_2d_bump_on_manifold",
|
|
111
|
+
"cohomap",
|
|
112
|
+
"cohomap_upgrade",
|
|
113
|
+
"fit_cohomap_stripes",
|
|
114
|
+
"fit_cohomap_stripes_upgrade",
|
|
115
|
+
"plot_cohomap",
|
|
116
|
+
"plot_cohomap_upgrade",
|
|
117
|
+
"cohospace",
|
|
118
|
+
"cohospace_upgrade",
|
|
119
|
+
"plot_cohospace",
|
|
120
|
+
"plot_cohospace_skewed",
|
|
121
|
+
"plot_cohospace_upgrade",
|
|
122
|
+
"plot_cohospace_upgrade_skewed",
|
|
123
|
+
"ecohomap",
|
|
124
|
+
"ecohospace",
|
|
125
|
+
"plot_ecohomap",
|
|
126
|
+
"plot_ecohospace",
|
|
127
|
+
"plot_ecohospace_skewed",
|
|
128
|
+
"cohomap_vectors",
|
|
129
|
+
"plot_cohomap_stripes",
|
|
130
|
+
"plot_cohomap_vectors",
|
|
131
|
+
"cohospace_phase_centers",
|
|
132
|
+
"plot_cohospace_phase_centers",
|
|
79
133
|
"BumpFitsConfig",
|
|
80
134
|
"CANN1DPlotConfig",
|
|
81
135
|
"create_1d_bump_animation",
|
|
@@ -85,14 +139,16 @@ __all__ = [
|
|
|
85
139
|
"FRMResult",
|
|
86
140
|
"compute_frm",
|
|
87
141
|
"plot_frm",
|
|
88
|
-
"
|
|
89
|
-
"
|
|
90
|
-
"
|
|
91
|
-
"
|
|
92
|
-
"
|
|
93
|
-
"
|
|
94
|
-
"
|
|
95
|
-
"
|
|
142
|
+
"plot_cohospace_scatter_trajectory_1d",
|
|
143
|
+
"plot_cohospace_scatter_trajectory_2d",
|
|
144
|
+
"plot_cohospace_scatter_neuron_1d",
|
|
145
|
+
"plot_cohospace_scatter_neuron_2d",
|
|
146
|
+
"plot_cohospace_scatter_population_1d",
|
|
147
|
+
"plot_cohospace_scatter_population_2d",
|
|
148
|
+
"plot_cohospace_scatter_neuron_skewed",
|
|
149
|
+
"plot_cohospace_scatter_population_skewed",
|
|
150
|
+
"compute_cohoscore_scatter_1d",
|
|
151
|
+
"compute_cohoscore_scatter_2d",
|
|
96
152
|
"align_coords_to_position_1d",
|
|
97
153
|
"align_coords_to_position_2d",
|
|
98
154
|
"apply_angle_scale",
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Shared helpers for CohoMap/CohoSpace analysis."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from scipy.ndimage import gaussian_filter
|
|
10
|
+
|
|
11
|
+
from ...visualization.core import PlotConfig
|
|
12
|
+
from .path import find_coords_matrix, find_times_box
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _ensure_plot_config(
|
|
16
|
+
config: PlotConfig | None,
|
|
17
|
+
factory,
|
|
18
|
+
*args,
|
|
19
|
+
**defaults,
|
|
20
|
+
) -> PlotConfig:
|
|
21
|
+
if config is None:
|
|
22
|
+
return factory(*args, **defaults)
|
|
23
|
+
return config
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
27
|
+
if save_path:
|
|
28
|
+
parent = os.path.dirname(save_path)
|
|
29
|
+
if parent:
|
|
30
|
+
os.makedirs(parent, exist_ok=True)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _circmean(x: np.ndarray) -> float:
|
|
34
|
+
return float(np.arctan2(np.mean(np.sin(x)), np.mean(np.cos(x))))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _smooth_circular_map(
|
|
38
|
+
mtot: np.ndarray,
|
|
39
|
+
smooth_sigma: float,
|
|
40
|
+
*,
|
|
41
|
+
fill_nan: bool = False,
|
|
42
|
+
fill_sigma: float | None = None,
|
|
43
|
+
fill_min_weight: float = 1e-3,
|
|
44
|
+
) -> np.ndarray:
|
|
45
|
+
mtot = np.asarray(mtot, dtype=float)
|
|
46
|
+
nans = np.isnan(mtot)
|
|
47
|
+
mask = (~nans).astype(float)
|
|
48
|
+
sintot = np.sin(mtot)
|
|
49
|
+
costot = np.cos(mtot)
|
|
50
|
+
sintot[nans] = 0.0
|
|
51
|
+
costot[nans] = 0.0
|
|
52
|
+
|
|
53
|
+
if fill_nan:
|
|
54
|
+
if fill_sigma is None:
|
|
55
|
+
fill_sigma = smooth_sigma if smooth_sigma and smooth_sigma > 0 else 1.0
|
|
56
|
+
weight = gaussian_filter(mask, fill_sigma)
|
|
57
|
+
sintot = gaussian_filter(sintot * mask, fill_sigma)
|
|
58
|
+
costot = gaussian_filter(costot * mask, fill_sigma)
|
|
59
|
+
min_weight = max(float(fill_min_weight), 0.0)
|
|
60
|
+
valid = weight > min_weight
|
|
61
|
+
sintot = np.divide(sintot, weight, out=np.zeros_like(sintot), where=valid)
|
|
62
|
+
costot = np.divide(costot, weight, out=np.zeros_like(costot), where=valid)
|
|
63
|
+
mtot = np.arctan2(sintot, costot)
|
|
64
|
+
if fill_min_weight > 0:
|
|
65
|
+
mtot[~valid] = np.nan
|
|
66
|
+
return mtot
|
|
67
|
+
|
|
68
|
+
if smooth_sigma and smooth_sigma > 0:
|
|
69
|
+
sintot = gaussian_filter(sintot, smooth_sigma)
|
|
70
|
+
costot = gaussian_filter(costot, smooth_sigma)
|
|
71
|
+
mtot = np.arctan2(sintot, costot)
|
|
72
|
+
mtot[nans] = np.nan
|
|
73
|
+
return mtot
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _extract_coords_and_times(
|
|
77
|
+
decoding_result: dict[str, Any],
|
|
78
|
+
coords_key: str | None = None,
|
|
79
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
80
|
+
if coords_key is not None:
|
|
81
|
+
if coords_key not in decoding_result:
|
|
82
|
+
raise KeyError(f"coords_key '{coords_key}' not found in decoding_result.")
|
|
83
|
+
coords = np.asarray(decoding_result[coords_key])
|
|
84
|
+
elif "coordsbox" in decoding_result:
|
|
85
|
+
coords = np.asarray(decoding_result["coordsbox"])
|
|
86
|
+
else:
|
|
87
|
+
coords, _ = find_coords_matrix(decoding_result)
|
|
88
|
+
|
|
89
|
+
times_box, _ = find_times_box(decoding_result)
|
|
90
|
+
return coords, times_box
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _phase_map_valid_fraction(phase_map: np.ndarray) -> float:
|
|
94
|
+
valid = np.isfinite(phase_map)
|
|
95
|
+
if valid.size == 0:
|
|
96
|
+
return 0.0
|
|
97
|
+
return float(np.mean(valid))
|
|
@@ -0,0 +1,408 @@
|
|
|
1
|
+
"""CohoMap (Ecoho-style) computation and plotting."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import numpy as np
|
|
9
|
+
from scipy.ndimage import rotate
|
|
10
|
+
from scipy.optimize import minimize
|
|
11
|
+
from scipy.stats import binned_statistic_2d
|
|
12
|
+
|
|
13
|
+
from ...visualization.core import PlotConfig, finalize_figure
|
|
14
|
+
from .coho import (
|
|
15
|
+
_circmean,
|
|
16
|
+
_ensure_parent_dir,
|
|
17
|
+
_ensure_plot_config,
|
|
18
|
+
_extract_coords_and_times,
|
|
19
|
+
_phase_map_valid_fraction,
|
|
20
|
+
_smooth_circular_map,
|
|
21
|
+
)
|
|
22
|
+
from .path import parse_times_box_to_indices
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _select_phase_sign(
|
|
26
|
+
phase_map: np.ndarray,
|
|
27
|
+
params: np.ndarray,
|
|
28
|
+
*,
|
|
29
|
+
grid_size: int,
|
|
30
|
+
trim: int,
|
|
31
|
+
) -> int:
|
|
32
|
+
mtot = np.asarray(phase_map)
|
|
33
|
+
expected = grid_size - 1 - 2 * trim
|
|
34
|
+
if mtot.ndim != 2 or mtot.shape[0] != expected or mtot.shape[1] != expected:
|
|
35
|
+
raise ValueError("phase_map shape does not match grid_size/trim for alignment")
|
|
36
|
+
nnans = ~np.isnan(mtot)
|
|
37
|
+
if not np.any(nnans):
|
|
38
|
+
return 1
|
|
39
|
+
x, _ = np.meshgrid(
|
|
40
|
+
np.linspace(0, 3 * np.pi, grid_size - 1),
|
|
41
|
+
np.linspace(0, 3 * np.pi, grid_size - 1),
|
|
42
|
+
)
|
|
43
|
+
x1 = rotate(x, params[0] * 360.0 / (2 * np.pi), reshape=False)
|
|
44
|
+
base = params[2] * x1[trim:-trim, trim:-trim] + params[1]
|
|
45
|
+
mtot_vals = (mtot[nnans]) % (2 * np.pi)
|
|
46
|
+
pm1 = (base % (2 * np.pi))[nnans] - mtot_vals
|
|
47
|
+
pm2 = ((2 * np.pi - base) % (2 * np.pi))[nnans] - mtot_vals
|
|
48
|
+
if np.sum(np.abs(pm1)) > np.sum(np.abs(pm2)):
|
|
49
|
+
return -1
|
|
50
|
+
return 1
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _rot_coord(
|
|
54
|
+
params1: np.ndarray,
|
|
55
|
+
params2: np.ndarray,
|
|
56
|
+
c1: np.ndarray,
|
|
57
|
+
c2: np.ndarray,
|
|
58
|
+
p: tuple[int, int],
|
|
59
|
+
) -> np.ndarray:
|
|
60
|
+
if abs(np.cos(params1[0])) < abs(np.cos(params2[0])):
|
|
61
|
+
cc1 = c2.copy()
|
|
62
|
+
cc2 = c1.copy()
|
|
63
|
+
y = params1.copy()
|
|
64
|
+
x = params2.copy()
|
|
65
|
+
p = (p[1], p[0])
|
|
66
|
+
else:
|
|
67
|
+
cc1 = c1.copy()
|
|
68
|
+
cc2 = c2.copy()
|
|
69
|
+
x = params1.copy()
|
|
70
|
+
y = params2.copy()
|
|
71
|
+
|
|
72
|
+
if p[1] == -1:
|
|
73
|
+
cc2 = 2 * np.pi - cc2
|
|
74
|
+
if p[0] == -1:
|
|
75
|
+
cc1 = 2 * np.pi - cc1
|
|
76
|
+
|
|
77
|
+
alpha = y[0] - x[0]
|
|
78
|
+
if (alpha < 0) and (abs(alpha) > np.pi / 2):
|
|
79
|
+
cctmp = cc2.copy()
|
|
80
|
+
cc2 = cc1.copy()
|
|
81
|
+
cc1 = cctmp
|
|
82
|
+
|
|
83
|
+
if (alpha < 0) and (abs(alpha) < np.pi / 2):
|
|
84
|
+
cc1 = 2 * np.pi - cc1 + (np.pi / 3) * cc2
|
|
85
|
+
elif abs(alpha) > np.pi / 2:
|
|
86
|
+
cc2 = cc2 + (np.pi / 3) * cc1
|
|
87
|
+
|
|
88
|
+
return np.stack([cc1, cc2], axis=1) % (2 * np.pi)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _toroidal_align_coords(
|
|
92
|
+
coords: np.ndarray,
|
|
93
|
+
phase_map1: np.ndarray,
|
|
94
|
+
phase_map2: np.ndarray,
|
|
95
|
+
*,
|
|
96
|
+
trim: int,
|
|
97
|
+
grid_size: int | None,
|
|
98
|
+
) -> tuple[np.ndarray, float, float]:
|
|
99
|
+
if phase_map1.shape != phase_map2.shape:
|
|
100
|
+
raise ValueError("phase_map shapes do not match for alignment")
|
|
101
|
+
coords = np.asarray(coords)
|
|
102
|
+
if coords.ndim != 2 or coords.shape[1] < 2:
|
|
103
|
+
raise ValueError(f"coords must be (N,2+) array, got {coords.shape}")
|
|
104
|
+
if grid_size is None:
|
|
105
|
+
grid_size = int(phase_map1.shape[0]) + 2 * trim + 1
|
|
106
|
+
p1, f1 = fit_cohomap_stripes(phase_map1, grid_size=grid_size, trim=trim)
|
|
107
|
+
p2, f2 = fit_cohomap_stripes(phase_map2, grid_size=grid_size, trim=trim)
|
|
108
|
+
s1 = _select_phase_sign(phase_map1, p1, grid_size=grid_size, trim=trim)
|
|
109
|
+
s2 = _select_phase_sign(phase_map2, p2, grid_size=grid_size, trim=trim)
|
|
110
|
+
aligned = _rot_coord(p1, p2, coords[:, 0], coords[:, 1], (s1, s2))
|
|
111
|
+
return aligned, float(f1), float(f2)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def cohomap(
|
|
115
|
+
decoding_result: dict[str, Any],
|
|
116
|
+
position_data: dict[str, Any],
|
|
117
|
+
*,
|
|
118
|
+
coords_key: str | None = None,
|
|
119
|
+
bins: int = 101,
|
|
120
|
+
margin_frac: float = 0.0025,
|
|
121
|
+
smooth_sigma: float = 1.0,
|
|
122
|
+
fill_nan: bool = True,
|
|
123
|
+
fill_sigma: float | None = None,
|
|
124
|
+
fill_min_weight: float = 1e-3,
|
|
125
|
+
align_torus: bool = True,
|
|
126
|
+
align_trim: int = 25,
|
|
127
|
+
align_grid_size: int | None = None,
|
|
128
|
+
align_min_valid_frac: float | None = None,
|
|
129
|
+
align_max_fit_error: float | None = None,
|
|
130
|
+
) -> dict[str, Any]:
|
|
131
|
+
"""
|
|
132
|
+
Compute EcohoMap phase maps using circular-mean binning.
|
|
133
|
+
|
|
134
|
+
This mirrors GridCellTorus get_ang_hist: bin spatial positions and compute the
|
|
135
|
+
circular mean of each decoded angle within spatial bins, then smooth in sin/cos
|
|
136
|
+
space. Optional toroidal alignment follows the GridCellTorus stripe fit + rotation.
|
|
137
|
+
You can gate alignment by valid fraction or fit error thresholds.
|
|
138
|
+
"""
|
|
139
|
+
coords, times_box = _extract_coords_and_times(decoding_result, coords_key)
|
|
140
|
+
if coords.ndim != 2 or coords.shape[1] < 2:
|
|
141
|
+
raise ValueError(f"coords must be (N,2+) array, got {coords.shape}")
|
|
142
|
+
|
|
143
|
+
xx = np.asarray(position_data["x"])
|
|
144
|
+
yy = np.asarray(position_data["y"])
|
|
145
|
+
|
|
146
|
+
if times_box is not None:
|
|
147
|
+
if "t" in position_data:
|
|
148
|
+
idx, _ = parse_times_box_to_indices(times_box, np.asarray(position_data["t"]))
|
|
149
|
+
xx = xx[idx]
|
|
150
|
+
yy = yy[idx]
|
|
151
|
+
else:
|
|
152
|
+
idx = np.asarray(times_box).astype(int)
|
|
153
|
+
xx = xx[idx]
|
|
154
|
+
yy = yy[idx]
|
|
155
|
+
|
|
156
|
+
if len(xx) != coords.shape[0]:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
"Length mismatch: coords length does not match position length after times_box."
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
x_min, x_max = float(np.min(xx)), float(np.max(xx))
|
|
162
|
+
y_min, y_max = float(np.min(yy)), float(np.max(yy))
|
|
163
|
+
x_pad = (x_max - x_min) * margin_frac
|
|
164
|
+
y_pad = (y_max - y_min) * margin_frac
|
|
165
|
+
|
|
166
|
+
binsx = np.linspace(x_min + x_pad, x_max - x_pad, bins)
|
|
167
|
+
binsy = np.linspace(y_min + y_pad, y_max - y_pad, bins)
|
|
168
|
+
|
|
169
|
+
def _angle_hist(values: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
170
|
+
nnans = ~np.isnan(values)
|
|
171
|
+
mtot, x_edge, y_edge, _ = binned_statistic_2d(
|
|
172
|
+
xx[nnans],
|
|
173
|
+
yy[nnans],
|
|
174
|
+
values[nnans],
|
|
175
|
+
statistic=_circmean,
|
|
176
|
+
bins=(binsx, binsy),
|
|
177
|
+
range=None,
|
|
178
|
+
expand_binnumbers=True,
|
|
179
|
+
)
|
|
180
|
+
mtot = _smooth_circular_map(
|
|
181
|
+
mtot,
|
|
182
|
+
smooth_sigma,
|
|
183
|
+
fill_nan=fill_nan,
|
|
184
|
+
fill_sigma=fill_sigma,
|
|
185
|
+
fill_min_weight=fill_min_weight,
|
|
186
|
+
)
|
|
187
|
+
return mtot, x_edge, y_edge
|
|
188
|
+
|
|
189
|
+
coords_use = np.asarray(coords, float)
|
|
190
|
+
m1_raw, x_edge, y_edge = _angle_hist(coords_use[:, 0])
|
|
191
|
+
m2_raw, _, _ = _angle_hist(coords_use[:, 1])
|
|
192
|
+
aligned = False
|
|
193
|
+
align_error = None
|
|
194
|
+
align_valid_frac1 = None
|
|
195
|
+
align_valid_frac2 = None
|
|
196
|
+
align_fit_error1 = None
|
|
197
|
+
align_fit_error2 = None
|
|
198
|
+
|
|
199
|
+
if align_torus:
|
|
200
|
+
try:
|
|
201
|
+
align_valid_frac1 = _phase_map_valid_fraction(m1_raw)
|
|
202
|
+
align_valid_frac2 = _phase_map_valid_fraction(m2_raw)
|
|
203
|
+
min_valid = min(align_valid_frac1, align_valid_frac2)
|
|
204
|
+
|
|
205
|
+
if align_min_valid_frac is not None and min_valid < align_min_valid_frac:
|
|
206
|
+
align_error = (
|
|
207
|
+
f"valid fraction too low ({min_valid:.3f} < {align_min_valid_frac:.3f})"
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
coords_aligned, f1, f2 = _toroidal_align_coords(
|
|
211
|
+
coords_use[:, :2],
|
|
212
|
+
m1_raw,
|
|
213
|
+
m2_raw,
|
|
214
|
+
trim=align_trim,
|
|
215
|
+
grid_size=align_grid_size,
|
|
216
|
+
)
|
|
217
|
+
align_fit_error1 = f1
|
|
218
|
+
align_fit_error2 = f2
|
|
219
|
+
if align_max_fit_error is not None and (
|
|
220
|
+
f1 > align_max_fit_error or f2 > align_max_fit_error
|
|
221
|
+
):
|
|
222
|
+
align_error = (
|
|
223
|
+
f"fit error too high ({f1:.4f}, {f2:.4f} > {align_max_fit_error:.4f})"
|
|
224
|
+
)
|
|
225
|
+
else:
|
|
226
|
+
coords_use = coords_use.copy()
|
|
227
|
+
coords_use[:, :2] = coords_aligned
|
|
228
|
+
aligned = True
|
|
229
|
+
except Exception as exc:
|
|
230
|
+
align_error = str(exc)
|
|
231
|
+
|
|
232
|
+
if aligned:
|
|
233
|
+
m1, x_edge, y_edge = _angle_hist(coords_use[:, 0])
|
|
234
|
+
m2, _, _ = _angle_hist(coords_use[:, 1])
|
|
235
|
+
else:
|
|
236
|
+
m1, m2 = m1_raw, m2_raw
|
|
237
|
+
|
|
238
|
+
return {
|
|
239
|
+
"phase_map1": m1,
|
|
240
|
+
"phase_map2": m2,
|
|
241
|
+
"phase_map1_raw": m1_raw,
|
|
242
|
+
"phase_map2_raw": m2_raw,
|
|
243
|
+
"x_edge": x_edge,
|
|
244
|
+
"y_edge": y_edge,
|
|
245
|
+
"bins": bins,
|
|
246
|
+
"margin_frac": margin_frac,
|
|
247
|
+
"smooth_sigma": smooth_sigma,
|
|
248
|
+
"fill_nan": fill_nan,
|
|
249
|
+
"fill_sigma": fill_sigma,
|
|
250
|
+
"fill_min_weight": fill_min_weight,
|
|
251
|
+
"aligned": aligned,
|
|
252
|
+
"align_error": align_error,
|
|
253
|
+
"align_min_valid_frac": align_min_valid_frac,
|
|
254
|
+
"align_max_fit_error": align_max_fit_error,
|
|
255
|
+
"align_valid_frac1": align_valid_frac1,
|
|
256
|
+
"align_valid_frac2": align_valid_frac2,
|
|
257
|
+
"align_fit_error1": align_fit_error1,
|
|
258
|
+
"align_fit_error2": align_fit_error2,
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def fit_cohomap_stripes(
|
|
263
|
+
phase_map: np.ndarray,
|
|
264
|
+
*,
|
|
265
|
+
grid_size: int | None = 151,
|
|
266
|
+
trim: int = 25,
|
|
267
|
+
angle_grid: int = 10,
|
|
268
|
+
phase_grid: int = 10,
|
|
269
|
+
spacing_grid: int = 10,
|
|
270
|
+
spacing_range: tuple[float, float] = (1.0, 6.0),
|
|
271
|
+
) -> tuple[np.ndarray, float]:
|
|
272
|
+
"""
|
|
273
|
+
Fit a cosine stripe model to a phase map, mirroring GridCellTorus fit_sine_wave.
|
|
274
|
+
"""
|
|
275
|
+
mtot = np.asarray(phase_map)
|
|
276
|
+
if mtot.ndim != 2:
|
|
277
|
+
raise ValueError(f"phase_map must be 2D, got {mtot.shape}")
|
|
278
|
+
|
|
279
|
+
if grid_size is None:
|
|
280
|
+
grid_size = mtot.shape[0] + 2 * trim + 1
|
|
281
|
+
|
|
282
|
+
expected = grid_size - 1 - 2 * trim
|
|
283
|
+
if expected != mtot.shape[0]:
|
|
284
|
+
raise ValueError(
|
|
285
|
+
f"grid_size/trim incompatible with phase_map shape: "
|
|
286
|
+
f"expected {expected} but got {mtot.shape[0]}"
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
numangsint = grid_size
|
|
290
|
+
x, _ = np.meshgrid(
|
|
291
|
+
np.linspace(0, 3 * np.pi, numangsint - 1),
|
|
292
|
+
np.linspace(0, 3 * np.pi, numangsint - 1),
|
|
293
|
+
)
|
|
294
|
+
nnans = ~np.isnan(mtot)
|
|
295
|
+
|
|
296
|
+
def cos_wave(p: np.ndarray) -> float:
|
|
297
|
+
x1 = rotate(x, p[0] * 360.0 / (2 * np.pi), reshape=False)
|
|
298
|
+
model = np.cos(p[2] * x1[trim:-trim, trim:-trim] + p[1])
|
|
299
|
+
return float(np.mean(np.square(model[nnans] - np.cos(mtot[nnans]))))
|
|
300
|
+
|
|
301
|
+
angle_space = np.linspace(0, np.pi, angle_grid)
|
|
302
|
+
phase_space = np.linspace(0, 2 * np.pi, phase_grid)
|
|
303
|
+
spacing_space = np.linspace(spacing_range[0], spacing_range[1], spacing_grid)
|
|
304
|
+
|
|
305
|
+
grid = np.zeros((angle_grid, phase_grid, spacing_grid))
|
|
306
|
+
for i, ang in enumerate(angle_space):
|
|
307
|
+
for j, ph in enumerate(phase_space):
|
|
308
|
+
for k, sp in enumerate(spacing_space):
|
|
309
|
+
grid[i, j, k] = cos_wave(np.array([ang, ph, sp]))
|
|
310
|
+
|
|
311
|
+
p_ind = np.unravel_index(np.argmin(grid), grid.shape)
|
|
312
|
+
p0 = np.array([angle_space[p_ind[0]], phase_space[p_ind[1]], spacing_space[p_ind[2]]])
|
|
313
|
+
res = minimize(cos_wave, p0, method="SLSQP", options={"disp": False})
|
|
314
|
+
return res["x"], float(res["fun"])
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def plot_cohomap(
|
|
318
|
+
cohomap_result: dict[str, Any],
|
|
319
|
+
*,
|
|
320
|
+
config: PlotConfig | None = None,
|
|
321
|
+
save_path: str | None = None,
|
|
322
|
+
show: bool = False,
|
|
323
|
+
figsize: tuple[int, int] = (10, 4),
|
|
324
|
+
cmap: str = "viridis",
|
|
325
|
+
mode: str = "cos",
|
|
326
|
+
) -> plt.Figure:
|
|
327
|
+
"""
|
|
328
|
+
Plot EcohoMap phase maps (two panels: phase_map1/phase_map2).
|
|
329
|
+
|
|
330
|
+
mode:
|
|
331
|
+
"phase" to show raw phase (radians),
|
|
332
|
+
"cos" or "sin" to show cosine/sine of phase like GridCellTorus.
|
|
333
|
+
"""
|
|
334
|
+
config = _ensure_plot_config(
|
|
335
|
+
config,
|
|
336
|
+
PlotConfig.for_static_plot,
|
|
337
|
+
title="EcohoMap",
|
|
338
|
+
xlabel="",
|
|
339
|
+
ylabel="",
|
|
340
|
+
figsize=figsize,
|
|
341
|
+
save_path=save_path,
|
|
342
|
+
show=show,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
m1 = cohomap_result["phase_map1"]
|
|
346
|
+
m2 = cohomap_result["phase_map2"]
|
|
347
|
+
x_edge = cohomap_result["x_edge"]
|
|
348
|
+
y_edge = cohomap_result["y_edge"]
|
|
349
|
+
|
|
350
|
+
fig, ax = plt.subplots(1, 2, figsize=config.figsize)
|
|
351
|
+
for i, (mtot, title) in enumerate(((m1, "Phase Map 1"), (m2, "Phase Map 2"))):
|
|
352
|
+
if mode == "phase":
|
|
353
|
+
plot_map = mtot
|
|
354
|
+
cbar_label = "Phase (rad)"
|
|
355
|
+
vmin, vmax = -np.pi, np.pi
|
|
356
|
+
elif mode == "cos":
|
|
357
|
+
plot_map = np.cos(mtot)
|
|
358
|
+
cbar_label = "cos(phase)"
|
|
359
|
+
vmin, vmax = -1.0, 1.0
|
|
360
|
+
elif mode == "sin":
|
|
361
|
+
plot_map = np.sin(mtot)
|
|
362
|
+
cbar_label = "sin(phase)"
|
|
363
|
+
vmin, vmax = -1.0, 1.0
|
|
364
|
+
else:
|
|
365
|
+
raise ValueError(f"Unknown mode '{mode}'. Use 'phase', 'cos', or 'sin'.")
|
|
366
|
+
im = ax[i].imshow(
|
|
367
|
+
plot_map,
|
|
368
|
+
origin="lower",
|
|
369
|
+
extent=[x_edge[0], x_edge[-1], y_edge[0], y_edge[-1]],
|
|
370
|
+
cmap=cmap,
|
|
371
|
+
vmin=vmin,
|
|
372
|
+
vmax=vmax,
|
|
373
|
+
)
|
|
374
|
+
ax[i].set_title(title, fontsize=10)
|
|
375
|
+
ax[i].set_aspect("equal", "box")
|
|
376
|
+
ax[i].set_xticks([])
|
|
377
|
+
ax[i].set_yticks([])
|
|
378
|
+
plt.colorbar(im, ax=ax[i], fraction=0.046, pad=0.04, label=cbar_label)
|
|
379
|
+
|
|
380
|
+
fig.tight_layout()
|
|
381
|
+
_ensure_parent_dir(config.save_path)
|
|
382
|
+
finalize_figure(fig, config)
|
|
383
|
+
return fig
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def cohomap_upgrade(*args, **kwargs) -> dict[str, Any]:
|
|
387
|
+
"""Legacy alias for EcohoMap (formerly cohomap_upgrade)."""
|
|
388
|
+
return cohomap(*args, **kwargs)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def ecohomap(*args, **kwargs) -> dict[str, Any]:
|
|
392
|
+
"""Alias for EcohoMap (GridCellTorus-style)."""
|
|
393
|
+
return cohomap(*args, **kwargs)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def fit_cohomap_stripes_upgrade(*args, **kwargs) -> tuple[np.ndarray, float]:
|
|
397
|
+
"""Legacy alias for EcohoMap stripe fitting."""
|
|
398
|
+
return fit_cohomap_stripes(*args, **kwargs)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def plot_cohomap_upgrade(*args, **kwargs) -> plt.Figure:
|
|
402
|
+
"""Legacy alias for EcohoMap plotting (formerly plot_cohomap_upgrade)."""
|
|
403
|
+
return plot_cohomap(*args, **kwargs)
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def plot_ecohomap(*args, **kwargs) -> plt.Figure:
|
|
407
|
+
"""Alias for EcohoMap plotting."""
|
|
408
|
+
return plot_cohomap(*args, **kwargs)
|