canns 0.15.0__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,15 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  from .cohomap import (
4
4
  cohomap,
5
- cohomap_upgrade,
6
- ecohomap,
7
5
  fit_cohomap_stripes,
8
- fit_cohomap_stripes_upgrade,
9
6
  plot_cohomap,
10
- plot_cohomap_upgrade,
11
- plot_ecohomap,
12
7
  )
13
- from .cohomap_scatter import plot_cohomap_scatter, plot_cohomap_scatter_multi
14
8
  from .cohomap_vectors import (
15
9
  cohomap_vectors,
16
10
  plot_cohomap_stripes,
@@ -18,14 +12,8 @@ from .cohomap_vectors import (
18
12
  )
19
13
  from .cohospace import (
20
14
  cohospace,
21
- cohospace_upgrade,
22
- ecohospace,
23
15
  plot_cohospace,
24
16
  plot_cohospace_skewed,
25
- plot_cohospace_upgrade,
26
- plot_cohospace_upgrade_skewed,
27
- plot_ecohospace,
28
- plot_ecohospace_skewed,
29
17
  )
30
18
  from .cohospace_phase_centers import (
31
19
  cohospace_phase_centers,
@@ -81,6 +69,8 @@ from .path import (
81
69
  from .plotting import (
82
70
  plot_2d_bump_on_manifold,
83
71
  plot_3d_bump_on_torus,
72
+ plot_cohomap_scatter,
73
+ plot_cohomap_scatter_multi,
84
74
  plot_path_compare_1d,
85
75
  plot_path_compare_2d,
86
76
  plot_projection,
@@ -109,22 +99,11 @@ __all__ = [
109
99
  "plot_3d_bump_on_torus",
110
100
  "plot_2d_bump_on_manifold",
111
101
  "cohomap",
112
- "cohomap_upgrade",
113
102
  "fit_cohomap_stripes",
114
- "fit_cohomap_stripes_upgrade",
115
103
  "plot_cohomap",
116
- "plot_cohomap_upgrade",
117
104
  "cohospace",
118
- "cohospace_upgrade",
119
105
  "plot_cohospace",
120
106
  "plot_cohospace_skewed",
121
- "plot_cohospace_upgrade",
122
- "plot_cohospace_upgrade_skewed",
123
- "ecohomap",
124
- "ecohospace",
125
- "plot_ecohomap",
126
- "plot_ecohospace",
127
- "plot_ecohospace_skewed",
128
107
  "cohomap_vectors",
129
108
  "plot_cohomap_stripes",
130
109
  "plot_cohomap_vectors",
@@ -2,96 +2,20 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- import os
6
- from typing import Any
7
-
8
- import numpy as np
9
- from scipy.ndimage import gaussian_filter
10
-
11
- from ...visualization.core import PlotConfig
12
- from .path import find_coords_matrix, find_times_box
13
-
14
-
15
- def _ensure_plot_config(
16
- config: PlotConfig | None,
17
- factory,
18
- *args,
19
- **defaults,
20
- ) -> PlotConfig:
21
- if config is None:
22
- return factory(*args, **defaults)
23
- return config
24
-
25
-
26
- def _ensure_parent_dir(save_path: str | None) -> None:
27
- if save_path:
28
- parent = os.path.dirname(save_path)
29
- if parent:
30
- os.makedirs(parent, exist_ok=True)
31
-
32
-
33
- def _circmean(x: np.ndarray) -> float:
34
- return float(np.arctan2(np.mean(np.sin(x)), np.mean(np.cos(x))))
35
-
36
-
37
- def _smooth_circular_map(
38
- mtot: np.ndarray,
39
- smooth_sigma: float,
40
- *,
41
- fill_nan: bool = False,
42
- fill_sigma: float | None = None,
43
- fill_min_weight: float = 1e-3,
44
- ) -> np.ndarray:
45
- mtot = np.asarray(mtot, dtype=float)
46
- nans = np.isnan(mtot)
47
- mask = (~nans).astype(float)
48
- sintot = np.sin(mtot)
49
- costot = np.cos(mtot)
50
- sintot[nans] = 0.0
51
- costot[nans] = 0.0
52
-
53
- if fill_nan:
54
- if fill_sigma is None:
55
- fill_sigma = smooth_sigma if smooth_sigma and smooth_sigma > 0 else 1.0
56
- weight = gaussian_filter(mask, fill_sigma)
57
- sintot = gaussian_filter(sintot * mask, fill_sigma)
58
- costot = gaussian_filter(costot * mask, fill_sigma)
59
- min_weight = max(float(fill_min_weight), 0.0)
60
- valid = weight > min_weight
61
- sintot = np.divide(sintot, weight, out=np.zeros_like(sintot), where=valid)
62
- costot = np.divide(costot, weight, out=np.zeros_like(costot), where=valid)
63
- mtot = np.arctan2(sintot, costot)
64
- if fill_min_weight > 0:
65
- mtot[~valid] = np.nan
66
- return mtot
67
-
68
- if smooth_sigma and smooth_sigma > 0:
69
- sintot = gaussian_filter(sintot, smooth_sigma)
70
- costot = gaussian_filter(costot, smooth_sigma)
71
- mtot = np.arctan2(sintot, costot)
72
- mtot[nans] = np.nan
73
- return mtot
74
-
75
-
76
- def _extract_coords_and_times(
77
- decoding_result: dict[str, Any],
78
- coords_key: str | None = None,
79
- ) -> tuple[np.ndarray, np.ndarray | None]:
80
- if coords_key is not None:
81
- if coords_key not in decoding_result:
82
- raise KeyError(f"coords_key '{coords_key}' not found in decoding_result.")
83
- coords = np.asarray(decoding_result[coords_key])
84
- elif "coordsbox" in decoding_result:
85
- coords = np.asarray(decoding_result["coordsbox"])
86
- else:
87
- coords, _ = find_coords_matrix(decoding_result)
88
-
89
- times_box, _ = find_times_box(decoding_result)
90
- return coords, times_box
91
-
92
-
93
- def _phase_map_valid_fraction(phase_map: np.ndarray) -> float:
94
- valid = np.isfinite(phase_map)
95
- if valid.size == 0:
96
- return 0.0
97
- return float(np.mean(valid))
5
+ from .utils import (
6
+ _circmean,
7
+ _ensure_parent_dir,
8
+ _ensure_plot_config,
9
+ _extract_coords_and_times,
10
+ _phase_map_valid_fraction,
11
+ _smooth_circular_map,
12
+ )
13
+
14
+ __all__ = [
15
+ "_ensure_plot_config",
16
+ "_ensure_parent_dir",
17
+ "_circmean",
18
+ "_smooth_circular_map",
19
+ "_extract_coords_and_times",
20
+ "_phase_map_valid_fraction",
21
+ ]
@@ -57,6 +57,36 @@ def _rot_coord(
57
57
  c2: np.ndarray,
58
58
  p: tuple[int, int],
59
59
  ) -> np.ndarray:
60
+ """Transform and align decoded coordinates based on stripe orientation.
61
+
62
+ This function rotates and aligns coordinates to match the dominant stripe
63
+ orientation in the CohoMap. It handles the torus geometry by choosing the
64
+ appropriate coordinate system and applying reflections/rotations.
65
+
66
+ Parameters
67
+ ----------
68
+ params1 : np.ndarray
69
+ Stripe fit parameters for first dimension (includes orientation angle).
70
+ params2 : np.ndarray
71
+ Stripe fit parameters for second dimension (includes orientation angle).
72
+ c1 : np.ndarray
73
+ Decoded coordinates for first dimension.
74
+ c2 : np.ndarray
75
+ Decoded coordinates for second dimension.
76
+ p : tuple[int, int]
77
+ Direction indicators (1 or -1) for each dimension, controlling reflections.
78
+
79
+ Returns
80
+ -------
81
+ np.ndarray
82
+ Aligned coordinates of shape (N, 2), wrapped to [0, 2π).
83
+
84
+ Notes
85
+ -----
86
+ The function selects the coordinate system based on which dimension has
87
+ stronger horizontal alignment (larger |cos(angle)|), then applies geometric
88
+ transformations to align the coordinates with the stripe pattern.
89
+ """
60
90
  if abs(np.cos(params1[0])) < abs(np.cos(params2[0])):
61
91
  cc1 = c2.copy()
62
92
  cc2 = c1.copy()
@@ -96,6 +126,46 @@ def _toroidal_align_coords(
96
126
  trim: int,
97
127
  grid_size: int | None,
98
128
  ) -> tuple[np.ndarray, float, float]:
129
+ """Align decoded coordinates to CohoMap stripe patterns.
130
+
131
+ This function fits stripe patterns to both phase maps, determines the
132
+ optimal orientation and phase signs, then transforms the decoded coordinates
133
+ to align with the detected stripe structure.
134
+
135
+ Parameters
136
+ ----------
137
+ coords : np.ndarray
138
+ Decoded coordinates of shape (N, 2+). Only first two columns are used.
139
+ phase_map1 : np.ndarray
140
+ Phase map for first dimension (2D grid).
141
+ phase_map2 : np.ndarray
142
+ Phase map for second dimension (2D grid), must match phase_map1 shape.
143
+ trim : int
144
+ Number of edge bins to trim when fitting stripes.
145
+ grid_size : int, optional
146
+ Grid size for stripe fitting. If None, inferred from phase_map shape.
147
+
148
+ Returns
149
+ -------
150
+ aligned : np.ndarray
151
+ Aligned coordinates of shape (N, 2), wrapped to [0, 2π).
152
+ f1 : float
153
+ Fit quality score for first dimension (higher is better).
154
+ f2 : float
155
+ Fit quality score for second dimension (higher is better).
156
+
157
+ Raises
158
+ ------
159
+ ValueError
160
+ If phase_map shapes don't match or coords has wrong shape.
161
+
162
+ Notes
163
+ -----
164
+ The alignment process:
165
+ 1. Fits stripe patterns to both phase maps using FFT-based detection
166
+ 2. Determines phase signs (±1) for each dimension
167
+ 3. Applies geometric transformations via `_rot_coord()` to align coordinates
168
+ """
99
169
  if phase_map1.shape != phase_map2.shape:
100
170
  raise ValueError("phase_map shapes do not match for alignment")
101
171
  coords = np.asarray(coords)
@@ -381,28 +451,3 @@ def plot_cohomap(
381
451
  _ensure_parent_dir(config.save_path)
382
452
  finalize_figure(fig, config)
383
453
  return fig
384
-
385
-
386
- def cohomap_upgrade(*args, **kwargs) -> dict[str, Any]:
387
- """Legacy alias for EcohoMap (formerly cohomap_upgrade)."""
388
- return cohomap(*args, **kwargs)
389
-
390
-
391
- def ecohomap(*args, **kwargs) -> dict[str, Any]:
392
- """Alias for EcohoMap (GridCellTorus-style)."""
393
- return cohomap(*args, **kwargs)
394
-
395
-
396
- def fit_cohomap_stripes_upgrade(*args, **kwargs) -> tuple[np.ndarray, float]:
397
- """Legacy alias for EcohoMap stripe fitting."""
398
- return fit_cohomap_stripes(*args, **kwargs)
399
-
400
-
401
- def plot_cohomap_upgrade(*args, **kwargs) -> plt.Figure:
402
- """Legacy alias for EcohoMap plotting (formerly plot_cohomap_upgrade)."""
403
- return plot_cohomap(*args, **kwargs)
404
-
405
-
406
- def plot_ecohomap(*args, **kwargs) -> plt.Figure:
407
- """Alias for EcohoMap plotting."""
408
- return plot_cohomap(*args, **kwargs)
@@ -2,7 +2,6 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- import os
6
5
  from typing import Any
7
6
 
8
7
  import matplotlib.pyplot as plt
@@ -11,27 +10,37 @@ from scipy.ndimage import rotate
11
10
 
12
11
  from ...visualization.core import PlotConfig, finalize_figure
13
12
  from .cohomap import fit_cohomap_stripes
14
-
15
-
16
- def _ensure_plot_config(
17
- config: PlotConfig | None,
18
- factory,
19
- *args,
20
- **defaults,
21
- ) -> PlotConfig:
22
- if config is None:
23
- return factory(*args, **defaults)
24
- return config
25
-
26
-
27
- def _ensure_parent_dir(save_path: str | None) -> None:
28
- if save_path:
29
- parent = os.path.dirname(save_path)
30
- if parent:
31
- os.makedirs(parent, exist_ok=True)
13
+ from .utils import _ensure_parent_dir, _ensure_plot_config
32
14
 
33
15
 
34
16
  def _rot_para(params1: np.ndarray, params2: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
17
+ """Transform stripe fit parameters to canonical orientation.
18
+
19
+ This function adjusts the orientation angles of stripe fit parameters to
20
+ align them with a canonical coordinate system. Unlike `_rot_coord()` which
21
+ transforms actual coordinate data, this operates on the fit parameters
22
+ themselves (orientation angles, wavelengths, etc.).
23
+
24
+ Parameters
25
+ ----------
26
+ params1 : np.ndarray
27
+ Stripe fit parameters for first dimension. First element is orientation angle.
28
+ params2 : np.ndarray
29
+ Stripe fit parameters for second dimension. First element is orientation angle.
30
+
31
+ Returns
32
+ -------
33
+ x : np.ndarray
34
+ Transformed parameters for the dimension with stronger horizontal alignment.
35
+ y : np.ndarray
36
+ Transformed parameters for the other dimension.
37
+
38
+ Notes
39
+ -----
40
+ The function selects which parameter set becomes 'x' and 'y' based on which
41
+ has larger |cos(angle)|, then applies angle adjustments to ensure the stripe
42
+ vectors are in a canonical orientation for visualization and analysis.
43
+ """
35
44
  if abs(np.cos(params1[0])) < abs(np.cos(params2[0])):
36
45
  y = params1.copy()
37
46
  x = params2.copy()
@@ -204,12 +213,57 @@ def plot_cohomap_vectors(
204
213
 
205
214
 
206
215
  def _resolve_grid_size(phase_map: np.ndarray, grid_size: int | None, trim: int) -> int:
216
+ """Determine grid size for stripe fitting.
217
+
218
+ Parameters
219
+ ----------
220
+ phase_map : np.ndarray
221
+ Phase map array (2D grid).
222
+ grid_size : int, optional
223
+ Explicit grid size. If None, inferred from phase_map shape.
224
+ trim : int
225
+ Number of edge bins to trim.
226
+
227
+ Returns
228
+ -------
229
+ int
230
+ Grid size to use for stripe fitting.
231
+ """
207
232
  if grid_size is None:
208
233
  return int(phase_map.shape[0]) + 2 * trim + 1
209
234
  return int(grid_size)
210
235
 
211
236
 
212
237
  def _stripe_fit_map(params: np.ndarray, grid_size: int, trim: int) -> np.ndarray:
238
+ """Generate a synthetic stripe pattern from fit parameters.
239
+
240
+ Creates a 2D cosine stripe pattern based on the fitted parameters (orientation,
241
+ phase, and spatial frequency). Used for visualizing the fitted stripe model
242
+ and comparing it with the actual phase map.
243
+
244
+ Parameters
245
+ ----------
246
+ params : np.ndarray
247
+ Stripe fit parameters: [angle, phase, frequency].
248
+ - angle: Orientation angle in radians
249
+ - phase: Phase offset in radians
250
+ - frequency: Spatial frequency (inverse wavelength)
251
+ grid_size : int
252
+ Size of the grid to generate (before trimming).
253
+ trim : int
254
+ Number of edge bins to trim from the generated pattern.
255
+
256
+ Returns
257
+ -------
258
+ np.ndarray
259
+ 2D array of shape (grid_size-2*trim, grid_size-2*trim) containing
260
+ the cosine stripe pattern with values in [-1, 1].
261
+
262
+ Notes
263
+ -----
264
+ The pattern is generated on a [0, 3π] × [0, 3π] domain to cover the
265
+ extended torus space, then rotated by the fitted angle and trimmed.
266
+ """
213
267
  numangsint = grid_size
214
268
  x, _ = np.meshgrid(
215
269
  np.linspace(0, 3 * np.pi, numangsint - 1),
@@ -219,33 +219,3 @@ def plot_cohospace_skewed(
219
219
  _ensure_parent_dir(config.save_path)
220
220
  finalize_figure(fig, config)
221
221
  return fig
222
-
223
-
224
- def cohospace_upgrade(*args, **kwargs) -> dict[str, Any]:
225
- """Legacy alias for EcohoSpace (formerly cohospace_upgrade)."""
226
- return cohospace(*args, **kwargs)
227
-
228
-
229
- def ecohospace(*args, **kwargs) -> dict[str, Any]:
230
- """Alias for EcohoSpace (GridCellTorus-style)."""
231
- return cohospace(*args, **kwargs)
232
-
233
-
234
- def plot_cohospace_upgrade(*args, **kwargs) -> plt.Figure:
235
- """Legacy alias for EcohoSpace plotting (formerly plot_cohospace_upgrade)."""
236
- return plot_cohospace(*args, **kwargs)
237
-
238
-
239
- def plot_cohospace_upgrade_skewed(*args, **kwargs) -> plt.Figure:
240
- """Legacy alias for EcohoSpace skewed plotting."""
241
- return plot_cohospace_skewed(*args, **kwargs)
242
-
243
-
244
- def plot_ecohospace(*args, **kwargs) -> plt.Figure:
245
- """Alias for EcohoSpace plotting."""
246
- return plot_cohospace(*args, **kwargs)
247
-
248
-
249
- def plot_ecohospace_skewed(*args, **kwargs) -> plt.Figure:
250
- """Alias for EcohoSpace skewed plotting."""
251
- return plot_cohospace_skewed(*args, **kwargs)
@@ -2,32 +2,14 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- import os
6
5
  from typing import Any
7
6
 
8
7
  import matplotlib.pyplot as plt
9
8
  import numpy as np
10
9
 
11
10
  from ...visualization.core import PlotConfig, finalize_figure
12
- from .cohospace_scatter import skew_transform_torus_scatter
13
-
14
-
15
- def _ensure_plot_config(
16
- config: PlotConfig | None,
17
- factory,
18
- *args,
19
- **defaults,
20
- ) -> PlotConfig:
21
- if config is None:
22
- return factory(*args, **defaults)
23
- return config
24
-
25
-
26
- def _ensure_parent_dir(save_path: str | None) -> None:
27
- if save_path:
28
- parent = os.path.dirname(save_path)
29
- if parent:
30
- os.makedirs(parent, exist_ok=True)
11
+ from .path import skew_transform
12
+ from .utils import _ensure_parent_dir, _ensure_plot_config
31
13
 
32
14
 
33
15
  def cohospace_phase_centers(cohospace_result: dict[str, Any]) -> dict[str, Any]:
@@ -40,7 +22,7 @@ def cohospace_phase_centers(cohospace_result: dict[str, Any]) -> dict[str, Any]:
40
22
  Output from `data.cohospace(...)` (must include `centers`).
41
23
  """
42
24
  centers = np.asarray(cohospace_result["centers"], dtype=float) % (2 * np.pi)
43
- centers_skew = skew_transform_torus_scatter(centers)
25
+ centers_skew = skew_transform(centers)
44
26
  return {
45
27
  "centers": centers,
46
28
  "centers_skew": centers_skew,
@@ -2,38 +2,13 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- import os
6
-
7
5
  import matplotlib.pyplot as plt
8
6
  import numpy as np
9
7
  from scipy.stats import circvar
10
8
 
11
9
  from ...visualization.core import PlotConfig, finalize_figure
12
-
13
-
14
- def _ensure_plot_config(
15
- config: PlotConfig | None,
16
- factory,
17
- *,
18
- kwargs: dict | None = None,
19
- **defaults,
20
- ) -> PlotConfig:
21
- if config is None:
22
- defaults.update({"kwargs": kwargs or {}})
23
- return factory(**defaults)
24
-
25
- if kwargs:
26
- config_kwargs = config.kwargs or {}
27
- config_kwargs.update(kwargs)
28
- config.kwargs = config_kwargs
29
- return config
30
-
31
-
32
- def _ensure_parent_dir(save_path: str | None) -> None:
33
- if save_path:
34
- parent = os.path.dirname(save_path)
35
- if parent:
36
- os.makedirs(parent, exist_ok=True)
10
+ from .path import _align_activity_to_coords, skew_transform
11
+ from .utils import _ensure_parent_dir, _ensure_plot_config
37
12
 
38
13
 
39
14
  # =====================================================================
@@ -48,51 +23,6 @@ def _coho_coords_to_degrees(coords: np.ndarray) -> np.ndarray:
48
23
  return np.degrees(coords % (2 * np.pi))
49
24
 
50
25
 
51
- def _align_activity_to_coords(
52
- coords: np.ndarray,
53
- activity: np.ndarray,
54
- times: np.ndarray | None = None,
55
- *,
56
- label: str = "activity",
57
- auto_filter: bool = True,
58
- ) -> np.ndarray:
59
- """
60
- Align activity to coords by optional time indices and validate lengths.
61
- """
62
- coords = np.asarray(coords)
63
- activity = np.asarray(activity)
64
-
65
- if times is not None:
66
- times = np.asarray(times)
67
- try:
68
- activity = activity[times]
69
- except Exception as exc:
70
- raise ValueError(
71
- f"Failed to index {label} with `times`. Ensure `times` indexes the original time axis."
72
- ) from exc
73
-
74
- if activity.shape[0] != coords.shape[0]:
75
- # Try to reproduce decode's zero-spike filtering if lengths mismatch.
76
- if auto_filter and times is None and activity.ndim == 2:
77
- mask = np.sum(activity > 0, axis=1) >= 1
78
- if mask.sum() == coords.shape[0]:
79
- activity = activity[mask]
80
- else:
81
- raise ValueError(
82
- f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
83
- "If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
84
- "`times=decoding['times']` or slice the activity accordingly."
85
- )
86
- else:
87
- raise ValueError(
88
- f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
89
- "If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
90
- "`times=decoding['times']` or slice the activity accordingly."
91
- )
92
-
93
- return activity
94
-
95
-
96
26
  def plot_cohospace_scatter_trajectory_2d(
97
27
  coords: np.ndarray,
98
28
  times: np.ndarray | None = None,
@@ -802,41 +732,6 @@ def compute_cohoscore_scatter_1d(
802
732
  return scores
803
733
 
804
734
 
805
- def skew_transform_torus_scatter(coords):
806
- """
807
- Convert torus angles (theta1, theta2) into coordinates in a skewed parallelogram fundamental domain.
808
-
809
- Given theta1, theta2 in radians, map:
810
- x = theta1 + 0.5 * theta2
811
- y = (sqrt(3)/2) * theta2
812
-
813
- This is a linear change of basis that turns the square [0, 2π)×[0, 2π) into a 60-degree
814
- parallelogram, which is convenient for visualizing wrap-around behavior on a 2-torus.
815
-
816
- Parameters
817
- ----------
818
- coords : ndarray, shape (T, 2)
819
- Angles (theta1, theta2) in radians.
820
-
821
- Returns
822
- -------
823
- xy : ndarray, shape (T, 2)
824
- Skewed planar coordinates.
825
- """
826
- coords = np.asarray(coords)
827
- if coords.ndim != 2 or coords.shape[1] != 2:
828
- raise ValueError(f"coords must be (T,2), got {coords.shape}")
829
-
830
- theta1 = coords[:, 0]
831
- theta2 = coords[:, 1]
832
-
833
- # Linear change of basis (NO nonlinear scaling)
834
- x = theta1 + 0.5 * theta2
835
- y = (np.sqrt(3) / 2.0) * theta2
836
-
837
- return np.stack([x, y], axis=1)
838
-
839
-
840
735
  def draw_torus_parallelogram_grid_scatter(ax, n_tiles=1, color="0.7", lw=1.0, alpha=0.8):
841
736
  """
842
737
  Draw parallelogram grid corresponding to torus fundamental domain.
@@ -874,7 +769,7 @@ def tile_parallelogram_points_scatter(xy, n_tiles=1):
874
769
  Parameters
875
770
  ----------
876
771
  points : ndarray, shape (T, 2)
877
- Points in the skewed plane (same coordinates as returned by `skew_transform_torus_scatter`).
772
+ Points in the skewed plane (same coordinates as returned by `skew_transform`).
878
773
  n_tiles : int
879
774
  Number of tiles to extend around the base domain.
880
775
  - n_tiles=1 produces a 3x3 tiling
@@ -1020,7 +915,7 @@ def plot_cohospace_scatter_neuron_skewed(
1020
915
  )
1021
916
 
1022
917
  # --- skew transform
1023
- xy = skew_transform_torus_scatter(coords[mask])
918
+ xy = skew_transform(coords[mask])
1024
919
 
1025
920
  # Tiling: if points are tiled, values must be tiled too (FR mode) to keep lengths consistent
1026
921
  if n_tiles and n_tiles > 0:
@@ -1165,7 +1060,7 @@ def plot_cohospace_scatter_population_skewed(
1165
1060
  thr = np.percentile(a, 100 - top_percent)
1166
1061
  mask = a >= thr
1167
1062
 
1168
- xy = skew_transform_torus_scatter(coords[mask])
1063
+ xy = skew_transform(coords[mask])
1169
1064
 
1170
1065
  if n_tiles and n_tiles > 0:
1171
1066
  xy = tile_parallelogram_points_scatter(xy, n_tiles=n_tiles)
@@ -6,7 +6,6 @@ import numpy as np
6
6
  from scipy.ndimage import gaussian_filter1d
7
7
 
8
8
  from .config import DataLoadError, ProcessingError, SpikeEmbeddingConfig
9
- from .filters import _gaussian_filter1d
10
9
 
11
10
 
12
11
  def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None, **kwargs):
@@ -262,8 +261,8 @@ def _load_pos(t, x, y, res=100000, dt=1000):
262
261
  xx = np.interp(tt, t, x)
263
262
  yy = np.interp(tt, t, y)
264
263
 
265
- xxs = _gaussian_filter1d(xx - np.min(xx), sigma=100)
266
- yys = _gaussian_filter1d(yy - np.min(yy), sigma=100)
264
+ xxs = gaussian_filter1d(xx - np.min(xx), sigma=100)
265
+ yys = gaussian_filter1d(yy - np.min(yy), sigma=100)
267
266
  dx = (xxs[1:] - xxs[:-1]) * 100
268
267
  dy = (yys[1:] - yys[:-1]) * 100
269
268
  speed = np.sqrt(dx**2 + dy**2) / 0.01
@@ -1,18 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
- import os
4
3
  from dataclasses import dataclass
5
4
 
6
5
  import numpy as np
7
6
 
8
7
  from ...visualization.core import PlotConfig, finalize_figure
9
-
10
-
11
- def _ensure_parent_dir(save_path: str | None) -> None:
12
- if save_path:
13
- parent = os.path.dirname(save_path)
14
- if parent:
15
- os.makedirs(parent, exist_ok=True)
8
+ from .utils import _ensure_parent_dir
16
9
 
17
10
 
18
11
  def _slice_range(r: tuple[int, int] | None, length: int) -> slice:
@@ -240,6 +240,76 @@ def interp_coords_to_full_1d(idx_map: np.ndarray, coords1: np.ndarray, T_full: i
240
240
  return np.mod(out, 2 * np.pi)[:, None]
241
241
 
242
242
 
243
+ def _align_activity_to_coords(
244
+ coords: np.ndarray,
245
+ activity: np.ndarray,
246
+ times: np.ndarray | None = None,
247
+ *,
248
+ label: str = "activity",
249
+ auto_filter: bool = True,
250
+ ) -> np.ndarray:
251
+ """
252
+ Align activity to coords by optional time indices and validate lengths.
253
+
254
+ Parameters
255
+ ----------
256
+ coords : ndarray
257
+ Decoded coordinates array.
258
+ activity : ndarray
259
+ Activity matrix (firing rate or spikes).
260
+ times : ndarray, optional
261
+ Optional time indices to align activity to coords when coords are computed
262
+ on a subset of timepoints.
263
+ label : str
264
+ Label for error messages (default: "activity").
265
+ auto_filter : bool
266
+ If True and lengths mismatch, auto-filter activity with activity>0 to mimic
267
+ decode filtering.
268
+
269
+ Returns
270
+ -------
271
+ ndarray
272
+ Aligned activity array.
273
+
274
+ Raises
275
+ ------
276
+ ValueError
277
+ If activity length doesn't match coords length after alignment attempts.
278
+ """
279
+ coords = np.asarray(coords)
280
+ activity = np.asarray(activity)
281
+
282
+ if times is not None:
283
+ times = np.asarray(times)
284
+ try:
285
+ activity = activity[times]
286
+ except Exception as exc:
287
+ raise ValueError(
288
+ f"Failed to index {label} with `times`. Ensure `times` indexes the original time axis."
289
+ ) from exc
290
+
291
+ if activity.shape[0] != coords.shape[0]:
292
+ # Try to reproduce decode's zero-spike filtering if lengths mismatch.
293
+ if auto_filter and times is None and activity.ndim == 2:
294
+ mask = np.sum(activity > 0, axis=1) >= 1
295
+ if mask.sum() == coords.shape[0]:
296
+ activity = activity[mask]
297
+ else:
298
+ raise ValueError(
299
+ f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
300
+ "If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
301
+ "`times=decoding['times']` or slice the activity accordingly."
302
+ )
303
+ else:
304
+ raise ValueError(
305
+ f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
306
+ "If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
307
+ "`times=decoding['times']` or slice the activity accordingly."
308
+ )
309
+
310
+ return activity
311
+
312
+
243
313
  def align_coords_to_position_2d(
244
314
  t_full: np.ndarray,
245
315
  x_full: np.ndarray,
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import os
4
3
  from typing import Any
5
4
 
6
5
  import matplotlib.pyplot as plt
@@ -24,31 +23,7 @@ from ...visualization.core import (
24
23
  from ...visualization.core.jupyter_utils import display_animation_in_jupyter, is_jupyter_environment
25
24
  from .config import CANN2DPlotConfig, ProcessingError, SpikeEmbeddingConfig
26
25
  from .embedding import embed_spike_trains
27
-
28
-
29
- def _ensure_plot_config(
30
- config: PlotConfig | None,
31
- factory,
32
- *,
33
- kwargs: dict[str, Any] | None = None,
34
- **defaults: Any,
35
- ) -> PlotConfig:
36
- if config is None:
37
- defaults.update({"kwargs": kwargs or {}})
38
- return factory(**defaults)
39
-
40
- if kwargs:
41
- config_kwargs = config.kwargs or {}
42
- config_kwargs.update(kwargs)
43
- config.kwargs = config_kwargs
44
- return config
45
-
46
-
47
- def _ensure_parent_dir(save_path: str | None) -> None:
48
- if save_path:
49
- parent = os.path.dirname(save_path)
50
- if parent:
51
- os.makedirs(parent, exist_ok=True)
26
+ from .utils import _ensure_parent_dir, _ensure_plot_config
52
27
 
53
28
 
54
29
  def _render_torus_frame(frame_index: int, frame_data: dict[str, Any]) -> np.ndarray:
@@ -0,0 +1,160 @@
1
+ """Shared utility functions for ASA (Attractor State Analysis) modules."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ from scipy.ndimage import gaussian_filter
10
+
11
+ from ...visualization.core import PlotConfig
12
+ from .path import find_coords_matrix, find_times_box
13
+
14
+
15
+ def _ensure_plot_config(
16
+ config: PlotConfig | None,
17
+ factory,
18
+ *args,
19
+ kwargs: dict | None = None,
20
+ **defaults,
21
+ ) -> PlotConfig:
22
+ """Ensure a PlotConfig exists, creating one from factory if needed.
23
+
24
+ Args:
25
+ config: Optional existing PlotConfig.
26
+ factory: Factory function to create PlotConfig if config is None.
27
+ *args: Positional arguments for factory.
28
+ kwargs: Optional dict to merge into config.kwargs.
29
+ **defaults: Keyword arguments for factory.
30
+
31
+ Returns:
32
+ PlotConfig instance.
33
+ """
34
+ if config is None:
35
+ if kwargs:
36
+ defaults.update({"kwargs": kwargs})
37
+ return factory(*args, **defaults)
38
+
39
+ # If config exists and kwargs provided, merge them
40
+ if kwargs:
41
+ config_kwargs = config.kwargs or {}
42
+ config_kwargs.update(kwargs)
43
+ config.kwargs = config_kwargs
44
+ return config
45
+
46
+
47
+ def _ensure_parent_dir(save_path: str | None) -> None:
48
+ """Create parent directory for save_path if it doesn't exist.
49
+
50
+ Args:
51
+ save_path: Optional file path. If provided, creates parent directory.
52
+ """
53
+ if save_path:
54
+ parent = os.path.dirname(save_path)
55
+ if parent:
56
+ os.makedirs(parent, exist_ok=True)
57
+
58
+
59
+ def _circmean(x: np.ndarray) -> float:
60
+ """Compute circular mean of angles.
61
+
62
+ Args:
63
+ x: Array of angles in radians.
64
+
65
+ Returns:
66
+ Circular mean in radians.
67
+ """
68
+ return float(np.arctan2(np.mean(np.sin(x)), np.mean(np.cos(x))))
69
+
70
+
71
+ def _smooth_circular_map(
72
+ mtot: np.ndarray,
73
+ smooth_sigma: float,
74
+ *,
75
+ fill_nan: bool = False,
76
+ fill_sigma: float | None = None,
77
+ fill_min_weight: float = 1e-3,
78
+ ) -> np.ndarray:
79
+ """Smooth a circular phase map using Gaussian filtering in sin/cos space.
80
+
81
+ Args:
82
+ mtot: Phase map array (angles in radians).
83
+ smooth_sigma: Gaussian smoothing sigma.
84
+ fill_nan: Whether to fill NaN values using weighted interpolation.
85
+ fill_sigma: Sigma for NaN filling (defaults to smooth_sigma).
86
+ fill_min_weight: Minimum weight threshold for valid interpolation.
87
+
88
+ Returns:
89
+ Smoothed phase map.
90
+ """
91
+ mtot = np.asarray(mtot, dtype=float)
92
+ nans = np.isnan(mtot)
93
+ mask = (~nans).astype(float)
94
+ sintot = np.sin(mtot)
95
+ costot = np.cos(mtot)
96
+ sintot[nans] = 0.0
97
+ costot[nans] = 0.0
98
+
99
+ if fill_nan:
100
+ if fill_sigma is None:
101
+ fill_sigma = smooth_sigma if smooth_sigma and smooth_sigma > 0 else 1.0
102
+ weight = gaussian_filter(mask, fill_sigma)
103
+ sintot = gaussian_filter(sintot * mask, fill_sigma)
104
+ costot = gaussian_filter(costot * mask, fill_sigma)
105
+ min_weight = max(float(fill_min_weight), 0.0)
106
+ valid = weight > min_weight
107
+ sintot = np.divide(sintot, weight, out=np.zeros_like(sintot), where=valid)
108
+ costot = np.divide(costot, weight, out=np.zeros_like(costot), where=valid)
109
+ mtot = np.arctan2(sintot, costot)
110
+ if fill_min_weight > 0:
111
+ mtot[~valid] = np.nan
112
+ return mtot
113
+
114
+ if smooth_sigma and smooth_sigma > 0:
115
+ sintot = gaussian_filter(sintot, smooth_sigma)
116
+ costot = gaussian_filter(costot, smooth_sigma)
117
+ mtot = np.arctan2(sintot, costot)
118
+ mtot[nans] = np.nan
119
+ return mtot
120
+
121
+
122
+ def _extract_coords_and_times(
123
+ decoding_result: dict[str, Any],
124
+ coords_key: str | None = None,
125
+ ) -> tuple[np.ndarray, np.ndarray | None]:
126
+ """Extract coordinates and time indices from decoding result.
127
+
128
+ Args:
129
+ decoding_result: Dictionary containing decoding results.
130
+ coords_key: Optional key for coordinates (defaults to 'coordsbox' or auto-detect).
131
+
132
+ Returns:
133
+ Tuple of (coords, times_box) where coords is (T, 2) and times_box is optional.
134
+ """
135
+ if coords_key is not None:
136
+ if coords_key not in decoding_result:
137
+ raise KeyError(f"coords_key '{coords_key}' not found in decoding_result.")
138
+ coords = np.asarray(decoding_result[coords_key])
139
+ elif "coordsbox" in decoding_result:
140
+ coords = np.asarray(decoding_result["coordsbox"])
141
+ else:
142
+ coords, _ = find_coords_matrix(decoding_result)
143
+
144
+ times_box, _ = find_times_box(decoding_result)
145
+ return coords, times_box
146
+
147
+
148
+ def _phase_map_valid_fraction(phase_map: np.ndarray) -> float:
149
+ """Calculate fraction of valid (non-NaN) values in phase map.
150
+
151
+ Args:
152
+ phase_map: Phase map array.
153
+
154
+ Returns:
155
+ Fraction of valid values (0.0 to 1.0).
156
+ """
157
+ valid = np.isfinite(phase_map)
158
+ if valid.size == 0:
159
+ return 0.0
160
+ return float(np.mean(valid))
@@ -8,6 +8,7 @@ import numpy as np
8
8
  from matplotlib import pyplot as plt
9
9
  from scipy.ndimage import gaussian_filter1d
10
10
 
11
+ from canns.analyzer.data.asa.utils import _ensure_plot_config
11
12
  from canns.analyzer.visualization.core.config import PlotConfig, finalize_figure
12
13
 
13
14
  _DEFAULT_BTN_COLORS = {
@@ -17,24 +18,6 @@ _DEFAULT_BTN_COLORS = {
17
18
  }
18
19
 
19
20
 
20
- def _ensure_plot_config(
21
- config: PlotConfig | None,
22
- factory,
23
- *,
24
- kwargs: dict[str, Any] | None = None,
25
- **defaults: Any,
26
- ) -> PlotConfig:
27
- if config is None:
28
- defaults.update({"kwargs": kwargs or {}})
29
- return factory(**defaults)
30
-
31
- if kwargs:
32
- config_kwargs = config.kwargs or {}
33
- config_kwargs.update(kwargs)
34
- config.kwargs = config_kwargs
35
- return config
36
-
37
-
38
21
  def _canonical_label(label: str) -> str:
39
22
  lab = label.strip().lower()
40
23
  if lab in ("b", "bursty"):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: canns
3
- Version: 0.15.0
3
+ Version: 0.15.1
4
4
  Summary: A Python Library for Continuous Attractor Neural Networks
5
5
  Project-URL: Repository, https://github.com/routhleck/canns
6
6
  Author-email: Sichao He <sichaohe@outlook.com>
@@ -3,23 +3,22 @@ canns/_version.py,sha256=zIvJPOGBFvo4VV6f586rlO_bvhuFp1fsxjf6xhsqkJY,1547
3
3
  canns/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  canns/analyzer/__init__.py,sha256=EQ02fYHkpMADp-ojpVCVtapuSPkl6j5WVfdPy0mOTs4,506
5
5
  canns/analyzer/data/__init__.py,sha256=RfS8vwApLkNF05Y_lfPaJpN_bRv-mOA_uFziaduDHgI,354
6
- canns/analyzer/data/asa/__init__.py,sha256=nXaRd49ePAnrjJaUO-FNkN_RlldEkC55e9jszEhmYKM,3897
7
- canns/analyzer/data/asa/coho.py,sha256=rYKb0ae2ZK7PX_XkDd7JOVNmPRStG6ZLCKvUsseoK1g,2884
8
- canns/analyzer/data/asa/cohomap.py,sha256=EmlG13f948de1O7Sm9e0YACtjbp9lEoFWeG3-hWRu6M,13437
9
- canns/analyzer/data/asa/cohomap_scatter.py,sha256=qu5F9UBCG-JuRrQL0Ag0ACAM61bzAGXEYpPgZ-Oqciw,231
10
- canns/analyzer/data/asa/cohomap_vectors.py,sha256=krzLr-MSw4Az8VD0AZruXHkv6G9i6VOE6AgavDJCGIw,9113
11
- canns/analyzer/data/asa/cohospace.py,sha256=dQ_HDNW6xGy6RRBUD8zK3FnQksjiupd84ORIwy6NKCE,7668
12
- canns/analyzer/data/asa/cohospace_phase_centers.py,sha256=JqgDH_SJZOln5wFhnNJ_qnPTlsTBFWb3M5VZX5b0Wrk,4113
13
- canns/analyzer/data/asa/cohospace_scatter.py,sha256=z13lLFdYkODmjKeOYildM5TgvT2QxmFSh7q8BYS5WjA,37163
6
+ canns/analyzer/data/asa/__init__.py,sha256=JtXp1QgoqsA1_XiZ4TNFnTGP_iKo5fHeJ_V3wpdhN6g,3327
7
+ canns/analyzer/data/asa/coho.py,sha256=775bOruTVSNRjM0A4YRTeiWKVQtEMCR3zTYXgVTeLiw,444
8
+ canns/analyzer/data/asa/cohomap.py,sha256=ZIiH1Mu306jg1lLmeV-akEzbojXQQVKW3nl0PbFKD4M,15217
9
+ canns/analyzer/data/asa/cohomap_vectors.py,sha256=o2eHMqiAvmhZfebOvHZ1zypqZKKgvZu0WLs3KD3Nfis,11273
10
+ canns/analyzer/data/asa/cohospace.py,sha256=8sq1Z0QGDxzC-5YY1n-nDEtqH3vnIsVA0Q1ST7egxTE,6701
11
+ canns/analyzer/data/asa/cohospace_phase_centers.py,sha256=ZQDhd_cBv8SFcVxgGM0nJabTLjKQZ1mg9Ho6duUG0xY,3738
12
+ canns/analyzer/data/asa/cohospace_scatter.py,sha256=lBubii5rz-MiTT_rYa5u59nzw2vkuOcFDnM6UhvvJ9M,33893
14
13
  canns/analyzer/data/asa/config.py,sha256=qm0k0nt0xuDUK5t63MG7ii7fgs2XbxyLxKOaOKJuB_s,6398
15
14
  canns/analyzer/data/asa/decode.py,sha256=NG8vVx2cPG7uSJDovnC2vzk0dsqU8oR4jaNPxxrvCc0,16501
16
- canns/analyzer/data/asa/embedding.py,sha256=CupBvkqZJ7zDMQWH-aIfqrOPTdUfv6wfTtZrI4tj0f0,9623
17
- canns/analyzer/data/asa/filters.py,sha256=D-1mDVn4hBEAphKUgx1gQEUfgbghKcNQhZmr4xEExQA,7146
15
+ canns/analyzer/data/asa/embedding.py,sha256=V1TzOD4wzyx3xmwCBHLLWgX92rBkMjWkFz-F-4xPPxo,9581
18
16
  canns/analyzer/data/asa/fly_roi.py,sha256=_scBOd-4t9yv_1tHk7wbXJwPieU-L-QtFJY6fhHpxDI,38031
19
- canns/analyzer/data/asa/fr.py,sha256=jt99H50e1RRAQgMIdkfK0rBbembZJEr9SMrxK-ZI_LA,13449
20
- canns/analyzer/data/asa/path.py,sha256=dL6hsqBoPFfC4ZrHDVFDWprbRfJAAYpiq4tIkZ6NvHY,15540
21
- canns/analyzer/data/asa/plotting.py,sha256=lpQfN55Sy8B5P1YSQhch_wc0e6SIEVCw9Wch7wHw04o,42798
17
+ canns/analyzer/data/asa/fr.py,sha256=-MuA50eTVdVJB7CUBeLelvnm0i0c0pK5NRICLoSD7oo,13292
18
+ canns/analyzer/data/asa/path.py,sha256=JM8NqkhwFx7qDs26n_t2bL5md9V4oR2E2ksXwYAwTug,17965
19
+ canns/analyzer/data/asa/plotting.py,sha256=PuYGZgal67uDRZ4xBczJKbVblqbza4a03re105Wj7vA,42243
22
20
  canns/analyzer/data/asa/tda.py,sha256=7IdxhBNEE99qenG6Zi4B5tv_L9K6gAW6HHxYGiErx4c,30574
21
+ canns/analyzer/data/asa/utils.py,sha256=s9R1d6K4op9MnjcBaI4vrnEAC-9KyZE6S9dg_zR0ONs,4899
23
22
  canns/analyzer/data/cell_classification/__init__.py,sha256=Ri0VJYn2OI3ygC4m-Xc9rjFvgPLaEykc-D94VxmUclQ,2447
24
23
  canns/analyzer/data/cell_classification/core/__init__.py,sha256=J9uqjx2wTK-uh3OWFqP8BkY_ySz3rU_VWRNQ_1t3EbM,865
25
24
  canns/analyzer/data/cell_classification/core/btn.py,sha256=rgZdEoMgqgOPah5KHEQAx7pNWDsWBzZ2pNQ5BstTYFM,8546
@@ -35,7 +34,7 @@ canns/analyzer/data/cell_classification/utils/correlation.py,sha256=57Ckn8OQGLip
35
34
  canns/analyzer/data/cell_classification/utils/geometry.py,sha256=jOLh3GeO-riR5a7r7Q7uON3HU_bYOZZJLbokU5bjCOQ,12683
36
35
  canns/analyzer/data/cell_classification/utils/image_processing.py,sha256=o9bLT4ycJ_IF7SKBe2RqSWIQwNcpi9v4AI-N5vpm_jM,12805
37
36
  canns/analyzer/data/cell_classification/visualization/__init__.py,sha256=fmEHZBcurW6y6FwySLoq65b6CH2kNUB02NCVw2ou6Nc,590
38
- canns/analyzer/data/cell_classification/visualization/btn_plots.py,sha256=nl29Ihe-gayCu_poJIWLN9oT6Srg1yjC-okPZ0IeRjo,7702
37
+ canns/analyzer/data/cell_classification/visualization/btn_plots.py,sha256=wHP6jyzzqvD2bgYriQFu1tZQiHevlYDy8V1fFNBw4M0,7345
39
38
  canns/analyzer/data/cell_classification/visualization/grid_plots.py,sha256=NFtyYOe2Szt0EOIwQmZradwEvvRjjm7mm6VnnGThDQ0,7914
40
39
  canns/analyzer/data/cell_classification/visualization/hd_plots.py,sha256=nzw1jck3VHvAFsJAGelhrJf1q27A5PI0r3NKVgeea8U,5670
41
40
  canns/analyzer/metrics/__init__.py,sha256=DTsrv1HW133_RgvhWzz7Gx-bP2hOZbPO2unCPPyf9gs,178
@@ -166,8 +165,8 @@ canns/trainer/utils.py,sha256=ZdoLiRqFLfKXsWi0KX3wGUp0OqFikwiou8dPf3xvFhE,2847
166
165
  canns/typing/__init__.py,sha256=mXySdfmD8fA56WqZTb1Nj-ZovcejwLzNjuk6PRfTwmA,156
167
166
  canns/utils/__init__.py,sha256=OMyZ5jqZAIUS2Jr0qcnvvrx6YM-BZ1EJy5uZYeA3HC0,366
168
167
  canns/utils/benchmark.py,sha256=oJ7nvbvnQMh4_MZh7z160NPLp-197X0rEnmnLHYlev4,1361
169
- canns-0.15.0.dist-info/METADATA,sha256=18mW4kldZzRDEw-pm3KnNp_mIQWjoGbj_sFw43GnaSE,9799
170
- canns-0.15.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
171
- canns-0.15.0.dist-info/entry_points.txt,sha256=57YF2HZp_BG3GeGB8L0m3wR1sSfNyMXF1q4CKEjce6U,164
172
- canns-0.15.0.dist-info/licenses/LICENSE,sha256=u6NJ1N-QSnf5yTwSk5UvFAdU2yKD0jxG0Xa91n1cPO4,11306
173
- canns-0.15.0.dist-info/RECORD,,
168
+ canns-0.15.1.dist-info/METADATA,sha256=Rxhq9ndM1JJjHYGYaB_BjffkvlLXO-22xIzwzOm8h1I,9799
169
+ canns-0.15.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
170
+ canns-0.15.1.dist-info/entry_points.txt,sha256=57YF2HZp_BG3GeGB8L0m3wR1sSfNyMXF1q4CKEjce6U,164
171
+ canns-0.15.1.dist-info/licenses/LICENSE,sha256=u6NJ1N-QSnf5yTwSk5UvFAdU2yKD0jxG0Xa91n1cPO4,11306
172
+ canns-0.15.1.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- """Scatter-style CohoMap plotting helpers."""
2
-
3
- from __future__ import annotations
4
-
5
- from .plotting import plot_cohomap_scatter, plot_cohomap_scatter_multi
6
-
7
- __all__ = [
8
- "plot_cohomap_scatter",
9
- "plot_cohomap_scatter_multi",
10
- ]
@@ -1,208 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import numbers
4
-
5
- import numpy as np
6
- from numpy.exceptions import AxisError
7
- from scipy.ndimage import _nd_image, _ni_support
8
- from scipy.ndimage._filters import _invalid_origin
9
-
10
-
11
- def _gaussian_filter1d(
12
- input,
13
- sigma,
14
- axis=-1,
15
- order=0,
16
- output=None,
17
- mode="reflect",
18
- cval=0.0,
19
- truncate=4.0,
20
- *,
21
- radius=None,
22
- ):
23
- """1-D Gaussian filter.
24
-
25
- Parameters
26
- ----------
27
- %(input)s
28
- sigma : scalar
29
- standard deviation for Gaussian kernel
30
- %(axis)s
31
- order : int, optional
32
- An order of 0 corresponds to convolution with a Gaussian
33
- kernel. A positive order corresponds to convolution with
34
- that derivative of a Gaussian.
35
- %(output)s
36
- %(mode_reflect)s
37
- %(cval)s
38
- truncate : float, optional
39
- Truncate the filter at this many standard deviations.
40
- Default is 4.0.
41
- radius : None or int, optional
42
- Radius of the Gaussian kernel. If specified, the size of
43
- the kernel will be ``2*radius + 1``, and `truncate` is ignored.
44
- Default is None.
45
-
46
- Returns
47
- -------
48
- gaussian_filter1d : ndarray
49
-
50
- Notes
51
- -----
52
- The Gaussian kernel will have size ``2*radius + 1`` along each axis. If
53
- `radius` is None, a default ``radius = round(truncate * sigma)`` will be
54
- used.
55
-
56
- Examples
57
- --------
58
- >>> from scipy.ndimage import gaussian_filter1d
59
- >>> import numpy as np
60
- >>> gaussian_filter1d([1.0, 2.0, 3.0, 4.0, 5.0], 1)
61
- array([ 1.42704095, 2.06782203, 3. , 3.93217797, 4.57295905])
62
- >>> _gaussian_filter1d([1.0, 2.0, 3.0, 4.0, 5.0], 4)
63
- array([ 2.91948343, 2.95023502, 3. , 3.04976498, 3.08051657])
64
- >>> import matplotlib.pyplot as plt
65
- >>> rng = np.random.default_rng()
66
- >>> x = rng.standard_normal(101).cumsum()
67
- >>> y3 = _gaussian_filter1d(x, 3)
68
- >>> y6 = _gaussian_filter1d(x, 6)
69
- >>> plt.plot(x, 'k', label='original data')
70
- >>> plt.plot(y3, '--', label='filtered, sigma=3')
71
- >>> plt.plot(y6, ':', label='filtered, sigma=6')
72
- >>> plt.legend()
73
- >>> plt.grid()
74
- >>> plt.show()
75
-
76
- """
77
- sd = float(sigma)
78
- # make the radius of the filter equal to truncate standard deviations
79
- lw = int(truncate * sd + 0.5)
80
- if radius is not None:
81
- lw = radius
82
- if not isinstance(lw, numbers.Integral) or lw < 0:
83
- raise ValueError("Radius must be a nonnegative integer.")
84
- # Since we are calling correlate, not convolve, revert the kernel
85
- weights = _gaussian_kernel1d(sigma, order, lw)[::-1]
86
- return _correlate1d(input, weights, axis, output, mode, cval, 0)
87
-
88
-
89
- def _gaussian_kernel1d(sigma, order, radius):
90
- """
91
- Computes a 1-D Gaussian convolution kernel.
92
- """
93
- if order < 0:
94
- raise ValueError("order must be non-negative")
95
- exponent_range = np.arange(order + 1)
96
- sigma2 = sigma * sigma
97
- x = np.arange(-radius, radius + 1)
98
- phi_x = np.exp(-0.5 / sigma2 * x**2)
99
- phi_x = phi_x / phi_x.sum()
100
-
101
- if order == 0:
102
- return phi_x
103
- else:
104
- # f(x) = q(x) * phi(x) = q(x) * exp(p(x))
105
- # f'(x) = (q'(x) + q(x) * p'(x)) * phi(x)
106
- # p'(x) = -1 / sigma ** 2
107
- # Implement q'(x) + q(x) * p'(x) as a matrix operator and apply to the
108
- # coefficients of q(x)
109
- q = np.zeros(order + 1)
110
- q[0] = 1
111
- D = np.diag(exponent_range[1:], 1) # D @ q(x) = q'(x)
112
- P = np.diag(np.ones(order) / -sigma2, -1) # P @ q(x) = q(x) * p'(x)
113
- Q_deriv = D + P
114
- for _ in range(order):
115
- q = Q_deriv.dot(q)
116
- q = (x[:, None] ** exponent_range).dot(q)
117
- return q * phi_x
118
-
119
-
120
- def _correlate1d(input, weights, axis=-1, output=None, mode="reflect", cval=0.0, origin=0):
121
- """Calculate a 1-D correlation along the given axis.
122
-
123
- The lines of the array along the given axis are correlated with the
124
- given weights.
125
-
126
- Parameters
127
- ----------
128
- %(input)s
129
- weights : array
130
- 1-D sequence of numbers.
131
- %(axis)s
132
- %(output)s
133
- %(mode_reflect)s
134
- %(cval)s
135
- %(origin)s
136
-
137
- Returns
138
- -------
139
- result : ndarray
140
- Correlation result. Has the same shape as `input`.
141
-
142
- Examples
143
- --------
144
- >>> from scipy.ndimage import correlate1d
145
- >>> correlate1d([2, 8, 0, 4, 1, 9, 9, 0], weights=[1, 3])
146
- array([ 8, 26, 8, 12, 7, 28, 36, 9])
147
- """
148
- input = np.asarray(input)
149
- weights = np.asarray(weights)
150
- complex_input = input.dtype.kind == "c"
151
- complex_weights = weights.dtype.kind == "c"
152
- if complex_input or complex_weights:
153
- if complex_weights:
154
- weights = weights.conj()
155
- weights = weights.astype(np.complex128, copy=False)
156
- kwargs = dict(axis=axis, mode=mode, origin=origin)
157
- output = _ni_support._get_output(output, input, complex_output=True)
158
- return _complex_via_real_components(_correlate1d, input, weights, output, cval, **kwargs)
159
-
160
- output = _ni_support._get_output(output, input)
161
- weights = np.asarray(weights, dtype=np.float64)
162
- if weights.ndim != 1 or weights.shape[0] < 1:
163
- raise RuntimeError("no filter weights given")
164
- if not weights.flags.contiguous:
165
- weights = weights.copy()
166
- axis = _normalize_axis_index(axis, input.ndim)
167
- if _invalid_origin(origin, len(weights)):
168
- raise ValueError(
169
- "Invalid origin; origin must satisfy "
170
- "-(len(weights) // 2) <= origin <= "
171
- "(len(weights)-1) // 2"
172
- )
173
- mode = _ni_support._extend_mode_to_code(mode)
174
- _nd_image.correlate1d(input, weights, axis, output, mode, cval, origin)
175
- return output
176
-
177
-
178
- def _complex_via_real_components(func, input, weights, output, cval, **kwargs):
179
- """Complex convolution via a linear combination of real convolutions."""
180
- complex_input = input.dtype.kind == "c"
181
- complex_weights = weights.dtype.kind == "c"
182
- if complex_input and complex_weights:
183
- # real component of the output
184
- func(input.real, weights.real, output=output.real, cval=np.real(cval), **kwargs)
185
- output.real -= func(input.imag, weights.imag, output=None, cval=np.imag(cval), **kwargs)
186
- # imaginary component of the output
187
- func(input.real, weights.imag, output=output.imag, cval=np.real(cval), **kwargs)
188
- output.imag += func(input.imag, weights.real, output=None, cval=np.imag(cval), **kwargs)
189
- elif complex_input:
190
- func(input.real, weights, output=output.real, cval=np.real(cval), **kwargs)
191
- func(input.imag, weights, output=output.imag, cval=np.imag(cval), **kwargs)
192
- else:
193
- if np.iscomplexobj(cval):
194
- raise ValueError("Cannot provide a complex-valued cval when the input is real.")
195
- func(input, weights.real, output=output.real, cval=cval, **kwargs)
196
- func(input, weights.imag, output=output.imag, cval=cval, **kwargs)
197
- return output
198
-
199
-
200
- def _normalize_axis_index(axis, ndim):
201
- # Check if `axis` is in the correct range and normalize it
202
- if axis < -ndim or axis >= ndim:
203
- msg = f"axis {axis} is out of bounds for array of dimension {ndim}"
204
- raise AxisError(msg)
205
-
206
- if axis < 0:
207
- axis = axis + ndim
208
- return axis
File without changes