canns 0.12.7__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. canns/analyzer/data/__init__.py +3 -11
  2. canns/analyzer/data/asa/__init__.py +74 -0
  3. canns/analyzer/data/asa/cohospace.py +905 -0
  4. canns/analyzer/data/asa/config.py +246 -0
  5. canns/analyzer/data/asa/decode.py +448 -0
  6. canns/analyzer/data/asa/embedding.py +269 -0
  7. canns/analyzer/data/asa/filters.py +208 -0
  8. canns/analyzer/data/asa/fr.py +439 -0
  9. canns/analyzer/data/asa/path.py +389 -0
  10. canns/analyzer/data/asa/plotting.py +1276 -0
  11. canns/analyzer/data/asa/tda.py +901 -0
  12. canns/analyzer/data/legacy/__init__.py +6 -0
  13. canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
  14. canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
  15. canns/analyzer/visualization/core/backend.py +1 -1
  16. canns/analyzer/visualization/core/config.py +77 -0
  17. canns/analyzer/visualization/core/rendering.py +10 -6
  18. canns/analyzer/visualization/energy_plots.py +22 -8
  19. canns/analyzer/visualization/spatial_plots.py +31 -11
  20. canns/analyzer/visualization/theta_sweep_plots.py +15 -6
  21. canns/pipeline/__init__.py +4 -8
  22. canns/pipeline/asa/__init__.py +21 -0
  23. canns/pipeline/asa/__main__.py +11 -0
  24. canns/pipeline/asa/app.py +1000 -0
  25. canns/pipeline/asa/runner.py +1095 -0
  26. canns/pipeline/asa/screens.py +215 -0
  27. canns/pipeline/asa/state.py +248 -0
  28. canns/pipeline/asa/styles.tcss +221 -0
  29. canns/pipeline/asa/widgets.py +233 -0
  30. canns/pipeline/gallery/__init__.py +7 -0
  31. canns/task/open_loop_navigation.py +3 -1
  32. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
  33. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/RECORD +36 -17
  34. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
  35. canns/pipeline/theta_sweep.py +0 -573
  36. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
  37. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,389 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+
6
+ def load_npz_any(path: str) -> dict:
7
+ """Load .npz into a plain dict (allow_pickle=True)."""
8
+ obj = np.load(path, allow_pickle=True)
9
+ try:
10
+ if hasattr(obj, "files"):
11
+ return {k: obj[k] for k in obj.files}
12
+ return dict(obj)
13
+ finally:
14
+ try:
15
+ obj.close()
16
+ except Exception:
17
+ pass
18
+
19
+
20
+ def unwrap_container(x):
21
+ """Unwrap 0-d object arrays that store a python object."""
22
+ if isinstance(x, np.ndarray) and x.dtype == object and x.shape == ():
23
+ return x.item()
24
+ return x
25
+
26
+
27
+ def as_1d(x) -> np.ndarray:
28
+ """Convert input to a 1D numpy array (robust to 0-d object containers)."""
29
+ x = np.asarray(x)
30
+ if x.dtype == object and x.shape == ():
31
+ x = x.item()
32
+ x = np.asarray(x)
33
+ return np.asarray(x).ravel()
34
+
35
+
36
+ def find_times_box(dec: dict) -> tuple[np.ndarray | None, str | None]:
37
+ """Try to find a 'times_box' / keep-index vector in decoding dict."""
38
+ keys = [
39
+ "times_box",
40
+ "timesbox",
41
+ "t_box",
42
+ "idx_box",
43
+ "index_box",
44
+ "indices_box",
45
+ "keep",
46
+ "keep_idx",
47
+ "keep_indices",
48
+ "idx",
49
+ "indices",
50
+ "times",
51
+ "t",
52
+ ]
53
+ for k in keys:
54
+ if k in dec:
55
+ tb = unwrap_container(dec[k])
56
+ tb = np.asarray(tb)
57
+ if tb.ndim == 1 or (tb.ndim == 2 and 1 in tb.shape):
58
+ return as_1d(tb), k
59
+ return None, None
60
+
61
+
62
+ def _is_2col(arr) -> bool:
63
+ return isinstance(arr, np.ndarray) and arr.ndim == 2 and arr.shape[1] == 2
64
+
65
+
66
+ def find_coords_matrix(
67
+ dec: dict,
68
+ coords_key: str | None = None,
69
+ prefer_box_fallback: bool = False,
70
+ ) -> tuple[np.ndarray, str]:
71
+ """Find a decoded coords matrix (N,D>=2) in decoding dict.
72
+
73
+ IMPORTANT: to match your original test1 behavior, we ALWAYS prefer a true (N,2)
74
+ angles matrix (e.g. key 'coords') if it exists, even if you set --use-box.
75
+
76
+ Only when no (N,2) matrix exists do we fall back to >=2-col matrices (coordsbox, etc.).
77
+ """
78
+ dec = {k: unwrap_container(v) for k, v in dec.items()}
79
+
80
+ if coords_key is not None:
81
+ if coords_key not in dec:
82
+ raise KeyError(f"--coords-key '{coords_key}' not found. keys={list(dec.keys())}")
83
+ arr = np.asarray(dec[coords_key])
84
+ if arr.ndim != 2 or arr.shape[1] < 2:
85
+ raise ValueError(
86
+ f"--coords-key '{coords_key}' must be 2D with >=2 cols, got {arr.shape}"
87
+ )
88
+ return arr, coords_key
89
+
90
+ base_keys = ["coords", "theta", "thetas", "circular_coords", "cc", "decoded_coords"]
91
+ box_keys = ["coordsbox", "coords_box", "coordsBox", "coords_box_full"]
92
+
93
+ # Pass 1: prefer true (N,2) in base keys (this matches original test1)
94
+ for k in base_keys:
95
+ if k in dec and _is_2col(np.asarray(dec[k])):
96
+ return np.asarray(dec[k]), k
97
+
98
+ # Pass 2: accept true (N,2) in box keys
99
+ for k in box_keys:
100
+ if k in dec and _is_2col(np.asarray(dec[k])):
101
+ return np.asarray(dec[k]), k
102
+
103
+ # Pass 3: fall back to any matrix with >=2 cols
104
+ search = (box_keys + base_keys) if prefer_box_fallback else (base_keys + box_keys)
105
+ for k in search:
106
+ if k in dec:
107
+ arr = np.asarray(dec[k])
108
+ if arr.ndim == 2 and arr.shape[1] >= 2:
109
+ return arr, k
110
+
111
+ # fallback: two 1D arrays
112
+ cands1 = ["theta1", "th1", "phi1", "u", "circ1"]
113
+ cands2 = ["theta2", "th2", "phi2", "v", "circ2"]
114
+ for k1 in cands1:
115
+ for k2 in cands2:
116
+ if k1 in dec and k2 in dec:
117
+ a = as_1d(dec[k1])
118
+ b = as_1d(dec[k2])
119
+ if len(a) == len(b) and len(a) > 0:
120
+ return np.stack([a, b], axis=1), f"{k1}+{k2}"
121
+
122
+ raise KeyError(f"Cannot find decoded coords matrix. keys={list(dec.keys())}")
123
+
124
+
125
+ def resolve_time_slice(t: np.ndarray, tmin, tmax, imin, imax) -> tuple[int, int]:
126
+ """Return [i0,i1) slice bounds using either index bounds or time bounds."""
127
+ T = len(t)
128
+ if imin is not None or imax is not None:
129
+ i0 = 0 if imin is None else max(0, int(imin))
130
+ i1 = T if imax is None else min(T, int(imax))
131
+ return i0, i1
132
+
133
+ if tmin is None and tmax is None:
134
+ return 0, T
135
+
136
+ tmin = t[0] if tmin is None else float(tmin)
137
+ tmax = t[-1] if tmax is None else float(tmax)
138
+ if tmax < tmin:
139
+ tmin, tmax = tmax, tmin
140
+
141
+ i0 = int(np.searchsorted(t, tmin, side="left"))
142
+ i1 = int(np.searchsorted(t, tmax, side="right"))
143
+ i0 = max(0, min(T, i0))
144
+ i1 = max(0, min(T, i1))
145
+ return i0, i1
146
+
147
+
148
+ def skew_transform(theta_2d: np.ndarray) -> np.ndarray:
149
+ """Map (theta1,theta2) to skew coordinates used in the base parallelogram."""
150
+ th1 = theta_2d[:, 0]
151
+ th2 = theta_2d[:, 1]
152
+ X = th1 + 0.5 * th2
153
+ Y = (np.sqrt(3) / 2.0) * th2
154
+ return np.stack([X, Y], axis=1)
155
+
156
+
157
+ def draw_base_parallelogram(ax):
158
+ e1 = np.array([2 * np.pi, 0.0])
159
+ e2 = np.array([np.pi, np.sqrt(3) * np.pi])
160
+ P00 = np.array([0.0, 0.0])
161
+ P10 = e1
162
+ P01 = e2
163
+ P11 = e1 + e2
164
+ poly = np.vstack([P00, P10, P11, P01, P00])
165
+ ax.plot(poly[:, 0], poly[:, 1], lw=1.2, color="0.35")
166
+
167
+
168
+ def parse_times_box_to_indices(times_box: np.ndarray, t_full: np.ndarray) -> tuple[np.ndarray, str]:
169
+ """Convert times_box to integer indices into t_full."""
170
+ tb = as_1d(times_box)
171
+ T_full = len(t_full)
172
+
173
+ if np.issubdtype(tb.dtype, np.integer):
174
+ idx = tb.astype(int)
175
+ kind = "index"
176
+ else:
177
+ # If values are basically integers and within range -> treat as indices
178
+ if (
179
+ np.all(np.isfinite(tb))
180
+ and np.all(np.abs(tb - np.round(tb)) < 1e-6)
181
+ and np.nanmax(tb) <= T_full + 1
182
+ ):
183
+ idx = np.round(tb).astype(int)
184
+ kind = "index(float)"
185
+ else:
186
+ # Treat as timestamps -> map via searchsorted
187
+ idx = np.searchsorted(t_full, tb, side="left").astype(int)
188
+ idx = np.clip(idx, 0, T_full - 1)
189
+ kind = "time"
190
+
191
+ return idx, kind
192
+
193
+
194
+ def interp_coords_to_full(idx_map: np.ndarray, coords2: np.ndarray, T_full: int) -> np.ndarray:
195
+ """Interpolate (K,2) circular coords back to full length (T_full,2)."""
196
+ idx_map = np.asarray(idx_map).astype(int).ravel()
197
+ coords2 = np.asarray(coords2, float)
198
+
199
+ # sort & unique
200
+ order = np.argsort(idx_map)
201
+ idx_map = idx_map[order]
202
+ coords2 = coords2[order]
203
+
204
+ uniq_idx, uniq_pos = np.unique(idx_map, return_index=True)
205
+ coords2 = coords2[uniq_pos]
206
+ idx_map = uniq_idx
207
+
208
+ # unwrap for interpolation stability
209
+ ang = np.unwrap(coords2, axis=0)
210
+ full_i = np.arange(T_full, dtype=float)
211
+
212
+ out = np.zeros((T_full, 2), float)
213
+ for d in range(2):
214
+ out[:, d] = np.interp(full_i, idx_map.astype(float), ang[:, d])
215
+
216
+ return np.mod(out, 2 * np.pi)
217
+
218
+
219
+ def align_coords_to_position(
220
+ t_full: np.ndarray,
221
+ x_full: np.ndarray,
222
+ y_full: np.ndarray,
223
+ coords2: np.ndarray,
224
+ use_box: bool,
225
+ times_box: np.ndarray | None,
226
+ interp_to_full: bool,
227
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
228
+ """Align decoded coordinates to the original (x, y, t) trajectory.
229
+
230
+ Parameters
231
+ ----------
232
+ t_full, x_full, y_full : np.ndarray
233
+ Full-length trajectory arrays of shape (T,).
234
+ coords2 : np.ndarray
235
+ Decoded coordinates of shape (K, 2) or (T, 2).
236
+ use_box : bool
237
+ Whether to use ``times_box`` to align to the original trajectory.
238
+ times_box : np.ndarray | None
239
+ Time indices or timestamps corresponding to ``coords2`` when ``use_box=True``.
240
+ interp_to_full : bool
241
+ If True, interpolate decoded coords back to full length; otherwise return a subset.
242
+
243
+ Returns
244
+ -------
245
+ tuple
246
+ ``(t_aligned, x_aligned, y_aligned, coords_aligned, tag)`` where ``tag`` describes
247
+ the alignment path used.
248
+
249
+ Examples
250
+ --------
251
+ >>> t, x, y, coords2, tag = align_coords_to_position( # doctest: +SKIP
252
+ ... t_full, x_full, y_full, coords2,
253
+ ... use_box=True, times_box=decoding["times_box"], interp_to_full=True
254
+ ... )
255
+ >>> coords2.shape[1]
256
+ 2
257
+ """
258
+ t_full = np.asarray(t_full).ravel()
259
+ x_full = np.asarray(x_full).ravel()
260
+ y_full = np.asarray(y_full).ravel()
261
+ coords2 = np.asarray(coords2, float)
262
+
263
+ T_full = len(t_full)
264
+
265
+ if not use_box:
266
+ if len(coords2) != T_full:
267
+ raise ValueError(
268
+ f"coords length {len(coords2)} != t length {T_full} "
269
+ f"(set --use-box if you have times_box)"
270
+ )
271
+ return t_full, x_full, y_full, coords2, "full(no-box)"
272
+
273
+ if times_box is None:
274
+ if len(coords2) == T_full:
275
+ return (
276
+ t_full,
277
+ x_full,
278
+ y_full,
279
+ coords2,
280
+ "full(use-box but no times_box; treated as full)",
281
+ )
282
+ raise KeyError("use_box=True but times_box not found, and coords is not full-length.")
283
+
284
+ idx_map, kind = parse_times_box_to_indices(times_box, t_full)
285
+
286
+ # If coords already full and times_box also full, keep full
287
+ if len(coords2) == T_full and len(idx_map) == T_full:
288
+ return (
289
+ t_full,
290
+ x_full,
291
+ y_full,
292
+ coords2,
293
+ f"full(coords already full; times_box kind={kind} ignored)",
294
+ )
295
+
296
+ if len(idx_map) != len(coords2):
297
+ raise ValueError(f"times_box length {len(idx_map)} != coords length {len(coords2)}")
298
+
299
+ order = np.argsort(idx_map)
300
+ idx_map = idx_map[order]
301
+ coords2 = coords2[order]
302
+
303
+ if interp_to_full:
304
+ coords_full = interp_coords_to_full(idx_map, coords2, T_full)
305
+ return (
306
+ t_full,
307
+ x_full,
308
+ y_full,
309
+ coords_full,
310
+ f"interp_to_full(times_box kind={kind}, K={len(idx_map)})",
311
+ )
312
+
313
+ idx_map = np.clip(idx_map, 0, T_full - 1)
314
+ return (
315
+ t_full[idx_map],
316
+ x_full[idx_map],
317
+ y_full[idx_map],
318
+ coords2,
319
+ f"subset(times_box kind={kind}, K={len(idx_map)})",
320
+ )
321
+
322
+
323
+ def snake_wrap_trail_in_parallelogram(
324
+ xy_base: np.ndarray, e1: np.ndarray, e2: np.ndarray
325
+ ) -> np.ndarray:
326
+ """Insert NaNs when the trail wraps across the torus fundamental domain."""
327
+ xy_base = np.asarray(xy_base, float)
328
+ if xy_base.ndim != 2 or xy_base.shape[1] != 2:
329
+ raise ValueError(f"xy_base must be (T,2), got {xy_base.shape}")
330
+
331
+ shifts = []
332
+ for i in (-1, 0, 1):
333
+ for j in (-1, 0, 1):
334
+ shifts.append(i * e1 + j * e2)
335
+ shifts = np.asarray(shifts) # (9,2)
336
+
337
+ out = [xy_base[0]]
338
+ for k in range(1, len(xy_base)):
339
+ prev = xy_base[k - 1]
340
+ cur = xy_base[k]
341
+
342
+ disp = (cur[None, :] + shifts) - prev[None, :]
343
+ d2 = np.sum(disp * disp, axis=1)
344
+ best = shifts[np.argmin(d2)]
345
+
346
+ if best[0] != 0.0 or best[1] != 0.0:
347
+ out.append(np.array([np.nan, np.nan]))
348
+ out.append(cur)
349
+
350
+ return np.vstack(out)
351
+
352
+
353
+ def apply_angle_scale(coords2: np.ndarray, scale: str) -> np.ndarray:
354
+ """Convert angle units to radians before wrapping.
355
+
356
+ Parameters
357
+ ----------
358
+ coords2 : np.ndarray
359
+ Angle array of shape (T, 2) in the given ``scale``.
360
+ scale : {"rad", "deg", "unit", "auto"}
361
+ ``rad`` : already in radians.
362
+ ``deg`` : degrees -> radians.
363
+ ``unit`` : unit circle in [0, 1] -> radians.
364
+ ``auto`` : infer unit circle if values look like [0, 1].
365
+
366
+ Returns
367
+ -------
368
+ np.ndarray
369
+ Angles in radians.
370
+
371
+ Examples
372
+ --------
373
+ >>> apply_angle_scale([[0.25, 0.5]], "unit") # doctest: +SKIP
374
+ """
375
+ coords2 = np.asarray(coords2, float)
376
+ if scale == "rad":
377
+ return coords2
378
+ if scale == "unit":
379
+ return coords2 * (2 * np.pi)
380
+ if scale == "deg":
381
+ return np.deg2rad(coords2)
382
+ if scale == "auto":
383
+ # Heuristic: if values look like [0,1] or [-0.2,1.2], treat as unit circle coords.
384
+ mn = float(np.nanmin(coords2))
385
+ mx = float(np.nanmax(coords2))
386
+ if mx <= 1.2 and mn >= -0.2:
387
+ return coords2 * (2 * np.pi)
388
+ return coords2
389
+ raise ValueError(f"Unknown --scale option: {scale}")