canns 0.15.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/asa/__init__.py +2 -23
- canns/analyzer/data/asa/coho.py +17 -93
- canns/analyzer/data/asa/cohomap.py +70 -25
- canns/analyzer/data/asa/cohomap_vectors.py +73 -19
- canns/analyzer/data/asa/cohospace.py +0 -30
- canns/analyzer/data/asa/cohospace_phase_centers.py +3 -21
- canns/analyzer/data/asa/cohospace_scatter.py +5 -110
- canns/analyzer/data/asa/embedding.py +2 -3
- canns/analyzer/data/asa/fr.py +1 -8
- canns/analyzer/data/asa/path.py +70 -0
- canns/analyzer/data/asa/plotting.py +1 -26
- canns/analyzer/data/asa/utils.py +160 -0
- canns/analyzer/data/cell_classification/visualization/btn_plots.py +1 -18
- canns-1.0.0.dist-info/METADATA +257 -0
- {canns-0.15.0.dist-info → canns-1.0.0.dist-info}/RECORD +18 -19
- canns/analyzer/data/asa/cohomap_scatter.py +0 -10
- canns/analyzer/data/asa/filters.py +0 -208
- canns-0.15.0.dist-info/METADATA +0 -245
- {canns-0.15.0.dist-info → canns-1.0.0.dist-info}/WHEEL +0 -0
- {canns-0.15.0.dist-info → canns-1.0.0.dist-info}/entry_points.txt +0 -0
- {canns-0.15.0.dist-info → canns-1.0.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,15 +2,9 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from .cohomap import (
|
|
4
4
|
cohomap,
|
|
5
|
-
cohomap_upgrade,
|
|
6
|
-
ecohomap,
|
|
7
5
|
fit_cohomap_stripes,
|
|
8
|
-
fit_cohomap_stripes_upgrade,
|
|
9
6
|
plot_cohomap,
|
|
10
|
-
plot_cohomap_upgrade,
|
|
11
|
-
plot_ecohomap,
|
|
12
7
|
)
|
|
13
|
-
from .cohomap_scatter import plot_cohomap_scatter, plot_cohomap_scatter_multi
|
|
14
8
|
from .cohomap_vectors import (
|
|
15
9
|
cohomap_vectors,
|
|
16
10
|
plot_cohomap_stripes,
|
|
@@ -18,14 +12,8 @@ from .cohomap_vectors import (
|
|
|
18
12
|
)
|
|
19
13
|
from .cohospace import (
|
|
20
14
|
cohospace,
|
|
21
|
-
cohospace_upgrade,
|
|
22
|
-
ecohospace,
|
|
23
15
|
plot_cohospace,
|
|
24
16
|
plot_cohospace_skewed,
|
|
25
|
-
plot_cohospace_upgrade,
|
|
26
|
-
plot_cohospace_upgrade_skewed,
|
|
27
|
-
plot_ecohospace,
|
|
28
|
-
plot_ecohospace_skewed,
|
|
29
17
|
)
|
|
30
18
|
from .cohospace_phase_centers import (
|
|
31
19
|
cohospace_phase_centers,
|
|
@@ -81,6 +69,8 @@ from .path import (
|
|
|
81
69
|
from .plotting import (
|
|
82
70
|
plot_2d_bump_on_manifold,
|
|
83
71
|
plot_3d_bump_on_torus,
|
|
72
|
+
plot_cohomap_scatter,
|
|
73
|
+
plot_cohomap_scatter_multi,
|
|
84
74
|
plot_path_compare_1d,
|
|
85
75
|
plot_path_compare_2d,
|
|
86
76
|
plot_projection,
|
|
@@ -109,22 +99,11 @@ __all__ = [
|
|
|
109
99
|
"plot_3d_bump_on_torus",
|
|
110
100
|
"plot_2d_bump_on_manifold",
|
|
111
101
|
"cohomap",
|
|
112
|
-
"cohomap_upgrade",
|
|
113
102
|
"fit_cohomap_stripes",
|
|
114
|
-
"fit_cohomap_stripes_upgrade",
|
|
115
103
|
"plot_cohomap",
|
|
116
|
-
"plot_cohomap_upgrade",
|
|
117
104
|
"cohospace",
|
|
118
|
-
"cohospace_upgrade",
|
|
119
105
|
"plot_cohospace",
|
|
120
106
|
"plot_cohospace_skewed",
|
|
121
|
-
"plot_cohospace_upgrade",
|
|
122
|
-
"plot_cohospace_upgrade_skewed",
|
|
123
|
-
"ecohomap",
|
|
124
|
-
"ecohospace",
|
|
125
|
-
"plot_ecohomap",
|
|
126
|
-
"plot_ecohospace",
|
|
127
|
-
"plot_ecohospace_skewed",
|
|
128
107
|
"cohomap_vectors",
|
|
129
108
|
"plot_cohomap_stripes",
|
|
130
109
|
"plot_cohomap_vectors",
|
canns/analyzer/data/asa/coho.py
CHANGED
|
@@ -2,96 +2,20 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
import
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
return factory(*args, **defaults)
|
|
23
|
-
return config
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
27
|
-
if save_path:
|
|
28
|
-
parent = os.path.dirname(save_path)
|
|
29
|
-
if parent:
|
|
30
|
-
os.makedirs(parent, exist_ok=True)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def _circmean(x: np.ndarray) -> float:
|
|
34
|
-
return float(np.arctan2(np.mean(np.sin(x)), np.mean(np.cos(x))))
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def _smooth_circular_map(
|
|
38
|
-
mtot: np.ndarray,
|
|
39
|
-
smooth_sigma: float,
|
|
40
|
-
*,
|
|
41
|
-
fill_nan: bool = False,
|
|
42
|
-
fill_sigma: float | None = None,
|
|
43
|
-
fill_min_weight: float = 1e-3,
|
|
44
|
-
) -> np.ndarray:
|
|
45
|
-
mtot = np.asarray(mtot, dtype=float)
|
|
46
|
-
nans = np.isnan(mtot)
|
|
47
|
-
mask = (~nans).astype(float)
|
|
48
|
-
sintot = np.sin(mtot)
|
|
49
|
-
costot = np.cos(mtot)
|
|
50
|
-
sintot[nans] = 0.0
|
|
51
|
-
costot[nans] = 0.0
|
|
52
|
-
|
|
53
|
-
if fill_nan:
|
|
54
|
-
if fill_sigma is None:
|
|
55
|
-
fill_sigma = smooth_sigma if smooth_sigma and smooth_sigma > 0 else 1.0
|
|
56
|
-
weight = gaussian_filter(mask, fill_sigma)
|
|
57
|
-
sintot = gaussian_filter(sintot * mask, fill_sigma)
|
|
58
|
-
costot = gaussian_filter(costot * mask, fill_sigma)
|
|
59
|
-
min_weight = max(float(fill_min_weight), 0.0)
|
|
60
|
-
valid = weight > min_weight
|
|
61
|
-
sintot = np.divide(sintot, weight, out=np.zeros_like(sintot), where=valid)
|
|
62
|
-
costot = np.divide(costot, weight, out=np.zeros_like(costot), where=valid)
|
|
63
|
-
mtot = np.arctan2(sintot, costot)
|
|
64
|
-
if fill_min_weight > 0:
|
|
65
|
-
mtot[~valid] = np.nan
|
|
66
|
-
return mtot
|
|
67
|
-
|
|
68
|
-
if smooth_sigma and smooth_sigma > 0:
|
|
69
|
-
sintot = gaussian_filter(sintot, smooth_sigma)
|
|
70
|
-
costot = gaussian_filter(costot, smooth_sigma)
|
|
71
|
-
mtot = np.arctan2(sintot, costot)
|
|
72
|
-
mtot[nans] = np.nan
|
|
73
|
-
return mtot
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def _extract_coords_and_times(
|
|
77
|
-
decoding_result: dict[str, Any],
|
|
78
|
-
coords_key: str | None = None,
|
|
79
|
-
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
80
|
-
if coords_key is not None:
|
|
81
|
-
if coords_key not in decoding_result:
|
|
82
|
-
raise KeyError(f"coords_key '{coords_key}' not found in decoding_result.")
|
|
83
|
-
coords = np.asarray(decoding_result[coords_key])
|
|
84
|
-
elif "coordsbox" in decoding_result:
|
|
85
|
-
coords = np.asarray(decoding_result["coordsbox"])
|
|
86
|
-
else:
|
|
87
|
-
coords, _ = find_coords_matrix(decoding_result)
|
|
88
|
-
|
|
89
|
-
times_box, _ = find_times_box(decoding_result)
|
|
90
|
-
return coords, times_box
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def _phase_map_valid_fraction(phase_map: np.ndarray) -> float:
|
|
94
|
-
valid = np.isfinite(phase_map)
|
|
95
|
-
if valid.size == 0:
|
|
96
|
-
return 0.0
|
|
97
|
-
return float(np.mean(valid))
|
|
5
|
+
from .utils import (
|
|
6
|
+
_circmean,
|
|
7
|
+
_ensure_parent_dir,
|
|
8
|
+
_ensure_plot_config,
|
|
9
|
+
_extract_coords_and_times,
|
|
10
|
+
_phase_map_valid_fraction,
|
|
11
|
+
_smooth_circular_map,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"_ensure_plot_config",
|
|
16
|
+
"_ensure_parent_dir",
|
|
17
|
+
"_circmean",
|
|
18
|
+
"_smooth_circular_map",
|
|
19
|
+
"_extract_coords_and_times",
|
|
20
|
+
"_phase_map_valid_fraction",
|
|
21
|
+
]
|
|
@@ -57,6 +57,36 @@ def _rot_coord(
|
|
|
57
57
|
c2: np.ndarray,
|
|
58
58
|
p: tuple[int, int],
|
|
59
59
|
) -> np.ndarray:
|
|
60
|
+
"""Transform and align decoded coordinates based on stripe orientation.
|
|
61
|
+
|
|
62
|
+
This function rotates and aligns coordinates to match the dominant stripe
|
|
63
|
+
orientation in the CohoMap. It handles the torus geometry by choosing the
|
|
64
|
+
appropriate coordinate system and applying reflections/rotations.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
params1 : np.ndarray
|
|
69
|
+
Stripe fit parameters for first dimension (includes orientation angle).
|
|
70
|
+
params2 : np.ndarray
|
|
71
|
+
Stripe fit parameters for second dimension (includes orientation angle).
|
|
72
|
+
c1 : np.ndarray
|
|
73
|
+
Decoded coordinates for first dimension.
|
|
74
|
+
c2 : np.ndarray
|
|
75
|
+
Decoded coordinates for second dimension.
|
|
76
|
+
p : tuple[int, int]
|
|
77
|
+
Direction indicators (1 or -1) for each dimension, controlling reflections.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
np.ndarray
|
|
82
|
+
Aligned coordinates of shape (N, 2), wrapped to [0, 2π).
|
|
83
|
+
|
|
84
|
+
Notes
|
|
85
|
+
-----
|
|
86
|
+
The function selects the coordinate system based on which dimension has
|
|
87
|
+
stronger horizontal alignment (larger |cos(angle)|), then applies geometric
|
|
88
|
+
transformations to align the coordinates with the stripe pattern.
|
|
89
|
+
"""
|
|
60
90
|
if abs(np.cos(params1[0])) < abs(np.cos(params2[0])):
|
|
61
91
|
cc1 = c2.copy()
|
|
62
92
|
cc2 = c1.copy()
|
|
@@ -96,6 +126,46 @@ def _toroidal_align_coords(
|
|
|
96
126
|
trim: int,
|
|
97
127
|
grid_size: int | None,
|
|
98
128
|
) -> tuple[np.ndarray, float, float]:
|
|
129
|
+
"""Align decoded coordinates to CohoMap stripe patterns.
|
|
130
|
+
|
|
131
|
+
This function fits stripe patterns to both phase maps, determines the
|
|
132
|
+
optimal orientation and phase signs, then transforms the decoded coordinates
|
|
133
|
+
to align with the detected stripe structure.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
coords : np.ndarray
|
|
138
|
+
Decoded coordinates of shape (N, 2+). Only first two columns are used.
|
|
139
|
+
phase_map1 : np.ndarray
|
|
140
|
+
Phase map for first dimension (2D grid).
|
|
141
|
+
phase_map2 : np.ndarray
|
|
142
|
+
Phase map for second dimension (2D grid), must match phase_map1 shape.
|
|
143
|
+
trim : int
|
|
144
|
+
Number of edge bins to trim when fitting stripes.
|
|
145
|
+
grid_size : int, optional
|
|
146
|
+
Grid size for stripe fitting. If None, inferred from phase_map shape.
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
aligned : np.ndarray
|
|
151
|
+
Aligned coordinates of shape (N, 2), wrapped to [0, 2π).
|
|
152
|
+
f1 : float
|
|
153
|
+
Fit quality score for first dimension (higher is better).
|
|
154
|
+
f2 : float
|
|
155
|
+
Fit quality score for second dimension (higher is better).
|
|
156
|
+
|
|
157
|
+
Raises
|
|
158
|
+
------
|
|
159
|
+
ValueError
|
|
160
|
+
If phase_map shapes don't match or coords has wrong shape.
|
|
161
|
+
|
|
162
|
+
Notes
|
|
163
|
+
-----
|
|
164
|
+
The alignment process:
|
|
165
|
+
1. Fits stripe patterns to both phase maps using FFT-based detection
|
|
166
|
+
2. Determines phase signs (±1) for each dimension
|
|
167
|
+
3. Applies geometric transformations via `_rot_coord()` to align coordinates
|
|
168
|
+
"""
|
|
99
169
|
if phase_map1.shape != phase_map2.shape:
|
|
100
170
|
raise ValueError("phase_map shapes do not match for alignment")
|
|
101
171
|
coords = np.asarray(coords)
|
|
@@ -381,28 +451,3 @@ def plot_cohomap(
|
|
|
381
451
|
_ensure_parent_dir(config.save_path)
|
|
382
452
|
finalize_figure(fig, config)
|
|
383
453
|
return fig
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
def cohomap_upgrade(*args, **kwargs) -> dict[str, Any]:
|
|
387
|
-
"""Legacy alias for EcohoMap (formerly cohomap_upgrade)."""
|
|
388
|
-
return cohomap(*args, **kwargs)
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
def ecohomap(*args, **kwargs) -> dict[str, Any]:
|
|
392
|
-
"""Alias for EcohoMap (GridCellTorus-style)."""
|
|
393
|
-
return cohomap(*args, **kwargs)
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
def fit_cohomap_stripes_upgrade(*args, **kwargs) -> tuple[np.ndarray, float]:
|
|
397
|
-
"""Legacy alias for EcohoMap stripe fitting."""
|
|
398
|
-
return fit_cohomap_stripes(*args, **kwargs)
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
def plot_cohomap_upgrade(*args, **kwargs) -> plt.Figure:
|
|
402
|
-
"""Legacy alias for EcohoMap plotting (formerly plot_cohomap_upgrade)."""
|
|
403
|
-
return plot_cohomap(*args, **kwargs)
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
def plot_ecohomap(*args, **kwargs) -> plt.Figure:
|
|
407
|
-
"""Alias for EcohoMap plotting."""
|
|
408
|
-
return plot_cohomap(*args, **kwargs)
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
import os
|
|
6
5
|
from typing import Any
|
|
7
6
|
|
|
8
7
|
import matplotlib.pyplot as plt
|
|
@@ -11,27 +10,37 @@ from scipy.ndimage import rotate
|
|
|
11
10
|
|
|
12
11
|
from ...visualization.core import PlotConfig, finalize_figure
|
|
13
12
|
from .cohomap import fit_cohomap_stripes
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def _ensure_plot_config(
|
|
17
|
-
config: PlotConfig | None,
|
|
18
|
-
factory,
|
|
19
|
-
*args,
|
|
20
|
-
**defaults,
|
|
21
|
-
) -> PlotConfig:
|
|
22
|
-
if config is None:
|
|
23
|
-
return factory(*args, **defaults)
|
|
24
|
-
return config
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
28
|
-
if save_path:
|
|
29
|
-
parent = os.path.dirname(save_path)
|
|
30
|
-
if parent:
|
|
31
|
-
os.makedirs(parent, exist_ok=True)
|
|
13
|
+
from .utils import _ensure_parent_dir, _ensure_plot_config
|
|
32
14
|
|
|
33
15
|
|
|
34
16
|
def _rot_para(params1: np.ndarray, params2: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
|
17
|
+
"""Transform stripe fit parameters to canonical orientation.
|
|
18
|
+
|
|
19
|
+
This function adjusts the orientation angles of stripe fit parameters to
|
|
20
|
+
align them with a canonical coordinate system. Unlike `_rot_coord()` which
|
|
21
|
+
transforms actual coordinate data, this operates on the fit parameters
|
|
22
|
+
themselves (orientation angles, wavelengths, etc.).
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
params1 : np.ndarray
|
|
27
|
+
Stripe fit parameters for first dimension. First element is orientation angle.
|
|
28
|
+
params2 : np.ndarray
|
|
29
|
+
Stripe fit parameters for second dimension. First element is orientation angle.
|
|
30
|
+
|
|
31
|
+
Returns
|
|
32
|
+
-------
|
|
33
|
+
x : np.ndarray
|
|
34
|
+
Transformed parameters for the dimension with stronger horizontal alignment.
|
|
35
|
+
y : np.ndarray
|
|
36
|
+
Transformed parameters for the other dimension.
|
|
37
|
+
|
|
38
|
+
Notes
|
|
39
|
+
-----
|
|
40
|
+
The function selects which parameter set becomes 'x' and 'y' based on which
|
|
41
|
+
has larger |cos(angle)|, then applies angle adjustments to ensure the stripe
|
|
42
|
+
vectors are in a canonical orientation for visualization and analysis.
|
|
43
|
+
"""
|
|
35
44
|
if abs(np.cos(params1[0])) < abs(np.cos(params2[0])):
|
|
36
45
|
y = params1.copy()
|
|
37
46
|
x = params2.copy()
|
|
@@ -204,12 +213,57 @@ def plot_cohomap_vectors(
|
|
|
204
213
|
|
|
205
214
|
|
|
206
215
|
def _resolve_grid_size(phase_map: np.ndarray, grid_size: int | None, trim: int) -> int:
|
|
216
|
+
"""Determine grid size for stripe fitting.
|
|
217
|
+
|
|
218
|
+
Parameters
|
|
219
|
+
----------
|
|
220
|
+
phase_map : np.ndarray
|
|
221
|
+
Phase map array (2D grid).
|
|
222
|
+
grid_size : int, optional
|
|
223
|
+
Explicit grid size. If None, inferred from phase_map shape.
|
|
224
|
+
trim : int
|
|
225
|
+
Number of edge bins to trim.
|
|
226
|
+
|
|
227
|
+
Returns
|
|
228
|
+
-------
|
|
229
|
+
int
|
|
230
|
+
Grid size to use for stripe fitting.
|
|
231
|
+
"""
|
|
207
232
|
if grid_size is None:
|
|
208
233
|
return int(phase_map.shape[0]) + 2 * trim + 1
|
|
209
234
|
return int(grid_size)
|
|
210
235
|
|
|
211
236
|
|
|
212
237
|
def _stripe_fit_map(params: np.ndarray, grid_size: int, trim: int) -> np.ndarray:
|
|
238
|
+
"""Generate a synthetic stripe pattern from fit parameters.
|
|
239
|
+
|
|
240
|
+
Creates a 2D cosine stripe pattern based on the fitted parameters (orientation,
|
|
241
|
+
phase, and spatial frequency). Used for visualizing the fitted stripe model
|
|
242
|
+
and comparing it with the actual phase map.
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
params : np.ndarray
|
|
247
|
+
Stripe fit parameters: [angle, phase, frequency].
|
|
248
|
+
- angle: Orientation angle in radians
|
|
249
|
+
- phase: Phase offset in radians
|
|
250
|
+
- frequency: Spatial frequency (inverse wavelength)
|
|
251
|
+
grid_size : int
|
|
252
|
+
Size of the grid to generate (before trimming).
|
|
253
|
+
trim : int
|
|
254
|
+
Number of edge bins to trim from the generated pattern.
|
|
255
|
+
|
|
256
|
+
Returns
|
|
257
|
+
-------
|
|
258
|
+
np.ndarray
|
|
259
|
+
2D array of shape (grid_size-2*trim, grid_size-2*trim) containing
|
|
260
|
+
the cosine stripe pattern with values in [-1, 1].
|
|
261
|
+
|
|
262
|
+
Notes
|
|
263
|
+
-----
|
|
264
|
+
The pattern is generated on a [0, 3π] × [0, 3π] domain to cover the
|
|
265
|
+
extended torus space, then rotated by the fitted angle and trimmed.
|
|
266
|
+
"""
|
|
213
267
|
numangsint = grid_size
|
|
214
268
|
x, _ = np.meshgrid(
|
|
215
269
|
np.linspace(0, 3 * np.pi, numangsint - 1),
|
|
@@ -219,33 +219,3 @@ def plot_cohospace_skewed(
|
|
|
219
219
|
_ensure_parent_dir(config.save_path)
|
|
220
220
|
finalize_figure(fig, config)
|
|
221
221
|
return fig
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
def cohospace_upgrade(*args, **kwargs) -> dict[str, Any]:
|
|
225
|
-
"""Legacy alias for EcohoSpace (formerly cohospace_upgrade)."""
|
|
226
|
-
return cohospace(*args, **kwargs)
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
def ecohospace(*args, **kwargs) -> dict[str, Any]:
|
|
230
|
-
"""Alias for EcohoSpace (GridCellTorus-style)."""
|
|
231
|
-
return cohospace(*args, **kwargs)
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
def plot_cohospace_upgrade(*args, **kwargs) -> plt.Figure:
|
|
235
|
-
"""Legacy alias for EcohoSpace plotting (formerly plot_cohospace_upgrade)."""
|
|
236
|
-
return plot_cohospace(*args, **kwargs)
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
def plot_cohospace_upgrade_skewed(*args, **kwargs) -> plt.Figure:
|
|
240
|
-
"""Legacy alias for EcohoSpace skewed plotting."""
|
|
241
|
-
return plot_cohospace_skewed(*args, **kwargs)
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
def plot_ecohospace(*args, **kwargs) -> plt.Figure:
|
|
245
|
-
"""Alias for EcohoSpace plotting."""
|
|
246
|
-
return plot_cohospace(*args, **kwargs)
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
def plot_ecohospace_skewed(*args, **kwargs) -> plt.Figure:
|
|
250
|
-
"""Alias for EcohoSpace skewed plotting."""
|
|
251
|
-
return plot_cohospace_skewed(*args, **kwargs)
|
|
@@ -2,32 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
import os
|
|
6
5
|
from typing import Any
|
|
7
6
|
|
|
8
7
|
import matplotlib.pyplot as plt
|
|
9
8
|
import numpy as np
|
|
10
9
|
|
|
11
10
|
from ...visualization.core import PlotConfig, finalize_figure
|
|
12
|
-
from .
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def _ensure_plot_config(
|
|
16
|
-
config: PlotConfig | None,
|
|
17
|
-
factory,
|
|
18
|
-
*args,
|
|
19
|
-
**defaults,
|
|
20
|
-
) -> PlotConfig:
|
|
21
|
-
if config is None:
|
|
22
|
-
return factory(*args, **defaults)
|
|
23
|
-
return config
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
27
|
-
if save_path:
|
|
28
|
-
parent = os.path.dirname(save_path)
|
|
29
|
-
if parent:
|
|
30
|
-
os.makedirs(parent, exist_ok=True)
|
|
11
|
+
from .path import skew_transform
|
|
12
|
+
from .utils import _ensure_parent_dir, _ensure_plot_config
|
|
31
13
|
|
|
32
14
|
|
|
33
15
|
def cohospace_phase_centers(cohospace_result: dict[str, Any]) -> dict[str, Any]:
|
|
@@ -40,7 +22,7 @@ def cohospace_phase_centers(cohospace_result: dict[str, Any]) -> dict[str, Any]:
|
|
|
40
22
|
Output from `data.cohospace(...)` (must include `centers`).
|
|
41
23
|
"""
|
|
42
24
|
centers = np.asarray(cohospace_result["centers"], dtype=float) % (2 * np.pi)
|
|
43
|
-
centers_skew =
|
|
25
|
+
centers_skew = skew_transform(centers)
|
|
44
26
|
return {
|
|
45
27
|
"centers": centers,
|
|
46
28
|
"centers_skew": centers_skew,
|
|
@@ -2,38 +2,13 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
import os
|
|
6
|
-
|
|
7
5
|
import matplotlib.pyplot as plt
|
|
8
6
|
import numpy as np
|
|
9
7
|
from scipy.stats import circvar
|
|
10
8
|
|
|
11
9
|
from ...visualization.core import PlotConfig, finalize_figure
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def _ensure_plot_config(
|
|
15
|
-
config: PlotConfig | None,
|
|
16
|
-
factory,
|
|
17
|
-
*,
|
|
18
|
-
kwargs: dict | None = None,
|
|
19
|
-
**defaults,
|
|
20
|
-
) -> PlotConfig:
|
|
21
|
-
if config is None:
|
|
22
|
-
defaults.update({"kwargs": kwargs or {}})
|
|
23
|
-
return factory(**defaults)
|
|
24
|
-
|
|
25
|
-
if kwargs:
|
|
26
|
-
config_kwargs = config.kwargs or {}
|
|
27
|
-
config_kwargs.update(kwargs)
|
|
28
|
-
config.kwargs = config_kwargs
|
|
29
|
-
return config
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
33
|
-
if save_path:
|
|
34
|
-
parent = os.path.dirname(save_path)
|
|
35
|
-
if parent:
|
|
36
|
-
os.makedirs(parent, exist_ok=True)
|
|
10
|
+
from .path import _align_activity_to_coords, skew_transform
|
|
11
|
+
from .utils import _ensure_parent_dir, _ensure_plot_config
|
|
37
12
|
|
|
38
13
|
|
|
39
14
|
# =====================================================================
|
|
@@ -48,51 +23,6 @@ def _coho_coords_to_degrees(coords: np.ndarray) -> np.ndarray:
|
|
|
48
23
|
return np.degrees(coords % (2 * np.pi))
|
|
49
24
|
|
|
50
25
|
|
|
51
|
-
def _align_activity_to_coords(
|
|
52
|
-
coords: np.ndarray,
|
|
53
|
-
activity: np.ndarray,
|
|
54
|
-
times: np.ndarray | None = None,
|
|
55
|
-
*,
|
|
56
|
-
label: str = "activity",
|
|
57
|
-
auto_filter: bool = True,
|
|
58
|
-
) -> np.ndarray:
|
|
59
|
-
"""
|
|
60
|
-
Align activity to coords by optional time indices and validate lengths.
|
|
61
|
-
"""
|
|
62
|
-
coords = np.asarray(coords)
|
|
63
|
-
activity = np.asarray(activity)
|
|
64
|
-
|
|
65
|
-
if times is not None:
|
|
66
|
-
times = np.asarray(times)
|
|
67
|
-
try:
|
|
68
|
-
activity = activity[times]
|
|
69
|
-
except Exception as exc:
|
|
70
|
-
raise ValueError(
|
|
71
|
-
f"Failed to index {label} with `times`. Ensure `times` indexes the original time axis."
|
|
72
|
-
) from exc
|
|
73
|
-
|
|
74
|
-
if activity.shape[0] != coords.shape[0]:
|
|
75
|
-
# Try to reproduce decode's zero-spike filtering if lengths mismatch.
|
|
76
|
-
if auto_filter and times is None and activity.ndim == 2:
|
|
77
|
-
mask = np.sum(activity > 0, axis=1) >= 1
|
|
78
|
-
if mask.sum() == coords.shape[0]:
|
|
79
|
-
activity = activity[mask]
|
|
80
|
-
else:
|
|
81
|
-
raise ValueError(
|
|
82
|
-
f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
|
|
83
|
-
"If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
|
|
84
|
-
"`times=decoding['times']` or slice the activity accordingly."
|
|
85
|
-
)
|
|
86
|
-
else:
|
|
87
|
-
raise ValueError(
|
|
88
|
-
f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
|
|
89
|
-
"If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
|
|
90
|
-
"`times=decoding['times']` or slice the activity accordingly."
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
return activity
|
|
94
|
-
|
|
95
|
-
|
|
96
26
|
def plot_cohospace_scatter_trajectory_2d(
|
|
97
27
|
coords: np.ndarray,
|
|
98
28
|
times: np.ndarray | None = None,
|
|
@@ -802,41 +732,6 @@ def compute_cohoscore_scatter_1d(
|
|
|
802
732
|
return scores
|
|
803
733
|
|
|
804
734
|
|
|
805
|
-
def skew_transform_torus_scatter(coords):
|
|
806
|
-
"""
|
|
807
|
-
Convert torus angles (theta1, theta2) into coordinates in a skewed parallelogram fundamental domain.
|
|
808
|
-
|
|
809
|
-
Given theta1, theta2 in radians, map:
|
|
810
|
-
x = theta1 + 0.5 * theta2
|
|
811
|
-
y = (sqrt(3)/2) * theta2
|
|
812
|
-
|
|
813
|
-
This is a linear change of basis that turns the square [0, 2π)×[0, 2π) into a 60-degree
|
|
814
|
-
parallelogram, which is convenient for visualizing wrap-around behavior on a 2-torus.
|
|
815
|
-
|
|
816
|
-
Parameters
|
|
817
|
-
----------
|
|
818
|
-
coords : ndarray, shape (T, 2)
|
|
819
|
-
Angles (theta1, theta2) in radians.
|
|
820
|
-
|
|
821
|
-
Returns
|
|
822
|
-
-------
|
|
823
|
-
xy : ndarray, shape (T, 2)
|
|
824
|
-
Skewed planar coordinates.
|
|
825
|
-
"""
|
|
826
|
-
coords = np.asarray(coords)
|
|
827
|
-
if coords.ndim != 2 or coords.shape[1] != 2:
|
|
828
|
-
raise ValueError(f"coords must be (T,2), got {coords.shape}")
|
|
829
|
-
|
|
830
|
-
theta1 = coords[:, 0]
|
|
831
|
-
theta2 = coords[:, 1]
|
|
832
|
-
|
|
833
|
-
# Linear change of basis (NO nonlinear scaling)
|
|
834
|
-
x = theta1 + 0.5 * theta2
|
|
835
|
-
y = (np.sqrt(3) / 2.0) * theta2
|
|
836
|
-
|
|
837
|
-
return np.stack([x, y], axis=1)
|
|
838
|
-
|
|
839
|
-
|
|
840
735
|
def draw_torus_parallelogram_grid_scatter(ax, n_tiles=1, color="0.7", lw=1.0, alpha=0.8):
|
|
841
736
|
"""
|
|
842
737
|
Draw parallelogram grid corresponding to torus fundamental domain.
|
|
@@ -874,7 +769,7 @@ def tile_parallelogram_points_scatter(xy, n_tiles=1):
|
|
|
874
769
|
Parameters
|
|
875
770
|
----------
|
|
876
771
|
points : ndarray, shape (T, 2)
|
|
877
|
-
Points in the skewed plane (same coordinates as returned by `
|
|
772
|
+
Points in the skewed plane (same coordinates as returned by `skew_transform`).
|
|
878
773
|
n_tiles : int
|
|
879
774
|
Number of tiles to extend around the base domain.
|
|
880
775
|
- n_tiles=1 produces a 3x3 tiling
|
|
@@ -1020,7 +915,7 @@ def plot_cohospace_scatter_neuron_skewed(
|
|
|
1020
915
|
)
|
|
1021
916
|
|
|
1022
917
|
# --- skew transform
|
|
1023
|
-
xy =
|
|
918
|
+
xy = skew_transform(coords[mask])
|
|
1024
919
|
|
|
1025
920
|
# Tiling: if points are tiled, values must be tiled too (FR mode) to keep lengths consistent
|
|
1026
921
|
if n_tiles and n_tiles > 0:
|
|
@@ -1165,7 +1060,7 @@ def plot_cohospace_scatter_population_skewed(
|
|
|
1165
1060
|
thr = np.percentile(a, 100 - top_percent)
|
|
1166
1061
|
mask = a >= thr
|
|
1167
1062
|
|
|
1168
|
-
xy =
|
|
1063
|
+
xy = skew_transform(coords[mask])
|
|
1169
1064
|
|
|
1170
1065
|
if n_tiles and n_tiles > 0:
|
|
1171
1066
|
xy = tile_parallelogram_points_scatter(xy, n_tiles=n_tiles)
|
|
@@ -6,7 +6,6 @@ import numpy as np
|
|
|
6
6
|
from scipy.ndimage import gaussian_filter1d
|
|
7
7
|
|
|
8
8
|
from .config import DataLoadError, ProcessingError, SpikeEmbeddingConfig
|
|
9
|
-
from .filters import _gaussian_filter1d
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None, **kwargs):
|
|
@@ -262,8 +261,8 @@ def _load_pos(t, x, y, res=100000, dt=1000):
|
|
|
262
261
|
xx = np.interp(tt, t, x)
|
|
263
262
|
yy = np.interp(tt, t, y)
|
|
264
263
|
|
|
265
|
-
xxs =
|
|
266
|
-
yys =
|
|
264
|
+
xxs = gaussian_filter1d(xx - np.min(xx), sigma=100)
|
|
265
|
+
yys = gaussian_filter1d(yy - np.min(yy), sigma=100)
|
|
267
266
|
dx = (xxs[1:] - xxs[:-1]) * 100
|
|
268
267
|
dy = (yys[1:] - yys[:-1]) * 100
|
|
269
268
|
speed = np.sqrt(dx**2 + dy**2) / 0.01
|
canns/analyzer/data/asa/fr.py
CHANGED
|
@@ -1,18 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from dataclasses import dataclass
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
|
|
8
7
|
from ...visualization.core import PlotConfig, finalize_figure
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
12
|
-
if save_path:
|
|
13
|
-
parent = os.path.dirname(save_path)
|
|
14
|
-
if parent:
|
|
15
|
-
os.makedirs(parent, exist_ok=True)
|
|
8
|
+
from .utils import _ensure_parent_dir
|
|
16
9
|
|
|
17
10
|
|
|
18
11
|
def _slice_range(r: tuple[int, int] | None, length: int) -> slice:
|