canns 0.14.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. canns/analyzer/data/asa/__init__.py +77 -21
  2. canns/analyzer/data/asa/coho.py +97 -0
  3. canns/analyzer/data/asa/cohomap.py +408 -0
  4. canns/analyzer/data/asa/cohomap_scatter.py +10 -0
  5. canns/analyzer/data/asa/cohomap_vectors.py +311 -0
  6. canns/analyzer/data/asa/cohospace.py +173 -1153
  7. canns/analyzer/data/asa/cohospace_phase_centers.py +137 -0
  8. canns/analyzer/data/asa/cohospace_scatter.py +1220 -0
  9. canns/analyzer/data/asa/embedding.py +3 -4
  10. canns/analyzer/data/asa/plotting.py +4 -4
  11. canns/analyzer/data/cell_classification/__init__.py +10 -0
  12. canns/analyzer/data/cell_classification/core/__init__.py +4 -0
  13. canns/analyzer/data/cell_classification/core/btn.py +272 -0
  14. canns/analyzer/data/cell_classification/visualization/__init__.py +3 -0
  15. canns/analyzer/data/cell_classification/visualization/btn_plots.py +258 -0
  16. canns/analyzer/visualization/__init__.py +2 -0
  17. canns/analyzer/visualization/core/config.py +20 -0
  18. canns/analyzer/visualization/theta_sweep_plots.py +142 -0
  19. canns/pipeline/asa/runner.py +19 -19
  20. canns/pipeline/asa_gui/__init__.py +5 -3
  21. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +32 -4
  22. canns/pipeline/asa_gui/core/runner.py +23 -23
  23. canns/pipeline/asa_gui/views/pages/preprocess_page.py +250 -8
  24. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/METADATA +2 -1
  25. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/RECORD +28 -20
  26. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/WHEEL +0 -0
  27. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/entry_points.txt +0 -0
  28. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1220 @@
1
+ """Scatter-style CohoSpace plots and cohoscore utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from scipy.stats import circvar
10
+
11
+ from ...visualization.core import PlotConfig, finalize_figure
12
+
13
+
14
+ def _ensure_plot_config(
15
+ config: PlotConfig | None,
16
+ factory,
17
+ *,
18
+ kwargs: dict | None = None,
19
+ **defaults,
20
+ ) -> PlotConfig:
21
+ if config is None:
22
+ defaults.update({"kwargs": kwargs or {}})
23
+ return factory(**defaults)
24
+
25
+ if kwargs:
26
+ config_kwargs = config.kwargs or {}
27
+ config_kwargs.update(kwargs)
28
+ config.kwargs = config_kwargs
29
+ return config
30
+
31
+
32
+ def _ensure_parent_dir(save_path: str | None) -> None:
33
+ if save_path:
34
+ parent = os.path.dirname(save_path)
35
+ if parent:
36
+ os.makedirs(parent, exist_ok=True)
37
+
38
+
39
+ # =====================================================================
40
+ # CohoSpace visualization and selectivity metrics (CohoScore)
41
+ # =====================================================================
42
+
43
+
44
+ def _coho_coords_to_degrees(coords: np.ndarray) -> np.ndarray:
45
+ """
46
+ Convert decoded coho coordinates (T x 2, radians) into degrees in [0, 360).
47
+ """
48
+ return np.degrees(coords % (2 * np.pi))
49
+
50
+
51
+ def _align_activity_to_coords(
52
+ coords: np.ndarray,
53
+ activity: np.ndarray,
54
+ times: np.ndarray | None = None,
55
+ *,
56
+ label: str = "activity",
57
+ auto_filter: bool = True,
58
+ ) -> np.ndarray:
59
+ """
60
+ Align activity to coords by optional time indices and validate lengths.
61
+ """
62
+ coords = np.asarray(coords)
63
+ activity = np.asarray(activity)
64
+
65
+ if times is not None:
66
+ times = np.asarray(times)
67
+ try:
68
+ activity = activity[times]
69
+ except Exception as exc:
70
+ raise ValueError(
71
+ f"Failed to index {label} with `times`. Ensure `times` indexes the original time axis."
72
+ ) from exc
73
+
74
+ if activity.shape[0] != coords.shape[0]:
75
+ # Try to reproduce decode's zero-spike filtering if lengths mismatch.
76
+ if auto_filter and times is None and activity.ndim == 2:
77
+ mask = np.sum(activity > 0, axis=1) >= 1
78
+ if mask.sum() == coords.shape[0]:
79
+ activity = activity[mask]
80
+ else:
81
+ raise ValueError(
82
+ f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
83
+ "If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
84
+ "`times=decoding['times']` or slice the activity accordingly."
85
+ )
86
+ else:
87
+ raise ValueError(
88
+ f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
89
+ "If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
90
+ "`times=decoding['times']` or slice the activity accordingly."
91
+ )
92
+
93
+ return activity
94
+
95
+
96
+ def plot_cohospace_scatter_trajectory_2d(
97
+ coords: np.ndarray,
98
+ times: np.ndarray | None = None,
99
+ subsample: int = 1,
100
+ figsize: tuple[int, int] = (6, 6),
101
+ cmap: str = "viridis",
102
+ save_path: str | None = None,
103
+ show: bool = False,
104
+ config: PlotConfig | None = None,
105
+ ) -> plt.Axes:
106
+ """
107
+ Plot a trajectory in cohomology space.
108
+
109
+ Parameters
110
+ ----------
111
+ coords : ndarray, shape (T, 2)
112
+ Decoded cohomology angles (theta1, theta2). Values may be in radians or in [0, 1] "unit circle"
113
+ convention depending on upstream decoding; this function will convert to degrees for plotting.
114
+ times : ndarray, optional, shape (T,)
115
+ Optional time array used to color points. If None, uses arange(T).
116
+ subsample : int
117
+ Downsampling step (>1 reduces the number of plotted points).
118
+ figsize : tuple
119
+ Matplotlib figure size.
120
+ cmap : str
121
+ Matplotlib colormap name.
122
+ save_path : str, optional
123
+ If provided, saves the figure to this path.
124
+ show : bool
125
+ If True, calls plt.show(). If False, closes the figure and returns the Axes.
126
+
127
+ Returns
128
+ -------
129
+ ax : matplotlib.axes.Axes
130
+ The Axes containing the plot.
131
+
132
+ Examples
133
+ --------
134
+ >>> fig = plot_cohospace_scatter_trajectory_2d(coords, subsample=2, show=False) # doctest: +SKIP
135
+ """
136
+
137
+ try:
138
+ subsample_i = int(subsample)
139
+ except Exception:
140
+ subsample_i = 1
141
+ if subsample_i < 1:
142
+ subsample_i = 1
143
+
144
+ coords = np.asarray(coords)
145
+ if coords.ndim != 2 or coords.shape[1] != 2:
146
+ raise ValueError(f"`coords` must have shape (T, 2). Got {coords.shape}.")
147
+
148
+ theta_deg = _coho_coords_to_degrees(coords)
149
+ if subsample_i > 1:
150
+ theta_deg = theta_deg[::subsample_i]
151
+
152
+ if times is None:
153
+ times_vis = np.arange(theta_deg.shape[0])
154
+ else:
155
+ times_vis = np.asarray(times)
156
+ if times_vis.shape[0] != coords.shape[0]:
157
+ raise ValueError(
158
+ f"`times` length must match coords length. Got times={times_vis.shape[0]}, coords={coords.shape[0]}."
159
+ )
160
+ if subsample_i > 1:
161
+ times_vis = times_vis[::subsample_i]
162
+
163
+ config = _ensure_plot_config(
164
+ config,
165
+ PlotConfig.for_static_plot,
166
+ title="CohoSpace trajectory",
167
+ xlabel="theta1 (deg)",
168
+ ylabel="theta2 (deg)",
169
+ figsize=figsize,
170
+ save_path=save_path,
171
+ show=show,
172
+ )
173
+
174
+ fig, ax = plt.subplots(figsize=config.figsize)
175
+ sc = ax.scatter(
176
+ theta_deg[:, 0],
177
+ theta_deg[:, 1],
178
+ c=times_vis,
179
+ cmap=cmap,
180
+ s=3,
181
+ alpha=0.8,
182
+ )
183
+ cbar = plt.colorbar(sc, ax=ax)
184
+ cbar.set_label("Time")
185
+
186
+ ax.set_xlim(0, 360)
187
+ ax.set_ylim(0, 360)
188
+ ax.set_xlabel(config.xlabel)
189
+ ax.set_ylabel(config.ylabel)
190
+ ax.set_title(config.title)
191
+ ax.set_aspect("equal", adjustable="box")
192
+ ax.grid(True, alpha=0.2)
193
+
194
+ _ensure_parent_dir(config.save_path)
195
+ finalize_figure(fig, config)
196
+ return ax
197
+
198
+
199
+ def plot_cohospace_scatter_trajectory_1d(
200
+ coords: np.ndarray,
201
+ times: np.ndarray | None = None,
202
+ subsample: int = 1,
203
+ figsize: tuple[int, int] = (6, 6),
204
+ cmap: str = "viridis",
205
+ save_path: str | None = None,
206
+ show: bool = False,
207
+ config: PlotConfig | None = None,
208
+ ) -> plt.Axes:
209
+ """
210
+ Plot a 1D cohomology trajectory on the unit circle.
211
+
212
+ Parameters
213
+ ----------
214
+ coords : ndarray, shape (T,) or (T, 1)
215
+ Decoded cohomology angles (theta). Values may be in radians or in [0, 1] "unit circle"
216
+ convention depending on upstream decoding; this function will plot on the unit circle.
217
+ times : ndarray, optional, shape (T,)
218
+ Optional time array used to color points. If None, uses arange(T).
219
+ subsample : int
220
+ Downsampling step (>1 reduces the number of plotted points).
221
+ figsize : tuple
222
+ Matplotlib figure size.
223
+ cmap : str
224
+ Matplotlib colormap name.
225
+ save_path : str, optional
226
+ If provided, saves the figure to this path.
227
+ show : bool
228
+ If True, calls plt.show(). If False, closes the figure and returns the Axes.
229
+ """
230
+ try:
231
+ subsample_i = int(subsample)
232
+ except Exception:
233
+ subsample_i = 1
234
+ if subsample_i < 1:
235
+ subsample_i = 1
236
+
237
+ coords = np.asarray(coords)
238
+ if coords.ndim == 2 and coords.shape[1] == 1:
239
+ coords = coords[:, 0]
240
+ if coords.ndim != 1:
241
+ raise ValueError(f"`coords` must have shape (T,) or (T, 1). Got {coords.shape}.")
242
+
243
+ if times is None:
244
+ times_vis = np.arange(coords.shape[0])
245
+ else:
246
+ times_vis = np.asarray(times)
247
+ if times_vis.shape[0] != coords.shape[0]:
248
+ raise ValueError(
249
+ f"`times` length must match coords length. Got times={times_vis.shape[0]}, coords={coords.shape[0]}."
250
+ )
251
+
252
+ if subsample_i > 1:
253
+ coords = coords[::subsample_i]
254
+ times_vis = times_vis[::subsample_i]
255
+
256
+ theta = coords % (2 * np.pi)
257
+ x = np.cos(theta)
258
+ y = np.sin(theta)
259
+
260
+ config = _ensure_plot_config(
261
+ config,
262
+ PlotConfig.for_static_plot,
263
+ title="CohoSpace trajectory (1D)",
264
+ xlabel="cos(theta)",
265
+ ylabel="sin(theta)",
266
+ figsize=figsize,
267
+ save_path=save_path,
268
+ show=show,
269
+ )
270
+
271
+ fig, ax = plt.subplots(figsize=config.figsize)
272
+ circle = np.linspace(0, 2 * np.pi, 200)
273
+ ax.plot(np.cos(circle), np.sin(circle), color="0.85", lw=1.0, zorder=0)
274
+ sc = ax.scatter(
275
+ x,
276
+ y,
277
+ c=times_vis,
278
+ cmap=cmap,
279
+ s=5,
280
+ alpha=0.8,
281
+ )
282
+ cbar = plt.colorbar(sc, ax=ax)
283
+ cbar.set_label("Time")
284
+
285
+ ax.set_xlim(-1.2, 1.2)
286
+ ax.set_ylim(-1.2, 1.2)
287
+ ax.set_xlabel(config.xlabel)
288
+ ax.set_ylabel(config.ylabel)
289
+ ax.set_title(config.title)
290
+ ax.set_aspect("equal", adjustable="box")
291
+ ax.grid(True, alpha=0.2)
292
+
293
+ _ensure_parent_dir(config.save_path)
294
+ finalize_figure(fig, config)
295
+ return ax
296
+
297
+
298
+ def plot_cohospace_scatter_neuron_2d(
299
+ coords: np.ndarray,
300
+ activity: np.ndarray,
301
+ neuron_id: int,
302
+ mode: str = "fr", # "fr" or "spike"
303
+ top_percent: float = 5.0, # Used in FR mode
304
+ times: np.ndarray | None = None,
305
+ auto_filter: bool = True,
306
+ figsize: tuple = (6, 6),
307
+ cmap: str = "hot",
308
+ save_path: str | None = None,
309
+ show: bool = True,
310
+ config: PlotConfig | None = None,
311
+ ) -> plt.Figure:
312
+ """
313
+ Overlay a single neuron's activity on the cohomology-space trajectory.
314
+
315
+ This is a visualization helper. In "fr" mode it marks the top top_percent% time points
316
+ by firing rate for the neuron. In "spike" mode it marks all time points where spike > 0.
317
+
318
+ Parameters
319
+ ----------
320
+ coords : ndarray, shape (T, 2)
321
+ Decoded cohomology angles (theta1, theta2), in radians.
322
+ activity : ndarray, shape (T, N)
323
+ Activity matrix (continuous firing rate or binned spikes).
324
+ times : ndarray, optional, shape (T_coords,)
325
+ Optional indices to align activity to coords when coords are computed on a subset of timepoints.
326
+ auto_filter : bool
327
+ If True and lengths mismatch, auto-filter activity with activity>0 to mimic decode filtering.
328
+ neuron_id : int
329
+ Neuron index to visualize.
330
+ mode : {"fr", "spike"}
331
+ top_percent : float
332
+ Used only when mode="fr". For example, 5.0 means "top 5%" time points.
333
+ figsize, cmap, save_path, show : see `plot_cohospace_scatter_trajectory_2d`.
334
+
335
+ Returns
336
+ -------
337
+ ax : matplotlib.axes.Axes
338
+
339
+ Examples
340
+ --------
341
+ >>> plot_cohospace_scatter_neuron_2d(coords, spikes, neuron_id=0, show=False) # doctest: +SKIP
342
+ """
343
+ coords = np.asarray(coords)
344
+ activity = _align_activity_to_coords(
345
+ coords, activity, times, label="activity", auto_filter=auto_filter
346
+ )
347
+ theta_deg = _coho_coords_to_degrees(coords)
348
+
349
+ signal = activity[:, neuron_id]
350
+
351
+ if mode == "fr":
352
+ # Select the neuron's top `top_percent`% time points
353
+ threshold = np.percentile(signal, 100 - top_percent)
354
+ idx = signal >= threshold
355
+ color = signal[idx]
356
+ title = f"Neuron {neuron_id} FR top {top_percent:.1f}% on coho-space"
357
+ use_cmap = cmap
358
+ elif mode == "spike":
359
+ idx = signal > 0
360
+ color = None
361
+ title = f"Neuron {neuron_id} spikes on coho-space"
362
+ use_cmap = None
363
+ else:
364
+ raise ValueError("mode must be 'fr' or 'spike'")
365
+
366
+ config = _ensure_plot_config(
367
+ config,
368
+ PlotConfig.for_static_plot,
369
+ title=title,
370
+ xlabel="Theta 1 (°)",
371
+ ylabel="Theta 2 (°)",
372
+ figsize=figsize,
373
+ save_path=save_path,
374
+ show=show,
375
+ )
376
+
377
+ fig, ax = plt.subplots(figsize=config.figsize)
378
+ sc = ax.scatter(
379
+ theta_deg[idx, 0],
380
+ theta_deg[idx, 1],
381
+ c=color if mode == "fr" else "red",
382
+ cmap=use_cmap,
383
+ s=5,
384
+ alpha=0.9,
385
+ )
386
+
387
+ if mode == "fr":
388
+ cbar = plt.colorbar(sc, ax=ax)
389
+ cbar.set_label("Firing rate")
390
+
391
+ ax.set_xlim(0, 360)
392
+ ax.set_ylim(0, 360)
393
+ ax.set_xlabel(config.xlabel)
394
+ ax.set_ylabel(config.ylabel)
395
+ ax.set_title(config.title)
396
+
397
+ _ensure_parent_dir(config.save_path)
398
+ finalize_figure(fig, config)
399
+
400
+ return fig
401
+
402
+
403
+ def plot_cohospace_scatter_neuron_1d(
404
+ coords: np.ndarray,
405
+ activity: np.ndarray,
406
+ neuron_id: int,
407
+ mode: str = "fr",
408
+ top_percent: float = 5.0,
409
+ times: np.ndarray | None = None,
410
+ auto_filter: bool = True,
411
+ figsize: tuple = (6, 6),
412
+ cmap: str = "hot",
413
+ save_path: str | None = None,
414
+ show: bool = True,
415
+ config: PlotConfig | None = None,
416
+ ) -> plt.Figure:
417
+ """
418
+ Overlay a single neuron's activity on the 1D cohomology trajectory (unit circle).
419
+ """
420
+ coords = np.asarray(coords)
421
+ if coords.ndim == 2 and coords.shape[1] == 1:
422
+ coords = coords[:, 0]
423
+ if coords.ndim != 1:
424
+ raise ValueError(f"coords must have shape (T,) or (T, 1), got {coords.shape}")
425
+
426
+ activity = _align_activity_to_coords(
427
+ coords[:, None], activity, times, label="activity", auto_filter=auto_filter
428
+ )
429
+
430
+ signal = activity[:, neuron_id]
431
+
432
+ if mode == "fr":
433
+ threshold = np.percentile(signal, 100 - top_percent)
434
+ idx = signal >= threshold
435
+ color = signal[idx]
436
+ title = f"Neuron {neuron_id} FR top {top_percent:.1f}% on coho-space (1D)"
437
+ use_cmap = cmap
438
+ elif mode == "spike":
439
+ idx = signal > 0
440
+ color = None
441
+ title = f"Neuron {neuron_id} spikes on coho-space (1D)"
442
+ use_cmap = None
443
+ else:
444
+ raise ValueError("mode must be 'fr' or 'spike'")
445
+
446
+ theta = coords % (2 * np.pi)
447
+ x = np.cos(theta)
448
+ y = np.sin(theta)
449
+
450
+ config = _ensure_plot_config(
451
+ config,
452
+ PlotConfig.for_static_plot,
453
+ title=title,
454
+ xlabel="cos(theta)",
455
+ ylabel="sin(theta)",
456
+ figsize=figsize,
457
+ save_path=save_path,
458
+ show=show,
459
+ )
460
+
461
+ fig, ax = plt.subplots(figsize=config.figsize)
462
+ circle = np.linspace(0, 2 * np.pi, 200)
463
+ ax.plot(np.cos(circle), np.sin(circle), color="0.85", lw=1.0, zorder=0)
464
+ sc = ax.scatter(
465
+ x[idx],
466
+ y[idx],
467
+ c=color if mode == "fr" else "red",
468
+ cmap=use_cmap,
469
+ s=8,
470
+ alpha=0.9,
471
+ )
472
+
473
+ if mode == "fr":
474
+ cbar = plt.colorbar(sc, ax=ax)
475
+ cbar.set_label("Firing rate")
476
+
477
+ ax.set_xlim(-1.2, 1.2)
478
+ ax.set_ylim(-1.2, 1.2)
479
+ ax.set_xlabel(config.xlabel)
480
+ ax.set_ylabel(config.ylabel)
481
+ ax.set_title(config.title)
482
+ ax.set_aspect("equal", adjustable="box")
483
+
484
+ _ensure_parent_dir(config.save_path)
485
+ finalize_figure(fig, config)
486
+
487
+ return fig
488
+
489
+
490
+ def plot_cohospace_scatter_population_2d(
491
+ coords: np.ndarray,
492
+ activity: np.ndarray,
493
+ neuron_ids: list[int] | np.ndarray,
494
+ mode: str = "fr", # "fr" or "spike"
495
+ top_percent: float = 5.0, # Used in FR mode
496
+ times: np.ndarray | None = None,
497
+ auto_filter: bool = True,
498
+ figsize: tuple = (6, 6),
499
+ cmap: str = "hot",
500
+ save_path: str | None = None,
501
+ show: bool = True,
502
+ config: PlotConfig | None = None,
503
+ ) -> plt.Figure:
504
+ """
505
+ Plot aggregated activity from multiple neurons in cohomology space.
506
+
507
+ In "fr" mode, select each neuron's top top_percent% time points by firing rate and
508
+ aggregate (sum) firing rates over the selected points for coloring. In "spike" mode,
509
+ count spikes at each time point (spike > 0) and aggregate counts over neurons.
510
+
511
+ Parameters
512
+ ----------
513
+ coords : ndarray, shape (T, 2)
514
+ activity : ndarray, shape (T, N)
515
+ times : ndarray, optional, shape (T_coords,)
516
+ Optional indices to align activity to coords when coords are computed on a subset of timepoints.
517
+ auto_filter : bool
518
+ If True and lengths mismatch, auto-filter activity with activity>0 to mimic decode filtering.
519
+ neuron_ids : iterable[int]
520
+ Neuron indices to include (use range(N) to include all).
521
+ mode : {"fr", "spike"}
522
+ top_percent : float
523
+ Used only when mode="fr".
524
+ figsize, cmap, save_path, show : see `plot_cohospace_scatter_trajectory_2d`.
525
+
526
+ Returns
527
+ -------
528
+ ax : matplotlib.axes.Axes
529
+
530
+ Examples
531
+ --------
532
+ >>> plot_cohospace_scatter_population_2d(coords, spikes, neuron_ids=[0, 1, 2], show=False) # doctest: +SKIP
533
+ """
534
+ coords = np.asarray(coords)
535
+ activity = _align_activity_to_coords(
536
+ coords, activity, times, label="activity", auto_filter=auto_filter
537
+ )
538
+ neuron_ids = np.asarray(neuron_ids, dtype=int)
539
+
540
+ theta_deg = _coho_coords_to_degrees(coords)
541
+
542
+ T = activity.shape[0]
543
+ mask = np.zeros(T, dtype=bool)
544
+ agg_color = np.zeros(T, dtype=float)
545
+
546
+ for n in neuron_ids:
547
+ signal = activity[:, n]
548
+
549
+ if mode == "fr":
550
+ threshold = np.percentile(signal, 100 - top_percent)
551
+ idx = signal >= threshold
552
+ agg_color[idx] += signal[idx]
553
+ mask |= idx
554
+ elif mode == "spike":
555
+ idx = signal > 0
556
+ agg_color[idx] += 1.0
557
+ mask |= idx
558
+ else:
559
+ raise ValueError("mode must be 'fr' or 'spike'")
560
+
561
+ config = _ensure_plot_config(
562
+ config,
563
+ PlotConfig.for_static_plot,
564
+ title=f"{len(neuron_ids)} neurons on coho-space",
565
+ xlabel="Theta 1 (°)",
566
+ ylabel="Theta 2 (°)",
567
+ figsize=figsize,
568
+ save_path=save_path,
569
+ show=show,
570
+ )
571
+
572
+ fig, ax = plt.subplots(figsize=config.figsize)
573
+ sc = ax.scatter(
574
+ theta_deg[mask, 0],
575
+ theta_deg[mask, 1],
576
+ c=agg_color[mask],
577
+ cmap=cmap,
578
+ s=5,
579
+ alpha=0.9,
580
+ )
581
+ cbar = plt.colorbar(sc, ax=ax)
582
+ label = "Aggregate FR" if mode == "fr" else "Spike count"
583
+ cbar.set_label(label)
584
+
585
+ ax.set_xlim(0, 360)
586
+ ax.set_ylim(0, 360)
587
+ ax.set_xlabel(config.xlabel)
588
+ ax.set_ylabel(config.ylabel)
589
+ ax.set_title(config.title)
590
+
591
+ _ensure_parent_dir(config.save_path)
592
+ finalize_figure(fig, config)
593
+
594
+ return fig
595
+
596
+
597
+ def plot_cohospace_scatter_population_1d(
598
+ coords: np.ndarray,
599
+ activity: np.ndarray,
600
+ neuron_ids: list[int] | np.ndarray,
601
+ mode: str = "fr",
602
+ top_percent: float = 5.0,
603
+ times: np.ndarray | None = None,
604
+ auto_filter: bool = True,
605
+ figsize: tuple = (6, 6),
606
+ cmap: str = "hot",
607
+ save_path: str | None = None,
608
+ show: bool = True,
609
+ config: PlotConfig | None = None,
610
+ ) -> plt.Figure:
611
+ """
612
+ Plot aggregated activity from multiple neurons on the 1D cohomology trajectory.
613
+ """
614
+ coords = np.asarray(coords)
615
+ if coords.ndim == 2 and coords.shape[1] == 1:
616
+ coords = coords[:, 0]
617
+ if coords.ndim != 1:
618
+ raise ValueError(f"coords must have shape (T,) or (T, 1), got {coords.shape}")
619
+
620
+ activity = _align_activity_to_coords(
621
+ coords[:, None], activity, times, label="activity", auto_filter=auto_filter
622
+ )
623
+ neuron_ids = np.asarray(neuron_ids, dtype=int)
624
+
625
+ T = activity.shape[0]
626
+ mask = np.zeros(T, dtype=bool)
627
+ agg_color = np.zeros(T, dtype=float)
628
+
629
+ for n in neuron_ids:
630
+ signal = activity[:, n]
631
+
632
+ if mode == "fr":
633
+ threshold = np.percentile(signal, 100 - top_percent)
634
+ idx = signal >= threshold
635
+ agg_color[idx] += signal[idx]
636
+ mask |= idx
637
+ elif mode == "spike":
638
+ idx = signal > 0
639
+ agg_color[idx] += 1.0
640
+ mask |= idx
641
+ else:
642
+ raise ValueError("mode must be 'fr' or 'spike'")
643
+
644
+ theta = coords % (2 * np.pi)
645
+ x = np.cos(theta)
646
+ y = np.sin(theta)
647
+
648
+ config = _ensure_plot_config(
649
+ config,
650
+ PlotConfig.for_static_plot,
651
+ title=f"{len(neuron_ids)} neurons on coho-space (1D)",
652
+ xlabel="cos(theta)",
653
+ ylabel="sin(theta)",
654
+ figsize=figsize,
655
+ save_path=save_path,
656
+ show=show,
657
+ )
658
+
659
+ fig, ax = plt.subplots(figsize=config.figsize)
660
+ circle = np.linspace(0, 2 * np.pi, 200)
661
+ ax.plot(np.cos(circle), np.sin(circle), color="0.85", lw=1.0, zorder=0)
662
+ sc = ax.scatter(
663
+ x[mask],
664
+ y[mask],
665
+ c=agg_color[mask],
666
+ cmap=cmap,
667
+ s=6,
668
+ alpha=0.9,
669
+ )
670
+ cbar = plt.colorbar(sc, ax=ax)
671
+ label = "Aggregate FR" if mode == "fr" else "Spike count"
672
+ cbar.set_label(label)
673
+
674
+ ax.set_xlim(-1.2, 1.2)
675
+ ax.set_ylim(-1.2, 1.2)
676
+ ax.set_xlabel(config.xlabel)
677
+ ax.set_ylabel(config.ylabel)
678
+ ax.set_title(config.title)
679
+ ax.set_aspect("equal", adjustable="box")
680
+
681
+ _ensure_parent_dir(config.save_path)
682
+ finalize_figure(fig, config)
683
+
684
+ return fig
685
+
686
+
687
+ def compute_cohoscore_scatter_2d(
688
+ coords: np.ndarray,
689
+ activity: np.ndarray,
690
+ top_percent: float = 2.0,
691
+ times: np.ndarray | None = None,
692
+ auto_filter: bool = True,
693
+ ) -> np.ndarray:
694
+ """
695
+ Compute a simple cohomology-space selectivity score (CohoScore) for each neuron.
696
+
697
+ For each neuron, select active time points (top_percent or activity > 0), compute
698
+ circular variance for theta1 and theta2 on the selected points, and average them.
699
+
700
+ Interpretation: smaller score means points are more concentrated in coho space
701
+ and the neuron is more selective.
702
+
703
+ Parameters
704
+ ----------
705
+ coords : ndarray, shape (T, 2)
706
+ Decoded cohomology angles (theta1, theta2), in radians.
707
+ activity : ndarray, shape (T, N)
708
+ times : ndarray, optional, shape (T_coords,)
709
+ Optional indices to align activity to coords when coords are computed on a subset of timepoints.
710
+ auto_filter : bool
711
+ If True and lengths mismatch, auto-filter activity with activity>0 to mimic decode filtering.
712
+ Activity matrix (FR or spikes).
713
+ top_percent : float | None
714
+ Percentage for selecting active points (e.g., 2.0 means top 2%). If None, use activity>0.
715
+
716
+ Returns
717
+ -------
718
+ scores : ndarray, shape (N,)
719
+ CohoScore per neuron (NaN for neurons with too few points).
720
+
721
+ Examples
722
+ --------
723
+ >>> scores = compute_cohoscore_scatter_2d(coords, spikes) # doctest: +SKIP
724
+ >>> scores.shape[0] # doctest: +SKIP
725
+ """
726
+ coords = np.asarray(coords)
727
+ activity = _align_activity_to_coords(
728
+ coords, activity, times, label="activity", auto_filter=auto_filter
729
+ )
730
+ T, N = activity.shape
731
+
732
+ theta = coords % (2 * np.pi) # Ensure values are in [0, 2π)
733
+ scores = np.zeros(N, dtype=float)
734
+
735
+ for n in range(N):
736
+ signal = activity[:, n]
737
+
738
+ if top_percent is None:
739
+ idx = signal > 0 # Use all time points with spikes
740
+ else:
741
+ threshold = np.percentile(signal, 100 - top_percent)
742
+ idx = signal >= threshold
743
+
744
+ if np.sum(idx) < 5:
745
+ scores[n] = np.nan # Too sparse; unreliable
746
+ continue
747
+
748
+ theta1 = theta[idx, 0]
749
+ theta2 = theta[idx, 1]
750
+
751
+ var1 = circvar(theta1, high=2 * np.pi, low=0)
752
+ var2 = circvar(theta2, high=2 * np.pi, low=0)
753
+
754
+ scores[n] = 0.5 * (var1 + var2)
755
+
756
+ return scores
757
+
758
+
759
+ def compute_cohoscore_scatter_1d(
760
+ coords: np.ndarray,
761
+ activity: np.ndarray,
762
+ top_percent: float = 2.0,
763
+ times: np.ndarray | None = None,
764
+ auto_filter: bool = True,
765
+ ) -> np.ndarray:
766
+ """
767
+ Compute 1D cohomology-space selectivity score (CohoScore) for each neuron.
768
+
769
+ For each neuron, select active time points (top_percent or activity > 0), compute
770
+ circular variance for theta on the selected points, and use it as the score.
771
+ """
772
+ coords = np.asarray(coords)
773
+ if coords.ndim == 2 and coords.shape[1] == 1:
774
+ coords = coords[:, 0]
775
+ if coords.ndim != 1:
776
+ raise ValueError(f"coords must have shape (T,) or (T, 1), got {coords.shape}")
777
+
778
+ activity = _align_activity_to_coords(
779
+ coords[:, None], activity, times, label="activity", auto_filter=auto_filter
780
+ )
781
+ _, n_neurons = activity.shape
782
+
783
+ theta = coords % (2 * np.pi)
784
+ scores = np.zeros(n_neurons, dtype=float)
785
+
786
+ for n in range(n_neurons):
787
+ signal = activity[:, n]
788
+
789
+ if top_percent is None:
790
+ idx = signal > 0
791
+ else:
792
+ threshold = np.percentile(signal, 100 - top_percent)
793
+ idx = signal >= threshold
794
+
795
+ if np.sum(idx) < 5:
796
+ scores[n] = np.nan
797
+ continue
798
+
799
+ var1 = circvar(theta[idx], high=2 * np.pi, low=0)
800
+ scores[n] = var1
801
+
802
+ return scores
803
+
804
+
805
+ def skew_transform_torus_scatter(coords):
806
+ """
807
+ Convert torus angles (theta1, theta2) into coordinates in a skewed parallelogram fundamental domain.
808
+
809
+ Given theta1, theta2 in radians, map:
810
+ x = theta1 + 0.5 * theta2
811
+ y = (sqrt(3)/2) * theta2
812
+
813
+ This is a linear change of basis that turns the square [0, 2π)×[0, 2π) into a 60-degree
814
+ parallelogram, which is convenient for visualizing wrap-around behavior on a 2-torus.
815
+
816
+ Parameters
817
+ ----------
818
+ coords : ndarray, shape (T, 2)
819
+ Angles (theta1, theta2) in radians.
820
+
821
+ Returns
822
+ -------
823
+ xy : ndarray, shape (T, 2)
824
+ Skewed planar coordinates.
825
+ """
826
+ coords = np.asarray(coords)
827
+ if coords.ndim != 2 or coords.shape[1] != 2:
828
+ raise ValueError(f"coords must be (T,2), got {coords.shape}")
829
+
830
+ theta1 = coords[:, 0]
831
+ theta2 = coords[:, 1]
832
+
833
+ # Linear change of basis (NO nonlinear scaling)
834
+ x = theta1 + 0.5 * theta2
835
+ y = (np.sqrt(3) / 2.0) * theta2
836
+
837
+ return np.stack([x, y], axis=1)
838
+
839
+
840
+ def draw_torus_parallelogram_grid_scatter(ax, n_tiles=1, color="0.7", lw=1.0, alpha=0.8):
841
+ """
842
+ Draw parallelogram grid corresponding to torus fundamental domain.
843
+
844
+ Fundamental vectors:
845
+ e1 = (2π, 0)
846
+ e2 = (π, √3 π)
847
+
848
+ Parameters
849
+ ----------
850
+ ax : matplotlib axis
851
+ n_tiles : int
852
+ How many tiles to draw in +/- directions (visual aid).
853
+ n_tiles=1 means draw [-1, 0, 1] shifts.
854
+ """
855
+ e1 = np.array([2 * np.pi, 0.0])
856
+ e2 = np.array([np.pi, np.sqrt(3) * np.pi])
857
+
858
+ shifts = range(-n_tiles, n_tiles + 1)
859
+
860
+ for i in shifts:
861
+ for j in shifts:
862
+ origin = i * e1 + j * e2
863
+ corners = np.array([origin, origin + e1, origin + e1 + e2, origin + e2, origin])
864
+ ax.plot(corners[:, 0], corners[:, 1], color=color, lw=lw, alpha=alpha)
865
+
866
+
867
+ def tile_parallelogram_points_scatter(xy, n_tiles=1):
868
+ """
869
+ Tile points in the skewed (parallelogram) torus fundamental domain.
870
+
871
+ This is mainly for static visualizations so you can visually inspect continuity
872
+ across domain boundaries.
873
+
874
+ Parameters
875
+ ----------
876
+ points : ndarray, shape (T, 2)
877
+ Points in the skewed plane (same coordinates as returned by `skew_transform_torus_scatter`).
878
+ n_tiles : int
879
+ Number of tiles to extend around the base domain.
880
+ - n_tiles=1 produces a 3x3 tiling
881
+ - n_tiles=2 produces a 5x5 tiling
882
+
883
+ Returns
884
+ -------
885
+ tiled : ndarray
886
+ Tiled points.
887
+ """
888
+ xy = np.asarray(xy, dtype=float)
889
+
890
+ e1 = np.array([2 * np.pi, 0.0])
891
+ e2 = np.array([np.pi, np.sqrt(3) * np.pi])
892
+
893
+ out = []
894
+ for i in range(-n_tiles, n_tiles + 1):
895
+ for j in range(-n_tiles, n_tiles + 1):
896
+ out.append(xy + i * e1 + j * e2)
897
+
898
+ return np.vstack(out) if len(out) else xy
899
+
900
+
901
+ def plot_cohospace_scatter_neuron_skewed(
902
+ coords,
903
+ activity,
904
+ neuron_id,
905
+ mode="spike",
906
+ top_percent=2.0,
907
+ times: np.ndarray | None = None,
908
+ auto_filter: bool = True,
909
+ save_path=None,
910
+ show=None,
911
+ ax=None,
912
+ show_grid=True,
913
+ n_tiles=1,
914
+ s=6,
915
+ alpha=0.8,
916
+ config: PlotConfig | None = None,
917
+ ):
918
+ """
919
+ Plot single-neuron CohoSpace on skewed torus domain.
920
+
921
+ Parameters
922
+ ----------
923
+ coords : ndarray, shape (T, 2)
924
+ Decoded circular coordinates (theta1, theta2), in radians.
925
+ activity : ndarray, shape (T, N)
926
+ Activity matrix aligned with coords.
927
+ neuron_id : int
928
+ Neuron index.
929
+ mode : {"spike", "fr"}
930
+ spike: use activity > 0
931
+ fr: use top_percent threshold
932
+ top_percent : float
933
+ Percentile for FR thresholding.
934
+ auto_filter : bool
935
+ If True and lengths mismatch, auto-filter activity with activity>0 to mimic decode filtering.
936
+ """
937
+ coords = np.asarray(coords)
938
+ activity = _align_activity_to_coords(
939
+ coords, activity, times, label="activity", auto_filter=auto_filter
940
+ )
941
+
942
+ # --- normalize angles to [0, 2π)
943
+ coords = coords % (2 * np.pi)
944
+
945
+ # --- select neuron activity
946
+ a = activity[:, neuron_id]
947
+
948
+ if mode == "spike":
949
+ mask = a > 0
950
+ elif mode == "fr":
951
+ thr = np.percentile(a, 100 - top_percent)
952
+ mask = a >= thr
953
+ else:
954
+ raise ValueError(f"Unknown mode: {mode}")
955
+
956
+ val = a[mask] # Used for FR-mode coloring
957
+
958
+ if config is None:
959
+ config = PlotConfig.for_static_plot(
960
+ title=f"Neuron {neuron_id} – CohoSpace (skewed, mode={mode})",
961
+ xlabel=r"$\theta_1 + \frac{1}{2}\theta_2$",
962
+ ylabel=r"$\frac{\sqrt{3}}{2}\theta_2$",
963
+ figsize=(5, 5),
964
+ save_path=save_path,
965
+ show=bool(show) if show is not None else False,
966
+ )
967
+ else:
968
+ if save_path is not None:
969
+ config.save_path = save_path
970
+ if show is not None:
971
+ config.show = show
972
+ if not config.title:
973
+ config.title = f"Neuron {neuron_id} – CohoSpace (skewed, mode={mode})"
974
+ if not config.xlabel:
975
+ config.xlabel = r"$\theta_1 + \frac{1}{2}\theta_2$"
976
+ if not config.ylabel:
977
+ config.ylabel = r"$\frac{\sqrt{3}}{2}\theta_2$"
978
+
979
+ created_fig = ax is None
980
+ if created_fig:
981
+ fig, ax = plt.subplots(figsize=config.figsize)
982
+ else:
983
+ fig = ax.figure
984
+
985
+ # --- fundamental domain vectors in skew plane
986
+ e1 = np.array([2 * np.pi, 0.0])
987
+ e2 = np.array([np.pi, np.sqrt(3) * np.pi])
988
+
989
+ def _draw_single_domain(ax):
990
+ P00 = np.array([0.0, 0.0])
991
+ P10 = e1
992
+ P01 = e2
993
+ P11 = e1 + e2
994
+ poly = np.vstack([P00, P10, P11, P01, P00])
995
+ ax.plot(poly[:, 0], poly[:, 1], lw=1.2, color="0.35")
996
+
997
+ def _annotate_corners(ax):
998
+ P00 = np.array([0.0, 0.0])
999
+ P10 = e1
1000
+ P01 = e2
1001
+ P11 = e1 + e2
1002
+
1003
+ corners = np.vstack([P00, P10, P01, P11])
1004
+ xmin, ymin = corners.min(axis=0)
1005
+ xmax, ymax = corners.max(axis=0)
1006
+ padx = 0.02 * (xmax - xmin)
1007
+ pady = 0.02 * (ymax - ymin)
1008
+
1009
+ bbox = dict(facecolor="white", edgecolor="none", alpha=0.7, pad=1.0)
1010
+
1011
+ ax.text(
1012
+ P00[0] + padx, P00[1] + pady, "(0,0)", fontsize=10, ha="left", va="bottom", bbox=bbox
1013
+ )
1014
+ ax.text(
1015
+ P10[0] - padx, P10[1] + pady, "(2π,0)", fontsize=10, ha="right", va="bottom", bbox=bbox
1016
+ )
1017
+ ax.text(P01[0] + padx, P01[1] - pady, "(0,2π)", fontsize=10, ha="left", va="top", bbox=bbox)
1018
+ ax.text(
1019
+ P11[0] - padx, P11[1] - pady, "(2π,2π)", fontsize=10, ha="right", va="top", bbox=bbox
1020
+ )
1021
+
1022
+ # --- skew transform
1023
+ xy = skew_transform_torus_scatter(coords[mask])
1024
+
1025
+ # Tiling: if points are tiled, values must be tiled too (FR mode) to keep lengths consistent
1026
+ if n_tiles and n_tiles > 0:
1027
+ xy = tile_parallelogram_points_scatter(xy, n_tiles=n_tiles)
1028
+ if mode == "fr":
1029
+ val = np.tile(val, (2 * n_tiles + 1) ** 2)
1030
+
1031
+ # --- scatter
1032
+ if mode == "fr":
1033
+ sc = ax.scatter(xy[:, 0], xy[:, 1], c=val, s=s, alpha=alpha, cmap="viridis")
1034
+ fig.colorbar(sc, ax=ax, shrink=0.85, pad=0.02, label="activity")
1035
+ else:
1036
+ ax.scatter(xy[:, 0], xy[:, 1], s=s, alpha=alpha, color="tab:blue")
1037
+
1038
+ # Always draw the base domain boundary
1039
+ _draw_single_domain(ax)
1040
+
1041
+ # Grid is optional (debug aid); when tiles=0 only the base domain is drawn
1042
+ if show_grid:
1043
+ draw_torus_parallelogram_grid_scatter(ax, n_tiles=n_tiles)
1044
+
1045
+ _annotate_corners(ax)
1046
+
1047
+ # Fix view limits: tiles=0 shows base domain; tiles>0 shows the tiled extent
1048
+ base = np.vstack([[0, 0], e1, e2, e1 + e2])
1049
+
1050
+ if n_tiles and n_tiles > 0:
1051
+ # Expand view by n_tiles rings around the base domain
1052
+ # Translation vectors for tiling are i*e1 + j*e2
1053
+ shifts = []
1054
+ for i in range(-n_tiles, n_tiles + 1):
1055
+ for j in range(-n_tiles, n_tiles + 1):
1056
+ shifts.append(i * e1 + j * e2)
1057
+ shifts = np.asarray(shifts) # ((2n+1)^2, 2)
1058
+
1059
+ all_corners = (base[None, :, :] + shifts[:, None, :]).reshape(-1, 2)
1060
+ xmin, ymin = all_corners.min(axis=0)
1061
+ xmax, ymax = all_corners.max(axis=0)
1062
+ else:
1063
+ xmin, ymin = base.min(axis=0)
1064
+ xmax, ymax = base.max(axis=0)
1065
+
1066
+ padx = 0.03 * (xmax - xmin)
1067
+ pady = 0.03 * (ymax - ymin)
1068
+ ax.set_xlim(xmin - padx, xmax + padx)
1069
+ ax.set_ylim(ymin - pady, ymax + pady)
1070
+
1071
+ ax.set_aspect("equal")
1072
+ ax.set_xlabel(config.xlabel)
1073
+ ax.set_ylabel(config.ylabel)
1074
+ ax.set_title(config.title)
1075
+
1076
+ if created_fig:
1077
+ _ensure_parent_dir(config.save_path)
1078
+ finalize_figure(fig, config)
1079
+ else:
1080
+ if config.save_path is not None:
1081
+ _ensure_parent_dir(config.save_path)
1082
+ fig.savefig(config.save_path, **config.to_savefig_kwargs())
1083
+ if config.show:
1084
+ plt.show()
1085
+
1086
+ return ax
1087
+
1088
+
1089
+ def plot_cohospace_scatter_population_skewed(
1090
+ coords,
1091
+ activity,
1092
+ neuron_ids,
1093
+ mode="spike",
1094
+ top_percent=2.0,
1095
+ times: np.ndarray | None = None,
1096
+ auto_filter: bool = True,
1097
+ save_path=None,
1098
+ show=False,
1099
+ ax=None,
1100
+ show_grid=True,
1101
+ n_tiles=1,
1102
+ s=4,
1103
+ alpha=0.5,
1104
+ config: PlotConfig | None = None,
1105
+ ):
1106
+ """
1107
+ Plot population CohoSpace on skewed torus domain.
1108
+
1109
+ neuron_ids : list or ndarray
1110
+ Neurons to include (e.g. top-K by CohoScore).
1111
+ auto_filter : bool
1112
+ If True and lengths mismatch, auto-filter activity with activity>0 to mimic decode filtering.
1113
+ """
1114
+ coords = np.asarray(coords)
1115
+ activity = _align_activity_to_coords(
1116
+ coords, activity, times, label="activity", auto_filter=auto_filter
1117
+ )
1118
+ coords = coords % (2 * np.pi)
1119
+
1120
+ if config is None:
1121
+ config = PlotConfig.for_static_plot(
1122
+ title=f"Population CohoSpace (skewed, n={len(neuron_ids)}, mode={mode})",
1123
+ xlabel=r"$\theta_1 + \frac{1}{2}\theta_2$",
1124
+ ylabel=r"$\frac{\sqrt{3}}{2}\theta_2$",
1125
+ figsize=(5, 5),
1126
+ save_path=save_path,
1127
+ show=show,
1128
+ )
1129
+ else:
1130
+ if save_path is not None:
1131
+ config.save_path = save_path
1132
+ if show is not None:
1133
+ config.show = show
1134
+ if not config.title:
1135
+ config.title = f"Population CohoSpace (skewed, n={len(neuron_ids)}, mode={mode})"
1136
+ if not config.xlabel:
1137
+ config.xlabel = r"$\theta_1 + \frac{1}{2}\theta_2$"
1138
+ if not config.ylabel:
1139
+ config.ylabel = r"$\frac{\sqrt{3}}{2}\theta_2$"
1140
+
1141
+ created_fig = ax is None
1142
+ if created_fig:
1143
+ fig, ax = plt.subplots(figsize=config.figsize)
1144
+ else:
1145
+ fig = ax.figure
1146
+
1147
+ # --- fundamental domain vectors in skew plane
1148
+ e1 = np.array([2 * np.pi, 0.0])
1149
+ e2 = np.array([np.pi, np.sqrt(3) * np.pi])
1150
+
1151
+ def _draw_single_domain(ax):
1152
+ P00 = np.array([0.0, 0.0])
1153
+ P10 = e1
1154
+ P01 = e2
1155
+ P11 = e1 + e2
1156
+ poly = np.vstack([P00, P10, P11, P01, P00])
1157
+ ax.plot(poly[:, 0], poly[:, 1], lw=1.2, color="0.35")
1158
+
1159
+ # --- scatter each neuron
1160
+ for nid in neuron_ids:
1161
+ a = activity[:, nid]
1162
+ if mode == "spike":
1163
+ mask = a > 0
1164
+ else:
1165
+ thr = np.percentile(a, 100 - top_percent)
1166
+ mask = a >= thr
1167
+
1168
+ xy = skew_transform_torus_scatter(coords[mask])
1169
+
1170
+ if n_tiles and n_tiles > 0:
1171
+ xy = tile_parallelogram_points_scatter(xy, n_tiles=n_tiles)
1172
+
1173
+ ax.scatter(xy[:, 0], xy[:, 1], s=s, alpha=alpha)
1174
+
1175
+ # Always draw the base domain boundary
1176
+ _draw_single_domain(ax)
1177
+
1178
+ if show_grid:
1179
+ draw_torus_parallelogram_grid_scatter(ax, n_tiles=n_tiles)
1180
+
1181
+ # Fix view limits: tiles=0 shows base domain; tiles>0 shows the tiled extent
1182
+ base = np.vstack([[0, 0], e1, e2, e1 + e2])
1183
+
1184
+ if n_tiles and n_tiles > 0:
1185
+ # Expand view by n_tiles rings around the base domain
1186
+ # Translation vectors for tiling are i*e1 + j*e2
1187
+ shifts = []
1188
+ for i in range(-n_tiles, n_tiles + 1):
1189
+ for j in range(-n_tiles, n_tiles + 1):
1190
+ shifts.append(i * e1 + j * e2)
1191
+ shifts = np.asarray(shifts) # ((2n+1)^2, 2)
1192
+
1193
+ all_corners = (base[None, :, :] + shifts[:, None, :]).reshape(-1, 2)
1194
+ xmin, ymin = all_corners.min(axis=0)
1195
+ xmax, ymax = all_corners.max(axis=0)
1196
+ else:
1197
+ xmin, ymin = base.min(axis=0)
1198
+ xmax, ymax = base.max(axis=0)
1199
+
1200
+ padx = 0.03 * (xmax - xmin)
1201
+ pady = 0.03 * (ymax - ymin)
1202
+ ax.set_xlim(xmin - padx, xmax + padx)
1203
+ ax.set_ylim(ymin - pady, ymax + pady)
1204
+
1205
+ ax.set_aspect("equal")
1206
+ ax.set_xlabel(config.xlabel)
1207
+ ax.set_ylabel(config.ylabel)
1208
+ ax.set_title(config.title)
1209
+
1210
+ if created_fig:
1211
+ _ensure_parent_dir(config.save_path)
1212
+ finalize_figure(fig, config)
1213
+ else:
1214
+ if config.save_path is not None:
1215
+ _ensure_parent_dir(config.save_path)
1216
+ fig.savefig(config.save_path, **config.to_savefig_kwargs())
1217
+ if config.show:
1218
+ plt.show()
1219
+
1220
+ return ax