canns 0.12.7__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. canns/analyzer/data/__init__.py +3 -11
  2. canns/analyzer/data/asa/__init__.py +74 -0
  3. canns/analyzer/data/asa/cohospace.py +905 -0
  4. canns/analyzer/data/asa/config.py +246 -0
  5. canns/analyzer/data/asa/decode.py +448 -0
  6. canns/analyzer/data/asa/embedding.py +269 -0
  7. canns/analyzer/data/asa/filters.py +208 -0
  8. canns/analyzer/data/asa/fr.py +439 -0
  9. canns/analyzer/data/asa/path.py +389 -0
  10. canns/analyzer/data/asa/plotting.py +1276 -0
  11. canns/analyzer/data/asa/tda.py +901 -0
  12. canns/analyzer/data/legacy/__init__.py +6 -0
  13. canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
  14. canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
  15. canns/analyzer/visualization/core/backend.py +1 -1
  16. canns/analyzer/visualization/core/config.py +77 -0
  17. canns/analyzer/visualization/core/rendering.py +10 -6
  18. canns/analyzer/visualization/energy_plots.py +22 -8
  19. canns/analyzer/visualization/spatial_plots.py +31 -11
  20. canns/analyzer/visualization/theta_sweep_plots.py +15 -6
  21. canns/pipeline/__init__.py +4 -8
  22. canns/pipeline/asa/__init__.py +21 -0
  23. canns/pipeline/asa/__main__.py +11 -0
  24. canns/pipeline/asa/app.py +1000 -0
  25. canns/pipeline/asa/runner.py +1095 -0
  26. canns/pipeline/asa/screens.py +215 -0
  27. canns/pipeline/asa/state.py +248 -0
  28. canns/pipeline/asa/styles.tcss +221 -0
  29. canns/pipeline/asa/widgets.py +233 -0
  30. canns/pipeline/gallery/__init__.py +7 -0
  31. canns/task/open_loop_navigation.py +3 -1
  32. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
  33. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/RECORD +36 -17
  34. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
  35. canns/pipeline/theta_sweep.py +0 -573
  36. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
  37. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,905 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from scipy.stats import circvar
8
+
9
+ from ...visualization.core import PlotConfig, finalize_figure
10
+
11
+
12
+ def _ensure_plot_config(
13
+ config: PlotConfig | None,
14
+ factory,
15
+ *,
16
+ kwargs: dict | None = None,
17
+ **defaults,
18
+ ) -> PlotConfig:
19
+ if config is None:
20
+ defaults.update({"kwargs": kwargs or {}})
21
+ return factory(**defaults)
22
+
23
+ if kwargs:
24
+ config_kwargs = config.kwargs or {}
25
+ config_kwargs.update(kwargs)
26
+ config.kwargs = config_kwargs
27
+ return config
28
+
29
+
30
+ def _ensure_parent_dir(save_path: str | None) -> None:
31
+ if save_path:
32
+ parent = os.path.dirname(save_path)
33
+ if parent:
34
+ os.makedirs(parent, exist_ok=True)
35
+
36
+
37
+ # =====================================================================
38
+ # CohoSpace visualization and selectivity metrics (CohoScore)
39
+ # =====================================================================
40
+
41
+
42
+ def _coho_coords_to_degrees(coords: np.ndarray) -> np.ndarray:
43
+ """
44
+ Convert decoded coho coordinates (T x 2, radians) into degrees in [0, 360).
45
+ """
46
+ return np.degrees(coords % (2 * np.pi))
47
+
48
+
49
+ def _align_activity_to_coords(
50
+ coords: np.ndarray,
51
+ activity: np.ndarray,
52
+ times: np.ndarray | None = None,
53
+ *,
54
+ label: str = "activity",
55
+ auto_filter: bool = True,
56
+ ) -> np.ndarray:
57
+ """
58
+ Align activity to coords by optional time indices and validate lengths.
59
+ """
60
+ coords = np.asarray(coords)
61
+ activity = np.asarray(activity)
62
+
63
+ if times is not None:
64
+ times = np.asarray(times)
65
+ try:
66
+ activity = activity[times]
67
+ except Exception as exc:
68
+ raise ValueError(
69
+ f"Failed to index {label} with `times`. Ensure `times` indexes the original time axis."
70
+ ) from exc
71
+
72
+ if activity.shape[0] != coords.shape[0]:
73
+ # Try to reproduce decode's zero-spike filtering if lengths mismatch.
74
+ if auto_filter and times is None and activity.ndim == 2:
75
+ mask = np.sum(activity > 0, axis=1) >= 1
76
+ if mask.sum() == coords.shape[0]:
77
+ activity = activity[mask]
78
+ else:
79
+ raise ValueError(
80
+ f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
81
+ "If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
82
+ "`times=decoding['times']` or slice the activity accordingly."
83
+ )
84
+ else:
85
+ raise ValueError(
86
+ f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
87
+ "If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
88
+ "`times=decoding['times']` or slice the activity accordingly."
89
+ )
90
+
91
+ return activity
92
+
93
+
94
+ def plot_cohospace_trajectory(
95
+ coords: np.ndarray,
96
+ times: np.ndarray | None = None,
97
+ subsample: int = 1,
98
+ figsize: tuple[int, int] = (6, 6),
99
+ cmap: str = "viridis",
100
+ save_path: str | None = None,
101
+ show: bool = False,
102
+ config: PlotConfig | None = None,
103
+ ) -> plt.Axes:
104
+ """
105
+ Plot a trajectory in cohomology space.
106
+
107
+ Parameters
108
+ ----------
109
+ coords : ndarray, shape (T, 2)
110
+ Decoded cohomology angles (theta1, theta2). Values may be in radians or in [0, 1] "unit circle"
111
+ convention depending on upstream decoding; this function will convert to degrees for plotting.
112
+ times : ndarray, optional, shape (T,)
113
+ Optional time array used to color points. If None, uses arange(T).
114
+ subsample : int
115
+ Downsampling step (>1 reduces the number of plotted points).
116
+ figsize : tuple
117
+ Matplotlib figure size.
118
+ cmap : str
119
+ Matplotlib colormap name.
120
+ save_path : str, optional
121
+ If provided, saves the figure to this path.
122
+ show : bool
123
+ If True, calls plt.show(). If False, closes the figure and returns the Axes.
124
+
125
+ Returns
126
+ -------
127
+ ax : matplotlib.axes.Axes
128
+ The Axes containing the plot.
129
+
130
+ Examples
131
+ --------
132
+ >>> fig = plot_cohospace_trajectory(coords, subsample=2, show=False) # doctest: +SKIP
133
+ """
134
+
135
+ try:
136
+ subsample_i = int(subsample)
137
+ except Exception:
138
+ subsample_i = 1
139
+ if subsample_i < 1:
140
+ subsample_i = 1
141
+
142
+ coords = np.asarray(coords)
143
+ if coords.ndim != 2 or coords.shape[1] != 2:
144
+ raise ValueError(f"`coords` must have shape (T, 2). Got {coords.shape}.")
145
+
146
+ theta_deg = _coho_coords_to_degrees(coords)
147
+ if subsample_i > 1:
148
+ theta_deg = theta_deg[::subsample_i]
149
+
150
+ if times is None:
151
+ times_vis = np.arange(theta_deg.shape[0])
152
+ else:
153
+ times_vis = np.asarray(times)
154
+ if times_vis.shape[0] != coords.shape[0]:
155
+ raise ValueError(
156
+ f"`times` length must match coords length. Got times={times_vis.shape[0]}, coords={coords.shape[0]}."
157
+ )
158
+ if subsample_i > 1:
159
+ times_vis = times_vis[::subsample_i]
160
+
161
+ config = _ensure_plot_config(
162
+ config,
163
+ PlotConfig.for_static_plot,
164
+ title="CohoSpace trajectory",
165
+ xlabel="theta1 (deg)",
166
+ ylabel="theta2 (deg)",
167
+ figsize=figsize,
168
+ save_path=save_path,
169
+ show=show,
170
+ )
171
+
172
+ fig, ax = plt.subplots(figsize=config.figsize)
173
+ sc = ax.scatter(
174
+ theta_deg[:, 0],
175
+ theta_deg[:, 1],
176
+ c=times_vis,
177
+ cmap=cmap,
178
+ s=3,
179
+ alpha=0.8,
180
+ )
181
+ cbar = plt.colorbar(sc, ax=ax)
182
+ cbar.set_label("Time")
183
+
184
+ ax.set_xlim(0, 360)
185
+ ax.set_ylim(0, 360)
186
+ ax.set_xlabel(config.xlabel)
187
+ ax.set_ylabel(config.ylabel)
188
+ ax.set_title(config.title)
189
+ ax.set_aspect("equal", adjustable="box")
190
+ ax.grid(True, alpha=0.2)
191
+
192
+ _ensure_parent_dir(config.save_path)
193
+ finalize_figure(fig, config)
194
+ return ax
195
+
196
+
197
+ def plot_cohospace_neuron(
198
+ coords: np.ndarray,
199
+ activity: np.ndarray,
200
+ neuron_id: int,
201
+ mode: str = "fr", # "fr" or "spike"
202
+ top_percent: float = 5.0, # Used in FR mode
203
+ times: np.ndarray | None = None,
204
+ auto_filter: bool = True,
205
+ figsize: tuple = (6, 6),
206
+ cmap: str = "hot",
207
+ save_path: str | None = None,
208
+ show: bool = True,
209
+ config: PlotConfig | None = None,
210
+ ) -> plt.Figure:
211
+ """
212
+ Overlay a single neuron's activity on the cohomology-space trajectory.
213
+
214
+ This is a visualization helper:
215
+ - mode="fr": marks the top `top_percent`%% time points by firing rate for the given neuron.
216
+ - mode="spike": marks all time points where spike > 0 for the given neuron.
217
+
218
+ Parameters
219
+ ----------
220
+ coords : ndarray, shape (T, 2)
221
+ Decoded cohomology angles (theta1, theta2), in radians.
222
+ activity : ndarray, shape (T, N)
223
+ Activity matrix (continuous firing rate or binned spikes).
224
+ times : ndarray, optional, shape (T_coords,)
225
+ Optional indices to align activity to coords when coords are computed on a subset of timepoints.
226
+ auto_filter : bool
227
+ If True and lengths mismatch, auto-filter activity with activity>0 to mimic decode filtering.
228
+ neuron_id : int
229
+ Neuron index to visualize.
230
+ mode : {"fr", "spike"}
231
+ top_percent : float
232
+ Used only when mode="fr". For example, 5.0 means "top 5%%" time points.
233
+ figsize, cmap, save_path, show : see `plot_cohospace_trajectory`.
234
+
235
+ Returns
236
+ -------
237
+ ax : matplotlib.axes.Axes
238
+
239
+ Examples
240
+ --------
241
+ >>> plot_cohospace_neuron(coords, spikes, neuron_id=0, show=False) # doctest: +SKIP
242
+ """
243
+ coords = np.asarray(coords)
244
+ activity = _align_activity_to_coords(
245
+ coords, activity, times, label="activity", auto_filter=auto_filter
246
+ )
247
+ theta_deg = _coho_coords_to_degrees(coords)
248
+
249
+ signal = activity[:, neuron_id]
250
+
251
+ if mode == "fr":
252
+ # Select the neuron's top `top_percent`% time points
253
+ threshold = np.percentile(signal, 100 - top_percent)
254
+ idx = signal >= threshold
255
+ color = signal[idx]
256
+ title = f"Neuron {neuron_id} FR top {top_percent:.1f}% on coho-space"
257
+ use_cmap = cmap
258
+ elif mode == "spike":
259
+ idx = signal > 0
260
+ color = None
261
+ title = f"Neuron {neuron_id} spikes on coho-space"
262
+ use_cmap = None
263
+ else:
264
+ raise ValueError("mode must be 'fr' or 'spike'")
265
+
266
+ config = _ensure_plot_config(
267
+ config,
268
+ PlotConfig.for_static_plot,
269
+ title=title,
270
+ xlabel="Theta 1 (°)",
271
+ ylabel="Theta 2 (°)",
272
+ figsize=figsize,
273
+ save_path=save_path,
274
+ show=show,
275
+ )
276
+
277
+ fig, ax = plt.subplots(figsize=config.figsize)
278
+ sc = ax.scatter(
279
+ theta_deg[idx, 0],
280
+ theta_deg[idx, 1],
281
+ c=color if mode == "fr" else "red",
282
+ cmap=use_cmap,
283
+ s=5,
284
+ alpha=0.9,
285
+ )
286
+
287
+ if mode == "fr":
288
+ cbar = plt.colorbar(sc, ax=ax)
289
+ cbar.set_label("Firing rate")
290
+
291
+ ax.set_xlim(0, 360)
292
+ ax.set_ylim(0, 360)
293
+ ax.set_xlabel(config.xlabel)
294
+ ax.set_ylabel(config.ylabel)
295
+ ax.set_title(config.title)
296
+
297
+ _ensure_parent_dir(config.save_path)
298
+ finalize_figure(fig, config)
299
+
300
+ return fig
301
+
302
+
303
+ def plot_cohospace_population(
304
+ coords: np.ndarray,
305
+ activity: np.ndarray,
306
+ neuron_ids: list[int] | np.ndarray,
307
+ mode: str = "fr", # "fr" or "spike"
308
+ top_percent: float = 5.0, # Used in FR mode
309
+ times: np.ndarray | None = None,
310
+ auto_filter: bool = True,
311
+ figsize: tuple = (6, 6),
312
+ cmap: str = "hot",
313
+ save_path: str | None = None,
314
+ show: bool = True,
315
+ config: PlotConfig | None = None,
316
+ ) -> plt.Figure:
317
+ """
318
+ Plot aggregated activity from multiple neurons in cohomology space.
319
+
320
+ For mode="fr":
321
+ - For each neuron, select its top `top_percent`%% time points by firing rate.
322
+ - Aggregate (sum) firing rates over the selected points and plot as colors.
323
+
324
+ For mode="spike":
325
+ - For each neuron, count spikes at each time point (spike > 0).
326
+ - Aggregate counts over neurons and plot as colors.
327
+
328
+ Parameters
329
+ ----------
330
+ coords : ndarray, shape (T, 2)
331
+ activity : ndarray, shape (T, N)
332
+ times : ndarray, optional, shape (T_coords,)
333
+ Optional indices to align activity to coords when coords are computed on a subset of timepoints.
334
+ auto_filter : bool
335
+ If True and lengths mismatch, auto-filter activity with activity>0 to mimic decode filtering.
336
+ neuron_ids : iterable[int]
337
+ Neuron indices to include (use range(N) to include all).
338
+ mode : {"fr", "spike"}
339
+ top_percent : float
340
+ Used only when mode="fr".
341
+ figsize, cmap, save_path, show : see `plot_cohospace_trajectory`.
342
+
343
+ Returns
344
+ -------
345
+ ax : matplotlib.axes.Axes
346
+
347
+ Examples
348
+ --------
349
+ >>> plot_cohospace_population(coords, spikes, neuron_ids=[0, 1, 2], show=False) # doctest: +SKIP
350
+ """
351
+ coords = np.asarray(coords)
352
+ activity = _align_activity_to_coords(
353
+ coords, activity, times, label="activity", auto_filter=auto_filter
354
+ )
355
+ neuron_ids = np.asarray(neuron_ids, dtype=int)
356
+
357
+ theta_deg = _coho_coords_to_degrees(coords)
358
+
359
+ T = activity.shape[0]
360
+ mask = np.zeros(T, dtype=bool)
361
+ agg_color = np.zeros(T, dtype=float)
362
+
363
+ for n in neuron_ids:
364
+ signal = activity[:, n]
365
+
366
+ if mode == "fr":
367
+ threshold = np.percentile(signal, 100 - top_percent)
368
+ idx = signal >= threshold
369
+ agg_color[idx] += signal[idx]
370
+ mask |= idx
371
+ elif mode == "spike":
372
+ idx = signal > 0
373
+ agg_color[idx] += 1.0
374
+ mask |= idx
375
+ else:
376
+ raise ValueError("mode must be 'fr' or 'spike'")
377
+
378
+ config = _ensure_plot_config(
379
+ config,
380
+ PlotConfig.for_static_plot,
381
+ title=f"{len(neuron_ids)} neurons on coho-space",
382
+ xlabel="Theta 1 (°)",
383
+ ylabel="Theta 2 (°)",
384
+ figsize=figsize,
385
+ save_path=save_path,
386
+ show=show,
387
+ )
388
+
389
+ fig, ax = plt.subplots(figsize=config.figsize)
390
+ sc = ax.scatter(
391
+ theta_deg[mask, 0],
392
+ theta_deg[mask, 1],
393
+ c=agg_color[mask],
394
+ cmap=cmap,
395
+ s=5,
396
+ alpha=0.9,
397
+ )
398
+ cbar = plt.colorbar(sc, ax=ax)
399
+ label = "Aggregate FR" if mode == "fr" else "Spike count"
400
+ cbar.set_label(label)
401
+
402
+ ax.set_xlim(0, 360)
403
+ ax.set_ylim(0, 360)
404
+ ax.set_xlabel(config.xlabel)
405
+ ax.set_ylabel(config.ylabel)
406
+ ax.set_title(config.title)
407
+
408
+ _ensure_parent_dir(config.save_path)
409
+ finalize_figure(fig, config)
410
+
411
+ return fig
412
+
413
+
414
+ def compute_cohoscore(
415
+ coords: np.ndarray,
416
+ activity: np.ndarray,
417
+ top_percent: float = 2.0,
418
+ times: np.ndarray | None = None,
419
+ auto_filter: bool = True,
420
+ ) -> np.ndarray:
421
+ """
422
+ Compute a simple cohomology-space selectivity score (CohoScore) for each neuron.
423
+
424
+ For each neuron:
425
+ - Select "active" time points:
426
+ - If top_percent is None: all time points with activity > 0
427
+ - Else: top `top_percent`%% time points by activity value
428
+ - Compute circular variance for theta1 and theta2 on the selected points.
429
+ - CohoScore = 0.5 * (var(theta1) + var(theta2))
430
+
431
+ Interpretation:
432
+ - Smaller score => points are more concentrated in coho space => higher selectivity.
433
+
434
+ Parameters
435
+ ----------
436
+ coords : ndarray, shape (T, 2)
437
+ Decoded cohomology angles (theta1, theta2), in radians.
438
+ activity : ndarray, shape (T, N)
439
+ times : ndarray, optional, shape (T_coords,)
440
+ Optional indices to align activity to coords when coords are computed on a subset of timepoints.
441
+ auto_filter : bool
442
+ If True and lengths mismatch, auto-filter activity with activity>0 to mimic decode filtering.
443
+ Activity matrix (FR or spikes).
444
+ top_percent : float | None
445
+ Percentage for selecting active points (e.g., 2.0 means top 2%%). If None, use activity>0.
446
+
447
+ Returns
448
+ -------
449
+ scores : ndarray, shape (N,)
450
+ CohoScore per neuron (NaN for neurons with too few points).
451
+
452
+ Examples
453
+ --------
454
+ >>> scores = compute_cohoscore(coords, spikes) # doctest: +SKIP
455
+ >>> scores.shape[0] # doctest: +SKIP
456
+ """
457
+ coords = np.asarray(coords)
458
+ activity = _align_activity_to_coords(
459
+ coords, activity, times, label="activity", auto_filter=auto_filter
460
+ )
461
+ T, N = activity.shape
462
+
463
+ theta = coords % (2 * np.pi) # Ensure values are in [0, 2π)
464
+ scores = np.zeros(N, dtype=float)
465
+
466
+ for n in range(N):
467
+ signal = activity[:, n]
468
+
469
+ if top_percent is None:
470
+ idx = signal > 0 # Use all time points with spikes
471
+ else:
472
+ threshold = np.percentile(signal, 100 - top_percent)
473
+ idx = signal >= threshold
474
+
475
+ if np.sum(idx) < 5:
476
+ scores[n] = np.nan # Too sparse; unreliable
477
+ continue
478
+
479
+ theta1 = theta[idx, 0]
480
+ theta2 = theta[idx, 1]
481
+
482
+ var1 = circvar(theta1, high=2 * np.pi, low=0)
483
+ var2 = circvar(theta2, high=2 * np.pi, low=0)
484
+
485
+ scores[n] = 0.5 * (var1 + var2)
486
+
487
+ return scores
488
+
489
+
490
+ def skew_transform_torus(coords):
491
+ """
492
+ Convert torus angles (theta1, theta2) into coordinates in a skewed parallelogram fundamental domain.
493
+
494
+ Given theta1, theta2 in radians, map:
495
+ x = theta1 + 0.5 * theta2
496
+ y = (sqrt(3)/2) * theta2
497
+
498
+ This is a linear change of basis that turns the square [0, 2π)×[0, 2π) into a 60-degree
499
+ parallelogram, which is convenient for visualizing wrap-around behavior on a 2-torus.
500
+
501
+ Parameters
502
+ ----------
503
+ coords : ndarray, shape (T, 2)
504
+ Angles (theta1, theta2) in radians.
505
+
506
+ Returns
507
+ -------
508
+ xy : ndarray, shape (T, 2)
509
+ Skewed planar coordinates.
510
+ """
511
+ coords = np.asarray(coords)
512
+ if coords.ndim != 2 or coords.shape[1] != 2:
513
+ raise ValueError(f"coords must be (T,2), got {coords.shape}")
514
+
515
+ theta1 = coords[:, 0]
516
+ theta2 = coords[:, 1]
517
+
518
+ # Linear change of basis (NO nonlinear scaling)
519
+ x = theta1 + 0.5 * theta2
520
+ y = (np.sqrt(3) / 2.0) * theta2
521
+
522
+ return np.stack([x, y], axis=1)
523
+
524
+
525
+ def draw_torus_parallelogram_grid(ax, n_tiles=1, color="0.7", lw=1.0, alpha=0.8):
526
+ """
527
+ Draw parallelogram grid corresponding to torus fundamental domain.
528
+
529
+ Fundamental vectors:
530
+ e1 = (2π, 0)
531
+ e2 = (π, √3 π)
532
+
533
+ Parameters
534
+ ----------
535
+ ax : matplotlib axis
536
+ n_tiles : int
537
+ How many tiles to draw in +/- directions (visual aid).
538
+ n_tiles=1 means draw [-1, 0, 1] shifts.
539
+ """
540
+ e1 = np.array([2 * np.pi, 0.0])
541
+ e2 = np.array([np.pi, np.sqrt(3) * np.pi])
542
+
543
+ shifts = range(-n_tiles, n_tiles + 1)
544
+
545
+ for i in shifts:
546
+ for j in shifts:
547
+ origin = i * e1 + j * e2
548
+ corners = np.array([origin, origin + e1, origin + e1 + e2, origin + e2, origin])
549
+ ax.plot(corners[:, 0], corners[:, 1], color=color, lw=lw, alpha=alpha)
550
+
551
+
552
+ def tile_parallelogram_points(xy, n_tiles=1):
553
+ """
554
+ Tile points in the skewed (parallelogram) torus fundamental domain.
555
+
556
+ This is mainly for static visualizations so you can visually inspect continuity
557
+ across domain boundaries.
558
+
559
+ Parameters
560
+ ----------
561
+ points : ndarray, shape (T, 2)
562
+ Points in the skewed plane (same coordinates as returned by `skew_transform_torus`).
563
+ n_tiles : int
564
+ Number of tiles to extend around the base domain.
565
+ - n_tiles=1 produces a 3x3 tiling
566
+ - n_tiles=2 produces a 5x5 tiling
567
+
568
+ Returns
569
+ -------
570
+ tiled : ndarray
571
+ Tiled points.
572
+ """
573
+ xy = np.asarray(xy, dtype=float)
574
+
575
+ e1 = np.array([2 * np.pi, 0.0])
576
+ e2 = np.array([np.pi, np.sqrt(3) * np.pi])
577
+
578
+ out = []
579
+ for i in range(-n_tiles, n_tiles + 1):
580
+ for j in range(-n_tiles, n_tiles + 1):
581
+ out.append(xy + i * e1 + j * e2)
582
+
583
+ return np.vstack(out) if len(out) else xy
584
+
585
+
586
+ def plot_cohospace_neuron_skewed(
587
+ coords,
588
+ activity,
589
+ neuron_id,
590
+ mode="spike",
591
+ top_percent=2.0,
592
+ times: np.ndarray | None = None,
593
+ auto_filter: bool = True,
594
+ save_path=None,
595
+ show=None,
596
+ ax=None,
597
+ show_grid=True,
598
+ n_tiles=1,
599
+ s=6,
600
+ alpha=0.8,
601
+ config: PlotConfig | None = None,
602
+ ):
603
+ """
604
+ Plot single-neuron CohoSpace on skewed torus domain.
605
+
606
+ Parameters
607
+ ----------
608
+ coords : ndarray, shape (T, 2)
609
+ Decoded circular coordinates (theta1, theta2), in radians.
610
+ activity : ndarray, shape (T, N)
611
+ Activity matrix aligned with coords.
612
+ neuron_id : int
613
+ Neuron index.
614
+ mode : {"spike", "fr"}
615
+ spike: use activity > 0
616
+ fr: use top_percent threshold
617
+ top_percent : float
618
+ Percentile for FR thresholding.
619
+ auto_filter : bool
620
+ If True and lengths mismatch, auto-filter activity with activity>0 to mimic decode filtering.
621
+ """
622
+ coords = np.asarray(coords)
623
+ activity = _align_activity_to_coords(
624
+ coords, activity, times, label="activity", auto_filter=auto_filter
625
+ )
626
+
627
+ # --- normalize angles to [0, 2π)
628
+ coords = coords % (2 * np.pi)
629
+
630
+ # --- select neuron activity
631
+ a = activity[:, neuron_id]
632
+
633
+ if mode == "spike":
634
+ mask = a > 0
635
+ elif mode == "fr":
636
+ thr = np.percentile(a, 100 - top_percent)
637
+ mask = a >= thr
638
+ else:
639
+ raise ValueError(f"Unknown mode: {mode}")
640
+
641
+ val = a[mask] # Used for FR-mode coloring
642
+
643
+ if config is None:
644
+ config = PlotConfig.for_static_plot(
645
+ title=f"Neuron {neuron_id} – CohoSpace (skewed, mode={mode})",
646
+ xlabel=r"$\theta_1 + \frac{1}{2}\theta_2$",
647
+ ylabel=r"$\frac{\sqrt{3}}{2}\theta_2$",
648
+ figsize=(5, 5),
649
+ save_path=save_path,
650
+ show=bool(show) if show is not None else False,
651
+ )
652
+ else:
653
+ if save_path is not None:
654
+ config.save_path = save_path
655
+ if show is not None:
656
+ config.show = show
657
+ if not config.title:
658
+ config.title = f"Neuron {neuron_id} – CohoSpace (skewed, mode={mode})"
659
+ if not config.xlabel:
660
+ config.xlabel = r"$\theta_1 + \frac{1}{2}\theta_2$"
661
+ if not config.ylabel:
662
+ config.ylabel = r"$\frac{\sqrt{3}}{2}\theta_2$"
663
+
664
+ created_fig = ax is None
665
+ if created_fig:
666
+ fig, ax = plt.subplots(figsize=config.figsize)
667
+ else:
668
+ fig = ax.figure
669
+
670
+ # --- fundamental domain vectors in skew plane
671
+ e1 = np.array([2 * np.pi, 0.0])
672
+ e2 = np.array([np.pi, np.sqrt(3) * np.pi])
673
+
674
+ def _draw_single_domain(ax):
675
+ P00 = np.array([0.0, 0.0])
676
+ P10 = e1
677
+ P01 = e2
678
+ P11 = e1 + e2
679
+ poly = np.vstack([P00, P10, P11, P01, P00])
680
+ ax.plot(poly[:, 0], poly[:, 1], lw=1.2, color="0.35")
681
+
682
+ def _annotate_corners(ax):
683
+ P00 = np.array([0.0, 0.0])
684
+ P10 = e1
685
+ P01 = e2
686
+ P11 = e1 + e2
687
+
688
+ corners = np.vstack([P00, P10, P01, P11])
689
+ xmin, ymin = corners.min(axis=0)
690
+ xmax, ymax = corners.max(axis=0)
691
+ padx = 0.02 * (xmax - xmin)
692
+ pady = 0.02 * (ymax - ymin)
693
+
694
+ bbox = dict(facecolor="white", edgecolor="none", alpha=0.7, pad=1.0)
695
+
696
+ ax.text(
697
+ P00[0] + padx, P00[1] + pady, "(0,0)", fontsize=10, ha="left", va="bottom", bbox=bbox
698
+ )
699
+ ax.text(
700
+ P10[0] - padx, P10[1] + pady, "(2π,0)", fontsize=10, ha="right", va="bottom", bbox=bbox
701
+ )
702
+ ax.text(P01[0] + padx, P01[1] - pady, "(0,2π)", fontsize=10, ha="left", va="top", bbox=bbox)
703
+ ax.text(
704
+ P11[0] - padx, P11[1] - pady, "(2π,2π)", fontsize=10, ha="right", va="top", bbox=bbox
705
+ )
706
+
707
+ # --- skew transform
708
+ xy = skew_transform_torus(coords[mask])
709
+
710
+ # Tiling: if points are tiled, values must be tiled too (FR mode) to keep lengths consistent
711
+ if n_tiles and n_tiles > 0:
712
+ xy = tile_parallelogram_points(xy, n_tiles=n_tiles)
713
+ if mode == "fr":
714
+ val = np.tile(val, (2 * n_tiles + 1) ** 2)
715
+
716
+ # --- scatter
717
+ if mode == "fr":
718
+ sc = ax.scatter(xy[:, 0], xy[:, 1], c=val, s=s, alpha=alpha, cmap="viridis")
719
+ fig.colorbar(sc, ax=ax, shrink=0.85, pad=0.02, label="activity")
720
+ else:
721
+ ax.scatter(xy[:, 0], xy[:, 1], s=s, alpha=alpha, color="tab:blue")
722
+
723
+ # Always draw the base domain boundary
724
+ _draw_single_domain(ax)
725
+
726
+ # Grid is optional (debug aid); when tiles=0 only the base domain is drawn
727
+ if show_grid:
728
+ draw_torus_parallelogram_grid(ax, n_tiles=n_tiles)
729
+
730
+ _annotate_corners(ax)
731
+
732
+ # Fix view limits: tiles=0 shows base domain; tiles>0 shows the tiled extent
733
+ base = np.vstack([[0, 0], e1, e2, e1 + e2])
734
+
735
+ if n_tiles and n_tiles > 0:
736
+ # Expand view by n_tiles rings around the base domain
737
+ # Translation vectors for tiling are i*e1 + j*e2
738
+ shifts = []
739
+ for i in range(-n_tiles, n_tiles + 1):
740
+ for j in range(-n_tiles, n_tiles + 1):
741
+ shifts.append(i * e1 + j * e2)
742
+ shifts = np.asarray(shifts) # ((2n+1)^2, 2)
743
+
744
+ all_corners = (base[None, :, :] + shifts[:, None, :]).reshape(-1, 2)
745
+ xmin, ymin = all_corners.min(axis=0)
746
+ xmax, ymax = all_corners.max(axis=0)
747
+ else:
748
+ xmin, ymin = base.min(axis=0)
749
+ xmax, ymax = base.max(axis=0)
750
+
751
+ padx = 0.03 * (xmax - xmin)
752
+ pady = 0.03 * (ymax - ymin)
753
+ ax.set_xlim(xmin - padx, xmax + padx)
754
+ ax.set_ylim(ymin - pady, ymax + pady)
755
+
756
+ ax.set_aspect("equal")
757
+ ax.set_xlabel(config.xlabel)
758
+ ax.set_ylabel(config.ylabel)
759
+ ax.set_title(config.title)
760
+
761
+ if created_fig:
762
+ _ensure_parent_dir(config.save_path)
763
+ finalize_figure(fig, config)
764
+ else:
765
+ if config.save_path is not None:
766
+ _ensure_parent_dir(config.save_path)
767
+ fig.savefig(config.save_path, **config.to_savefig_kwargs())
768
+ if config.show:
769
+ plt.show()
770
+
771
+ return ax
772
+
773
+
774
+ def plot_cohospace_population_skewed(
775
+ coords,
776
+ activity,
777
+ neuron_ids,
778
+ mode="spike",
779
+ top_percent=2.0,
780
+ times: np.ndarray | None = None,
781
+ auto_filter: bool = True,
782
+ save_path=None,
783
+ show=False,
784
+ ax=None,
785
+ show_grid=True,
786
+ n_tiles=1,
787
+ s=4,
788
+ alpha=0.5,
789
+ config: PlotConfig | None = None,
790
+ ):
791
+ """
792
+ Plot population CohoSpace on skewed torus domain.
793
+
794
+ neuron_ids : list or ndarray
795
+ Neurons to include (e.g. top-K by CohoScore).
796
+ auto_filter : bool
797
+ If True and lengths mismatch, auto-filter activity with activity>0 to mimic decode filtering.
798
+ """
799
+ coords = np.asarray(coords)
800
+ activity = _align_activity_to_coords(
801
+ coords, activity, times, label="activity", auto_filter=auto_filter
802
+ )
803
+ coords = coords % (2 * np.pi)
804
+
805
+ if config is None:
806
+ config = PlotConfig.for_static_plot(
807
+ title=f"Population CohoSpace (skewed, n={len(neuron_ids)}, mode={mode})",
808
+ xlabel=r"$\theta_1 + \frac{1}{2}\theta_2$",
809
+ ylabel=r"$\frac{\sqrt{3}}{2}\theta_2$",
810
+ figsize=(5, 5),
811
+ save_path=save_path,
812
+ show=show,
813
+ )
814
+ else:
815
+ if save_path is not None:
816
+ config.save_path = save_path
817
+ if show is not None:
818
+ config.show = show
819
+ if not config.title:
820
+ config.title = f"Population CohoSpace (skewed, n={len(neuron_ids)}, mode={mode})"
821
+ if not config.xlabel:
822
+ config.xlabel = r"$\theta_1 + \frac{1}{2}\theta_2$"
823
+ if not config.ylabel:
824
+ config.ylabel = r"$\frac{\sqrt{3}}{2}\theta_2$"
825
+
826
+ created_fig = ax is None
827
+ if created_fig:
828
+ fig, ax = plt.subplots(figsize=config.figsize)
829
+ else:
830
+ fig = ax.figure
831
+
832
+ # --- fundamental domain vectors in skew plane
833
+ e1 = np.array([2 * np.pi, 0.0])
834
+ e2 = np.array([np.pi, np.sqrt(3) * np.pi])
835
+
836
+ def _draw_single_domain(ax):
837
+ P00 = np.array([0.0, 0.0])
838
+ P10 = e1
839
+ P01 = e2
840
+ P11 = e1 + e2
841
+ poly = np.vstack([P00, P10, P11, P01, P00])
842
+ ax.plot(poly[:, 0], poly[:, 1], lw=1.2, color="0.35")
843
+
844
+ # --- scatter each neuron
845
+ for nid in neuron_ids:
846
+ a = activity[:, nid]
847
+ if mode == "spike":
848
+ mask = a > 0
849
+ else:
850
+ thr = np.percentile(a, 100 - top_percent)
851
+ mask = a >= thr
852
+
853
+ xy = skew_transform_torus(coords[mask])
854
+
855
+ if n_tiles and n_tiles > 0:
856
+ xy = tile_parallelogram_points(xy, n_tiles=n_tiles)
857
+
858
+ ax.scatter(xy[:, 0], xy[:, 1], s=s, alpha=alpha)
859
+
860
+ # Always draw the base domain boundary
861
+ _draw_single_domain(ax)
862
+
863
+ if show_grid:
864
+ draw_torus_parallelogram_grid(ax, n_tiles=n_tiles)
865
+
866
+ # Fix view limits: tiles=0 shows base domain; tiles>0 shows the tiled extent
867
+ base = np.vstack([[0, 0], e1, e2, e1 + e2])
868
+
869
+ if n_tiles and n_tiles > 0:
870
+ # Expand view by n_tiles rings around the base domain
871
+ # Translation vectors for tiling are i*e1 + j*e2
872
+ shifts = []
873
+ for i in range(-n_tiles, n_tiles + 1):
874
+ for j in range(-n_tiles, n_tiles + 1):
875
+ shifts.append(i * e1 + j * e2)
876
+ shifts = np.asarray(shifts) # ((2n+1)^2, 2)
877
+
878
+ all_corners = (base[None, :, :] + shifts[:, None, :]).reshape(-1, 2)
879
+ xmin, ymin = all_corners.min(axis=0)
880
+ xmax, ymax = all_corners.max(axis=0)
881
+ else:
882
+ xmin, ymin = base.min(axis=0)
883
+ xmax, ymax = base.max(axis=0)
884
+
885
+ padx = 0.03 * (xmax - xmin)
886
+ pady = 0.03 * (ymax - ymin)
887
+ ax.set_xlim(xmin - padx, xmax + padx)
888
+ ax.set_ylim(ymin - pady, ymax + pady)
889
+
890
+ ax.set_aspect("equal")
891
+ ax.set_xlabel(config.xlabel)
892
+ ax.set_ylabel(config.ylabel)
893
+ ax.set_title(config.title)
894
+
895
+ if created_fig:
896
+ _ensure_parent_dir(config.save_path)
897
+ finalize_figure(fig, config)
898
+ else:
899
+ if config.save_path is not None:
900
+ _ensure_parent_dir(config.save_path)
901
+ fig.savefig(config.save_path, **config.to_savefig_kwargs())
902
+ if config.show:
903
+ plt.show()
904
+
905
+ return ax