canns 0.15.0__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/asa/__init__.py +2 -23
- canns/analyzer/data/asa/coho.py +17 -93
- canns/analyzer/data/asa/cohomap.py +70 -25
- canns/analyzer/data/asa/cohomap_vectors.py +73 -19
- canns/analyzer/data/asa/cohospace.py +0 -30
- canns/analyzer/data/asa/cohospace_phase_centers.py +3 -21
- canns/analyzer/data/asa/cohospace_scatter.py +5 -110
- canns/analyzer/data/asa/embedding.py +2 -3
- canns/analyzer/data/asa/fr.py +1 -8
- canns/analyzer/data/asa/path.py +70 -0
- canns/analyzer/data/asa/plotting.py +1 -26
- canns/analyzer/data/asa/utils.py +160 -0
- canns/analyzer/data/cell_classification/visualization/btn_plots.py +1 -18
- {canns-0.15.0.dist-info → canns-0.15.1.dist-info}/METADATA +1 -1
- {canns-0.15.0.dist-info → canns-0.15.1.dist-info}/RECORD +18 -19
- canns/analyzer/data/asa/cohomap_scatter.py +0 -10
- canns/analyzer/data/asa/filters.py +0 -208
- {canns-0.15.0.dist-info → canns-0.15.1.dist-info}/WHEEL +0 -0
- {canns-0.15.0.dist-info → canns-0.15.1.dist-info}/entry_points.txt +0 -0
- {canns-0.15.0.dist-info → canns-0.15.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,15 +2,9 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from .cohomap import (
|
|
4
4
|
cohomap,
|
|
5
|
-
cohomap_upgrade,
|
|
6
|
-
ecohomap,
|
|
7
5
|
fit_cohomap_stripes,
|
|
8
|
-
fit_cohomap_stripes_upgrade,
|
|
9
6
|
plot_cohomap,
|
|
10
|
-
plot_cohomap_upgrade,
|
|
11
|
-
plot_ecohomap,
|
|
12
7
|
)
|
|
13
|
-
from .cohomap_scatter import plot_cohomap_scatter, plot_cohomap_scatter_multi
|
|
14
8
|
from .cohomap_vectors import (
|
|
15
9
|
cohomap_vectors,
|
|
16
10
|
plot_cohomap_stripes,
|
|
@@ -18,14 +12,8 @@ from .cohomap_vectors import (
|
|
|
18
12
|
)
|
|
19
13
|
from .cohospace import (
|
|
20
14
|
cohospace,
|
|
21
|
-
cohospace_upgrade,
|
|
22
|
-
ecohospace,
|
|
23
15
|
plot_cohospace,
|
|
24
16
|
plot_cohospace_skewed,
|
|
25
|
-
plot_cohospace_upgrade,
|
|
26
|
-
plot_cohospace_upgrade_skewed,
|
|
27
|
-
plot_ecohospace,
|
|
28
|
-
plot_ecohospace_skewed,
|
|
29
17
|
)
|
|
30
18
|
from .cohospace_phase_centers import (
|
|
31
19
|
cohospace_phase_centers,
|
|
@@ -81,6 +69,8 @@ from .path import (
|
|
|
81
69
|
from .plotting import (
|
|
82
70
|
plot_2d_bump_on_manifold,
|
|
83
71
|
plot_3d_bump_on_torus,
|
|
72
|
+
plot_cohomap_scatter,
|
|
73
|
+
plot_cohomap_scatter_multi,
|
|
84
74
|
plot_path_compare_1d,
|
|
85
75
|
plot_path_compare_2d,
|
|
86
76
|
plot_projection,
|
|
@@ -109,22 +99,11 @@ __all__ = [
|
|
|
109
99
|
"plot_3d_bump_on_torus",
|
|
110
100
|
"plot_2d_bump_on_manifold",
|
|
111
101
|
"cohomap",
|
|
112
|
-
"cohomap_upgrade",
|
|
113
102
|
"fit_cohomap_stripes",
|
|
114
|
-
"fit_cohomap_stripes_upgrade",
|
|
115
103
|
"plot_cohomap",
|
|
116
|
-
"plot_cohomap_upgrade",
|
|
117
104
|
"cohospace",
|
|
118
|
-
"cohospace_upgrade",
|
|
119
105
|
"plot_cohospace",
|
|
120
106
|
"plot_cohospace_skewed",
|
|
121
|
-
"plot_cohospace_upgrade",
|
|
122
|
-
"plot_cohospace_upgrade_skewed",
|
|
123
|
-
"ecohomap",
|
|
124
|
-
"ecohospace",
|
|
125
|
-
"plot_ecohomap",
|
|
126
|
-
"plot_ecohospace",
|
|
127
|
-
"plot_ecohospace_skewed",
|
|
128
107
|
"cohomap_vectors",
|
|
129
108
|
"plot_cohomap_stripes",
|
|
130
109
|
"plot_cohomap_vectors",
|
canns/analyzer/data/asa/coho.py
CHANGED
|
@@ -2,96 +2,20 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
import
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
return factory(*args, **defaults)
|
|
23
|
-
return config
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
27
|
-
if save_path:
|
|
28
|
-
parent = os.path.dirname(save_path)
|
|
29
|
-
if parent:
|
|
30
|
-
os.makedirs(parent, exist_ok=True)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def _circmean(x: np.ndarray) -> float:
|
|
34
|
-
return float(np.arctan2(np.mean(np.sin(x)), np.mean(np.cos(x))))
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def _smooth_circular_map(
|
|
38
|
-
mtot: np.ndarray,
|
|
39
|
-
smooth_sigma: float,
|
|
40
|
-
*,
|
|
41
|
-
fill_nan: bool = False,
|
|
42
|
-
fill_sigma: float | None = None,
|
|
43
|
-
fill_min_weight: float = 1e-3,
|
|
44
|
-
) -> np.ndarray:
|
|
45
|
-
mtot = np.asarray(mtot, dtype=float)
|
|
46
|
-
nans = np.isnan(mtot)
|
|
47
|
-
mask = (~nans).astype(float)
|
|
48
|
-
sintot = np.sin(mtot)
|
|
49
|
-
costot = np.cos(mtot)
|
|
50
|
-
sintot[nans] = 0.0
|
|
51
|
-
costot[nans] = 0.0
|
|
52
|
-
|
|
53
|
-
if fill_nan:
|
|
54
|
-
if fill_sigma is None:
|
|
55
|
-
fill_sigma = smooth_sigma if smooth_sigma and smooth_sigma > 0 else 1.0
|
|
56
|
-
weight = gaussian_filter(mask, fill_sigma)
|
|
57
|
-
sintot = gaussian_filter(sintot * mask, fill_sigma)
|
|
58
|
-
costot = gaussian_filter(costot * mask, fill_sigma)
|
|
59
|
-
min_weight = max(float(fill_min_weight), 0.0)
|
|
60
|
-
valid = weight > min_weight
|
|
61
|
-
sintot = np.divide(sintot, weight, out=np.zeros_like(sintot), where=valid)
|
|
62
|
-
costot = np.divide(costot, weight, out=np.zeros_like(costot), where=valid)
|
|
63
|
-
mtot = np.arctan2(sintot, costot)
|
|
64
|
-
if fill_min_weight > 0:
|
|
65
|
-
mtot[~valid] = np.nan
|
|
66
|
-
return mtot
|
|
67
|
-
|
|
68
|
-
if smooth_sigma and smooth_sigma > 0:
|
|
69
|
-
sintot = gaussian_filter(sintot, smooth_sigma)
|
|
70
|
-
costot = gaussian_filter(costot, smooth_sigma)
|
|
71
|
-
mtot = np.arctan2(sintot, costot)
|
|
72
|
-
mtot[nans] = np.nan
|
|
73
|
-
return mtot
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def _extract_coords_and_times(
|
|
77
|
-
decoding_result: dict[str, Any],
|
|
78
|
-
coords_key: str | None = None,
|
|
79
|
-
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
80
|
-
if coords_key is not None:
|
|
81
|
-
if coords_key not in decoding_result:
|
|
82
|
-
raise KeyError(f"coords_key '{coords_key}' not found in decoding_result.")
|
|
83
|
-
coords = np.asarray(decoding_result[coords_key])
|
|
84
|
-
elif "coordsbox" in decoding_result:
|
|
85
|
-
coords = np.asarray(decoding_result["coordsbox"])
|
|
86
|
-
else:
|
|
87
|
-
coords, _ = find_coords_matrix(decoding_result)
|
|
88
|
-
|
|
89
|
-
times_box, _ = find_times_box(decoding_result)
|
|
90
|
-
return coords, times_box
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def _phase_map_valid_fraction(phase_map: np.ndarray) -> float:
|
|
94
|
-
valid = np.isfinite(phase_map)
|
|
95
|
-
if valid.size == 0:
|
|
96
|
-
return 0.0
|
|
97
|
-
return float(np.mean(valid))
|
|
5
|
+
from .utils import (
|
|
6
|
+
_circmean,
|
|
7
|
+
_ensure_parent_dir,
|
|
8
|
+
_ensure_plot_config,
|
|
9
|
+
_extract_coords_and_times,
|
|
10
|
+
_phase_map_valid_fraction,
|
|
11
|
+
_smooth_circular_map,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"_ensure_plot_config",
|
|
16
|
+
"_ensure_parent_dir",
|
|
17
|
+
"_circmean",
|
|
18
|
+
"_smooth_circular_map",
|
|
19
|
+
"_extract_coords_and_times",
|
|
20
|
+
"_phase_map_valid_fraction",
|
|
21
|
+
]
|
|
@@ -57,6 +57,36 @@ def _rot_coord(
|
|
|
57
57
|
c2: np.ndarray,
|
|
58
58
|
p: tuple[int, int],
|
|
59
59
|
) -> np.ndarray:
|
|
60
|
+
"""Transform and align decoded coordinates based on stripe orientation.
|
|
61
|
+
|
|
62
|
+
This function rotates and aligns coordinates to match the dominant stripe
|
|
63
|
+
orientation in the CohoMap. It handles the torus geometry by choosing the
|
|
64
|
+
appropriate coordinate system and applying reflections/rotations.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
params1 : np.ndarray
|
|
69
|
+
Stripe fit parameters for first dimension (includes orientation angle).
|
|
70
|
+
params2 : np.ndarray
|
|
71
|
+
Stripe fit parameters for second dimension (includes orientation angle).
|
|
72
|
+
c1 : np.ndarray
|
|
73
|
+
Decoded coordinates for first dimension.
|
|
74
|
+
c2 : np.ndarray
|
|
75
|
+
Decoded coordinates for second dimension.
|
|
76
|
+
p : tuple[int, int]
|
|
77
|
+
Direction indicators (1 or -1) for each dimension, controlling reflections.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
np.ndarray
|
|
82
|
+
Aligned coordinates of shape (N, 2), wrapped to [0, 2π).
|
|
83
|
+
|
|
84
|
+
Notes
|
|
85
|
+
-----
|
|
86
|
+
The function selects the coordinate system based on which dimension has
|
|
87
|
+
stronger horizontal alignment (larger |cos(angle)|), then applies geometric
|
|
88
|
+
transformations to align the coordinates with the stripe pattern.
|
|
89
|
+
"""
|
|
60
90
|
if abs(np.cos(params1[0])) < abs(np.cos(params2[0])):
|
|
61
91
|
cc1 = c2.copy()
|
|
62
92
|
cc2 = c1.copy()
|
|
@@ -96,6 +126,46 @@ def _toroidal_align_coords(
|
|
|
96
126
|
trim: int,
|
|
97
127
|
grid_size: int | None,
|
|
98
128
|
) -> tuple[np.ndarray, float, float]:
|
|
129
|
+
"""Align decoded coordinates to CohoMap stripe patterns.
|
|
130
|
+
|
|
131
|
+
This function fits stripe patterns to both phase maps, determines the
|
|
132
|
+
optimal orientation and phase signs, then transforms the decoded coordinates
|
|
133
|
+
to align with the detected stripe structure.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
coords : np.ndarray
|
|
138
|
+
Decoded coordinates of shape (N, 2+). Only first two columns are used.
|
|
139
|
+
phase_map1 : np.ndarray
|
|
140
|
+
Phase map for first dimension (2D grid).
|
|
141
|
+
phase_map2 : np.ndarray
|
|
142
|
+
Phase map for second dimension (2D grid), must match phase_map1 shape.
|
|
143
|
+
trim : int
|
|
144
|
+
Number of edge bins to trim when fitting stripes.
|
|
145
|
+
grid_size : int, optional
|
|
146
|
+
Grid size for stripe fitting. If None, inferred from phase_map shape.
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
aligned : np.ndarray
|
|
151
|
+
Aligned coordinates of shape (N, 2), wrapped to [0, 2π).
|
|
152
|
+
f1 : float
|
|
153
|
+
Fit quality score for first dimension (higher is better).
|
|
154
|
+
f2 : float
|
|
155
|
+
Fit quality score for second dimension (higher is better).
|
|
156
|
+
|
|
157
|
+
Raises
|
|
158
|
+
------
|
|
159
|
+
ValueError
|
|
160
|
+
If phase_map shapes don't match or coords has wrong shape.
|
|
161
|
+
|
|
162
|
+
Notes
|
|
163
|
+
-----
|
|
164
|
+
The alignment process:
|
|
165
|
+
1. Fits stripe patterns to both phase maps using FFT-based detection
|
|
166
|
+
2. Determines phase signs (±1) for each dimension
|
|
167
|
+
3. Applies geometric transformations via `_rot_coord()` to align coordinates
|
|
168
|
+
"""
|
|
99
169
|
if phase_map1.shape != phase_map2.shape:
|
|
100
170
|
raise ValueError("phase_map shapes do not match for alignment")
|
|
101
171
|
coords = np.asarray(coords)
|
|
@@ -381,28 +451,3 @@ def plot_cohomap(
|
|
|
381
451
|
_ensure_parent_dir(config.save_path)
|
|
382
452
|
finalize_figure(fig, config)
|
|
383
453
|
return fig
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
def cohomap_upgrade(*args, **kwargs) -> dict[str, Any]:
|
|
387
|
-
"""Legacy alias for EcohoMap (formerly cohomap_upgrade)."""
|
|
388
|
-
return cohomap(*args, **kwargs)
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
def ecohomap(*args, **kwargs) -> dict[str, Any]:
|
|
392
|
-
"""Alias for EcohoMap (GridCellTorus-style)."""
|
|
393
|
-
return cohomap(*args, **kwargs)
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
def fit_cohomap_stripes_upgrade(*args, **kwargs) -> tuple[np.ndarray, float]:
|
|
397
|
-
"""Legacy alias for EcohoMap stripe fitting."""
|
|
398
|
-
return fit_cohomap_stripes(*args, **kwargs)
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
def plot_cohomap_upgrade(*args, **kwargs) -> plt.Figure:
|
|
402
|
-
"""Legacy alias for EcohoMap plotting (formerly plot_cohomap_upgrade)."""
|
|
403
|
-
return plot_cohomap(*args, **kwargs)
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
def plot_ecohomap(*args, **kwargs) -> plt.Figure:
|
|
407
|
-
"""Alias for EcohoMap plotting."""
|
|
408
|
-
return plot_cohomap(*args, **kwargs)
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
import os
|
|
6
5
|
from typing import Any
|
|
7
6
|
|
|
8
7
|
import matplotlib.pyplot as plt
|
|
@@ -11,27 +10,37 @@ from scipy.ndimage import rotate
|
|
|
11
10
|
|
|
12
11
|
from ...visualization.core import PlotConfig, finalize_figure
|
|
13
12
|
from .cohomap import fit_cohomap_stripes
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def _ensure_plot_config(
|
|
17
|
-
config: PlotConfig | None,
|
|
18
|
-
factory,
|
|
19
|
-
*args,
|
|
20
|
-
**defaults,
|
|
21
|
-
) -> PlotConfig:
|
|
22
|
-
if config is None:
|
|
23
|
-
return factory(*args, **defaults)
|
|
24
|
-
return config
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
28
|
-
if save_path:
|
|
29
|
-
parent = os.path.dirname(save_path)
|
|
30
|
-
if parent:
|
|
31
|
-
os.makedirs(parent, exist_ok=True)
|
|
13
|
+
from .utils import _ensure_parent_dir, _ensure_plot_config
|
|
32
14
|
|
|
33
15
|
|
|
34
16
|
def _rot_para(params1: np.ndarray, params2: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
|
17
|
+
"""Transform stripe fit parameters to canonical orientation.
|
|
18
|
+
|
|
19
|
+
This function adjusts the orientation angles of stripe fit parameters to
|
|
20
|
+
align them with a canonical coordinate system. Unlike `_rot_coord()` which
|
|
21
|
+
transforms actual coordinate data, this operates on the fit parameters
|
|
22
|
+
themselves (orientation angles, wavelengths, etc.).
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
params1 : np.ndarray
|
|
27
|
+
Stripe fit parameters for first dimension. First element is orientation angle.
|
|
28
|
+
params2 : np.ndarray
|
|
29
|
+
Stripe fit parameters for second dimension. First element is orientation angle.
|
|
30
|
+
|
|
31
|
+
Returns
|
|
32
|
+
-------
|
|
33
|
+
x : np.ndarray
|
|
34
|
+
Transformed parameters for the dimension with stronger horizontal alignment.
|
|
35
|
+
y : np.ndarray
|
|
36
|
+
Transformed parameters for the other dimension.
|
|
37
|
+
|
|
38
|
+
Notes
|
|
39
|
+
-----
|
|
40
|
+
The function selects which parameter set becomes 'x' and 'y' based on which
|
|
41
|
+
has larger |cos(angle)|, then applies angle adjustments to ensure the stripe
|
|
42
|
+
vectors are in a canonical orientation for visualization and analysis.
|
|
43
|
+
"""
|
|
35
44
|
if abs(np.cos(params1[0])) < abs(np.cos(params2[0])):
|
|
36
45
|
y = params1.copy()
|
|
37
46
|
x = params2.copy()
|
|
@@ -204,12 +213,57 @@ def plot_cohomap_vectors(
|
|
|
204
213
|
|
|
205
214
|
|
|
206
215
|
def _resolve_grid_size(phase_map: np.ndarray, grid_size: int | None, trim: int) -> int:
|
|
216
|
+
"""Determine grid size for stripe fitting.
|
|
217
|
+
|
|
218
|
+
Parameters
|
|
219
|
+
----------
|
|
220
|
+
phase_map : np.ndarray
|
|
221
|
+
Phase map array (2D grid).
|
|
222
|
+
grid_size : int, optional
|
|
223
|
+
Explicit grid size. If None, inferred from phase_map shape.
|
|
224
|
+
trim : int
|
|
225
|
+
Number of edge bins to trim.
|
|
226
|
+
|
|
227
|
+
Returns
|
|
228
|
+
-------
|
|
229
|
+
int
|
|
230
|
+
Grid size to use for stripe fitting.
|
|
231
|
+
"""
|
|
207
232
|
if grid_size is None:
|
|
208
233
|
return int(phase_map.shape[0]) + 2 * trim + 1
|
|
209
234
|
return int(grid_size)
|
|
210
235
|
|
|
211
236
|
|
|
212
237
|
def _stripe_fit_map(params: np.ndarray, grid_size: int, trim: int) -> np.ndarray:
|
|
238
|
+
"""Generate a synthetic stripe pattern from fit parameters.
|
|
239
|
+
|
|
240
|
+
Creates a 2D cosine stripe pattern based on the fitted parameters (orientation,
|
|
241
|
+
phase, and spatial frequency). Used for visualizing the fitted stripe model
|
|
242
|
+
and comparing it with the actual phase map.
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
params : np.ndarray
|
|
247
|
+
Stripe fit parameters: [angle, phase, frequency].
|
|
248
|
+
- angle: Orientation angle in radians
|
|
249
|
+
- phase: Phase offset in radians
|
|
250
|
+
- frequency: Spatial frequency (inverse wavelength)
|
|
251
|
+
grid_size : int
|
|
252
|
+
Size of the grid to generate (before trimming).
|
|
253
|
+
trim : int
|
|
254
|
+
Number of edge bins to trim from the generated pattern.
|
|
255
|
+
|
|
256
|
+
Returns
|
|
257
|
+
-------
|
|
258
|
+
np.ndarray
|
|
259
|
+
2D array of shape (grid_size-2*trim, grid_size-2*trim) containing
|
|
260
|
+
the cosine stripe pattern with values in [-1, 1].
|
|
261
|
+
|
|
262
|
+
Notes
|
|
263
|
+
-----
|
|
264
|
+
The pattern is generated on a [0, 3π] × [0, 3π] domain to cover the
|
|
265
|
+
extended torus space, then rotated by the fitted angle and trimmed.
|
|
266
|
+
"""
|
|
213
267
|
numangsint = grid_size
|
|
214
268
|
x, _ = np.meshgrid(
|
|
215
269
|
np.linspace(0, 3 * np.pi, numangsint - 1),
|
|
@@ -219,33 +219,3 @@ def plot_cohospace_skewed(
|
|
|
219
219
|
_ensure_parent_dir(config.save_path)
|
|
220
220
|
finalize_figure(fig, config)
|
|
221
221
|
return fig
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
def cohospace_upgrade(*args, **kwargs) -> dict[str, Any]:
|
|
225
|
-
"""Legacy alias for EcohoSpace (formerly cohospace_upgrade)."""
|
|
226
|
-
return cohospace(*args, **kwargs)
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
def ecohospace(*args, **kwargs) -> dict[str, Any]:
|
|
230
|
-
"""Alias for EcohoSpace (GridCellTorus-style)."""
|
|
231
|
-
return cohospace(*args, **kwargs)
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
def plot_cohospace_upgrade(*args, **kwargs) -> plt.Figure:
|
|
235
|
-
"""Legacy alias for EcohoSpace plotting (formerly plot_cohospace_upgrade)."""
|
|
236
|
-
return plot_cohospace(*args, **kwargs)
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
def plot_cohospace_upgrade_skewed(*args, **kwargs) -> plt.Figure:
|
|
240
|
-
"""Legacy alias for EcohoSpace skewed plotting."""
|
|
241
|
-
return plot_cohospace_skewed(*args, **kwargs)
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
def plot_ecohospace(*args, **kwargs) -> plt.Figure:
|
|
245
|
-
"""Alias for EcohoSpace plotting."""
|
|
246
|
-
return plot_cohospace(*args, **kwargs)
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
def plot_ecohospace_skewed(*args, **kwargs) -> plt.Figure:
|
|
250
|
-
"""Alias for EcohoSpace skewed plotting."""
|
|
251
|
-
return plot_cohospace_skewed(*args, **kwargs)
|
|
@@ -2,32 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
import os
|
|
6
5
|
from typing import Any
|
|
7
6
|
|
|
8
7
|
import matplotlib.pyplot as plt
|
|
9
8
|
import numpy as np
|
|
10
9
|
|
|
11
10
|
from ...visualization.core import PlotConfig, finalize_figure
|
|
12
|
-
from .
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def _ensure_plot_config(
|
|
16
|
-
config: PlotConfig | None,
|
|
17
|
-
factory,
|
|
18
|
-
*args,
|
|
19
|
-
**defaults,
|
|
20
|
-
) -> PlotConfig:
|
|
21
|
-
if config is None:
|
|
22
|
-
return factory(*args, **defaults)
|
|
23
|
-
return config
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
27
|
-
if save_path:
|
|
28
|
-
parent = os.path.dirname(save_path)
|
|
29
|
-
if parent:
|
|
30
|
-
os.makedirs(parent, exist_ok=True)
|
|
11
|
+
from .path import skew_transform
|
|
12
|
+
from .utils import _ensure_parent_dir, _ensure_plot_config
|
|
31
13
|
|
|
32
14
|
|
|
33
15
|
def cohospace_phase_centers(cohospace_result: dict[str, Any]) -> dict[str, Any]:
|
|
@@ -40,7 +22,7 @@ def cohospace_phase_centers(cohospace_result: dict[str, Any]) -> dict[str, Any]:
|
|
|
40
22
|
Output from `data.cohospace(...)` (must include `centers`).
|
|
41
23
|
"""
|
|
42
24
|
centers = np.asarray(cohospace_result["centers"], dtype=float) % (2 * np.pi)
|
|
43
|
-
centers_skew =
|
|
25
|
+
centers_skew = skew_transform(centers)
|
|
44
26
|
return {
|
|
45
27
|
"centers": centers,
|
|
46
28
|
"centers_skew": centers_skew,
|
|
@@ -2,38 +2,13 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
import os
|
|
6
|
-
|
|
7
5
|
import matplotlib.pyplot as plt
|
|
8
6
|
import numpy as np
|
|
9
7
|
from scipy.stats import circvar
|
|
10
8
|
|
|
11
9
|
from ...visualization.core import PlotConfig, finalize_figure
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def _ensure_plot_config(
|
|
15
|
-
config: PlotConfig | None,
|
|
16
|
-
factory,
|
|
17
|
-
*,
|
|
18
|
-
kwargs: dict | None = None,
|
|
19
|
-
**defaults,
|
|
20
|
-
) -> PlotConfig:
|
|
21
|
-
if config is None:
|
|
22
|
-
defaults.update({"kwargs": kwargs or {}})
|
|
23
|
-
return factory(**defaults)
|
|
24
|
-
|
|
25
|
-
if kwargs:
|
|
26
|
-
config_kwargs = config.kwargs or {}
|
|
27
|
-
config_kwargs.update(kwargs)
|
|
28
|
-
config.kwargs = config_kwargs
|
|
29
|
-
return config
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
33
|
-
if save_path:
|
|
34
|
-
parent = os.path.dirname(save_path)
|
|
35
|
-
if parent:
|
|
36
|
-
os.makedirs(parent, exist_ok=True)
|
|
10
|
+
from .path import _align_activity_to_coords, skew_transform
|
|
11
|
+
from .utils import _ensure_parent_dir, _ensure_plot_config
|
|
37
12
|
|
|
38
13
|
|
|
39
14
|
# =====================================================================
|
|
@@ -48,51 +23,6 @@ def _coho_coords_to_degrees(coords: np.ndarray) -> np.ndarray:
|
|
|
48
23
|
return np.degrees(coords % (2 * np.pi))
|
|
49
24
|
|
|
50
25
|
|
|
51
|
-
def _align_activity_to_coords(
|
|
52
|
-
coords: np.ndarray,
|
|
53
|
-
activity: np.ndarray,
|
|
54
|
-
times: np.ndarray | None = None,
|
|
55
|
-
*,
|
|
56
|
-
label: str = "activity",
|
|
57
|
-
auto_filter: bool = True,
|
|
58
|
-
) -> np.ndarray:
|
|
59
|
-
"""
|
|
60
|
-
Align activity to coords by optional time indices and validate lengths.
|
|
61
|
-
"""
|
|
62
|
-
coords = np.asarray(coords)
|
|
63
|
-
activity = np.asarray(activity)
|
|
64
|
-
|
|
65
|
-
if times is not None:
|
|
66
|
-
times = np.asarray(times)
|
|
67
|
-
try:
|
|
68
|
-
activity = activity[times]
|
|
69
|
-
except Exception as exc:
|
|
70
|
-
raise ValueError(
|
|
71
|
-
f"Failed to index {label} with `times`. Ensure `times` indexes the original time axis."
|
|
72
|
-
) from exc
|
|
73
|
-
|
|
74
|
-
if activity.shape[0] != coords.shape[0]:
|
|
75
|
-
# Try to reproduce decode's zero-spike filtering if lengths mismatch.
|
|
76
|
-
if auto_filter and times is None and activity.ndim == 2:
|
|
77
|
-
mask = np.sum(activity > 0, axis=1) >= 1
|
|
78
|
-
if mask.sum() == coords.shape[0]:
|
|
79
|
-
activity = activity[mask]
|
|
80
|
-
else:
|
|
81
|
-
raise ValueError(
|
|
82
|
-
f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
|
|
83
|
-
"If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
|
|
84
|
-
"`times=decoding['times']` or slice the activity accordingly."
|
|
85
|
-
)
|
|
86
|
-
else:
|
|
87
|
-
raise ValueError(
|
|
88
|
-
f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
|
|
89
|
-
"If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
|
|
90
|
-
"`times=decoding['times']` or slice the activity accordingly."
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
return activity
|
|
94
|
-
|
|
95
|
-
|
|
96
26
|
def plot_cohospace_scatter_trajectory_2d(
|
|
97
27
|
coords: np.ndarray,
|
|
98
28
|
times: np.ndarray | None = None,
|
|
@@ -802,41 +732,6 @@ def compute_cohoscore_scatter_1d(
|
|
|
802
732
|
return scores
|
|
803
733
|
|
|
804
734
|
|
|
805
|
-
def skew_transform_torus_scatter(coords):
|
|
806
|
-
"""
|
|
807
|
-
Convert torus angles (theta1, theta2) into coordinates in a skewed parallelogram fundamental domain.
|
|
808
|
-
|
|
809
|
-
Given theta1, theta2 in radians, map:
|
|
810
|
-
x = theta1 + 0.5 * theta2
|
|
811
|
-
y = (sqrt(3)/2) * theta2
|
|
812
|
-
|
|
813
|
-
This is a linear change of basis that turns the square [0, 2π)×[0, 2π) into a 60-degree
|
|
814
|
-
parallelogram, which is convenient for visualizing wrap-around behavior on a 2-torus.
|
|
815
|
-
|
|
816
|
-
Parameters
|
|
817
|
-
----------
|
|
818
|
-
coords : ndarray, shape (T, 2)
|
|
819
|
-
Angles (theta1, theta2) in radians.
|
|
820
|
-
|
|
821
|
-
Returns
|
|
822
|
-
-------
|
|
823
|
-
xy : ndarray, shape (T, 2)
|
|
824
|
-
Skewed planar coordinates.
|
|
825
|
-
"""
|
|
826
|
-
coords = np.asarray(coords)
|
|
827
|
-
if coords.ndim != 2 or coords.shape[1] != 2:
|
|
828
|
-
raise ValueError(f"coords must be (T,2), got {coords.shape}")
|
|
829
|
-
|
|
830
|
-
theta1 = coords[:, 0]
|
|
831
|
-
theta2 = coords[:, 1]
|
|
832
|
-
|
|
833
|
-
# Linear change of basis (NO nonlinear scaling)
|
|
834
|
-
x = theta1 + 0.5 * theta2
|
|
835
|
-
y = (np.sqrt(3) / 2.0) * theta2
|
|
836
|
-
|
|
837
|
-
return np.stack([x, y], axis=1)
|
|
838
|
-
|
|
839
|
-
|
|
840
735
|
def draw_torus_parallelogram_grid_scatter(ax, n_tiles=1, color="0.7", lw=1.0, alpha=0.8):
|
|
841
736
|
"""
|
|
842
737
|
Draw parallelogram grid corresponding to torus fundamental domain.
|
|
@@ -874,7 +769,7 @@ def tile_parallelogram_points_scatter(xy, n_tiles=1):
|
|
|
874
769
|
Parameters
|
|
875
770
|
----------
|
|
876
771
|
points : ndarray, shape (T, 2)
|
|
877
|
-
Points in the skewed plane (same coordinates as returned by `
|
|
772
|
+
Points in the skewed plane (same coordinates as returned by `skew_transform`).
|
|
878
773
|
n_tiles : int
|
|
879
774
|
Number of tiles to extend around the base domain.
|
|
880
775
|
- n_tiles=1 produces a 3x3 tiling
|
|
@@ -1020,7 +915,7 @@ def plot_cohospace_scatter_neuron_skewed(
|
|
|
1020
915
|
)
|
|
1021
916
|
|
|
1022
917
|
# --- skew transform
|
|
1023
|
-
xy =
|
|
918
|
+
xy = skew_transform(coords[mask])
|
|
1024
919
|
|
|
1025
920
|
# Tiling: if points are tiled, values must be tiled too (FR mode) to keep lengths consistent
|
|
1026
921
|
if n_tiles and n_tiles > 0:
|
|
@@ -1165,7 +1060,7 @@ def plot_cohospace_scatter_population_skewed(
|
|
|
1165
1060
|
thr = np.percentile(a, 100 - top_percent)
|
|
1166
1061
|
mask = a >= thr
|
|
1167
1062
|
|
|
1168
|
-
xy =
|
|
1063
|
+
xy = skew_transform(coords[mask])
|
|
1169
1064
|
|
|
1170
1065
|
if n_tiles and n_tiles > 0:
|
|
1171
1066
|
xy = tile_parallelogram_points_scatter(xy, n_tiles=n_tiles)
|
|
@@ -6,7 +6,6 @@ import numpy as np
|
|
|
6
6
|
from scipy.ndimage import gaussian_filter1d
|
|
7
7
|
|
|
8
8
|
from .config import DataLoadError, ProcessingError, SpikeEmbeddingConfig
|
|
9
|
-
from .filters import _gaussian_filter1d
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None, **kwargs):
|
|
@@ -262,8 +261,8 @@ def _load_pos(t, x, y, res=100000, dt=1000):
|
|
|
262
261
|
xx = np.interp(tt, t, x)
|
|
263
262
|
yy = np.interp(tt, t, y)
|
|
264
263
|
|
|
265
|
-
xxs =
|
|
266
|
-
yys =
|
|
264
|
+
xxs = gaussian_filter1d(xx - np.min(xx), sigma=100)
|
|
265
|
+
yys = gaussian_filter1d(yy - np.min(yy), sigma=100)
|
|
267
266
|
dx = (xxs[1:] - xxs[:-1]) * 100
|
|
268
267
|
dy = (yys[1:] - yys[:-1]) * 100
|
|
269
268
|
speed = np.sqrt(dx**2 + dy**2) / 0.01
|
canns/analyzer/data/asa/fr.py
CHANGED
|
@@ -1,18 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from dataclasses import dataclass
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
|
|
8
7
|
from ...visualization.core import PlotConfig, finalize_figure
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
12
|
-
if save_path:
|
|
13
|
-
parent = os.path.dirname(save_path)
|
|
14
|
-
if parent:
|
|
15
|
-
os.makedirs(parent, exist_ok=True)
|
|
8
|
+
from .utils import _ensure_parent_dir
|
|
16
9
|
|
|
17
10
|
|
|
18
11
|
def _slice_range(r: tuple[int, int] | None, length: int) -> slice:
|
canns/analyzer/data/asa/path.py
CHANGED
|
@@ -240,6 +240,76 @@ def interp_coords_to_full_1d(idx_map: np.ndarray, coords1: np.ndarray, T_full: i
|
|
|
240
240
|
return np.mod(out, 2 * np.pi)[:, None]
|
|
241
241
|
|
|
242
242
|
|
|
243
|
+
def _align_activity_to_coords(
|
|
244
|
+
coords: np.ndarray,
|
|
245
|
+
activity: np.ndarray,
|
|
246
|
+
times: np.ndarray | None = None,
|
|
247
|
+
*,
|
|
248
|
+
label: str = "activity",
|
|
249
|
+
auto_filter: bool = True,
|
|
250
|
+
) -> np.ndarray:
|
|
251
|
+
"""
|
|
252
|
+
Align activity to coords by optional time indices and validate lengths.
|
|
253
|
+
|
|
254
|
+
Parameters
|
|
255
|
+
----------
|
|
256
|
+
coords : ndarray
|
|
257
|
+
Decoded coordinates array.
|
|
258
|
+
activity : ndarray
|
|
259
|
+
Activity matrix (firing rate or spikes).
|
|
260
|
+
times : ndarray, optional
|
|
261
|
+
Optional time indices to align activity to coords when coords are computed
|
|
262
|
+
on a subset of timepoints.
|
|
263
|
+
label : str
|
|
264
|
+
Label for error messages (default: "activity").
|
|
265
|
+
auto_filter : bool
|
|
266
|
+
If True and lengths mismatch, auto-filter activity with activity>0 to mimic
|
|
267
|
+
decode filtering.
|
|
268
|
+
|
|
269
|
+
Returns
|
|
270
|
+
-------
|
|
271
|
+
ndarray
|
|
272
|
+
Aligned activity array.
|
|
273
|
+
|
|
274
|
+
Raises
|
|
275
|
+
------
|
|
276
|
+
ValueError
|
|
277
|
+
If activity length doesn't match coords length after alignment attempts.
|
|
278
|
+
"""
|
|
279
|
+
coords = np.asarray(coords)
|
|
280
|
+
activity = np.asarray(activity)
|
|
281
|
+
|
|
282
|
+
if times is not None:
|
|
283
|
+
times = np.asarray(times)
|
|
284
|
+
try:
|
|
285
|
+
activity = activity[times]
|
|
286
|
+
except Exception as exc:
|
|
287
|
+
raise ValueError(
|
|
288
|
+
f"Failed to index {label} with `times`. Ensure `times` indexes the original time axis."
|
|
289
|
+
) from exc
|
|
290
|
+
|
|
291
|
+
if activity.shape[0] != coords.shape[0]:
|
|
292
|
+
# Try to reproduce decode's zero-spike filtering if lengths mismatch.
|
|
293
|
+
if auto_filter and times is None and activity.ndim == 2:
|
|
294
|
+
mask = np.sum(activity > 0, axis=1) >= 1
|
|
295
|
+
if mask.sum() == coords.shape[0]:
|
|
296
|
+
activity = activity[mask]
|
|
297
|
+
else:
|
|
298
|
+
raise ValueError(
|
|
299
|
+
f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
|
|
300
|
+
"If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
|
|
301
|
+
"`times=decoding['times']` or slice the activity accordingly."
|
|
302
|
+
)
|
|
303
|
+
else:
|
|
304
|
+
raise ValueError(
|
|
305
|
+
f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
|
|
306
|
+
"If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
|
|
307
|
+
"`times=decoding['times']` or slice the activity accordingly."
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
return activity
|
|
311
|
+
|
|
312
|
+
|
|
243
313
|
def align_coords_to_position_2d(
|
|
244
314
|
t_full: np.ndarray,
|
|
245
315
|
x_full: np.ndarray,
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from typing import Any
|
|
5
4
|
|
|
6
5
|
import matplotlib.pyplot as plt
|
|
@@ -24,31 +23,7 @@ from ...visualization.core import (
|
|
|
24
23
|
from ...visualization.core.jupyter_utils import display_animation_in_jupyter, is_jupyter_environment
|
|
25
24
|
from .config import CANN2DPlotConfig, ProcessingError, SpikeEmbeddingConfig
|
|
26
25
|
from .embedding import embed_spike_trains
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def _ensure_plot_config(
|
|
30
|
-
config: PlotConfig | None,
|
|
31
|
-
factory,
|
|
32
|
-
*,
|
|
33
|
-
kwargs: dict[str, Any] | None = None,
|
|
34
|
-
**defaults: Any,
|
|
35
|
-
) -> PlotConfig:
|
|
36
|
-
if config is None:
|
|
37
|
-
defaults.update({"kwargs": kwargs or {}})
|
|
38
|
-
return factory(**defaults)
|
|
39
|
-
|
|
40
|
-
if kwargs:
|
|
41
|
-
config_kwargs = config.kwargs or {}
|
|
42
|
-
config_kwargs.update(kwargs)
|
|
43
|
-
config.kwargs = config_kwargs
|
|
44
|
-
return config
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
48
|
-
if save_path:
|
|
49
|
-
parent = os.path.dirname(save_path)
|
|
50
|
-
if parent:
|
|
51
|
-
os.makedirs(parent, exist_ok=True)
|
|
26
|
+
from .utils import _ensure_parent_dir, _ensure_plot_config
|
|
52
27
|
|
|
53
28
|
|
|
54
29
|
def _render_torus_frame(frame_index: int, frame_data: dict[str, Any]) -> np.ndarray:
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""Shared utility functions for ASA (Attractor State Analysis) modules."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from scipy.ndimage import gaussian_filter
|
|
10
|
+
|
|
11
|
+
from ...visualization.core import PlotConfig
|
|
12
|
+
from .path import find_coords_matrix, find_times_box
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _ensure_plot_config(
|
|
16
|
+
config: PlotConfig | None,
|
|
17
|
+
factory,
|
|
18
|
+
*args,
|
|
19
|
+
kwargs: dict | None = None,
|
|
20
|
+
**defaults,
|
|
21
|
+
) -> PlotConfig:
|
|
22
|
+
"""Ensure a PlotConfig exists, creating one from factory if needed.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
config: Optional existing PlotConfig.
|
|
26
|
+
factory: Factory function to create PlotConfig if config is None.
|
|
27
|
+
*args: Positional arguments for factory.
|
|
28
|
+
kwargs: Optional dict to merge into config.kwargs.
|
|
29
|
+
**defaults: Keyword arguments for factory.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
PlotConfig instance.
|
|
33
|
+
"""
|
|
34
|
+
if config is None:
|
|
35
|
+
if kwargs:
|
|
36
|
+
defaults.update({"kwargs": kwargs})
|
|
37
|
+
return factory(*args, **defaults)
|
|
38
|
+
|
|
39
|
+
# If config exists and kwargs provided, merge them
|
|
40
|
+
if kwargs:
|
|
41
|
+
config_kwargs = config.kwargs or {}
|
|
42
|
+
config_kwargs.update(kwargs)
|
|
43
|
+
config.kwargs = config_kwargs
|
|
44
|
+
return config
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
48
|
+
"""Create parent directory for save_path if it doesn't exist.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
save_path: Optional file path. If provided, creates parent directory.
|
|
52
|
+
"""
|
|
53
|
+
if save_path:
|
|
54
|
+
parent = os.path.dirname(save_path)
|
|
55
|
+
if parent:
|
|
56
|
+
os.makedirs(parent, exist_ok=True)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _circmean(x: np.ndarray) -> float:
|
|
60
|
+
"""Compute circular mean of angles.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
x: Array of angles in radians.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Circular mean in radians.
|
|
67
|
+
"""
|
|
68
|
+
return float(np.arctan2(np.mean(np.sin(x)), np.mean(np.cos(x))))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _smooth_circular_map(
|
|
72
|
+
mtot: np.ndarray,
|
|
73
|
+
smooth_sigma: float,
|
|
74
|
+
*,
|
|
75
|
+
fill_nan: bool = False,
|
|
76
|
+
fill_sigma: float | None = None,
|
|
77
|
+
fill_min_weight: float = 1e-3,
|
|
78
|
+
) -> np.ndarray:
|
|
79
|
+
"""Smooth a circular phase map using Gaussian filtering in sin/cos space.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
mtot: Phase map array (angles in radians).
|
|
83
|
+
smooth_sigma: Gaussian smoothing sigma.
|
|
84
|
+
fill_nan: Whether to fill NaN values using weighted interpolation.
|
|
85
|
+
fill_sigma: Sigma for NaN filling (defaults to smooth_sigma).
|
|
86
|
+
fill_min_weight: Minimum weight threshold for valid interpolation.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Smoothed phase map.
|
|
90
|
+
"""
|
|
91
|
+
mtot = np.asarray(mtot, dtype=float)
|
|
92
|
+
nans = np.isnan(mtot)
|
|
93
|
+
mask = (~nans).astype(float)
|
|
94
|
+
sintot = np.sin(mtot)
|
|
95
|
+
costot = np.cos(mtot)
|
|
96
|
+
sintot[nans] = 0.0
|
|
97
|
+
costot[nans] = 0.0
|
|
98
|
+
|
|
99
|
+
if fill_nan:
|
|
100
|
+
if fill_sigma is None:
|
|
101
|
+
fill_sigma = smooth_sigma if smooth_sigma and smooth_sigma > 0 else 1.0
|
|
102
|
+
weight = gaussian_filter(mask, fill_sigma)
|
|
103
|
+
sintot = gaussian_filter(sintot * mask, fill_sigma)
|
|
104
|
+
costot = gaussian_filter(costot * mask, fill_sigma)
|
|
105
|
+
min_weight = max(float(fill_min_weight), 0.0)
|
|
106
|
+
valid = weight > min_weight
|
|
107
|
+
sintot = np.divide(sintot, weight, out=np.zeros_like(sintot), where=valid)
|
|
108
|
+
costot = np.divide(costot, weight, out=np.zeros_like(costot), where=valid)
|
|
109
|
+
mtot = np.arctan2(sintot, costot)
|
|
110
|
+
if fill_min_weight > 0:
|
|
111
|
+
mtot[~valid] = np.nan
|
|
112
|
+
return mtot
|
|
113
|
+
|
|
114
|
+
if smooth_sigma and smooth_sigma > 0:
|
|
115
|
+
sintot = gaussian_filter(sintot, smooth_sigma)
|
|
116
|
+
costot = gaussian_filter(costot, smooth_sigma)
|
|
117
|
+
mtot = np.arctan2(sintot, costot)
|
|
118
|
+
mtot[nans] = np.nan
|
|
119
|
+
return mtot
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _extract_coords_and_times(
|
|
123
|
+
decoding_result: dict[str, Any],
|
|
124
|
+
coords_key: str | None = None,
|
|
125
|
+
) -> tuple[np.ndarray, np.ndarray | None]:
|
|
126
|
+
"""Extract coordinates and time indices from decoding result.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
decoding_result: Dictionary containing decoding results.
|
|
130
|
+
coords_key: Optional key for coordinates (defaults to 'coordsbox' or auto-detect).
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
Tuple of (coords, times_box) where coords is (T, 2) and times_box is optional.
|
|
134
|
+
"""
|
|
135
|
+
if coords_key is not None:
|
|
136
|
+
if coords_key not in decoding_result:
|
|
137
|
+
raise KeyError(f"coords_key '{coords_key}' not found in decoding_result.")
|
|
138
|
+
coords = np.asarray(decoding_result[coords_key])
|
|
139
|
+
elif "coordsbox" in decoding_result:
|
|
140
|
+
coords = np.asarray(decoding_result["coordsbox"])
|
|
141
|
+
else:
|
|
142
|
+
coords, _ = find_coords_matrix(decoding_result)
|
|
143
|
+
|
|
144
|
+
times_box, _ = find_times_box(decoding_result)
|
|
145
|
+
return coords, times_box
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _phase_map_valid_fraction(phase_map: np.ndarray) -> float:
|
|
149
|
+
"""Calculate fraction of valid (non-NaN) values in phase map.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
phase_map: Phase map array.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Fraction of valid values (0.0 to 1.0).
|
|
156
|
+
"""
|
|
157
|
+
valid = np.isfinite(phase_map)
|
|
158
|
+
if valid.size == 0:
|
|
159
|
+
return 0.0
|
|
160
|
+
return float(np.mean(valid))
|
|
@@ -8,6 +8,7 @@ import numpy as np
|
|
|
8
8
|
from matplotlib import pyplot as plt
|
|
9
9
|
from scipy.ndimage import gaussian_filter1d
|
|
10
10
|
|
|
11
|
+
from canns.analyzer.data.asa.utils import _ensure_plot_config
|
|
11
12
|
from canns.analyzer.visualization.core.config import PlotConfig, finalize_figure
|
|
12
13
|
|
|
13
14
|
_DEFAULT_BTN_COLORS = {
|
|
@@ -17,24 +18,6 @@ _DEFAULT_BTN_COLORS = {
|
|
|
17
18
|
}
|
|
18
19
|
|
|
19
20
|
|
|
20
|
-
def _ensure_plot_config(
|
|
21
|
-
config: PlotConfig | None,
|
|
22
|
-
factory,
|
|
23
|
-
*,
|
|
24
|
-
kwargs: dict[str, Any] | None = None,
|
|
25
|
-
**defaults: Any,
|
|
26
|
-
) -> PlotConfig:
|
|
27
|
-
if config is None:
|
|
28
|
-
defaults.update({"kwargs": kwargs or {}})
|
|
29
|
-
return factory(**defaults)
|
|
30
|
-
|
|
31
|
-
if kwargs:
|
|
32
|
-
config_kwargs = config.kwargs or {}
|
|
33
|
-
config_kwargs.update(kwargs)
|
|
34
|
-
config.kwargs = config_kwargs
|
|
35
|
-
return config
|
|
36
|
-
|
|
37
|
-
|
|
38
21
|
def _canonical_label(label: str) -> str:
|
|
39
22
|
lab = label.strip().lower()
|
|
40
23
|
if lab in ("b", "bursty"):
|
|
@@ -3,23 +3,22 @@ canns/_version.py,sha256=zIvJPOGBFvo4VV6f586rlO_bvhuFp1fsxjf6xhsqkJY,1547
|
|
|
3
3
|
canns/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
4
|
canns/analyzer/__init__.py,sha256=EQ02fYHkpMADp-ojpVCVtapuSPkl6j5WVfdPy0mOTs4,506
|
|
5
5
|
canns/analyzer/data/__init__.py,sha256=RfS8vwApLkNF05Y_lfPaJpN_bRv-mOA_uFziaduDHgI,354
|
|
6
|
-
canns/analyzer/data/asa/__init__.py,sha256=
|
|
7
|
-
canns/analyzer/data/asa/coho.py,sha256=
|
|
8
|
-
canns/analyzer/data/asa/cohomap.py,sha256=
|
|
9
|
-
canns/analyzer/data/asa/
|
|
10
|
-
canns/analyzer/data/asa/
|
|
11
|
-
canns/analyzer/data/asa/
|
|
12
|
-
canns/analyzer/data/asa/
|
|
13
|
-
canns/analyzer/data/asa/cohospace_scatter.py,sha256=z13lLFdYkODmjKeOYildM5TgvT2QxmFSh7q8BYS5WjA,37163
|
|
6
|
+
canns/analyzer/data/asa/__init__.py,sha256=JtXp1QgoqsA1_XiZ4TNFnTGP_iKo5fHeJ_V3wpdhN6g,3327
|
|
7
|
+
canns/analyzer/data/asa/coho.py,sha256=775bOruTVSNRjM0A4YRTeiWKVQtEMCR3zTYXgVTeLiw,444
|
|
8
|
+
canns/analyzer/data/asa/cohomap.py,sha256=ZIiH1Mu306jg1lLmeV-akEzbojXQQVKW3nl0PbFKD4M,15217
|
|
9
|
+
canns/analyzer/data/asa/cohomap_vectors.py,sha256=o2eHMqiAvmhZfebOvHZ1zypqZKKgvZu0WLs3KD3Nfis,11273
|
|
10
|
+
canns/analyzer/data/asa/cohospace.py,sha256=8sq1Z0QGDxzC-5YY1n-nDEtqH3vnIsVA0Q1ST7egxTE,6701
|
|
11
|
+
canns/analyzer/data/asa/cohospace_phase_centers.py,sha256=ZQDhd_cBv8SFcVxgGM0nJabTLjKQZ1mg9Ho6duUG0xY,3738
|
|
12
|
+
canns/analyzer/data/asa/cohospace_scatter.py,sha256=lBubii5rz-MiTT_rYa5u59nzw2vkuOcFDnM6UhvvJ9M,33893
|
|
14
13
|
canns/analyzer/data/asa/config.py,sha256=qm0k0nt0xuDUK5t63MG7ii7fgs2XbxyLxKOaOKJuB_s,6398
|
|
15
14
|
canns/analyzer/data/asa/decode.py,sha256=NG8vVx2cPG7uSJDovnC2vzk0dsqU8oR4jaNPxxrvCc0,16501
|
|
16
|
-
canns/analyzer/data/asa/embedding.py,sha256=
|
|
17
|
-
canns/analyzer/data/asa/filters.py,sha256=D-1mDVn4hBEAphKUgx1gQEUfgbghKcNQhZmr4xEExQA,7146
|
|
15
|
+
canns/analyzer/data/asa/embedding.py,sha256=V1TzOD4wzyx3xmwCBHLLWgX92rBkMjWkFz-F-4xPPxo,9581
|
|
18
16
|
canns/analyzer/data/asa/fly_roi.py,sha256=_scBOd-4t9yv_1tHk7wbXJwPieU-L-QtFJY6fhHpxDI,38031
|
|
19
|
-
canns/analyzer/data/asa/fr.py,sha256
|
|
20
|
-
canns/analyzer/data/asa/path.py,sha256=
|
|
21
|
-
canns/analyzer/data/asa/plotting.py,sha256=
|
|
17
|
+
canns/analyzer/data/asa/fr.py,sha256=-MuA50eTVdVJB7CUBeLelvnm0i0c0pK5NRICLoSD7oo,13292
|
|
18
|
+
canns/analyzer/data/asa/path.py,sha256=JM8NqkhwFx7qDs26n_t2bL5md9V4oR2E2ksXwYAwTug,17965
|
|
19
|
+
canns/analyzer/data/asa/plotting.py,sha256=PuYGZgal67uDRZ4xBczJKbVblqbza4a03re105Wj7vA,42243
|
|
22
20
|
canns/analyzer/data/asa/tda.py,sha256=7IdxhBNEE99qenG6Zi4B5tv_L9K6gAW6HHxYGiErx4c,30574
|
|
21
|
+
canns/analyzer/data/asa/utils.py,sha256=s9R1d6K4op9MnjcBaI4vrnEAC-9KyZE6S9dg_zR0ONs,4899
|
|
23
22
|
canns/analyzer/data/cell_classification/__init__.py,sha256=Ri0VJYn2OI3ygC4m-Xc9rjFvgPLaEykc-D94VxmUclQ,2447
|
|
24
23
|
canns/analyzer/data/cell_classification/core/__init__.py,sha256=J9uqjx2wTK-uh3OWFqP8BkY_ySz3rU_VWRNQ_1t3EbM,865
|
|
25
24
|
canns/analyzer/data/cell_classification/core/btn.py,sha256=rgZdEoMgqgOPah5KHEQAx7pNWDsWBzZ2pNQ5BstTYFM,8546
|
|
@@ -35,7 +34,7 @@ canns/analyzer/data/cell_classification/utils/correlation.py,sha256=57Ckn8OQGLip
|
|
|
35
34
|
canns/analyzer/data/cell_classification/utils/geometry.py,sha256=jOLh3GeO-riR5a7r7Q7uON3HU_bYOZZJLbokU5bjCOQ,12683
|
|
36
35
|
canns/analyzer/data/cell_classification/utils/image_processing.py,sha256=o9bLT4ycJ_IF7SKBe2RqSWIQwNcpi9v4AI-N5vpm_jM,12805
|
|
37
36
|
canns/analyzer/data/cell_classification/visualization/__init__.py,sha256=fmEHZBcurW6y6FwySLoq65b6CH2kNUB02NCVw2ou6Nc,590
|
|
38
|
-
canns/analyzer/data/cell_classification/visualization/btn_plots.py,sha256=
|
|
37
|
+
canns/analyzer/data/cell_classification/visualization/btn_plots.py,sha256=wHP6jyzzqvD2bgYriQFu1tZQiHevlYDy8V1fFNBw4M0,7345
|
|
39
38
|
canns/analyzer/data/cell_classification/visualization/grid_plots.py,sha256=NFtyYOe2Szt0EOIwQmZradwEvvRjjm7mm6VnnGThDQ0,7914
|
|
40
39
|
canns/analyzer/data/cell_classification/visualization/hd_plots.py,sha256=nzw1jck3VHvAFsJAGelhrJf1q27A5PI0r3NKVgeea8U,5670
|
|
41
40
|
canns/analyzer/metrics/__init__.py,sha256=DTsrv1HW133_RgvhWzz7Gx-bP2hOZbPO2unCPPyf9gs,178
|
|
@@ -166,8 +165,8 @@ canns/trainer/utils.py,sha256=ZdoLiRqFLfKXsWi0KX3wGUp0OqFikwiou8dPf3xvFhE,2847
|
|
|
166
165
|
canns/typing/__init__.py,sha256=mXySdfmD8fA56WqZTb1Nj-ZovcejwLzNjuk6PRfTwmA,156
|
|
167
166
|
canns/utils/__init__.py,sha256=OMyZ5jqZAIUS2Jr0qcnvvrx6YM-BZ1EJy5uZYeA3HC0,366
|
|
168
167
|
canns/utils/benchmark.py,sha256=oJ7nvbvnQMh4_MZh7z160NPLp-197X0rEnmnLHYlev4,1361
|
|
169
|
-
canns-0.15.
|
|
170
|
-
canns-0.15.
|
|
171
|
-
canns-0.15.
|
|
172
|
-
canns-0.15.
|
|
173
|
-
canns-0.15.
|
|
168
|
+
canns-0.15.1.dist-info/METADATA,sha256=Rxhq9ndM1JJjHYGYaB_BjffkvlLXO-22xIzwzOm8h1I,9799
|
|
169
|
+
canns-0.15.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
170
|
+
canns-0.15.1.dist-info/entry_points.txt,sha256=57YF2HZp_BG3GeGB8L0m3wR1sSfNyMXF1q4CKEjce6U,164
|
|
171
|
+
canns-0.15.1.dist-info/licenses/LICENSE,sha256=u6NJ1N-QSnf5yTwSk5UvFAdU2yKD0jxG0Xa91n1cPO4,11306
|
|
172
|
+
canns-0.15.1.dist-info/RECORD,,
|
|
@@ -1,208 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import numbers
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
6
|
-
from numpy.exceptions import AxisError
|
|
7
|
-
from scipy.ndimage import _nd_image, _ni_support
|
|
8
|
-
from scipy.ndimage._filters import _invalid_origin
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def _gaussian_filter1d(
|
|
12
|
-
input,
|
|
13
|
-
sigma,
|
|
14
|
-
axis=-1,
|
|
15
|
-
order=0,
|
|
16
|
-
output=None,
|
|
17
|
-
mode="reflect",
|
|
18
|
-
cval=0.0,
|
|
19
|
-
truncate=4.0,
|
|
20
|
-
*,
|
|
21
|
-
radius=None,
|
|
22
|
-
):
|
|
23
|
-
"""1-D Gaussian filter.
|
|
24
|
-
|
|
25
|
-
Parameters
|
|
26
|
-
----------
|
|
27
|
-
%(input)s
|
|
28
|
-
sigma : scalar
|
|
29
|
-
standard deviation for Gaussian kernel
|
|
30
|
-
%(axis)s
|
|
31
|
-
order : int, optional
|
|
32
|
-
An order of 0 corresponds to convolution with a Gaussian
|
|
33
|
-
kernel. A positive order corresponds to convolution with
|
|
34
|
-
that derivative of a Gaussian.
|
|
35
|
-
%(output)s
|
|
36
|
-
%(mode_reflect)s
|
|
37
|
-
%(cval)s
|
|
38
|
-
truncate : float, optional
|
|
39
|
-
Truncate the filter at this many standard deviations.
|
|
40
|
-
Default is 4.0.
|
|
41
|
-
radius : None or int, optional
|
|
42
|
-
Radius of the Gaussian kernel. If specified, the size of
|
|
43
|
-
the kernel will be ``2*radius + 1``, and `truncate` is ignored.
|
|
44
|
-
Default is None.
|
|
45
|
-
|
|
46
|
-
Returns
|
|
47
|
-
-------
|
|
48
|
-
gaussian_filter1d : ndarray
|
|
49
|
-
|
|
50
|
-
Notes
|
|
51
|
-
-----
|
|
52
|
-
The Gaussian kernel will have size ``2*radius + 1`` along each axis. If
|
|
53
|
-
`radius` is None, a default ``radius = round(truncate * sigma)`` will be
|
|
54
|
-
used.
|
|
55
|
-
|
|
56
|
-
Examples
|
|
57
|
-
--------
|
|
58
|
-
>>> from scipy.ndimage import gaussian_filter1d
|
|
59
|
-
>>> import numpy as np
|
|
60
|
-
>>> gaussian_filter1d([1.0, 2.0, 3.0, 4.0, 5.0], 1)
|
|
61
|
-
array([ 1.42704095, 2.06782203, 3. , 3.93217797, 4.57295905])
|
|
62
|
-
>>> _gaussian_filter1d([1.0, 2.0, 3.0, 4.0, 5.0], 4)
|
|
63
|
-
array([ 2.91948343, 2.95023502, 3. , 3.04976498, 3.08051657])
|
|
64
|
-
>>> import matplotlib.pyplot as plt
|
|
65
|
-
>>> rng = np.random.default_rng()
|
|
66
|
-
>>> x = rng.standard_normal(101).cumsum()
|
|
67
|
-
>>> y3 = _gaussian_filter1d(x, 3)
|
|
68
|
-
>>> y6 = _gaussian_filter1d(x, 6)
|
|
69
|
-
>>> plt.plot(x, 'k', label='original data')
|
|
70
|
-
>>> plt.plot(y3, '--', label='filtered, sigma=3')
|
|
71
|
-
>>> plt.plot(y6, ':', label='filtered, sigma=6')
|
|
72
|
-
>>> plt.legend()
|
|
73
|
-
>>> plt.grid()
|
|
74
|
-
>>> plt.show()
|
|
75
|
-
|
|
76
|
-
"""
|
|
77
|
-
sd = float(sigma)
|
|
78
|
-
# make the radius of the filter equal to truncate standard deviations
|
|
79
|
-
lw = int(truncate * sd + 0.5)
|
|
80
|
-
if radius is not None:
|
|
81
|
-
lw = radius
|
|
82
|
-
if not isinstance(lw, numbers.Integral) or lw < 0:
|
|
83
|
-
raise ValueError("Radius must be a nonnegative integer.")
|
|
84
|
-
# Since we are calling correlate, not convolve, revert the kernel
|
|
85
|
-
weights = _gaussian_kernel1d(sigma, order, lw)[::-1]
|
|
86
|
-
return _correlate1d(input, weights, axis, output, mode, cval, 0)
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
def _gaussian_kernel1d(sigma, order, radius):
|
|
90
|
-
"""
|
|
91
|
-
Computes a 1-D Gaussian convolution kernel.
|
|
92
|
-
"""
|
|
93
|
-
if order < 0:
|
|
94
|
-
raise ValueError("order must be non-negative")
|
|
95
|
-
exponent_range = np.arange(order + 1)
|
|
96
|
-
sigma2 = sigma * sigma
|
|
97
|
-
x = np.arange(-radius, radius + 1)
|
|
98
|
-
phi_x = np.exp(-0.5 / sigma2 * x**2)
|
|
99
|
-
phi_x = phi_x / phi_x.sum()
|
|
100
|
-
|
|
101
|
-
if order == 0:
|
|
102
|
-
return phi_x
|
|
103
|
-
else:
|
|
104
|
-
# f(x) = q(x) * phi(x) = q(x) * exp(p(x))
|
|
105
|
-
# f'(x) = (q'(x) + q(x) * p'(x)) * phi(x)
|
|
106
|
-
# p'(x) = -1 / sigma ** 2
|
|
107
|
-
# Implement q'(x) + q(x) * p'(x) as a matrix operator and apply to the
|
|
108
|
-
# coefficients of q(x)
|
|
109
|
-
q = np.zeros(order + 1)
|
|
110
|
-
q[0] = 1
|
|
111
|
-
D = np.diag(exponent_range[1:], 1) # D @ q(x) = q'(x)
|
|
112
|
-
P = np.diag(np.ones(order) / -sigma2, -1) # P @ q(x) = q(x) * p'(x)
|
|
113
|
-
Q_deriv = D + P
|
|
114
|
-
for _ in range(order):
|
|
115
|
-
q = Q_deriv.dot(q)
|
|
116
|
-
q = (x[:, None] ** exponent_range).dot(q)
|
|
117
|
-
return q * phi_x
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
def _correlate1d(input, weights, axis=-1, output=None, mode="reflect", cval=0.0, origin=0):
|
|
121
|
-
"""Calculate a 1-D correlation along the given axis.
|
|
122
|
-
|
|
123
|
-
The lines of the array along the given axis are correlated with the
|
|
124
|
-
given weights.
|
|
125
|
-
|
|
126
|
-
Parameters
|
|
127
|
-
----------
|
|
128
|
-
%(input)s
|
|
129
|
-
weights : array
|
|
130
|
-
1-D sequence of numbers.
|
|
131
|
-
%(axis)s
|
|
132
|
-
%(output)s
|
|
133
|
-
%(mode_reflect)s
|
|
134
|
-
%(cval)s
|
|
135
|
-
%(origin)s
|
|
136
|
-
|
|
137
|
-
Returns
|
|
138
|
-
-------
|
|
139
|
-
result : ndarray
|
|
140
|
-
Correlation result. Has the same shape as `input`.
|
|
141
|
-
|
|
142
|
-
Examples
|
|
143
|
-
--------
|
|
144
|
-
>>> from scipy.ndimage import correlate1d
|
|
145
|
-
>>> correlate1d([2, 8, 0, 4, 1, 9, 9, 0], weights=[1, 3])
|
|
146
|
-
array([ 8, 26, 8, 12, 7, 28, 36, 9])
|
|
147
|
-
"""
|
|
148
|
-
input = np.asarray(input)
|
|
149
|
-
weights = np.asarray(weights)
|
|
150
|
-
complex_input = input.dtype.kind == "c"
|
|
151
|
-
complex_weights = weights.dtype.kind == "c"
|
|
152
|
-
if complex_input or complex_weights:
|
|
153
|
-
if complex_weights:
|
|
154
|
-
weights = weights.conj()
|
|
155
|
-
weights = weights.astype(np.complex128, copy=False)
|
|
156
|
-
kwargs = dict(axis=axis, mode=mode, origin=origin)
|
|
157
|
-
output = _ni_support._get_output(output, input, complex_output=True)
|
|
158
|
-
return _complex_via_real_components(_correlate1d, input, weights, output, cval, **kwargs)
|
|
159
|
-
|
|
160
|
-
output = _ni_support._get_output(output, input)
|
|
161
|
-
weights = np.asarray(weights, dtype=np.float64)
|
|
162
|
-
if weights.ndim != 1 or weights.shape[0] < 1:
|
|
163
|
-
raise RuntimeError("no filter weights given")
|
|
164
|
-
if not weights.flags.contiguous:
|
|
165
|
-
weights = weights.copy()
|
|
166
|
-
axis = _normalize_axis_index(axis, input.ndim)
|
|
167
|
-
if _invalid_origin(origin, len(weights)):
|
|
168
|
-
raise ValueError(
|
|
169
|
-
"Invalid origin; origin must satisfy "
|
|
170
|
-
"-(len(weights) // 2) <= origin <= "
|
|
171
|
-
"(len(weights)-1) // 2"
|
|
172
|
-
)
|
|
173
|
-
mode = _ni_support._extend_mode_to_code(mode)
|
|
174
|
-
_nd_image.correlate1d(input, weights, axis, output, mode, cval, origin)
|
|
175
|
-
return output
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
def _complex_via_real_components(func, input, weights, output, cval, **kwargs):
|
|
179
|
-
"""Complex convolution via a linear combination of real convolutions."""
|
|
180
|
-
complex_input = input.dtype.kind == "c"
|
|
181
|
-
complex_weights = weights.dtype.kind == "c"
|
|
182
|
-
if complex_input and complex_weights:
|
|
183
|
-
# real component of the output
|
|
184
|
-
func(input.real, weights.real, output=output.real, cval=np.real(cval), **kwargs)
|
|
185
|
-
output.real -= func(input.imag, weights.imag, output=None, cval=np.imag(cval), **kwargs)
|
|
186
|
-
# imaginary component of the output
|
|
187
|
-
func(input.real, weights.imag, output=output.imag, cval=np.real(cval), **kwargs)
|
|
188
|
-
output.imag += func(input.imag, weights.real, output=None, cval=np.imag(cval), **kwargs)
|
|
189
|
-
elif complex_input:
|
|
190
|
-
func(input.real, weights, output=output.real, cval=np.real(cval), **kwargs)
|
|
191
|
-
func(input.imag, weights, output=output.imag, cval=np.imag(cval), **kwargs)
|
|
192
|
-
else:
|
|
193
|
-
if np.iscomplexobj(cval):
|
|
194
|
-
raise ValueError("Cannot provide a complex-valued cval when the input is real.")
|
|
195
|
-
func(input, weights.real, output=output.real, cval=cval, **kwargs)
|
|
196
|
-
func(input, weights.imag, output=output.imag, cval=cval, **kwargs)
|
|
197
|
-
return output
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
def _normalize_axis_index(axis, ndim):
|
|
201
|
-
# Check if `axis` is in the correct range and normalize it
|
|
202
|
-
if axis < -ndim or axis >= ndim:
|
|
203
|
-
msg = f"axis {axis} is out of bounds for array of dimension {ndim}"
|
|
204
|
-
raise AxisError(msg)
|
|
205
|
-
|
|
206
|
-
if axis < 0:
|
|
207
|
-
axis = axis + ndim
|
|
208
|
-
return axis
|
|
File without changes
|
|
File without changes
|
|
File without changes
|