canns 0.12.7__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. canns/analyzer/data/__init__.py +3 -11
  2. canns/analyzer/data/asa/__init__.py +74 -0
  3. canns/analyzer/data/asa/cohospace.py +905 -0
  4. canns/analyzer/data/asa/config.py +246 -0
  5. canns/analyzer/data/asa/decode.py +448 -0
  6. canns/analyzer/data/asa/embedding.py +269 -0
  7. canns/analyzer/data/asa/filters.py +208 -0
  8. canns/analyzer/data/asa/fr.py +439 -0
  9. canns/analyzer/data/asa/path.py +389 -0
  10. canns/analyzer/data/asa/plotting.py +1276 -0
  11. canns/analyzer/data/asa/tda.py +901 -0
  12. canns/analyzer/data/legacy/__init__.py +6 -0
  13. canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
  14. canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
  15. canns/analyzer/visualization/core/backend.py +1 -1
  16. canns/analyzer/visualization/core/config.py +77 -0
  17. canns/analyzer/visualization/core/rendering.py +10 -6
  18. canns/analyzer/visualization/energy_plots.py +22 -8
  19. canns/analyzer/visualization/spatial_plots.py +31 -11
  20. canns/analyzer/visualization/theta_sweep_plots.py +15 -6
  21. canns/pipeline/__init__.py +4 -8
  22. canns/pipeline/asa/__init__.py +21 -0
  23. canns/pipeline/asa/__main__.py +11 -0
  24. canns/pipeline/asa/app.py +1000 -0
  25. canns/pipeline/asa/runner.py +1095 -0
  26. canns/pipeline/asa/screens.py +215 -0
  27. canns/pipeline/asa/state.py +248 -0
  28. canns/pipeline/asa/styles.tcss +221 -0
  29. canns/pipeline/asa/widgets.py +233 -0
  30. canns/pipeline/gallery/__init__.py +7 -0
  31. canns/task/open_loop_navigation.py +3 -1
  32. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
  33. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/RECORD +36 -17
  34. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
  35. canns/pipeline/theta_sweep.py +0 -573
  36. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
  37. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,439 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+
6
+ import numpy as np
7
+
8
+ from ...visualization.core import PlotConfig, finalize_figure
9
+
10
+
11
+ def _ensure_parent_dir(save_path: str | None) -> None:
12
+ if save_path:
13
+ parent = os.path.dirname(save_path)
14
+ if parent:
15
+ os.makedirs(parent, exist_ok=True)
16
+
17
+
18
+ def _slice_range(r: tuple[int, int] | None, length: int) -> slice:
19
+ """Convert (start, end) into a safe Python slice within [0, length]."""
20
+ if r is None:
21
+ return slice(0, length)
22
+ s, e = r
23
+ s = 0 if s is None else int(s)
24
+ e = length if e is None else int(e)
25
+ s = max(0, min(length, s))
26
+ e = max(0, min(length, e))
27
+ if e < s:
28
+ e = s
29
+ return slice(s, e)
30
+
31
+
32
+ def compute_fr_heatmap_matrix(
33
+ spike: np.ndarray,
34
+ neuron_range: tuple[int, int] | None = None,
35
+ time_range: tuple[int, int] | None = None,
36
+ *,
37
+ transpose: bool = True,
38
+ normalize: str | None = None,
39
+ ) -> np.ndarray:
40
+ """
41
+ Compute a matrix for FR heatmap display from spike-like data.
42
+
43
+ Parameters
44
+ ----------
45
+ spike : np.ndarray
46
+ Shape (T, N). Can be continuous (float) or binned (int/float).
47
+ neuron_range : (start, end) or None
48
+ Neuron index range in [0, N]. End is exclusive.
49
+ time_range : (start, end) or None
50
+ Time index range in [0, T]. End is exclusive.
51
+ transpose : bool
52
+ If True, returns (N_sel, T_sel) which is convenient for imshow with
53
+ neurons on Y and time on X (like your utils did with spike.T).
54
+ If False, returns (T_sel, N_sel).
55
+ normalize : {'zscore_per_neuron','minmax_per_neuron', None}
56
+ Optional display normalization along time for each neuron.
57
+
58
+ Returns
59
+ -------
60
+ M : np.ndarray
61
+ Heatmap matrix. Default shape (N_sel, T_sel) if transpose=True.
62
+
63
+ Examples
64
+ --------
65
+ >>> M = compute_fr_heatmap_matrix(spikes, transpose=True) # doctest: +SKIP
66
+ >>> M.ndim
67
+ 2
68
+ """
69
+ spike = np.asarray(spike)
70
+ if spike.ndim != 2:
71
+ raise ValueError(f"spike must be 2D (T,N), got shape={spike.shape}")
72
+
73
+ T, N = spike.shape
74
+ t_sl = _slice_range(time_range, T)
75
+ n_sl = _slice_range(neuron_range, N)
76
+
77
+ sub = spike[t_sl, n_sl] # (T_sel, N_sel)
78
+
79
+ # normalization is only for display; does NOT change any downstream indices
80
+ if normalize is not None:
81
+ X = sub.astype(float, copy=False)
82
+ if normalize == "zscore_per_neuron":
83
+ mu = np.mean(X, axis=0, keepdims=True)
84
+ sd = np.std(X, axis=0, keepdims=True)
85
+ sd = np.where(sd == 0, 1.0, sd)
86
+ sub = (X - mu) / sd
87
+ elif normalize == "minmax_per_neuron":
88
+ mn = np.min(X, axis=0, keepdims=True)
89
+ mx = np.max(X, axis=0, keepdims=True)
90
+ den = np.where((mx - mn) == 0, 1.0, (mx - mn))
91
+ sub = (X - mn) / den
92
+ else:
93
+ raise ValueError(f"Unknown normalize={normalize!r}")
94
+
95
+ return sub.T if transpose else sub
96
+
97
+
98
+ def save_fr_heatmap_png(
99
+ M: np.ndarray,
100
+ *,
101
+ title: str = "Firing Rate Heatmap",
102
+ xlabel: str = "Time",
103
+ ylabel: str = "Neuron",
104
+ cmap: str | None = None,
105
+ interpolation: str | None = "nearest",
106
+ origin: str | None = "lower",
107
+ aspect: str | None = "auto",
108
+ clabel: str | None = None,
109
+ colorbar: bool = True,
110
+ dpi: int = 200,
111
+ show: bool | None = None,
112
+ config: PlotConfig | None = None,
113
+ **kwargs,
114
+ ) -> None:
115
+ """
116
+ Save a heatmap PNG from a matrix (typically output of compute_fr_heatmap_matrix).
117
+
118
+ Parameters
119
+ ----------
120
+ M : np.ndarray
121
+ Heatmap matrix (2D).
122
+ title, xlabel, ylabel : str
123
+ Plot labels (used when ``config`` is None or missing fields).
124
+ cmap, interpolation, origin, aspect : str, optional
125
+ Matplotlib imshow options.
126
+ clabel : str, optional
127
+ Colorbar label (defaults to ``config.clabel``).
128
+ colorbar : bool
129
+ Whether to draw a colorbar.
130
+ dpi : int
131
+ Save DPI.
132
+ show : bool | None
133
+ Whether to show the plot (overrides ``config.show`` if not None).
134
+ config : PlotConfig, optional
135
+ Plot configuration. Use ``config.save_path`` to specify output file.
136
+ **kwargs : Any
137
+ Additional ``imshow`` keyword arguments. ``save_path`` may be provided here
138
+ as a fallback if not set in ``config``.
139
+
140
+ Notes
141
+ -----
142
+ - Does not reorder neurons.
143
+ - Uses matplotlib only here (ASA core stays compute-friendly).
144
+
145
+ Examples
146
+ --------
147
+ >>> config = PlotConfig.for_static_plot(save_path="fr.png", show=False) # doctest: +SKIP
148
+ >>> save_fr_heatmap_png(M, config=config) # doctest: +SKIP
149
+ """
150
+ import matplotlib.pyplot as plt # local import to keep ASA light
151
+
152
+ save_path = kwargs.pop("save_path", None)
153
+
154
+ if config is None:
155
+ show_val = False if show is None else show
156
+ config = PlotConfig.for_static_plot(
157
+ title=title,
158
+ xlabel=xlabel,
159
+ ylabel=ylabel,
160
+ save_path=str(save_path) if save_path is not None else None,
161
+ show=show_val,
162
+ )
163
+ else:
164
+ if save_path is not None:
165
+ config.save_path = str(save_path)
166
+ if show is not None:
167
+ config.show = show
168
+ if not config.title:
169
+ config.title = title
170
+ if not config.xlabel:
171
+ config.xlabel = xlabel
172
+ if not config.ylabel:
173
+ config.ylabel = ylabel
174
+
175
+ if config.save_path is None:
176
+ raise ValueError(
177
+ "save_path must be provided via config.save_path or as a keyword argument."
178
+ )
179
+
180
+ config.save_dpi = dpi
181
+
182
+ M = np.asarray(M)
183
+ if M.ndim != 2:
184
+ raise ValueError(f"M must be 2D for heatmap display, got shape={M.shape}")
185
+ fig, ax = plt.subplots(figsize=config.figsize)
186
+ plot_kwargs = config.to_matplotlib_kwargs()
187
+ if cmap is not None and "cmap" not in plot_kwargs:
188
+ plot_kwargs["cmap"] = cmap
189
+ if interpolation is not None and "interpolation" not in plot_kwargs:
190
+ plot_kwargs["interpolation"] = interpolation
191
+ if origin is not None and "origin" not in plot_kwargs:
192
+ plot_kwargs["origin"] = origin
193
+ if aspect is not None and "aspect" not in plot_kwargs:
194
+ plot_kwargs["aspect"] = aspect
195
+ if kwargs:
196
+ plot_kwargs.update(kwargs)
197
+
198
+ im = ax.imshow(M, **plot_kwargs)
199
+ ax.set_title(config.title)
200
+ ax.set_xlabel(config.xlabel)
201
+ ax.set_ylabel(config.ylabel)
202
+ if colorbar:
203
+ label = clabel if clabel is not None else config.clabel
204
+ fig.colorbar(im, ax=ax, label=label)
205
+ fig.tight_layout()
206
+ _ensure_parent_dir(config.save_path)
207
+ finalize_figure(fig, config)
208
+
209
+
210
+ @dataclass
211
+ class FRMResult:
212
+ """Return object for firing-rate map computation.
213
+
214
+ Attributes
215
+ ----------
216
+ frm : np.ndarray
217
+ Firing rate map (bins_x, bins_y).
218
+ occupancy : np.ndarray
219
+ Occupancy counts per spatial bin.
220
+ spike_sum : np.ndarray
221
+ Spike counts per spatial bin.
222
+ x_edges, y_edges : np.ndarray
223
+ Bin edges used for the FRM computation.
224
+
225
+ Examples
226
+ --------
227
+ >>> res = FRMResult(frm=None, occupancy=None, spike_sum=None, x_edges=None, y_edges=None) # doctest: +SKIP
228
+ """
229
+
230
+ frm: np.ndarray
231
+ occupancy: np.ndarray
232
+ spike_sum: np.ndarray
233
+ x_edges: np.ndarray
234
+ y_edges: np.ndarray
235
+
236
+
237
+ def compute_frm(
238
+ spike: np.ndarray,
239
+ x: np.ndarray,
240
+ y: np.ndarray,
241
+ neuron_id: int,
242
+ *,
243
+ bins: int = 50,
244
+ x_range: tuple[float, float] | None = None,
245
+ y_range: tuple[float, float] | None = None,
246
+ min_occupancy: int = 1,
247
+ smoothing: bool = False,
248
+ sigma: float = 1.0,
249
+ nan_for_empty: bool = True,
250
+ ) -> FRMResult:
251
+ """
252
+ Compute a single-neuron firing rate map (FRM) on 2D space.
253
+
254
+ Parameters
255
+ ----------
256
+ spike : np.ndarray
257
+ Shape (T, N). Can be continuous (float) or binned counts (int/float).
258
+ x, y : np.ndarray
259
+ Shape (T,). Position samples aligned with spike rows.
260
+ neuron_id : int
261
+ Neuron index in [0, N).
262
+ bins : int
263
+ Number of spatial bins per dimension.
264
+ x_range, y_range : (min, max) or None
265
+ Explicit ranges. If None, uses data min/max.
266
+ min_occupancy : int
267
+ Bins with occupancy < min_occupancy are treated as empty.
268
+ smoothing : bool
269
+ If True, apply Gaussian smoothing to frm (and optionally to occupancy/spike_sum if you want later).
270
+ sigma : float
271
+ Gaussian sigma for smoothing (in bin units).
272
+ nan_for_empty : bool
273
+ If True, empty bins become NaN; else 0.
274
+
275
+ Returns
276
+ -------
277
+ FRMResult
278
+ frm: 2D array (bins_x, bins_y) in Hz-like units per sample (relative scale).
279
+
280
+ Examples
281
+ --------
282
+ >>> res = compute_frm(spikes, x, y, neuron_id=0) # doctest: +SKIP
283
+ >>> res.frm.shape # doctest: +SKIP
284
+ """
285
+ spike = np.asarray(spike)
286
+ x = np.asarray(x).ravel()
287
+ y = np.asarray(y).ravel()
288
+
289
+ if spike.ndim != 2:
290
+ raise ValueError(f"spike must be 2D (T,N), got shape={spike.shape}")
291
+ T, N = spike.shape
292
+ if len(x) != T or len(y) != T:
293
+ raise ValueError(
294
+ f"x/y length must match spike rows T={T}, got len(x)={len(x)}, len(y)={len(y)}"
295
+ )
296
+ if not (0 <= int(neuron_id) < N):
297
+ raise ValueError(f"neuron_id out of range: {neuron_id} for N={N}")
298
+
299
+ fr = spike[:, int(neuron_id)].astype(float, copy=False)
300
+
301
+ # ranges
302
+ if x_range is None:
303
+ x_min, x_max = float(np.min(x)), float(np.max(x))
304
+ else:
305
+ x_min, x_max = float(x_range[0]), float(x_range[1])
306
+
307
+ if y_range is None:
308
+ y_min, y_max = float(np.min(y)), float(np.max(y))
309
+ else:
310
+ y_min, y_max = float(y_range[0]), float(y_range[1])
311
+
312
+ # Edges (bins+1)
313
+ x_edges = np.linspace(x_min, x_max, bins + 1)
314
+ y_edges = np.linspace(y_min, y_max, bins + 1)
315
+
316
+ # Bin indices
317
+ xi = np.searchsorted(x_edges, x, side="right") - 1
318
+ yi = np.searchsorted(y_edges, y, side="right") - 1
319
+
320
+ # Keep only points inside range
321
+ valid = (xi >= 0) & (xi < bins) & (yi >= 0) & (yi < bins)
322
+ xi = xi[valid]
323
+ yi = yi[valid]
324
+ frv = fr[valid]
325
+
326
+ occupancy = np.zeros((bins, bins), dtype=np.int64)
327
+ spike_sum = np.zeros((bins, bins), dtype=np.float64)
328
+
329
+ # accumulate
330
+ # (Use np.add.at to avoid Python loops)
331
+ np.add.at(occupancy, (xi, yi), 1)
332
+ np.add.at(spike_sum, (xi, yi), frv)
333
+
334
+ # rate = sum / occupancy
335
+ with np.errstate(divide="ignore", invalid="ignore"):
336
+ frm = spike_sum / occupancy
337
+
338
+ # empty bins handling
339
+ empty = occupancy < int(min_occupancy)
340
+ if nan_for_empty:
341
+ frm = frm.astype(np.float64, copy=False)
342
+ frm[empty] = np.nan
343
+ else:
344
+ frm = np.where(empty, 0.0, frm)
345
+
346
+ if smoothing:
347
+ try:
348
+ from scipy.ndimage import gaussian_filter
349
+
350
+ # Smooth while respecting NaNs (simple approach: fill NaN -> 0, smooth weights)
351
+ if np.any(np.isnan(frm)):
352
+ val = np.nan_to_num(frm, nan=0.0)
353
+ w = (~np.isnan(frm)).astype(np.float64)
354
+ val_s = gaussian_filter(val, sigma=float(sigma))
355
+ w_s = gaussian_filter(w, sigma=float(sigma))
356
+ frm = np.divide(val_s, w_s, out=np.full_like(val_s, np.nan), where=(w_s > 1e-12))
357
+ else:
358
+ frm = gaussian_filter(frm, sigma=float(sigma))
359
+ except Exception:
360
+ # If scipy not available, just skip smoothing
361
+ pass
362
+
363
+ return FRMResult(
364
+ frm=frm, occupancy=occupancy, spike_sum=spike_sum, x_edges=x_edges, y_edges=y_edges
365
+ )
366
+
367
+
368
+ def plot_frm(
369
+ frm: np.ndarray,
370
+ *,
371
+ title: str = "Firing Rate Map",
372
+ dpi: int = 200,
373
+ show: bool | None = None,
374
+ config: PlotConfig | None = None,
375
+ **kwargs,
376
+ ) -> None:
377
+ """
378
+ Save FRM as PNG. Expects frm as 2D array (bins,bins).
379
+
380
+ Parameters
381
+ ----------
382
+ frm : np.ndarray
383
+ Firing rate map (2D).
384
+ title : str
385
+ Figure title (used when ``config`` is None or missing fields).
386
+ dpi : int
387
+ Save DPI.
388
+ show : bool | None
389
+ Whether to show the plot (overrides ``config.show`` if not None).
390
+ config : PlotConfig, optional
391
+ Plot configuration. Use ``config.save_path`` to specify output file.
392
+ **kwargs : Any
393
+ Additional ``imshow`` keyword arguments. ``save_path`` may be provided here
394
+ as a fallback if not set in ``config``.
395
+
396
+ Examples
397
+ --------
398
+ >>> cfg = PlotConfig.for_static_plot(save_path="frm.png", show=False) # doctest: +SKIP
399
+ >>> plot_frm(frm, config=cfg) # doctest: +SKIP
400
+ """
401
+ from ...visualization import plot_firing_field_heatmap
402
+
403
+ save_path = kwargs.pop("save_path", None)
404
+
405
+ if config is None:
406
+ show_val = False if show is None else show
407
+ config = PlotConfig.for_static_plot(
408
+ title=title,
409
+ xlabel="X bin",
410
+ ylabel="Y bin",
411
+ save_path=str(save_path) if save_path is not None else None,
412
+ show=show_val,
413
+ )
414
+ else:
415
+ if save_path is not None:
416
+ config.save_path = str(save_path)
417
+ if show is not None:
418
+ config.show = show
419
+ if not config.title:
420
+ config.title = title
421
+ if not config.xlabel:
422
+ config.xlabel = "X bin"
423
+ if not config.ylabel:
424
+ config.ylabel = "Y bin"
425
+
426
+ if config.save_path is None:
427
+ raise ValueError(
428
+ "save_path must be provided via config.save_path or as a keyword argument."
429
+ )
430
+
431
+ config.save_dpi = dpi
432
+
433
+ frm = np.asarray(frm)
434
+ plot_firing_field_heatmap(
435
+ frm,
436
+ config=config,
437
+ origin="lower",
438
+ **kwargs,
439
+ )