canns 0.14.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. canns/analyzer/data/asa/__init__.py +77 -21
  2. canns/analyzer/data/asa/coho.py +97 -0
  3. canns/analyzer/data/asa/cohomap.py +408 -0
  4. canns/analyzer/data/asa/cohomap_scatter.py +10 -0
  5. canns/analyzer/data/asa/cohomap_vectors.py +311 -0
  6. canns/analyzer/data/asa/cohospace.py +173 -1153
  7. canns/analyzer/data/asa/cohospace_phase_centers.py +137 -0
  8. canns/analyzer/data/asa/cohospace_scatter.py +1220 -0
  9. canns/analyzer/data/asa/embedding.py +3 -4
  10. canns/analyzer/data/asa/plotting.py +4 -4
  11. canns/analyzer/data/cell_classification/__init__.py +10 -0
  12. canns/analyzer/data/cell_classification/core/__init__.py +4 -0
  13. canns/analyzer/data/cell_classification/core/btn.py +272 -0
  14. canns/analyzer/data/cell_classification/visualization/__init__.py +3 -0
  15. canns/analyzer/data/cell_classification/visualization/btn_plots.py +258 -0
  16. canns/analyzer/visualization/__init__.py +2 -0
  17. canns/analyzer/visualization/core/config.py +20 -0
  18. canns/analyzer/visualization/theta_sweep_plots.py +142 -0
  19. canns/pipeline/asa/runner.py +19 -19
  20. canns/pipeline/asa_gui/__init__.py +5 -3
  21. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +32 -4
  22. canns/pipeline/asa_gui/core/runner.py +23 -23
  23. canns/pipeline/asa_gui/views/pages/preprocess_page.py +250 -8
  24. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/METADATA +2 -1
  25. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/RECORD +28 -20
  26. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/WHEEL +0 -0
  27. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/entry_points.txt +0 -0
  28. {canns-0.14.2.dist-info → canns-0.15.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,1231 +1,251 @@
1
1
  from __future__ import annotations
2
2
 
3
- import os
3
+ from typing import Any
4
4
 
5
5
  import matplotlib.pyplot as plt
6
6
  import numpy as np
7
- from scipy.stats import circvar
7
+ from scipy.ndimage import gaussian_filter
8
+ from scipy.stats import binned_statistic_2d
8
9
 
9
10
  from ...visualization.core import PlotConfig, finalize_figure
11
+ from .coho import _ensure_parent_dir, _ensure_plot_config, _extract_coords_and_times
10
12
 
11
13
 
12
- def _ensure_plot_config(
13
- config: PlotConfig | None,
14
- factory,
14
+ def cohospace(
15
+ coords: np.ndarray | dict[str, Any],
16
+ spikes: np.ndarray,
15
17
  *,
16
- kwargs: dict | None = None,
17
- **defaults,
18
- ) -> PlotConfig:
19
- if config is None:
20
- defaults.update({"kwargs": kwargs or {}})
21
- return factory(**defaults)
22
-
23
- if kwargs:
24
- config_kwargs = config.kwargs or {}
25
- config_kwargs.update(kwargs)
26
- config.kwargs = config_kwargs
27
- return config
28
-
29
-
30
- def _ensure_parent_dir(save_path: str | None) -> None:
31
- if save_path:
32
- parent = os.path.dirname(save_path)
33
- if parent:
34
- os.makedirs(parent, exist_ok=True)
35
-
36
-
37
- # =====================================================================
38
- # CohoSpace visualization and selectivity metrics (CohoScore)
39
- # =====================================================================
40
-
41
-
42
- def _coho_coords_to_degrees(coords: np.ndarray) -> np.ndarray:
43
- """
44
- Convert decoded coho coordinates (T x 2, radians) into degrees in [0, 360).
45
- """
46
- return np.degrees(coords % (2 * np.pi))
47
-
48
-
49
- def _align_activity_to_coords(
50
- coords: np.ndarray,
51
- activity: np.ndarray,
52
18
  times: np.ndarray | None = None,
53
- *,
54
- label: str = "activity",
55
- auto_filter: bool = True,
56
- ) -> np.ndarray:
57
- """
58
- Align activity to coords by optional time indices and validate lengths.
19
+ coords_key: str | None = None,
20
+ bins: int = 51,
21
+ coords_in_unit: bool = False,
22
+ smooth_sigma: float = 0.0,
23
+ ) -> dict[str, Any]:
59
24
  """
60
- coords = np.asarray(coords)
61
- activity = np.asarray(activity)
62
-
63
- if times is not None:
64
- times = np.asarray(times)
65
- try:
66
- activity = activity[times]
67
- except Exception as exc:
68
- raise ValueError(
69
- f"Failed to index {label} with `times`. Ensure `times` indexes the original time axis."
70
- ) from exc
71
-
72
- if activity.shape[0] != coords.shape[0]:
73
- # Try to reproduce decode's zero-spike filtering if lengths mismatch.
74
- if auto_filter and times is None and activity.ndim == 2:
75
- mask = np.sum(activity > 0, axis=1) >= 1
76
- if mask.sum() == coords.shape[0]:
77
- activity = activity[mask]
78
- else:
79
- raise ValueError(
80
- f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
81
- "If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
82
- "`times=decoding['times']` or slice the activity accordingly."
83
- )
84
- else:
85
- raise ValueError(
86
- f"{label} length must match coords length. Got {activity.shape[0]} vs {coords.shape[0]}. "
87
- "If coords are computed on a subset of timepoints (e.g., decode['times']), pass "
88
- "`times=decoding['times']` or slice the activity accordingly."
89
- )
90
-
91
- return activity
92
-
25
+ Compute EcohoSpace rate maps and phase centers.
93
26
 
94
- def plot_cohospace_trajectory_2d(
95
- coords: np.ndarray,
96
- times: np.ndarray | None = None,
97
- subsample: int = 1,
98
- figsize: tuple[int, int] = (6, 6),
99
- cmap: str = "viridis",
100
- save_path: str | None = None,
101
- show: bool = False,
102
- config: PlotConfig | None = None,
103
- ) -> plt.Axes:
104
- """
105
- Plot a trajectory in cohomology space.
106
-
107
- Parameters
108
- ----------
109
- coords : ndarray, shape (T, 2)
110
- Decoded cohomology angles (theta1, theta2). Values may be in radians or in [0, 1] "unit circle"
111
- convention depending on upstream decoding; this function will convert to degrees for plotting.
112
- times : ndarray, optional, shape (T,)
113
- Optional time array used to color points. If None, uses arange(T).
114
- subsample : int
115
- Downsampling step (>1 reduces the number of plotted points).
116
- figsize : tuple
117
- Matplotlib figure size.
118
- cmap : str
119
- Matplotlib colormap name.
120
- save_path : str, optional
121
- If provided, saves the figure to this path.
122
- show : bool
123
- If True, calls plt.show(). If False, closes the figure and returns the Axes.
124
-
125
- Returns
126
- -------
127
- ax : matplotlib.axes.Axes
128
- The Axes containing the plot.
129
-
130
- Examples
131
- --------
132
- >>> fig = plot_cohospace_trajectory_2d(coords, subsample=2, show=False) # doctest: +SKIP
27
+ Mirrors GridCellTorus get_ratemaps: mean activity in coho-space bins and
28
+ a circular-mean center for each neuron. Optionally smooths the rate maps.
133
29
  """
30
+ if isinstance(coords, dict):
31
+ coords, times_box = _extract_coords_and_times(coords, coords_key)
32
+ if times is None:
33
+ times = times_box
134
34
 
135
- try:
136
- subsample_i = int(subsample)
137
- except Exception:
138
- subsample_i = 1
139
- if subsample_i < 1:
140
- subsample_i = 1
35
+ coords = np.asarray(coords, float)
36
+ if coords.ndim != 2 or coords.shape[1] < 2:
37
+ raise ValueError(f"coords must be (N,2+) array, got {coords.shape}")
141
38
 
142
- coords = np.asarray(coords)
143
- if coords.ndim != 2 or coords.shape[1] != 2:
144
- raise ValueError(f"`coords` must have shape (T, 2). Got {coords.shape}.")
39
+ if coords_in_unit:
40
+ coords = coords * (2 * np.pi)
145
41
 
146
- theta_deg = _coho_coords_to_degrees(coords)
147
- if subsample_i > 1:
148
- theta_deg = theta_deg[::subsample_i]
149
-
150
- if times is None:
151
- times_vis = np.arange(theta_deg.shape[0])
152
- else:
153
- times_vis = np.asarray(times)
154
- if times_vis.shape[0] != coords.shape[0]:
155
- raise ValueError(
156
- f"`times` length must match coords length. Got times={times_vis.shape[0]}, coords={coords.shape[0]}."
157
- )
158
- if subsample_i > 1:
159
- times_vis = times_vis[::subsample_i]
160
-
161
- config = _ensure_plot_config(
162
- config,
163
- PlotConfig.for_static_plot,
164
- title="CohoSpace trajectory",
165
- xlabel="theta1 (deg)",
166
- ylabel="theta2 (deg)",
167
- figsize=figsize,
168
- save_path=save_path,
169
- show=show,
170
- )
171
-
172
- fig, ax = plt.subplots(figsize=config.figsize)
173
- sc = ax.scatter(
174
- theta_deg[:, 0],
175
- theta_deg[:, 1],
176
- c=times_vis,
177
- cmap=cmap,
178
- s=3,
179
- alpha=0.8,
180
- )
181
- cbar = plt.colorbar(sc, ax=ax)
182
- cbar.set_label("Time")
183
-
184
- ax.set_xlim(0, 360)
185
- ax.set_ylim(0, 360)
186
- ax.set_xlabel(config.xlabel)
187
- ax.set_ylabel(config.ylabel)
188
- ax.set_title(config.title)
189
- ax.set_aspect("equal", adjustable="box")
190
- ax.grid(True, alpha=0.2)
191
-
192
- _ensure_parent_dir(config.save_path)
193
- finalize_figure(fig, config)
194
- return ax
195
-
196
-
197
- def plot_cohospace_trajectory_1d(
198
- coords: np.ndarray,
199
- times: np.ndarray | None = None,
200
- subsample: int = 1,
201
- figsize: tuple[int, int] = (6, 6),
202
- cmap: str = "viridis",
203
- save_path: str | None = None,
204
- show: bool = False,
205
- config: PlotConfig | None = None,
206
- ) -> plt.Axes:
207
- """
208
- Plot a 1D cohomology trajectory on the unit circle.
209
-
210
- Parameters
211
- ----------
212
- coords : ndarray, shape (T,) or (T, 1)
213
- Decoded cohomology angles (theta). Values may be in radians or in [0, 1] "unit circle"
214
- convention depending on upstream decoding; this function will plot on the unit circle.
215
- times : ndarray, optional, shape (T,)
216
- Optional time array used to color points. If None, uses arange(T).
217
- subsample : int
218
- Downsampling step (>1 reduces the number of plotted points).
219
- figsize : tuple
220
- Matplotlib figure size.
221
- cmap : str
222
- Matplotlib colormap name.
223
- save_path : str, optional
224
- If provided, saves the figure to this path.
225
- show : bool
226
- If True, calls plt.show(). If False, closes the figure and returns the Axes.
227
- """
228
- try:
229
- subsample_i = int(subsample)
230
- except Exception:
231
- subsample_i = 1
232
- if subsample_i < 1:
233
- subsample_i = 1
234
-
235
- coords = np.asarray(coords)
236
- if coords.ndim == 2 and coords.shape[1] == 1:
237
- coords = coords[:, 0]
238
- if coords.ndim != 1:
239
- raise ValueError(f"`coords` must have shape (T,) or (T, 1). Got {coords.shape}.")
240
-
241
- if times is None:
242
- times_vis = np.arange(coords.shape[0])
243
- else:
244
- times_vis = np.asarray(times)
245
- if times_vis.shape[0] != coords.shape[0]:
246
- raise ValueError(
247
- f"`times` length must match coords length. Got times={times_vis.shape[0]}, coords={coords.shape[0]}."
248
- )
249
-
250
- if subsample_i > 1:
251
- coords = coords[::subsample_i]
252
- times_vis = times_vis[::subsample_i]
253
-
254
- theta = coords % (2 * np.pi)
255
- x = np.cos(theta)
256
- y = np.sin(theta)
42
+ spikes = np.asarray(spikes)
43
+ if times is not None:
44
+ spikes = spikes[np.asarray(times).astype(int)]
257
45
 
258
- config = _ensure_plot_config(
259
- config,
260
- PlotConfig.for_static_plot,
261
- title="CohoSpace trajectory (1D)",
262
- xlabel="cos(theta)",
263
- ylabel="sin(theta)",
264
- figsize=figsize,
265
- save_path=save_path,
266
- show=show,
267
- )
46
+ if spikes.ndim == 1:
47
+ spikes = spikes[:, np.newaxis]
268
48
 
269
- fig, ax = plt.subplots(figsize=config.figsize)
270
- circle = np.linspace(0, 2 * np.pi, 200)
271
- ax.plot(np.cos(circle), np.sin(circle), color="0.85", lw=1.0, zorder=0)
272
- sc = ax.scatter(
273
- x,
274
- y,
275
- c=times_vis,
276
- cmap=cmap,
277
- s=5,
278
- alpha=0.8,
279
- )
280
- cbar = plt.colorbar(sc, ax=ax)
281
- cbar.set_label("Time")
49
+ if spikes.shape[0] != coords.shape[0]:
50
+ raise ValueError(
51
+ f"spikes length must match coords length. Got {spikes.shape[0]} vs {coords.shape[0]}"
52
+ )
282
53
 
283
- ax.set_xlim(-1.2, 1.2)
284
- ax.set_ylim(-1.2, 1.2)
285
- ax.set_xlabel(config.xlabel)
286
- ax.set_ylabel(config.ylabel)
287
- ax.set_title(config.title)
288
- ax.set_aspect("equal", adjustable="box")
289
- ax.grid(True, alpha=0.2)
54
+ edges = np.linspace(0, 2 * np.pi, bins)
55
+ bin_centers = edges[:-1] + (edges[1:] - edges[:-1]) / 2.0
56
+ xv, yv = np.meshgrid(bin_centers, bin_centers)
57
+ pos = np.stack([xv.ravel(), yv.ravel()], axis=1)
58
+ ccos = np.cos(pos)
59
+ csin = np.sin(pos)
60
+
61
+ num_neurons = spikes.shape[1]
62
+ maps = np.zeros((num_neurons, bins - 1, bins - 1))
63
+ centers = np.zeros((num_neurons, 2))
64
+
65
+ for n in range(num_neurons):
66
+ mtot_tmp, x_edge, y_edge, _ = binned_statistic_2d(
67
+ coords[:, 0],
68
+ coords[:, 1],
69
+ spikes[:, n],
70
+ statistic="mean",
71
+ bins=edges,
72
+ range=None,
73
+ expand_binnumbers=True,
74
+ )
75
+ mtot_tmp = np.rot90(mtot_tmp, 1).T
76
+ if smooth_sigma and smooth_sigma > 0:
77
+ nan_mask = np.isnan(mtot_tmp)
78
+ mtot_tmp = np.nan_to_num(mtot_tmp, nan=0.0)
79
+ mtot_tmp = gaussian_filter(mtot_tmp, smooth_sigma)
80
+ mtot_tmp[nan_mask] = np.nan
81
+ maps[n, :, :] = mtot_tmp.copy()
82
+
83
+ flat = mtot_tmp.flatten()
84
+ nans = ~np.isnan(flat)
85
+ if np.any(nans):
86
+ centcos = np.sum(ccos[nans, :] * flat[nans, np.newaxis], axis=0)
87
+ centsin = np.sum(csin[nans, :] * flat[nans, np.newaxis], axis=0)
88
+ centers[n, :] = np.arctan2(centsin, centcos) % (2 * np.pi)
89
+ else:
90
+ centers[n, :] = np.nan
290
91
 
291
- _ensure_parent_dir(config.save_path)
292
- finalize_figure(fig, config)
293
- return ax
92
+ return {
93
+ "rate_maps": maps,
94
+ "centers": centers,
95
+ "x_edge": x_edge,
96
+ "y_edge": y_edge,
97
+ "bins": bins,
98
+ "smooth_sigma": smooth_sigma,
99
+ }
294
100
 
295
101
 
296
- def plot_cohospace_neuron_2d(
297
- coords: np.ndarray,
298
- activity: np.ndarray,
299
- neuron_id: int,
300
- mode: str = "fr", # "fr" or "spike"
301
- top_percent: float = 5.0, # Used in FR mode
302
- times: np.ndarray | None = None,
303
- auto_filter: bool = True,
304
- figsize: tuple = (6, 6),
305
- cmap: str = "hot",
306
- save_path: str | None = None,
307
- show: bool = True,
102
+ def plot_cohospace(
103
+ cohospace_result: dict[str, Any],
104
+ *,
105
+ neuron_id: int = 0,
308
106
  config: PlotConfig | None = None,
309
- ) -> plt.Figure:
310
- """
311
- Overlay a single neuron's activity on the cohomology-space trajectory.
312
-
313
- This is a visualization helper:
314
- - mode="fr": marks the top `top_percent`%% time points by firing rate for the given neuron.
315
- - mode="spike": marks all time points where spike > 0 for the given neuron.
316
-
317
- Parameters
318
- ----------
319
- coords : ndarray, shape (T, 2)
320
- Decoded cohomology angles (theta1, theta2), in radians.
321
- activity : ndarray, shape (T, N)
322
- Activity matrix (continuous firing rate or binned spikes).
323
- times : ndarray, optional, shape (T_coords,)
324
- Optional indices to align activity to coords when coords are computed on a subset of timepoints.
325
- auto_filter : bool
326
- If True and lengths mismatch, auto-filter activity with activity>0 to mimic decode filtering.
327
- neuron_id : int
328
- Neuron index to visualize.
329
- mode : {"fr", "spike"}
330
- top_percent : float
331
- Used only when mode="fr". For example, 5.0 means "top 5%%" time points.
332
- figsize, cmap, save_path, show : see `plot_cohospace_trajectory_2d`.
333
-
334
- Returns
335
- -------
336
- ax : matplotlib.axes.Axes
337
-
338
- Examples
339
- --------
340
- >>> plot_cohospace_neuron_2d(coords, spikes, neuron_id=0, show=False) # doctest: +SKIP
341
- """
342
- coords = np.asarray(coords)
343
- activity = _align_activity_to_coords(
344
- coords, activity, times, label="activity", auto_filter=auto_filter
345
- )
346
- theta_deg = _coho_coords_to_degrees(coords)
347
-
348
- signal = activity[:, neuron_id]
349
-
350
- if mode == "fr":
351
- # Select the neuron's top `top_percent`% time points
352
- threshold = np.percentile(signal, 100 - top_percent)
353
- idx = signal >= threshold
354
- color = signal[idx]
355
- title = f"Neuron {neuron_id} FR top {top_percent:.1f}% on coho-space"
356
- use_cmap = cmap
357
- elif mode == "spike":
358
- idx = signal > 0
359
- color = None
360
- title = f"Neuron {neuron_id} spikes on coho-space"
361
- use_cmap = None
362
- else:
363
- raise ValueError("mode must be 'fr' or 'spike'")
364
-
365
- config = _ensure_plot_config(
366
- config,
367
- PlotConfig.for_static_plot,
368
- title=title,
369
- xlabel="Theta 1 (°)",
370
- ylabel="Theta 2 (°)",
371
- figsize=figsize,
372
- save_path=save_path,
373
- show=show,
374
- )
375
-
376
- fig, ax = plt.subplots(figsize=config.figsize)
377
- sc = ax.scatter(
378
- theta_deg[idx, 0],
379
- theta_deg[idx, 1],
380
- c=color if mode == "fr" else "red",
381
- cmap=use_cmap,
382
- s=5,
383
- alpha=0.9,
384
- )
385
-
386
- if mode == "fr":
387
- cbar = plt.colorbar(sc, ax=ax)
388
- cbar.set_label("Firing rate")
389
-
390
- ax.set_xlim(0, 360)
391
- ax.set_ylim(0, 360)
392
- ax.set_xlabel(config.xlabel)
393
- ax.set_ylabel(config.ylabel)
394
- ax.set_title(config.title)
395
-
396
- _ensure_parent_dir(config.save_path)
397
- finalize_figure(fig, config)
398
-
399
- return fig
400
-
401
-
402
- def plot_cohospace_neuron_1d(
403
- coords: np.ndarray,
404
- activity: np.ndarray,
405
- neuron_id: int,
406
- mode: str = "fr",
407
- top_percent: float = 5.0,
408
- times: np.ndarray | None = None,
409
- auto_filter: bool = True,
410
- figsize: tuple = (6, 6),
411
- cmap: str = "hot",
412
107
  save_path: str | None = None,
413
- show: bool = True,
414
- config: PlotConfig | None = None,
108
+ show: bool = False,
109
+ figsize: tuple[int, int] = (5, 5),
110
+ cmap: str = "viridis",
415
111
  ) -> plt.Figure:
416
112
  """
417
- Overlay a single neuron's activity on the 1D cohomology trajectory (unit circle).
113
+ Plot a single-neuron EcohoSpace rate map.
418
114
  """
419
- coords = np.asarray(coords)
420
- if coords.ndim == 2 and coords.shape[1] == 1:
421
- coords = coords[:, 0]
422
- if coords.ndim != 1:
423
- raise ValueError(f"coords must have shape (T,) or (T, 1), got {coords.shape}")
424
-
425
- activity = _align_activity_to_coords(
426
- coords[:, None], activity, times, label="activity", auto_filter=auto_filter
427
- )
428
-
429
- signal = activity[:, neuron_id]
430
-
431
- if mode == "fr":
432
- threshold = np.percentile(signal, 100 - top_percent)
433
- idx = signal >= threshold
434
- color = signal[idx]
435
- title = f"Neuron {neuron_id} FR top {top_percent:.1f}% on coho-space (1D)"
436
- use_cmap = cmap
437
- elif mode == "spike":
438
- idx = signal > 0
439
- color = None
440
- title = f"Neuron {neuron_id} spikes on coho-space (1D)"
441
- use_cmap = None
442
- else:
443
- raise ValueError("mode must be 'fr' or 'spike'")
444
-
445
- theta = coords % (2 * np.pi)
446
- x = np.cos(theta)
447
- y = np.sin(theta)
448
-
449
115
  config = _ensure_plot_config(
450
116
  config,
451
117
  PlotConfig.for_static_plot,
452
- title=title,
453
- xlabel="cos(theta)",
454
- ylabel="sin(theta)",
118
+ title="EcohoSpace",
119
+ xlabel="",
120
+ ylabel="",
455
121
  figsize=figsize,
456
122
  save_path=save_path,
457
123
  show=show,
458
124
  )
459
125
 
460
- fig, ax = plt.subplots(figsize=config.figsize)
461
- circle = np.linspace(0, 2 * np.pi, 200)
462
- ax.plot(np.cos(circle), np.sin(circle), color="0.85", lw=1.0, zorder=0)
463
- sc = ax.scatter(
464
- x[idx],
465
- y[idx],
466
- c=color if mode == "fr" else "red",
467
- cmap=use_cmap,
468
- s=8,
469
- alpha=0.9,
470
- )
471
-
472
- if mode == "fr":
473
- cbar = plt.colorbar(sc, ax=ax)
474
- cbar.set_label("Firing rate")
475
-
476
- ax.set_xlim(-1.2, 1.2)
477
- ax.set_ylim(-1.2, 1.2)
478
- ax.set_xlabel(config.xlabel)
479
- ax.set_ylabel(config.ylabel)
480
- ax.set_title(config.title)
481
- ax.set_aspect("equal", adjustable="box")
482
-
483
- _ensure_parent_dir(config.save_path)
484
- finalize_figure(fig, config)
485
-
486
- return fig
487
-
488
-
489
- def plot_cohospace_population_2d(
490
- coords: np.ndarray,
491
- activity: np.ndarray,
492
- neuron_ids: list[int] | np.ndarray,
493
- mode: str = "fr", # "fr" or "spike"
494
- top_percent: float = 5.0, # Used in FR mode
495
- times: np.ndarray | None = None,
496
- auto_filter: bool = True,
497
- figsize: tuple = (6, 6),
498
- cmap: str = "hot",
499
- save_path: str | None = None,
500
- show: bool = True,
501
- config: PlotConfig | None = None,
502
- ) -> plt.Figure:
503
- """
504
- Plot aggregated activity from multiple neurons in cohomology space.
505
-
506
- For mode="fr":
507
- - For each neuron, select its top `top_percent`%% time points by firing rate.
508
- - Aggregate (sum) firing rates over the selected points and plot as colors.
509
-
510
- For mode="spike":
511
- - For each neuron, count spikes at each time point (spike > 0).
512
- - Aggregate counts over neurons and plot as colors.
513
-
514
- Parameters
515
- ----------
516
- coords : ndarray, shape (T, 2)
517
- activity : ndarray, shape (T, N)
518
- times : ndarray, optional, shape (T_coords,)
519
- Optional indices to align activity to coords when coords are computed on a subset of timepoints.
520
- auto_filter : bool
521
- If True and lengths mismatch, auto-filter activity with activity>0 to mimic decode filtering.
522
- neuron_ids : iterable[int]
523
- Neuron indices to include (use range(N) to include all).
524
- mode : {"fr", "spike"}
525
- top_percent : float
526
- Used only when mode="fr".
527
- figsize, cmap, save_path, show : see `plot_cohospace_trajectory_2d`.
528
-
529
- Returns
530
- -------
531
- ax : matplotlib.axes.Axes
532
-
533
- Examples
534
- --------
535
- >>> plot_cohospace_population_2d(coords, spikes, neuron_ids=[0, 1, 2], show=False) # doctest: +SKIP
536
- """
537
- coords = np.asarray(coords)
538
- activity = _align_activity_to_coords(
539
- coords, activity, times, label="activity", auto_filter=auto_filter
540
- )
541
- neuron_ids = np.asarray(neuron_ids, dtype=int)
542
-
543
- theta_deg = _coho_coords_to_degrees(coords)
544
-
545
- T = activity.shape[0]
546
- mask = np.zeros(T, dtype=bool)
547
- agg_color = np.zeros(T, dtype=float)
126
+ maps = cohospace_result["rate_maps"]
127
+ x_edge = cohospace_result["x_edge"]
128
+ y_edge = cohospace_result["y_edge"]
548
129
 
549
- for n in neuron_ids:
550
- signal = activity[:, n]
551
-
552
- if mode == "fr":
553
- threshold = np.percentile(signal, 100 - top_percent)
554
- idx = signal >= threshold
555
- agg_color[idx] += signal[idx]
556
- mask |= idx
557
- elif mode == "spike":
558
- idx = signal > 0
559
- agg_color[idx] += 1.0
560
- mask |= idx
561
- else:
562
- raise ValueError("mode must be 'fr' or 'spike'")
563
-
564
- config = _ensure_plot_config(
565
- config,
566
- PlotConfig.for_static_plot,
567
- title=f"{len(neuron_ids)} neurons on coho-space",
568
- xlabel="Theta 1 (°)",
569
- ylabel="Theta 2 (°)",
570
- figsize=figsize,
571
- save_path=save_path,
572
- show=show,
573
- )
130
+ if neuron_id < 0 or neuron_id >= maps.shape[0]:
131
+ raise ValueError(f"neuron_id out of range: {neuron_id}")
574
132
 
575
- fig, ax = plt.subplots(figsize=config.figsize)
576
- sc = ax.scatter(
577
- theta_deg[mask, 0],
578
- theta_deg[mask, 1],
579
- c=agg_color[mask],
133
+ fig, ax = plt.subplots(1, 1, figsize=config.figsize)
134
+ im = ax.imshow(
135
+ maps[neuron_id],
136
+ origin="lower",
137
+ extent=[x_edge[0], x_edge[-1], y_edge[0], y_edge[-1]],
580
138
  cmap=cmap,
581
- s=5,
582
- alpha=0.9,
583
139
  )
584
- cbar = plt.colorbar(sc, ax=ax)
585
- label = "Aggregate FR" if mode == "fr" else "Spike count"
586
- cbar.set_label(label)
587
-
588
- ax.set_xlim(0, 360)
589
- ax.set_ylim(0, 360)
590
- ax.set_xlabel(config.xlabel)
591
- ax.set_ylabel(config.ylabel)
592
- ax.set_title(config.title)
140
+ ax.set_aspect("equal", "box")
141
+ ax.set_xticks([])
142
+ ax.set_yticks([])
143
+ ax.set_title(f"EcohoSpace Rate Map (neuron {neuron_id})", fontsize=10)
144
+ plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="Mean activity")
593
145
 
146
+ fig.tight_layout()
594
147
  _ensure_parent_dir(config.save_path)
595
148
  finalize_figure(fig, config)
596
-
597
149
  return fig
598
150
 
599
151
 
600
- def plot_cohospace_population_1d(
601
- coords: np.ndarray,
602
- activity: np.ndarray,
603
- neuron_ids: list[int] | np.ndarray,
604
- mode: str = "fr",
605
- top_percent: float = 5.0,
606
- times: np.ndarray | None = None,
607
- auto_filter: bool = True,
608
- figsize: tuple = (6, 6),
609
- cmap: str = "hot",
610
- save_path: str | None = None,
611
- show: bool = True,
152
+ def plot_cohospace_skewed(
153
+ cohospace_result: dict[str, Any],
154
+ *,
155
+ neuron_id: int = 0,
612
156
  config: PlotConfig | None = None,
157
+ save_path: str | None = None,
158
+ show: bool = False,
159
+ figsize: tuple[int, int] = (5, 5),
160
+ cmap: str = "viridis",
161
+ show_grid: bool = True,
613
162
  ) -> plt.Figure:
614
163
  """
615
- Plot aggregated activity from multiple neurons on the 1D cohomology trajectory.
164
+ Plot a single-neuron EcohoSpace rate map in skewed torus coordinates.
616
165
  """
617
- coords = np.asarray(coords)
618
- if coords.ndim == 2 and coords.shape[1] == 1:
619
- coords = coords[:, 0]
620
- if coords.ndim != 1:
621
- raise ValueError(f"coords must have shape (T,) or (T, 1), got {coords.shape}")
622
-
623
- activity = _align_activity_to_coords(
624
- coords[:, None], activity, times, label="activity", auto_filter=auto_filter
625
- )
626
- neuron_ids = np.asarray(neuron_ids, dtype=int)
627
-
628
- T = activity.shape[0]
629
- mask = np.zeros(T, dtype=bool)
630
- agg_color = np.zeros(T, dtype=float)
631
-
632
- for n in neuron_ids:
633
- signal = activity[:, n]
634
-
635
- if mode == "fr":
636
- threshold = np.percentile(signal, 100 - top_percent)
637
- idx = signal >= threshold
638
- agg_color[idx] += signal[idx]
639
- mask |= idx
640
- elif mode == "spike":
641
- idx = signal > 0
642
- agg_color[idx] += 1.0
643
- mask |= idx
644
- else:
645
- raise ValueError("mode must be 'fr' or 'spike'")
646
-
647
- theta = coords % (2 * np.pi)
648
- x = np.cos(theta)
649
- y = np.sin(theta)
650
-
651
166
  config = _ensure_plot_config(
652
167
  config,
653
168
  PlotConfig.for_static_plot,
654
- title=f"{len(neuron_ids)} neurons on coho-space (1D)",
655
- xlabel="cos(theta)",
656
- ylabel="sin(theta)",
169
+ title="EcohoSpace (Skewed)",
170
+ xlabel=r"$\theta_1 + \frac{1}{2}\theta_2$",
171
+ ylabel=r"$\frac{\sqrt{3}}{2}\theta_2$",
657
172
  figsize=figsize,
658
173
  save_path=save_path,
659
174
  show=show,
660
175
  )
661
176
 
662
- fig, ax = plt.subplots(figsize=config.figsize)
663
- circle = np.linspace(0, 2 * np.pi, 200)
664
- ax.plot(np.cos(circle), np.sin(circle), color="0.85", lw=1.0, zorder=0)
665
- sc = ax.scatter(
666
- x[mask],
667
- y[mask],
668
- c=agg_color[mask],
669
- cmap=cmap,
670
- s=6,
671
- alpha=0.9,
672
- )
673
- cbar = plt.colorbar(sc, ax=ax)
674
- label = "Aggregate FR" if mode == "fr" else "Spike count"
675
- cbar.set_label(label)
676
-
677
- ax.set_xlim(-1.2, 1.2)
678
- ax.set_ylim(-1.2, 1.2)
679
- ax.set_xlabel(config.xlabel)
680
- ax.set_ylabel(config.ylabel)
681
- ax.set_title(config.title)
682
- ax.set_aspect("equal", adjustable="box")
683
-
684
- _ensure_parent_dir(config.save_path)
685
- finalize_figure(fig, config)
686
-
687
- return fig
688
-
689
-
690
- def compute_cohoscore_2d(
691
- coords: np.ndarray,
692
- activity: np.ndarray,
693
- top_percent: float = 2.0,
694
- times: np.ndarray | None = None,
695
- auto_filter: bool = True,
696
- ) -> np.ndarray:
697
- """
698
- Compute a simple cohomology-space selectivity score (CohoScore) for each neuron.
699
-
700
- For each neuron:
701
- - Select "active" time points:
702
- - If top_percent is None: all time points with activity > 0
703
- - Else: top `top_percent`%% time points by activity value
704
- - Compute circular variance for theta1 and theta2 on the selected points.
705
- - CohoScore = 0.5 * (var(theta1) + var(theta2))
177
+ maps = cohospace_result["rate_maps"]
178
+ x_edge = cohospace_result["x_edge"]
179
+ y_edge = cohospace_result["y_edge"]
706
180
 
707
- Interpretation:
708
- - Smaller score => points are more concentrated in coho space => higher selectivity.
181
+ if neuron_id < 0 or neuron_id >= maps.shape[0]:
182
+ raise ValueError(f"neuron_id out of range: {neuron_id}")
709
183
 
710
- Parameters
711
- ----------
712
- coords : ndarray, shape (T, 2)
713
- Decoded cohomology angles (theta1, theta2), in radians.
714
- activity : ndarray, shape (T, N)
715
- times : ndarray, optional, shape (T_coords,)
716
- Optional indices to align activity to coords when coords are computed on a subset of timepoints.
717
- auto_filter : bool
718
- If True and lengths mismatch, auto-filter activity with activity>0 to mimic decode filtering.
719
- Activity matrix (FR or spikes).
720
- top_percent : float | None
721
- Percentage for selecting active points (e.g., 2.0 means top 2%%). If None, use activity>0.
722
-
723
- Returns
724
- -------
725
- scores : ndarray, shape (N,)
726
- CohoScore per neuron (NaN for neurons with too few points).
727
-
728
- Examples
729
- --------
730
- >>> scores = compute_cohoscore_2d(coords, spikes) # doctest: +SKIP
731
- >>> scores.shape[0] # doctest: +SKIP
732
- """
733
- coords = np.asarray(coords)
734
- activity = _align_activity_to_coords(
735
- coords, activity, times, label="activity", auto_filter=auto_filter
736
- )
737
- T, N = activity.shape
184
+ th1, th2 = np.meshgrid(x_edge, y_edge, indexing="xy")
185
+ X = th1 + 0.5 * th2
186
+ Y = (np.sqrt(3) / 2.0) * th2
738
187
 
739
- theta = coords % (2 * np.pi) # Ensure values are in [0, 2π)
740
- scores = np.zeros(N, dtype=float)
188
+ fig, ax = plt.subplots(1, 1, figsize=config.figsize)
189
+ im = ax.pcolormesh(X, Y, maps[neuron_id], shading="auto", cmap=cmap)
741
190
 
742
- for n in range(N):
743
- signal = activity[:, n]
744
-
745
- if top_percent is None:
746
- idx = signal > 0 # Use all time points with spikes
747
- else:
748
- threshold = np.percentile(signal, 100 - top_percent)
749
- idx = signal >= threshold
750
-
751
- if np.sum(idx) < 5:
752
- scores[n] = np.nan # Too sparse; unreliable
753
- continue
754
-
755
- theta1 = theta[idx, 0]
756
- theta2 = theta[idx, 1]
757
-
758
- var1 = circvar(theta1, high=2 * np.pi, low=0)
759
- var2 = circvar(theta2, high=2 * np.pi, low=0)
760
-
761
- scores[n] = 0.5 * (var1 + var2)
762
-
763
- return scores
764
-
765
-
766
- def compute_cohoscore_1d(
767
- coords: np.ndarray,
768
- activity: np.ndarray,
769
- top_percent: float = 2.0,
770
- times: np.ndarray | None = None,
771
- auto_filter: bool = True,
772
- ) -> np.ndarray:
773
- """
774
- Compute 1D cohomology-space selectivity score (CohoScore) for each neuron.
775
-
776
- For each neuron:
777
- - Select "active" time points:
778
- - If top_percent is None: all time points with activity > 0
779
- - Else: top `top_percent`%% time points by activity value
780
- - Compute circular variance for theta on the selected points.
781
- - CohoScore = var(theta)
782
- """
783
- coords = np.asarray(coords)
784
- if coords.ndim == 2 and coords.shape[1] == 1:
785
- coords = coords[:, 0]
786
- if coords.ndim != 1:
787
- raise ValueError(f"coords must have shape (T,) or (T, 1), got {coords.shape}")
788
-
789
- activity = _align_activity_to_coords(
790
- coords[:, None], activity, times, label="activity", auto_filter=auto_filter
791
- )
792
- _, n_neurons = activity.shape
793
-
794
- theta = coords % (2 * np.pi)
795
- scores = np.zeros(n_neurons, dtype=float)
796
-
797
- for n in range(n_neurons):
798
- signal = activity[:, n]
799
-
800
- if top_percent is None:
801
- idx = signal > 0
802
- else:
803
- threshold = np.percentile(signal, 100 - top_percent)
804
- idx = signal >= threshold
805
-
806
- if np.sum(idx) < 5:
807
- scores[n] = np.nan
808
- continue
809
-
810
- var1 = circvar(theta[idx], high=2 * np.pi, low=0)
811
- scores[n] = var1
812
-
813
- return scores
814
-
815
-
816
- def skew_transform_torus(coords):
817
- """
818
- Convert torus angles (theta1, theta2) into coordinates in a skewed parallelogram fundamental domain.
819
-
820
- Given theta1, theta2 in radians, map:
821
- x = theta1 + 0.5 * theta2
822
- y = (sqrt(3)/2) * theta2
823
-
824
- This is a linear change of basis that turns the square [0, 2π)×[0, 2π) into a 60-degree
825
- parallelogram, which is convenient for visualizing wrap-around behavior on a 2-torus.
826
-
827
- Parameters
828
- ----------
829
- coords : ndarray, shape (T, 2)
830
- Angles (theta1, theta2) in radians.
831
-
832
- Returns
833
- -------
834
- xy : ndarray, shape (T, 2)
835
- Skewed planar coordinates.
836
- """
837
- coords = np.asarray(coords)
838
- if coords.ndim != 2 or coords.shape[1] != 2:
839
- raise ValueError(f"coords must be (T,2), got {coords.shape}")
840
-
841
- theta1 = coords[:, 0]
842
- theta2 = coords[:, 1]
843
-
844
- # Linear change of basis (NO nonlinear scaling)
845
- x = theta1 + 0.5 * theta2
846
- y = (np.sqrt(3) / 2.0) * theta2
847
-
848
- return np.stack([x, y], axis=1)
849
-
850
-
851
- def draw_torus_parallelogram_grid(ax, n_tiles=1, color="0.7", lw=1.0, alpha=0.8):
852
- """
853
- Draw parallelogram grid corresponding to torus fundamental domain.
854
-
855
- Fundamental vectors:
856
- e1 = (2π, 0)
857
- e2 = (π, √3 π)
858
-
859
- Parameters
860
- ----------
861
- ax : matplotlib axis
862
- n_tiles : int
863
- How many tiles to draw in +/- directions (visual aid).
864
- n_tiles=1 means draw [-1, 0, 1] shifts.
865
- """
866
- e1 = np.array([2 * np.pi, 0.0])
867
- e2 = np.array([np.pi, np.sqrt(3) * np.pi])
868
-
869
- shifts = range(-n_tiles, n_tiles + 1)
870
-
871
- for i in shifts:
872
- for j in shifts:
873
- origin = i * e1 + j * e2
874
- corners = np.array([origin, origin + e1, origin + e1 + e2, origin + e2, origin])
875
- ax.plot(corners[:, 0], corners[:, 1], color=color, lw=lw, alpha=alpha)
876
-
877
-
878
- def tile_parallelogram_points(xy, n_tiles=1):
879
- """
880
- Tile points in the skewed (parallelogram) torus fundamental domain.
881
-
882
- This is mainly for static visualizations so you can visually inspect continuity
883
- across domain boundaries.
884
-
885
- Parameters
886
- ----------
887
- points : ndarray, shape (T, 2)
888
- Points in the skewed plane (same coordinates as returned by `skew_transform_torus`).
889
- n_tiles : int
890
- Number of tiles to extend around the base domain.
891
- - n_tiles=1 produces a 3x3 tiling
892
- - n_tiles=2 produces a 5x5 tiling
893
-
894
- Returns
895
- -------
896
- tiled : ndarray
897
- Tiled points.
898
- """
899
- xy = np.asarray(xy, dtype=float)
900
-
901
- e1 = np.array([2 * np.pi, 0.0])
902
- e2 = np.array([np.pi, np.sqrt(3) * np.pi])
903
-
904
- out = []
905
- for i in range(-n_tiles, n_tiles + 1):
906
- for j in range(-n_tiles, n_tiles + 1):
907
- out.append(xy + i * e1 + j * e2)
908
-
909
- return np.vstack(out) if len(out) else xy
910
-
911
-
912
- def plot_cohospace_neuron_skewed(
913
- coords,
914
- activity,
915
- neuron_id,
916
- mode="spike",
917
- top_percent=2.0,
918
- times: np.ndarray | None = None,
919
- auto_filter: bool = True,
920
- save_path=None,
921
- show=None,
922
- ax=None,
923
- show_grid=True,
924
- n_tiles=1,
925
- s=6,
926
- alpha=0.8,
927
- config: PlotConfig | None = None,
928
- ):
929
- """
930
- Plot single-neuron CohoSpace on skewed torus domain.
931
-
932
- Parameters
933
- ----------
934
- coords : ndarray, shape (T, 2)
935
- Decoded circular coordinates (theta1, theta2), in radians.
936
- activity : ndarray, shape (T, N)
937
- Activity matrix aligned with coords.
938
- neuron_id : int
939
- Neuron index.
940
- mode : {"spike", "fr"}
941
- spike: use activity > 0
942
- fr: use top_percent threshold
943
- top_percent : float
944
- Percentile for FR thresholding.
945
- auto_filter : bool
946
- If True and lengths mismatch, auto-filter activity with activity>0 to mimic decode filtering.
947
- """
948
- coords = np.asarray(coords)
949
- activity = _align_activity_to_coords(
950
- coords, activity, times, label="activity", auto_filter=auto_filter
951
- )
952
-
953
- # --- normalize angles to [0, 2π)
954
- coords = coords % (2 * np.pi)
955
-
956
- # --- select neuron activity
957
- a = activity[:, neuron_id]
958
-
959
- if mode == "spike":
960
- mask = a > 0
961
- elif mode == "fr":
962
- thr = np.percentile(a, 100 - top_percent)
963
- mask = a >= thr
964
- else:
965
- raise ValueError(f"Unknown mode: {mode}")
966
-
967
- val = a[mask] # Used for FR-mode coloring
968
-
969
- if config is None:
970
- config = PlotConfig.for_static_plot(
971
- title=f"Neuron {neuron_id} – CohoSpace (skewed, mode={mode})",
972
- xlabel=r"$\theta_1 + \frac{1}{2}\theta_2$",
973
- ylabel=r"$\frac{\sqrt{3}}{2}\theta_2$",
974
- figsize=(5, 5),
975
- save_path=save_path,
976
- show=bool(show) if show is not None else False,
977
- )
978
- else:
979
- if save_path is not None:
980
- config.save_path = save_path
981
- if show is not None:
982
- config.show = show
983
- if not config.title:
984
- config.title = f"Neuron {neuron_id} – CohoSpace (skewed, mode={mode})"
985
- if not config.xlabel:
986
- config.xlabel = r"$\theta_1 + \frac{1}{2}\theta_2$"
987
- if not config.ylabel:
988
- config.ylabel = r"$\frac{\sqrt{3}}{2}\theta_2$"
989
-
990
- created_fig = ax is None
991
- if created_fig:
992
- fig, ax = plt.subplots(figsize=config.figsize)
993
- else:
994
- fig = ax.figure
995
-
996
- # --- fundamental domain vectors in skew plane
997
- e1 = np.array([2 * np.pi, 0.0])
998
- e2 = np.array([np.pi, np.sqrt(3) * np.pi])
999
-
1000
- def _draw_single_domain(ax):
1001
- P00 = np.array([0.0, 0.0])
1002
- P10 = e1
1003
- P01 = e2
1004
- P11 = e1 + e2
1005
- poly = np.vstack([P00, P10, P11, P01, P00])
1006
- ax.plot(poly[:, 0], poly[:, 1], lw=1.2, color="0.35")
1007
-
1008
- def _annotate_corners(ax):
1009
- P00 = np.array([0.0, 0.0])
1010
- P10 = e1
1011
- P01 = e2
1012
- P11 = e1 + e2
1013
-
1014
- corners = np.vstack([P00, P10, P01, P11])
1015
- xmin, ymin = corners.min(axis=0)
1016
- xmax, ymax = corners.max(axis=0)
1017
- padx = 0.02 * (xmax - xmin)
1018
- pady = 0.02 * (ymax - ymin)
1019
-
1020
- bbox = dict(facecolor="white", edgecolor="none", alpha=0.7, pad=1.0)
1021
-
1022
- ax.text(
1023
- P00[0] + padx, P00[1] + pady, "(0,0)", fontsize=10, ha="left", va="bottom", bbox=bbox
1024
- )
1025
- ax.text(
1026
- P10[0] - padx, P10[1] + pady, "(2π,0)", fontsize=10, ha="right", va="bottom", bbox=bbox
1027
- )
1028
- ax.text(P01[0] + padx, P01[1] - pady, "(0,2π)", fontsize=10, ha="left", va="top", bbox=bbox)
1029
- ax.text(
1030
- P11[0] - padx, P11[1] - pady, "(2π,2π)", fontsize=10, ha="right", va="top", bbox=bbox
1031
- )
1032
-
1033
- # --- skew transform
1034
- xy = skew_transform_torus(coords[mask])
1035
-
1036
- # Tiling: if points are tiled, values must be tiled too (FR mode) to keep lengths consistent
1037
- if n_tiles and n_tiles > 0:
1038
- xy = tile_parallelogram_points(xy, n_tiles=n_tiles)
1039
- if mode == "fr":
1040
- val = np.tile(val, (2 * n_tiles + 1) ** 2)
1041
-
1042
- # --- scatter
1043
- if mode == "fr":
1044
- sc = ax.scatter(xy[:, 0], xy[:, 1], c=val, s=s, alpha=alpha, cmap="viridis")
1045
- fig.colorbar(sc, ax=ax, shrink=0.85, pad=0.02, label="activity")
1046
- else:
1047
- ax.scatter(xy[:, 0], xy[:, 1], s=s, alpha=alpha, color="tab:blue")
1048
-
1049
- # Always draw the base domain boundary
1050
- _draw_single_domain(ax)
1051
-
1052
- # Grid is optional (debug aid); when tiles=0 only the base domain is drawn
1053
191
  if show_grid:
1054
- draw_torus_parallelogram_grid(ax, n_tiles=n_tiles)
1055
-
1056
- _annotate_corners(ax)
1057
-
1058
- # Fix view limits: tiles=0 shows base domain; tiles>0 shows the tiled extent
1059
- base = np.vstack([[0, 0], e1, e2, e1 + e2])
1060
-
1061
- if n_tiles and n_tiles > 0:
1062
- # Expand view by n_tiles rings around the base domain
1063
- # Translation vectors for tiling are i*e1 + j*e2
1064
- shifts = []
1065
- for i in range(-n_tiles, n_tiles + 1):
1066
- for j in range(-n_tiles, n_tiles + 1):
1067
- shifts.append(i * e1 + j * e2)
1068
- shifts = np.asarray(shifts) # ((2n+1)^2, 2)
1069
-
1070
- all_corners = (base[None, :, :] + shifts[:, None, :]).reshape(-1, 2)
1071
- xmin, ymin = all_corners.min(axis=0)
1072
- xmax, ymax = all_corners.max(axis=0)
1073
- else:
1074
- xmin, ymin = base.min(axis=0)
1075
- xmax, ymax = base.max(axis=0)
1076
-
1077
- padx = 0.03 * (xmax - xmin)
1078
- pady = 0.03 * (ymax - ymin)
192
+ e1 = np.array([2 * np.pi, 0.0])
193
+ e2 = np.array([np.pi, np.sqrt(3) * np.pi])
194
+ poly = np.vstack([np.zeros(2), e1, e1 + e2, e2, np.zeros(2)])
195
+ ax.plot(poly[:, 0], poly[:, 1], lw=1.1, color="0.35")
196
+
197
+ ax.set_aspect("equal", "box")
198
+ ax.set_xticks([])
199
+ ax.set_yticks([])
200
+ ax.set_title(f"EcohoSpace Rate Map (skewed, neuron {neuron_id})", fontsize=10)
201
+ plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="Mean activity")
202
+
203
+ corners = np.vstack(
204
+ [
205
+ [0.0, 0.0],
206
+ [2 * np.pi, 0.0],
207
+ [np.pi, np.sqrt(3) * np.pi],
208
+ [3 * np.pi, np.sqrt(3) * np.pi],
209
+ ]
210
+ )
211
+ xmin, ymin = corners.min(axis=0)
212
+ xmax, ymax = corners.max(axis=0)
213
+ padx = 0.02 * (xmax - xmin)
214
+ pady = 0.02 * (ymax - ymin)
1079
215
  ax.set_xlim(xmin - padx, xmax + padx)
1080
216
  ax.set_ylim(ymin - pady, ymax + pady)
1081
217
 
1082
- ax.set_aspect("equal")
1083
- ax.set_xlabel(config.xlabel)
1084
- ax.set_ylabel(config.ylabel)
1085
- ax.set_title(config.title)
1086
-
1087
- if created_fig:
1088
- _ensure_parent_dir(config.save_path)
1089
- finalize_figure(fig, config)
1090
- else:
1091
- if config.save_path is not None:
1092
- _ensure_parent_dir(config.save_path)
1093
- fig.savefig(config.save_path, **config.to_savefig_kwargs())
1094
- if config.show:
1095
- plt.show()
1096
-
1097
- return ax
1098
-
1099
-
1100
- def plot_cohospace_population_skewed(
1101
- coords,
1102
- activity,
1103
- neuron_ids,
1104
- mode="spike",
1105
- top_percent=2.0,
1106
- times: np.ndarray | None = None,
1107
- auto_filter: bool = True,
1108
- save_path=None,
1109
- show=False,
1110
- ax=None,
1111
- show_grid=True,
1112
- n_tiles=1,
1113
- s=4,
1114
- alpha=0.5,
1115
- config: PlotConfig | None = None,
1116
- ):
1117
- """
1118
- Plot population CohoSpace on skewed torus domain.
1119
-
1120
- neuron_ids : list or ndarray
1121
- Neurons to include (e.g. top-K by CohoScore).
1122
- auto_filter : bool
1123
- If True and lengths mismatch, auto-filter activity with activity>0 to mimic decode filtering.
1124
- """
1125
- coords = np.asarray(coords)
1126
- activity = _align_activity_to_coords(
1127
- coords, activity, times, label="activity", auto_filter=auto_filter
1128
- )
1129
- coords = coords % (2 * np.pi)
1130
-
1131
- if config is None:
1132
- config = PlotConfig.for_static_plot(
1133
- title=f"Population CohoSpace (skewed, n={len(neuron_ids)}, mode={mode})",
1134
- xlabel=r"$\theta_1 + \frac{1}{2}\theta_2$",
1135
- ylabel=r"$\frac{\sqrt{3}}{2}\theta_2$",
1136
- figsize=(5, 5),
1137
- save_path=save_path,
1138
- show=show,
1139
- )
1140
- else:
1141
- if save_path is not None:
1142
- config.save_path = save_path
1143
- if show is not None:
1144
- config.show = show
1145
- if not config.title:
1146
- config.title = f"Population CohoSpace (skewed, n={len(neuron_ids)}, mode={mode})"
1147
- if not config.xlabel:
1148
- config.xlabel = r"$\theta_1 + \frac{1}{2}\theta_2$"
1149
- if not config.ylabel:
1150
- config.ylabel = r"$\frac{\sqrt{3}}{2}\theta_2$"
1151
-
1152
- created_fig = ax is None
1153
- if created_fig:
1154
- fig, ax = plt.subplots(figsize=config.figsize)
1155
- else:
1156
- fig = ax.figure
1157
-
1158
- # --- fundamental domain vectors in skew plane
1159
- e1 = np.array([2 * np.pi, 0.0])
1160
- e2 = np.array([np.pi, np.sqrt(3) * np.pi])
1161
-
1162
- def _draw_single_domain(ax):
1163
- P00 = np.array([0.0, 0.0])
1164
- P10 = e1
1165
- P01 = e2
1166
- P11 = e1 + e2
1167
- poly = np.vstack([P00, P10, P11, P01, P00])
1168
- ax.plot(poly[:, 0], poly[:, 1], lw=1.2, color="0.35")
1169
-
1170
- # --- scatter each neuron
1171
- for nid in neuron_ids:
1172
- a = activity[:, nid]
1173
- if mode == "spike":
1174
- mask = a > 0
1175
- else:
1176
- thr = np.percentile(a, 100 - top_percent)
1177
- mask = a >= thr
218
+ fig.tight_layout()
219
+ _ensure_parent_dir(config.save_path)
220
+ finalize_figure(fig, config)
221
+ return fig
1178
222
 
1179
- xy = skew_transform_torus(coords[mask])
1180
223
 
1181
- if n_tiles and n_tiles > 0:
1182
- xy = tile_parallelogram_points(xy, n_tiles=n_tiles)
224
+ def cohospace_upgrade(*args, **kwargs) -> dict[str, Any]:
225
+ """Legacy alias for EcohoSpace (formerly cohospace_upgrade)."""
226
+ return cohospace(*args, **kwargs)
1183
227
 
1184
- ax.scatter(xy[:, 0], xy[:, 1], s=s, alpha=alpha)
1185
228
 
1186
- # Always draw the base domain boundary
1187
- _draw_single_domain(ax)
229
+ def ecohospace(*args, **kwargs) -> dict[str, Any]:
230
+ """Alias for EcohoSpace (GridCellTorus-style)."""
231
+ return cohospace(*args, **kwargs)
1188
232
 
1189
- if show_grid:
1190
- draw_torus_parallelogram_grid(ax, n_tiles=n_tiles)
1191
233
 
1192
- # Fix view limits: tiles=0 shows base domain; tiles>0 shows the tiled extent
1193
- base = np.vstack([[0, 0], e1, e2, e1 + e2])
234
+ def plot_cohospace_upgrade(*args, **kwargs) -> plt.Figure:
235
+ """Legacy alias for EcohoSpace plotting (formerly plot_cohospace_upgrade)."""
236
+ return plot_cohospace(*args, **kwargs)
1194
237
 
1195
- if n_tiles and n_tiles > 0:
1196
- # Expand view by n_tiles rings around the base domain
1197
- # Translation vectors for tiling are i*e1 + j*e2
1198
- shifts = []
1199
- for i in range(-n_tiles, n_tiles + 1):
1200
- for j in range(-n_tiles, n_tiles + 1):
1201
- shifts.append(i * e1 + j * e2)
1202
- shifts = np.asarray(shifts) # ((2n+1)^2, 2)
1203
238
 
1204
- all_corners = (base[None, :, :] + shifts[:, None, :]).reshape(-1, 2)
1205
- xmin, ymin = all_corners.min(axis=0)
1206
- xmax, ymax = all_corners.max(axis=0)
1207
- else:
1208
- xmin, ymin = base.min(axis=0)
1209
- xmax, ymax = base.max(axis=0)
239
+ def plot_cohospace_upgrade_skewed(*args, **kwargs) -> plt.Figure:
240
+ """Legacy alias for EcohoSpace skewed plotting."""
241
+ return plot_cohospace_skewed(*args, **kwargs)
1210
242
 
1211
- padx = 0.03 * (xmax - xmin)
1212
- pady = 0.03 * (ymax - ymin)
1213
- ax.set_xlim(xmin - padx, xmax + padx)
1214
- ax.set_ylim(ymin - pady, ymax + pady)
1215
243
 
1216
- ax.set_aspect("equal")
1217
- ax.set_xlabel(config.xlabel)
1218
- ax.set_ylabel(config.ylabel)
1219
- ax.set_title(config.title)
244
+ def plot_ecohospace(*args, **kwargs) -> plt.Figure:
245
+ """Alias for EcohoSpace plotting."""
246
+ return plot_cohospace(*args, **kwargs)
1220
247
 
1221
- if created_fig:
1222
- _ensure_parent_dir(config.save_path)
1223
- finalize_figure(fig, config)
1224
- else:
1225
- if config.save_path is not None:
1226
- _ensure_parent_dir(config.save_path)
1227
- fig.savefig(config.save_path, **config.to_savefig_kwargs())
1228
- if config.show:
1229
- plt.show()
1230
248
 
1231
- return ax
249
+ def plot_ecohospace_skewed(*args, **kwargs) -> plt.Figure:
250
+ """Alias for EcohoSpace skewed plotting."""
251
+ return plot_cohospace_skewed(*args, **kwargs)