canns 0.12.7__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/__init__.py +3 -11
- canns/analyzer/data/asa/__init__.py +74 -0
- canns/analyzer/data/asa/cohospace.py +905 -0
- canns/analyzer/data/asa/config.py +246 -0
- canns/analyzer/data/asa/decode.py +448 -0
- canns/analyzer/data/asa/embedding.py +269 -0
- canns/analyzer/data/asa/filters.py +208 -0
- canns/analyzer/data/asa/fr.py +439 -0
- canns/analyzer/data/asa/path.py +389 -0
- canns/analyzer/data/asa/plotting.py +1276 -0
- canns/analyzer/data/asa/tda.py +901 -0
- canns/analyzer/data/legacy/__init__.py +6 -0
- canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
- canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
- canns/analyzer/visualization/core/backend.py +1 -1
- canns/analyzer/visualization/core/config.py +77 -0
- canns/analyzer/visualization/core/rendering.py +10 -6
- canns/analyzer/visualization/energy_plots.py +22 -8
- canns/analyzer/visualization/spatial_plots.py +31 -11
- canns/analyzer/visualization/theta_sweep_plots.py +15 -6
- canns/pipeline/__init__.py +4 -8
- canns/pipeline/asa/__init__.py +21 -0
- canns/pipeline/asa/__main__.py +11 -0
- canns/pipeline/asa/app.py +1000 -0
- canns/pipeline/asa/runner.py +1095 -0
- canns/pipeline/asa/screens.py +215 -0
- canns/pipeline/asa/state.py +248 -0
- canns/pipeline/asa/styles.tcss +221 -0
- canns/pipeline/asa/widgets.py +233 -0
- canns/pipeline/gallery/__init__.py +7 -0
- canns/task/open_loop_navigation.py +3 -1
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/RECORD +36 -17
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
- canns/pipeline/theta_sweep.py +0 -573
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def load_npz_any(path: str) -> dict:
|
|
7
|
+
"""Load .npz into a plain dict (allow_pickle=True)."""
|
|
8
|
+
obj = np.load(path, allow_pickle=True)
|
|
9
|
+
try:
|
|
10
|
+
if hasattr(obj, "files"):
|
|
11
|
+
return {k: obj[k] for k in obj.files}
|
|
12
|
+
return dict(obj)
|
|
13
|
+
finally:
|
|
14
|
+
try:
|
|
15
|
+
obj.close()
|
|
16
|
+
except Exception:
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def unwrap_container(x):
|
|
21
|
+
"""Unwrap 0-d object arrays that store a python object."""
|
|
22
|
+
if isinstance(x, np.ndarray) and x.dtype == object and x.shape == ():
|
|
23
|
+
return x.item()
|
|
24
|
+
return x
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def as_1d(x) -> np.ndarray:
|
|
28
|
+
"""Convert input to a 1D numpy array (robust to 0-d object containers)."""
|
|
29
|
+
x = np.asarray(x)
|
|
30
|
+
if x.dtype == object and x.shape == ():
|
|
31
|
+
x = x.item()
|
|
32
|
+
x = np.asarray(x)
|
|
33
|
+
return np.asarray(x).ravel()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def find_times_box(dec: dict) -> tuple[np.ndarray | None, str | None]:
|
|
37
|
+
"""Try to find a 'times_box' / keep-index vector in decoding dict."""
|
|
38
|
+
keys = [
|
|
39
|
+
"times_box",
|
|
40
|
+
"timesbox",
|
|
41
|
+
"t_box",
|
|
42
|
+
"idx_box",
|
|
43
|
+
"index_box",
|
|
44
|
+
"indices_box",
|
|
45
|
+
"keep",
|
|
46
|
+
"keep_idx",
|
|
47
|
+
"keep_indices",
|
|
48
|
+
"idx",
|
|
49
|
+
"indices",
|
|
50
|
+
"times",
|
|
51
|
+
"t",
|
|
52
|
+
]
|
|
53
|
+
for k in keys:
|
|
54
|
+
if k in dec:
|
|
55
|
+
tb = unwrap_container(dec[k])
|
|
56
|
+
tb = np.asarray(tb)
|
|
57
|
+
if tb.ndim == 1 or (tb.ndim == 2 and 1 in tb.shape):
|
|
58
|
+
return as_1d(tb), k
|
|
59
|
+
return None, None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _is_2col(arr) -> bool:
|
|
63
|
+
return isinstance(arr, np.ndarray) and arr.ndim == 2 and arr.shape[1] == 2
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def find_coords_matrix(
|
|
67
|
+
dec: dict,
|
|
68
|
+
coords_key: str | None = None,
|
|
69
|
+
prefer_box_fallback: bool = False,
|
|
70
|
+
) -> tuple[np.ndarray, str]:
|
|
71
|
+
"""Find a decoded coords matrix (N,D>=2) in decoding dict.
|
|
72
|
+
|
|
73
|
+
IMPORTANT: to match your original test1 behavior, we ALWAYS prefer a true (N,2)
|
|
74
|
+
angles matrix (e.g. key 'coords') if it exists, even if you set --use-box.
|
|
75
|
+
|
|
76
|
+
Only when no (N,2) matrix exists do we fall back to >=2-col matrices (coordsbox, etc.).
|
|
77
|
+
"""
|
|
78
|
+
dec = {k: unwrap_container(v) for k, v in dec.items()}
|
|
79
|
+
|
|
80
|
+
if coords_key is not None:
|
|
81
|
+
if coords_key not in dec:
|
|
82
|
+
raise KeyError(f"--coords-key '{coords_key}' not found. keys={list(dec.keys())}")
|
|
83
|
+
arr = np.asarray(dec[coords_key])
|
|
84
|
+
if arr.ndim != 2 or arr.shape[1] < 2:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"--coords-key '{coords_key}' must be 2D with >=2 cols, got {arr.shape}"
|
|
87
|
+
)
|
|
88
|
+
return arr, coords_key
|
|
89
|
+
|
|
90
|
+
base_keys = ["coords", "theta", "thetas", "circular_coords", "cc", "decoded_coords"]
|
|
91
|
+
box_keys = ["coordsbox", "coords_box", "coordsBox", "coords_box_full"]
|
|
92
|
+
|
|
93
|
+
# Pass 1: prefer true (N,2) in base keys (this matches original test1)
|
|
94
|
+
for k in base_keys:
|
|
95
|
+
if k in dec and _is_2col(np.asarray(dec[k])):
|
|
96
|
+
return np.asarray(dec[k]), k
|
|
97
|
+
|
|
98
|
+
# Pass 2: accept true (N,2) in box keys
|
|
99
|
+
for k in box_keys:
|
|
100
|
+
if k in dec and _is_2col(np.asarray(dec[k])):
|
|
101
|
+
return np.asarray(dec[k]), k
|
|
102
|
+
|
|
103
|
+
# Pass 3: fall back to any matrix with >=2 cols
|
|
104
|
+
search = (box_keys + base_keys) if prefer_box_fallback else (base_keys + box_keys)
|
|
105
|
+
for k in search:
|
|
106
|
+
if k in dec:
|
|
107
|
+
arr = np.asarray(dec[k])
|
|
108
|
+
if arr.ndim == 2 and arr.shape[1] >= 2:
|
|
109
|
+
return arr, k
|
|
110
|
+
|
|
111
|
+
# fallback: two 1D arrays
|
|
112
|
+
cands1 = ["theta1", "th1", "phi1", "u", "circ1"]
|
|
113
|
+
cands2 = ["theta2", "th2", "phi2", "v", "circ2"]
|
|
114
|
+
for k1 in cands1:
|
|
115
|
+
for k2 in cands2:
|
|
116
|
+
if k1 in dec and k2 in dec:
|
|
117
|
+
a = as_1d(dec[k1])
|
|
118
|
+
b = as_1d(dec[k2])
|
|
119
|
+
if len(a) == len(b) and len(a) > 0:
|
|
120
|
+
return np.stack([a, b], axis=1), f"{k1}+{k2}"
|
|
121
|
+
|
|
122
|
+
raise KeyError(f"Cannot find decoded coords matrix. keys={list(dec.keys())}")
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def resolve_time_slice(t: np.ndarray, tmin, tmax, imin, imax) -> tuple[int, int]:
|
|
126
|
+
"""Return [i0,i1) slice bounds using either index bounds or time bounds."""
|
|
127
|
+
T = len(t)
|
|
128
|
+
if imin is not None or imax is not None:
|
|
129
|
+
i0 = 0 if imin is None else max(0, int(imin))
|
|
130
|
+
i1 = T if imax is None else min(T, int(imax))
|
|
131
|
+
return i0, i1
|
|
132
|
+
|
|
133
|
+
if tmin is None and tmax is None:
|
|
134
|
+
return 0, T
|
|
135
|
+
|
|
136
|
+
tmin = t[0] if tmin is None else float(tmin)
|
|
137
|
+
tmax = t[-1] if tmax is None else float(tmax)
|
|
138
|
+
if tmax < tmin:
|
|
139
|
+
tmin, tmax = tmax, tmin
|
|
140
|
+
|
|
141
|
+
i0 = int(np.searchsorted(t, tmin, side="left"))
|
|
142
|
+
i1 = int(np.searchsorted(t, tmax, side="right"))
|
|
143
|
+
i0 = max(0, min(T, i0))
|
|
144
|
+
i1 = max(0, min(T, i1))
|
|
145
|
+
return i0, i1
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def skew_transform(theta_2d: np.ndarray) -> np.ndarray:
|
|
149
|
+
"""Map (theta1,theta2) to skew coordinates used in the base parallelogram."""
|
|
150
|
+
th1 = theta_2d[:, 0]
|
|
151
|
+
th2 = theta_2d[:, 1]
|
|
152
|
+
X = th1 + 0.5 * th2
|
|
153
|
+
Y = (np.sqrt(3) / 2.0) * th2
|
|
154
|
+
return np.stack([X, Y], axis=1)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def draw_base_parallelogram(ax):
|
|
158
|
+
e1 = np.array([2 * np.pi, 0.0])
|
|
159
|
+
e2 = np.array([np.pi, np.sqrt(3) * np.pi])
|
|
160
|
+
P00 = np.array([0.0, 0.0])
|
|
161
|
+
P10 = e1
|
|
162
|
+
P01 = e2
|
|
163
|
+
P11 = e1 + e2
|
|
164
|
+
poly = np.vstack([P00, P10, P11, P01, P00])
|
|
165
|
+
ax.plot(poly[:, 0], poly[:, 1], lw=1.2, color="0.35")
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def parse_times_box_to_indices(times_box: np.ndarray, t_full: np.ndarray) -> tuple[np.ndarray, str]:
|
|
169
|
+
"""Convert times_box to integer indices into t_full."""
|
|
170
|
+
tb = as_1d(times_box)
|
|
171
|
+
T_full = len(t_full)
|
|
172
|
+
|
|
173
|
+
if np.issubdtype(tb.dtype, np.integer):
|
|
174
|
+
idx = tb.astype(int)
|
|
175
|
+
kind = "index"
|
|
176
|
+
else:
|
|
177
|
+
# If values are basically integers and within range -> treat as indices
|
|
178
|
+
if (
|
|
179
|
+
np.all(np.isfinite(tb))
|
|
180
|
+
and np.all(np.abs(tb - np.round(tb)) < 1e-6)
|
|
181
|
+
and np.nanmax(tb) <= T_full + 1
|
|
182
|
+
):
|
|
183
|
+
idx = np.round(tb).astype(int)
|
|
184
|
+
kind = "index(float)"
|
|
185
|
+
else:
|
|
186
|
+
# Treat as timestamps -> map via searchsorted
|
|
187
|
+
idx = np.searchsorted(t_full, tb, side="left").astype(int)
|
|
188
|
+
idx = np.clip(idx, 0, T_full - 1)
|
|
189
|
+
kind = "time"
|
|
190
|
+
|
|
191
|
+
return idx, kind
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def interp_coords_to_full(idx_map: np.ndarray, coords2: np.ndarray, T_full: int) -> np.ndarray:
|
|
195
|
+
"""Interpolate (K,2) circular coords back to full length (T_full,2)."""
|
|
196
|
+
idx_map = np.asarray(idx_map).astype(int).ravel()
|
|
197
|
+
coords2 = np.asarray(coords2, float)
|
|
198
|
+
|
|
199
|
+
# sort & unique
|
|
200
|
+
order = np.argsort(idx_map)
|
|
201
|
+
idx_map = idx_map[order]
|
|
202
|
+
coords2 = coords2[order]
|
|
203
|
+
|
|
204
|
+
uniq_idx, uniq_pos = np.unique(idx_map, return_index=True)
|
|
205
|
+
coords2 = coords2[uniq_pos]
|
|
206
|
+
idx_map = uniq_idx
|
|
207
|
+
|
|
208
|
+
# unwrap for interpolation stability
|
|
209
|
+
ang = np.unwrap(coords2, axis=0)
|
|
210
|
+
full_i = np.arange(T_full, dtype=float)
|
|
211
|
+
|
|
212
|
+
out = np.zeros((T_full, 2), float)
|
|
213
|
+
for d in range(2):
|
|
214
|
+
out[:, d] = np.interp(full_i, idx_map.astype(float), ang[:, d])
|
|
215
|
+
|
|
216
|
+
return np.mod(out, 2 * np.pi)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def align_coords_to_position(
|
|
220
|
+
t_full: np.ndarray,
|
|
221
|
+
x_full: np.ndarray,
|
|
222
|
+
y_full: np.ndarray,
|
|
223
|
+
coords2: np.ndarray,
|
|
224
|
+
use_box: bool,
|
|
225
|
+
times_box: np.ndarray | None,
|
|
226
|
+
interp_to_full: bool,
|
|
227
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
|
|
228
|
+
"""Align decoded coordinates to the original (x, y, t) trajectory.
|
|
229
|
+
|
|
230
|
+
Parameters
|
|
231
|
+
----------
|
|
232
|
+
t_full, x_full, y_full : np.ndarray
|
|
233
|
+
Full-length trajectory arrays of shape (T,).
|
|
234
|
+
coords2 : np.ndarray
|
|
235
|
+
Decoded coordinates of shape (K, 2) or (T, 2).
|
|
236
|
+
use_box : bool
|
|
237
|
+
Whether to use ``times_box`` to align to the original trajectory.
|
|
238
|
+
times_box : np.ndarray | None
|
|
239
|
+
Time indices or timestamps corresponding to ``coords2`` when ``use_box=True``.
|
|
240
|
+
interp_to_full : bool
|
|
241
|
+
If True, interpolate decoded coords back to full length; otherwise return a subset.
|
|
242
|
+
|
|
243
|
+
Returns
|
|
244
|
+
-------
|
|
245
|
+
tuple
|
|
246
|
+
``(t_aligned, x_aligned, y_aligned, coords_aligned, tag)`` where ``tag`` describes
|
|
247
|
+
the alignment path used.
|
|
248
|
+
|
|
249
|
+
Examples
|
|
250
|
+
--------
|
|
251
|
+
>>> t, x, y, coords2, tag = align_coords_to_position( # doctest: +SKIP
|
|
252
|
+
... t_full, x_full, y_full, coords2,
|
|
253
|
+
... use_box=True, times_box=decoding["times_box"], interp_to_full=True
|
|
254
|
+
... )
|
|
255
|
+
>>> coords2.shape[1]
|
|
256
|
+
2
|
|
257
|
+
"""
|
|
258
|
+
t_full = np.asarray(t_full).ravel()
|
|
259
|
+
x_full = np.asarray(x_full).ravel()
|
|
260
|
+
y_full = np.asarray(y_full).ravel()
|
|
261
|
+
coords2 = np.asarray(coords2, float)
|
|
262
|
+
|
|
263
|
+
T_full = len(t_full)
|
|
264
|
+
|
|
265
|
+
if not use_box:
|
|
266
|
+
if len(coords2) != T_full:
|
|
267
|
+
raise ValueError(
|
|
268
|
+
f"coords length {len(coords2)} != t length {T_full} "
|
|
269
|
+
f"(set --use-box if you have times_box)"
|
|
270
|
+
)
|
|
271
|
+
return t_full, x_full, y_full, coords2, "full(no-box)"
|
|
272
|
+
|
|
273
|
+
if times_box is None:
|
|
274
|
+
if len(coords2) == T_full:
|
|
275
|
+
return (
|
|
276
|
+
t_full,
|
|
277
|
+
x_full,
|
|
278
|
+
y_full,
|
|
279
|
+
coords2,
|
|
280
|
+
"full(use-box but no times_box; treated as full)",
|
|
281
|
+
)
|
|
282
|
+
raise KeyError("use_box=True but times_box not found, and coords is not full-length.")
|
|
283
|
+
|
|
284
|
+
idx_map, kind = parse_times_box_to_indices(times_box, t_full)
|
|
285
|
+
|
|
286
|
+
# If coords already full and times_box also full, keep full
|
|
287
|
+
if len(coords2) == T_full and len(idx_map) == T_full:
|
|
288
|
+
return (
|
|
289
|
+
t_full,
|
|
290
|
+
x_full,
|
|
291
|
+
y_full,
|
|
292
|
+
coords2,
|
|
293
|
+
f"full(coords already full; times_box kind={kind} ignored)",
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
if len(idx_map) != len(coords2):
|
|
297
|
+
raise ValueError(f"times_box length {len(idx_map)} != coords length {len(coords2)}")
|
|
298
|
+
|
|
299
|
+
order = np.argsort(idx_map)
|
|
300
|
+
idx_map = idx_map[order]
|
|
301
|
+
coords2 = coords2[order]
|
|
302
|
+
|
|
303
|
+
if interp_to_full:
|
|
304
|
+
coords_full = interp_coords_to_full(idx_map, coords2, T_full)
|
|
305
|
+
return (
|
|
306
|
+
t_full,
|
|
307
|
+
x_full,
|
|
308
|
+
y_full,
|
|
309
|
+
coords_full,
|
|
310
|
+
f"interp_to_full(times_box kind={kind}, K={len(idx_map)})",
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
idx_map = np.clip(idx_map, 0, T_full - 1)
|
|
314
|
+
return (
|
|
315
|
+
t_full[idx_map],
|
|
316
|
+
x_full[idx_map],
|
|
317
|
+
y_full[idx_map],
|
|
318
|
+
coords2,
|
|
319
|
+
f"subset(times_box kind={kind}, K={len(idx_map)})",
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def snake_wrap_trail_in_parallelogram(
|
|
324
|
+
xy_base: np.ndarray, e1: np.ndarray, e2: np.ndarray
|
|
325
|
+
) -> np.ndarray:
|
|
326
|
+
"""Insert NaNs when the trail wraps across the torus fundamental domain."""
|
|
327
|
+
xy_base = np.asarray(xy_base, float)
|
|
328
|
+
if xy_base.ndim != 2 or xy_base.shape[1] != 2:
|
|
329
|
+
raise ValueError(f"xy_base must be (T,2), got {xy_base.shape}")
|
|
330
|
+
|
|
331
|
+
shifts = []
|
|
332
|
+
for i in (-1, 0, 1):
|
|
333
|
+
for j in (-1, 0, 1):
|
|
334
|
+
shifts.append(i * e1 + j * e2)
|
|
335
|
+
shifts = np.asarray(shifts) # (9,2)
|
|
336
|
+
|
|
337
|
+
out = [xy_base[0]]
|
|
338
|
+
for k in range(1, len(xy_base)):
|
|
339
|
+
prev = xy_base[k - 1]
|
|
340
|
+
cur = xy_base[k]
|
|
341
|
+
|
|
342
|
+
disp = (cur[None, :] + shifts) - prev[None, :]
|
|
343
|
+
d2 = np.sum(disp * disp, axis=1)
|
|
344
|
+
best = shifts[np.argmin(d2)]
|
|
345
|
+
|
|
346
|
+
if best[0] != 0.0 or best[1] != 0.0:
|
|
347
|
+
out.append(np.array([np.nan, np.nan]))
|
|
348
|
+
out.append(cur)
|
|
349
|
+
|
|
350
|
+
return np.vstack(out)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def apply_angle_scale(coords2: np.ndarray, scale: str) -> np.ndarray:
|
|
354
|
+
"""Convert angle units to radians before wrapping.
|
|
355
|
+
|
|
356
|
+
Parameters
|
|
357
|
+
----------
|
|
358
|
+
coords2 : np.ndarray
|
|
359
|
+
Angle array of shape (T, 2) in the given ``scale``.
|
|
360
|
+
scale : {"rad", "deg", "unit", "auto"}
|
|
361
|
+
``rad`` : already in radians.
|
|
362
|
+
``deg`` : degrees -> radians.
|
|
363
|
+
``unit`` : unit circle in [0, 1] -> radians.
|
|
364
|
+
``auto`` : infer unit circle if values look like [0, 1].
|
|
365
|
+
|
|
366
|
+
Returns
|
|
367
|
+
-------
|
|
368
|
+
np.ndarray
|
|
369
|
+
Angles in radians.
|
|
370
|
+
|
|
371
|
+
Examples
|
|
372
|
+
--------
|
|
373
|
+
>>> apply_angle_scale([[0.25, 0.5]], "unit") # doctest: +SKIP
|
|
374
|
+
"""
|
|
375
|
+
coords2 = np.asarray(coords2, float)
|
|
376
|
+
if scale == "rad":
|
|
377
|
+
return coords2
|
|
378
|
+
if scale == "unit":
|
|
379
|
+
return coords2 * (2 * np.pi)
|
|
380
|
+
if scale == "deg":
|
|
381
|
+
return np.deg2rad(coords2)
|
|
382
|
+
if scale == "auto":
|
|
383
|
+
# Heuristic: if values look like [0,1] or [-0.2,1.2], treat as unit circle coords.
|
|
384
|
+
mn = float(np.nanmin(coords2))
|
|
385
|
+
mx = float(np.nanmax(coords2))
|
|
386
|
+
if mx <= 1.2 and mn >= -0.2:
|
|
387
|
+
return coords2 * (2 * np.pi)
|
|
388
|
+
return coords2
|
|
389
|
+
raise ValueError(f"Unknown --scale option: {scale}")
|