canns 0.15.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,15 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  from .cohomap import (
4
4
  cohomap,
5
- cohomap_upgrade,
6
- ecohomap,
7
5
  fit_cohomap_stripes,
8
- fit_cohomap_stripes_upgrade,
9
6
  plot_cohomap,
10
- plot_cohomap_upgrade,
11
- plot_ecohomap,
12
7
  )
13
- from .cohomap_scatter import plot_cohomap_scatter, plot_cohomap_scatter_multi
14
8
  from .cohomap_vectors import (
15
9
  cohomap_vectors,
16
10
  plot_cohomap_stripes,
@@ -18,14 +12,8 @@ from .cohomap_vectors import (
18
12
  )
19
13
  from .cohospace import (
20
14
  cohospace,
21
- cohospace_upgrade,
22
- ecohospace,
23
15
  plot_cohospace,
24
16
  plot_cohospace_skewed,
25
- plot_cohospace_upgrade,
26
- plot_cohospace_upgrade_skewed,
27
- plot_ecohospace,
28
- plot_ecohospace_skewed,
29
17
  )
30
18
  from .cohospace_phase_centers import (
31
19
  cohospace_phase_centers,
@@ -81,6 +69,8 @@ from .path import (
81
69
  from .plotting import (
82
70
  plot_2d_bump_on_manifold,
83
71
  plot_3d_bump_on_torus,
72
+ plot_cohomap_scatter,
73
+ plot_cohomap_scatter_multi,
84
74
  plot_path_compare_1d,
85
75
  plot_path_compare_2d,
86
76
  plot_projection,
@@ -109,22 +99,11 @@ __all__ = [
109
99
  "plot_3d_bump_on_torus",
110
100
  "plot_2d_bump_on_manifold",
111
101
  "cohomap",
112
- "cohomap_upgrade",
113
102
  "fit_cohomap_stripes",
114
- "fit_cohomap_stripes_upgrade",
115
103
  "plot_cohomap",
116
- "plot_cohomap_upgrade",
117
104
  "cohospace",
118
- "cohospace_upgrade",
119
105
  "plot_cohospace",
120
106
  "plot_cohospace_skewed",
121
- "plot_cohospace_upgrade",
122
- "plot_cohospace_upgrade_skewed",
123
- "ecohomap",
124
- "ecohospace",
125
- "plot_ecohomap",
126
- "plot_ecohospace",
127
- "plot_ecohospace_skewed",
128
107
  "cohomap_vectors",
129
108
  "plot_cohomap_stripes",
130
109
  "plot_cohomap_vectors",
@@ -2,96 +2,20 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- import os
6
- from typing import Any
7
-
8
- import numpy as np
9
- from scipy.ndimage import gaussian_filter
10
-
11
- from ...visualization.core import PlotConfig
12
- from .path import find_coords_matrix, find_times_box
13
-
14
-
15
- def _ensure_plot_config(
16
- config: PlotConfig | None,
17
- factory,
18
- *args,
19
- **defaults,
20
- ) -> PlotConfig:
21
- if config is None:
22
- return factory(*args, **defaults)
23
- return config
24
-
25
-
26
- def _ensure_parent_dir(save_path: str | None) -> None:
27
- if save_path:
28
- parent = os.path.dirname(save_path)
29
- if parent:
30
- os.makedirs(parent, exist_ok=True)
31
-
32
-
33
- def _circmean(x: np.ndarray) -> float:
34
- return float(np.arctan2(np.mean(np.sin(x)), np.mean(np.cos(x))))
35
-
36
-
37
- def _smooth_circular_map(
38
- mtot: np.ndarray,
39
- smooth_sigma: float,
40
- *,
41
- fill_nan: bool = False,
42
- fill_sigma: float | None = None,
43
- fill_min_weight: float = 1e-3,
44
- ) -> np.ndarray:
45
- mtot = np.asarray(mtot, dtype=float)
46
- nans = np.isnan(mtot)
47
- mask = (~nans).astype(float)
48
- sintot = np.sin(mtot)
49
- costot = np.cos(mtot)
50
- sintot[nans] = 0.0
51
- costot[nans] = 0.0
52
-
53
- if fill_nan:
54
- if fill_sigma is None:
55
- fill_sigma = smooth_sigma if smooth_sigma and smooth_sigma > 0 else 1.0
56
- weight = gaussian_filter(mask, fill_sigma)
57
- sintot = gaussian_filter(sintot * mask, fill_sigma)
58
- costot = gaussian_filter(costot * mask, fill_sigma)
59
- min_weight = max(float(fill_min_weight), 0.0)
60
- valid = weight > min_weight
61
- sintot = np.divide(sintot, weight, out=np.zeros_like(sintot), where=valid)
62
- costot = np.divide(costot, weight, out=np.zeros_like(costot), where=valid)
63
- mtot = np.arctan2(sintot, costot)
64
- if fill_min_weight > 0:
65
- mtot[~valid] = np.nan
66
- return mtot
67
-
68
- if smooth_sigma and smooth_sigma > 0:
69
- sintot = gaussian_filter(sintot, smooth_sigma)
70
- costot = gaussian_filter(costot, smooth_sigma)
71
- mtot = np.arctan2(sintot, costot)
72
- mtot[nans] = np.nan
73
- return mtot
74
-
75
-
76
- def _extract_coords_and_times(
77
- decoding_result: dict[str, Any],
78
- coords_key: str | None = None,
79
- ) -> tuple[np.ndarray, np.ndarray | None]:
80
- if coords_key is not None:
81
- if coords_key not in decoding_result:
82
- raise KeyError(f"coords_key '{coords_key}' not found in decoding_result.")
83
- coords = np.asarray(decoding_result[coords_key])
84
- elif "coordsbox" in decoding_result:
85
- coords = np.asarray(decoding_result["coordsbox"])
86
- else:
87
- coords, _ = find_coords_matrix(decoding_result)
88
-
89
- times_box, _ = find_times_box(decoding_result)
90
- return coords, times_box
91
-
92
-
93
- def _phase_map_valid_fraction(phase_map: np.ndarray) -> float:
94
- valid = np.isfinite(phase_map)
95
- if valid.size == 0:
96
- return 0.0
97
- return float(np.mean(valid))
5
+ from .utils import (
6
+ _circmean,
7
+ _ensure_parent_dir,
8
+ _ensure_plot_config,
9
+ _extract_coords_and_times,
10
+ _phase_map_valid_fraction,
11
+ _smooth_circular_map,
12
+ )
13
+
14
+ __all__ = [
15
+ "_ensure_plot_config",
16
+ "_ensure_parent_dir",
17
+ "_circmean",
18
+ "_smooth_circular_map",
19
+ "_extract_coords_and_times",
20
+ "_phase_map_valid_fraction",
21
+ ]
@@ -57,6 +57,36 @@ def _rot_coord(
57
57
  c2: np.ndarray,
58
58
  p: tuple[int, int],
59
59
  ) -> np.ndarray:
60
+ """Transform and align decoded coordinates based on stripe orientation.
61
+
62
+ This function rotates and aligns coordinates to match the dominant stripe
63
+ orientation in the CohoMap. It handles the torus geometry by choosing the
64
+ appropriate coordinate system and applying reflections/rotations.
65
+
66
+ Parameters
67
+ ----------
68
+ params1 : np.ndarray
69
+ Stripe fit parameters for first dimension (includes orientation angle).
70
+ params2 : np.ndarray
71
+ Stripe fit parameters for second dimension (includes orientation angle).
72
+ c1 : np.ndarray
73
+ Decoded coordinates for first dimension.
74
+ c2 : np.ndarray
75
+ Decoded coordinates for second dimension.
76
+ p : tuple[int, int]
77
+ Direction indicators (1 or -1) for each dimension, controlling reflections.
78
+
79
+ Returns
80
+ -------
81
+ np.ndarray
82
+ Aligned coordinates of shape (N, 2), wrapped to [0, 2π).
83
+
84
+ Notes
85
+ -----
86
+ The function selects the coordinate system based on which dimension has
87
+ stronger horizontal alignment (larger |cos(angle)|), then applies geometric
88
+ transformations to align the coordinates with the stripe pattern.
89
+ """
60
90
  if abs(np.cos(params1[0])) < abs(np.cos(params2[0])):
61
91
  cc1 = c2.copy()
62
92
  cc2 = c1.copy()
@@ -96,6 +126,46 @@ def _toroidal_align_coords(
96
126
  trim: int,
97
127
  grid_size: int | None,
98
128
  ) -> tuple[np.ndarray, float, float]:
129
+ """Align decoded coordinates to CohoMap stripe patterns.
130
+
131
+ This function fits stripe patterns to both phase maps, determines the
132
+ optimal orientation and phase signs, then transforms the decoded coordinates
133
+ to align with the detected stripe structure.
134
+
135
+ Parameters
136
+ ----------
137
+ coords : np.ndarray
138
+ Decoded coordinates of shape (N, 2+). Only first two columns are used.
139
+ phase_map1 : np.ndarray
140
+ Phase map for first dimension (2D grid).
141
+ phase_map2 : np.ndarray
142
+ Phase map for second dimension (2D grid), must match phase_map1 shape.
143
+ trim : int
144
+ Number of edge bins to trim when fitting stripes.
145
+ grid_size : int, optional
146
+ Grid size for stripe fitting. If None, inferred from phase_map shape.
147
+
148
+ Returns
149
+ -------
150
+ aligned : np.ndarray
151
+ Aligned coordinates of shape (N, 2), wrapped to [0, 2π).
152
+ f1 : float
153
+ Fit quality score for first dimension (higher is better).
154
+ f2 : float
155
+ Fit quality score for second dimension (higher is better).
156
+
157
+ Raises
158
+ ------
159
+ ValueError
160
+ If phase_map shapes don't match or coords has wrong shape.
161
+
162
+ Notes
163
+ -----
164
+ The alignment process:
165
+ 1. Fits stripe patterns to both phase maps using FFT-based detection
166
+ 2. Determines phase signs (±1) for each dimension
167
+ 3. Applies geometric transformations via `_rot_coord()` to align coordinates
168
+ """
99
169
  if phase_map1.shape != phase_map2.shape:
100
170
  raise ValueError("phase_map shapes do not match for alignment")
101
171
  coords = np.asarray(coords)
@@ -381,28 +451,3 @@ def plot_cohomap(
381
451
  _ensure_parent_dir(config.save_path)
382
452
  finalize_figure(fig, config)
383
453
  return fig
384
-
385
-
386
- def cohomap_upgrade(*args, **kwargs) -> dict[str, Any]:
387
- """Legacy alias for EcohoMap (formerly cohomap_upgrade)."""
388
- return cohomap(*args, **kwargs)
389
-
390
-
391
- def ecohomap(*args, **kwargs) -> dict[str, Any]:
392
- """Alias for EcohoMap (GridCellTorus-style)."""
393
- return cohomap(*args, **kwargs)
394
-
395
-
396
- def fit_cohomap_stripes_upgrade(*args, **kwargs) -> tuple[np.ndarray, float]:
397
- """Legacy alias for EcohoMap stripe fitting."""
398
- return fit_cohomap_stripes(*args, **kwargs)
399
-
400
-
401
- def plot_cohomap_upgrade(*args, **kwargs) -> plt.Figure:
402
- """Legacy alias for EcohoMap plotting (formerly plot_cohomap_upgrade)."""
403
- return plot_cohomap(*args, **kwargs)
404
-
405
-
406
- def plot_ecohomap(*args, **kwargs) -> plt.Figure:
407
- """Alias for EcohoMap plotting."""
408
- return plot_cohomap(*args, **kwargs)
@@ -2,7 +2,6 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- import os
6
5
  from typing import Any
7
6
 
8
7
  import matplotlib.pyplot as plt
@@ -11,27 +10,37 @@ from scipy.ndimage import rotate
11
10
 
12
11
  from ...visualization.core import PlotConfig, finalize_figure
13
12
  from .cohomap import fit_cohomap_stripes
14
-
15
-
16
- def _ensure_plot_config(
17
- config: PlotConfig | None,
18
- factory,
19
- *args,
20
- **defaults,
21
- ) -> PlotConfig:
22
- if config is None:
23
- return factory(*args, **defaults)
24
- return config
25
-
26
-
27
- def _ensure_parent_dir(save_path: str | None) -> None:
28
- if save_path:
29
- parent = os.path.dirname(save_path)
30
- if parent:
31
- os.makedirs(parent, exist_ok=True)
13
+ from .utils import _ensure_parent_dir, _ensure_plot_config
32
14
 
33
15
 
34
16
  def _rot_para(params1: np.ndarray, params2: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
17
+ """Transform stripe fit parameters to canonical orientation.
18
+
19
+ This function adjusts the orientation angles of stripe fit parameters to
20
+ align them with a canonical coordinate system. Unlike `_rot_coord()` which
21
+ transforms actual coordinate data, this operates on the fit parameters
22
+ themselves (orientation angles, wavelengths, etc.).
23
+
24
+ Parameters
25
+ ----------
26
+ params1 : np.ndarray
27
+ Stripe fit parameters for first dimension. First element is orientation angle.
28
+ params2 : np.ndarray
29
+ Stripe fit parameters for second dimension. First element is orientation angle.
30
+
31
+ Returns
32
+ -------
33
+ x : np.ndarray
34
+ Transformed parameters for the dimension with stronger horizontal alignment.
35
+ y : np.ndarray
36
+ Transformed parameters for the other dimension.
37
+
38
+ Notes
39
+ -----
40
+ The function selects which parameter set becomes 'x' and 'y' based on which
41
+ has larger |cos(angle)|, then applies angle adjustments to ensure the stripe
42
+ vectors are in a canonical orientation for visualization and analysis.
43
+ """
35
44
  if abs(np.cos(params1[0])) < abs(np.cos(params2[0])):
36
45
  y = params1.copy()
37
46
  x = params2.copy()
@@ -204,12 +213,57 @@ def plot_cohomap_vectors(
204
213
 
205
214
 
206
215
  def _resolve_grid_size(phase_map: np.ndarray, grid_size: int | None, trim: int) -> int:
216
+ """Determine grid size for stripe fitting.
217
+
218
+ Parameters
219
+ ----------
220
+ phase_map : np.ndarray
221
+ Phase map array (2D grid).
222
+ grid_size : int, optional
223
+ Explicit grid size. If None, inferred from phase_map shape.
224
+ trim : int
225
+ Number of edge bins to trim.
226
+
227
+ Returns
228
+ -------
229
+ int
230
+ Grid size to use for stripe fitting.
231
+ """
207
232
  if grid_size is None:
208
233
  return int(phase_map.shape[0]) + 2 * trim + 1
209
234
  return int(grid_size)
210
235
 
211
236
 
212
237
  def _stripe_fit_map(params: np.ndarray, grid_size: int, trim: int) -> np.ndarray:
238
+ """Generate a synthetic stripe pattern from fit parameters.
239
+
240
+ Creates a 2D cosine stripe pattern based on the fitted parameters (orientation,
241
+ phase, and spatial frequency). Used for visualizing the fitted stripe model
242
+ and comparing it with the actual phase map.
243
+
244
+ Parameters
245
+ ----------
246
+ params : np.ndarray
247
+ Stripe fit parameters: [angle, phase, frequency].
248
+ - angle: Orientation angle in radians
249
+ - phase: Phase offset in radians
250
+ - frequency: Spatial frequency (inverse wavelength)
251
+ grid_size : int
252
+ Size of the grid to generate (before trimming).
253
+ trim : int
254
+ Number of edge bins to trim from the generated pattern.
255
+
256
+ Returns
257
+ -------
258
+ np.ndarray
259
+ 2D array of shape (grid_size-2*trim, grid_size-2*trim) containing
260
+ the cosine stripe pattern with values in [-1, 1].
261
+
262
+ Notes
263
+ -----
264
+ The pattern is generated on a [0, 3π] × [0, 3π] domain to cover the
265
+ extended torus space, then rotated by the fitted angle and trimmed.
266
+ """
213
267
  numangsint = grid_size
214
268
  x, _ = np.meshgrid(
215
269
  np.linspace(0, 3 * np.pi, numangsint - 1),
@@ -219,33 +219,3 @@ def plot_cohospace_skewed(
219
219
  _ensure_parent_dir(config.save_path)
220
220
  finalize_figure(fig, config)
221
221
  return fig
222
-
223
-
224
- def cohospace_upgrade(*args, **kwargs) -> dict[str, Any]:
225
- """Legacy alias for EcohoSpace (formerly cohospace_upgrade)."""
226
- return cohospace(*args, **kwargs)
227
-
228
-
229
- def ecohospace(*args, **kwargs) -> dict[str, Any]:
230
- """Alias for EcohoSpace (GridCellTorus-style)."""
231
- return cohospace(*args, **kwargs)
232
-
233
-
234
- def plot_cohospace_upgrade(*args, **kwargs) -> plt.Figure:
235
- """Legacy alias for EcohoSpace plotting (formerly plot_cohospace_upgrade)."""
236
- return plot_cohospace(*args, **kwargs)
237
-
238
-
239
- def plot_cohospace_upgrade_skewed(*args, **kwargs) -> plt.Figure:
240
- """Legacy alias for EcohoSpace skewed plotting."""
241
- return plot_cohospace_skewed(*args, **kwargs)
242
-
243
-
244
- def plot_ecohospace(*args, **kwargs) -> plt.Figure:
245
- """Alias for EcohoSpace plotting."""
246
- return plot_cohospace(*args, **kwargs)
247
-
248
-
249
- def plot_ecohospace_skewed(*args, **kwargs) -> plt.Figure:
250
- """Alias for EcohoSpace skewed plotting."""
251
- return plot_cohospace_skewed(*args, **kwargs)
@@ -2,32 +2,14 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- import os
6
5
  from typing import Any
7
6
 
8
7
  import matplotlib.pyplot as plt
9
8
  import numpy as np
10
9
 
11
10
  from ...visualization.core import PlotConfig, finalize_figure
12
- from .cohospace_scatter import skew_transform_torus_scatter
13
-
14
-
15
- def _ensure_plot_config(
16
- config: PlotConfig | None,
17
- factory,
18
- *args,
19
- **defaults,
20
- ) -> PlotConfig:
21
- if config is None:
22
- return factory(*args, **defaults)
23
- return config
24
-
25
-
26
- def _ensure_parent_dir(save_path: str | None) -> None:
27
- if save_path:
28
- parent = os.path.dirname(save_path)
29
- if parent:
30
- os.makedirs(parent, exist_ok=True)
11
+ from .path import skew_transform
12
+ from .utils import _ensure_parent_dir, _ensure_plot_config
31
13
 
32
14
 
33
15
  def cohospace_phase_centers(cohospace_result: dict[str, Any]) -> dict[str, Any]:
@@ -40,7 +22,7 @@ def cohospace_phase_centers(cohospace_result: dict[str, Any]) -> dict[str, Any]:
40
22
  Output from `data.cohospace(...)` (must include `centers`).
41
23
  """
42
24
  centers = np.asarray(cohospace_result["centers"], dtype=float) % (2 * np.pi)
43
- centers_skew = skew_transform_torus_scatter(centers)
25
+ centers_skew = skew_transform(centers)
44
26
  return {
45
27
  "centers": centers,
46
28
  "centers_skew": centers_skew,
@@ -2,38 +2,13 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- import os
6
-
7
5
  import matplotlib.pyplot as plt
8
6
  import numpy as np
9
7
  from scipy.stats import circvar
10
8
 
11
9
  from ...visualization.core import PlotConfig, finalize_figure
12
-
13
-
14
- def _ensure_plot_config(
15
- config: PlotConfig | None,
16
- factory,
17
- *,
18
- kwargs: dict | None = None,
19
- **defaults,
20
- ) -> PlotConfig:
21
- if config is None:
22
- defaults.update({"kwargs": kwargs or {}})
23
- return factory(**defaults)
24
-
25
- if kwargs:
26
- config_kwargs = config.kwargs or {}
27
- config_kwargs.update(kwargs)
28
- config.kwargs = config_kwargs
29
- return config
30
-
31
-
32
- def _ensure_parent_dir(save_path: str | None) -> None:
33
- if save_path:
34
- parent = os.path.dirname(save_path)
35
- if parent:
36
- os.makedirs(parent, exist_ok=True)
10
+ from .path import _align_activity_to_coords, skew_transform
11
+ from .utils import _ensure_parent_dir, _ensure_plot_config
37
12
 
38
13
 
39
14
  # =====================================================================
@@ -48,51 +23,6 @@ def _coho_coords_to_degrees(coords: np.ndarray) -> np.ndarray:
48
23
  return np.degrees(coords % (2 * np.pi))
49
24
 
50
25
 
51
- def _align_activity_to_coords(
52
- coords: np.ndarray,
53
- activity: np.ndarray,
54
- times: np.ndarray | None = None,
55
- *,
56
- label: str = "activity",
57
- auto_filter: bool = True,
58
- ) -> np.ndarray:
59
- """
60
- Align activity to coords by optional time indices and validate lengths.
61
- """
62
- coords = np.asarray(coords)
63
- activity = np.asarray(activity)
64
-
65
- if times is not None:
66
- times = np.asarray(times)
67
- try:
68
- activity = activity[times]
69
- except Exception as exc:
70
- raise ValueError(
71
- f"Failed to index {label} with `times`. Ensure `times` indexes the original time axis."
72
- ) from exc
73
-
74
- if activity.shape[0] != coords.shape[0]:
75
- # Try to reproduce decode's zero-spike filtering if lengths mismatch.
76
- if auto_filter and times is None and activity.ndim == 2:
77
- mask = np.sum(activity > 0, axis=1) >= 1
78
- if mask.sum() == coords.shape[0]:
79
- activity = activity[mask]
80
- else:
81
- raise ValueError(
82
- f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
83
- "If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
84
- "`times=decoding['times']` or slice the activity accordingly."
85
- )
86
- else:
87
- raise ValueError(
88
- f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
89
- "If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
90
- "`times=decoding['times']` or slice the activity accordingly."
91
- )
92
-
93
- return activity
94
-
95
-
96
26
  def plot_cohospace_scatter_trajectory_2d(
97
27
  coords: np.ndarray,
98
28
  times: np.ndarray | None = None,
@@ -802,41 +732,6 @@ def compute_cohoscore_scatter_1d(
802
732
  return scores
803
733
 
804
734
 
805
- def skew_transform_torus_scatter(coords):
806
- """
807
- Convert torus angles (theta1, theta2) into coordinates in a skewed parallelogram fundamental domain.
808
-
809
- Given theta1, theta2 in radians, map:
810
- x = theta1 + 0.5 * theta2
811
- y = (sqrt(3)/2) * theta2
812
-
813
- This is a linear change of basis that turns the square [0, 2π)×[0, 2π) into a 60-degree
814
- parallelogram, which is convenient for visualizing wrap-around behavior on a 2-torus.
815
-
816
- Parameters
817
- ----------
818
- coords : ndarray, shape (T, 2)
819
- Angles (theta1, theta2) in radians.
820
-
821
- Returns
822
- -------
823
- xy : ndarray, shape (T, 2)
824
- Skewed planar coordinates.
825
- """
826
- coords = np.asarray(coords)
827
- if coords.ndim != 2 or coords.shape[1] != 2:
828
- raise ValueError(f"coords must be (T,2), got {coords.shape}")
829
-
830
- theta1 = coords[:, 0]
831
- theta2 = coords[:, 1]
832
-
833
- # Linear change of basis (NO nonlinear scaling)
834
- x = theta1 + 0.5 * theta2
835
- y = (np.sqrt(3) / 2.0) * theta2
836
-
837
- return np.stack([x, y], axis=1)
838
-
839
-
840
735
  def draw_torus_parallelogram_grid_scatter(ax, n_tiles=1, color="0.7", lw=1.0, alpha=0.8):
841
736
  """
842
737
  Draw parallelogram grid corresponding to torus fundamental domain.
@@ -874,7 +769,7 @@ def tile_parallelogram_points_scatter(xy, n_tiles=1):
874
769
  Parameters
875
770
  ----------
876
771
  points : ndarray, shape (T, 2)
877
- Points in the skewed plane (same coordinates as returned by `skew_transform_torus_scatter`).
772
+ Points in the skewed plane (same coordinates as returned by `skew_transform`).
878
773
  n_tiles : int
879
774
  Number of tiles to extend around the base domain.
880
775
  - n_tiles=1 produces a 3x3 tiling
@@ -1020,7 +915,7 @@ def plot_cohospace_scatter_neuron_skewed(
1020
915
  )
1021
916
 
1022
917
  # --- skew transform
1023
- xy = skew_transform_torus_scatter(coords[mask])
918
+ xy = skew_transform(coords[mask])
1024
919
 
1025
920
  # Tiling: if points are tiled, values must be tiled too (FR mode) to keep lengths consistent
1026
921
  if n_tiles and n_tiles > 0:
@@ -1165,7 +1060,7 @@ def plot_cohospace_scatter_population_skewed(
1165
1060
  thr = np.percentile(a, 100 - top_percent)
1166
1061
  mask = a >= thr
1167
1062
 
1168
- xy = skew_transform_torus_scatter(coords[mask])
1063
+ xy = skew_transform(coords[mask])
1169
1064
 
1170
1065
  if n_tiles and n_tiles > 0:
1171
1066
  xy = tile_parallelogram_points_scatter(xy, n_tiles=n_tiles)
@@ -6,7 +6,6 @@ import numpy as np
6
6
  from scipy.ndimage import gaussian_filter1d
7
7
 
8
8
  from .config import DataLoadError, ProcessingError, SpikeEmbeddingConfig
9
- from .filters import _gaussian_filter1d
10
9
 
11
10
 
12
11
  def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None, **kwargs):
@@ -262,8 +261,8 @@ def _load_pos(t, x, y, res=100000, dt=1000):
262
261
  xx = np.interp(tt, t, x)
263
262
  yy = np.interp(tt, t, y)
264
263
 
265
- xxs = _gaussian_filter1d(xx - np.min(xx), sigma=100)
266
- yys = _gaussian_filter1d(yy - np.min(yy), sigma=100)
264
+ xxs = gaussian_filter1d(xx - np.min(xx), sigma=100)
265
+ yys = gaussian_filter1d(yy - np.min(yy), sigma=100)
267
266
  dx = (xxs[1:] - xxs[:-1]) * 100
268
267
  dy = (yys[1:] - yys[:-1]) * 100
269
268
  speed = np.sqrt(dx**2 + dy**2) / 0.01
@@ -1,18 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
- import os
4
3
  from dataclasses import dataclass
5
4
 
6
5
  import numpy as np
7
6
 
8
7
  from ...visualization.core import PlotConfig, finalize_figure
9
-
10
-
11
- def _ensure_parent_dir(save_path: str | None) -> None:
12
- if save_path:
13
- parent = os.path.dirname(save_path)
14
- if parent:
15
- os.makedirs(parent, exist_ok=True)
8
+ from .utils import _ensure_parent_dir
16
9
 
17
10
 
18
11
  def _slice_range(r: tuple[int, int] | None, length: int) -> slice: