canns 0.13.1__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. canns/analyzer/data/__init__.py +5 -1
  2. canns/analyzer/data/asa/__init__.py +27 -12
  3. canns/analyzer/data/asa/cohospace.py +336 -10
  4. canns/analyzer/data/asa/config.py +3 -0
  5. canns/analyzer/data/asa/embedding.py +48 -45
  6. canns/analyzer/data/asa/path.py +104 -2
  7. canns/analyzer/data/asa/plotting.py +88 -19
  8. canns/analyzer/data/asa/tda.py +11 -4
  9. canns/analyzer/data/cell_classification/__init__.py +97 -0
  10. canns/analyzer/data/cell_classification/core/__init__.py +26 -0
  11. canns/analyzer/data/cell_classification/core/grid_cells.py +633 -0
  12. canns/analyzer/data/cell_classification/core/grid_modules_leiden.py +288 -0
  13. canns/analyzer/data/cell_classification/core/head_direction.py +347 -0
  14. canns/analyzer/data/cell_classification/core/spatial_analysis.py +431 -0
  15. canns/analyzer/data/cell_classification/io/__init__.py +5 -0
  16. canns/analyzer/data/cell_classification/io/matlab_loader.py +417 -0
  17. canns/analyzer/data/cell_classification/utils/__init__.py +39 -0
  18. canns/analyzer/data/cell_classification/utils/circular_stats.py +383 -0
  19. canns/analyzer/data/cell_classification/utils/correlation.py +318 -0
  20. canns/analyzer/data/cell_classification/utils/geometry.py +442 -0
  21. canns/analyzer/data/cell_classification/utils/image_processing.py +416 -0
  22. canns/analyzer/data/cell_classification/visualization/__init__.py +19 -0
  23. canns/analyzer/data/cell_classification/visualization/grid_plots.py +292 -0
  24. canns/analyzer/data/cell_classification/visualization/hd_plots.py +200 -0
  25. canns/analyzer/metrics/__init__.py +2 -1
  26. canns/analyzer/visualization/core/config.py +46 -4
  27. canns/data/__init__.py +6 -1
  28. canns/data/datasets.py +154 -1
  29. canns/data/loaders.py +37 -0
  30. canns/pipeline/__init__.py +13 -9
  31. canns/pipeline/__main__.py +6 -0
  32. canns/pipeline/asa/runner.py +105 -41
  33. canns/pipeline/asa_gui/__init__.py +68 -0
  34. canns/pipeline/asa_gui/__main__.py +6 -0
  35. canns/pipeline/asa_gui/analysis_modes/__init__.py +42 -0
  36. canns/pipeline/asa_gui/analysis_modes/base.py +39 -0
  37. canns/pipeline/asa_gui/analysis_modes/batch_mode.py +21 -0
  38. canns/pipeline/asa_gui/analysis_modes/cohomap_mode.py +56 -0
  39. canns/pipeline/asa_gui/analysis_modes/cohospace_mode.py +194 -0
  40. canns/pipeline/asa_gui/analysis_modes/decode_mode.py +52 -0
  41. canns/pipeline/asa_gui/analysis_modes/fr_mode.py +81 -0
  42. canns/pipeline/asa_gui/analysis_modes/frm_mode.py +92 -0
  43. canns/pipeline/asa_gui/analysis_modes/gridscore_mode.py +123 -0
  44. canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +199 -0
  45. canns/pipeline/asa_gui/analysis_modes/tda_mode.py +112 -0
  46. canns/pipeline/asa_gui/app.py +29 -0
  47. canns/pipeline/asa_gui/controllers/__init__.py +6 -0
  48. canns/pipeline/asa_gui/controllers/analysis_controller.py +59 -0
  49. canns/pipeline/asa_gui/controllers/preprocess_controller.py +89 -0
  50. canns/pipeline/asa_gui/core/__init__.py +15 -0
  51. canns/pipeline/asa_gui/core/cache.py +14 -0
  52. canns/pipeline/asa_gui/core/runner.py +1936 -0
  53. canns/pipeline/asa_gui/core/state.py +324 -0
  54. canns/pipeline/asa_gui/core/worker.py +260 -0
  55. canns/pipeline/asa_gui/main_window.py +184 -0
  56. canns/pipeline/asa_gui/models/__init__.py +7 -0
  57. canns/pipeline/asa_gui/models/config.py +14 -0
  58. canns/pipeline/asa_gui/models/job.py +31 -0
  59. canns/pipeline/asa_gui/models/presets.py +21 -0
  60. canns/pipeline/asa_gui/resources/__init__.py +16 -0
  61. canns/pipeline/asa_gui/resources/dark.qss +167 -0
  62. canns/pipeline/asa_gui/resources/light.qss +163 -0
  63. canns/pipeline/asa_gui/resources/styles.qss +130 -0
  64. canns/pipeline/asa_gui/utils/__init__.py +1 -0
  65. canns/pipeline/asa_gui/utils/formatters.py +15 -0
  66. canns/pipeline/asa_gui/utils/io_adapters.py +40 -0
  67. canns/pipeline/asa_gui/utils/validators.py +41 -0
  68. canns/pipeline/asa_gui/views/__init__.py +1 -0
  69. canns/pipeline/asa_gui/views/help_content.py +171 -0
  70. canns/pipeline/asa_gui/views/pages/__init__.py +6 -0
  71. canns/pipeline/asa_gui/views/pages/analysis_page.py +565 -0
  72. canns/pipeline/asa_gui/views/pages/preprocess_page.py +492 -0
  73. canns/pipeline/asa_gui/views/panels/__init__.py +1 -0
  74. canns/pipeline/asa_gui/views/widgets/__init__.py +21 -0
  75. canns/pipeline/asa_gui/views/widgets/artifacts_tab.py +44 -0
  76. canns/pipeline/asa_gui/views/widgets/drop_zone.py +80 -0
  77. canns/pipeline/asa_gui/views/widgets/file_list.py +27 -0
  78. canns/pipeline/asa_gui/views/widgets/gridscore_tab.py +308 -0
  79. canns/pipeline/asa_gui/views/widgets/help_dialog.py +27 -0
  80. canns/pipeline/asa_gui/views/widgets/image_tab.py +50 -0
  81. canns/pipeline/asa_gui/views/widgets/image_viewer.py +97 -0
  82. canns/pipeline/asa_gui/views/widgets/log_box.py +16 -0
  83. canns/pipeline/asa_gui/views/widgets/pathcompare_tab.py +200 -0
  84. canns/pipeline/asa_gui/views/widgets/popup_combo.py +25 -0
  85. canns/pipeline/gallery/__init__.py +15 -5
  86. canns/pipeline/gallery/__main__.py +11 -0
  87. canns/pipeline/gallery/app.py +705 -0
  88. canns/pipeline/gallery/runner.py +790 -0
  89. canns/pipeline/gallery/state.py +51 -0
  90. canns/pipeline/gallery/styles.tcss +123 -0
  91. canns/pipeline/launcher.py +81 -0
  92. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/METADATA +11 -1
  93. canns-0.14.0.dist-info/RECORD +163 -0
  94. canns-0.14.0.dist-info/entry_points.txt +5 -0
  95. canns/pipeline/_base.py +0 -50
  96. canns-0.13.1.dist-info/RECORD +0 -89
  97. canns-0.13.1.dist-info/entry_points.txt +0 -3
  98. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/WHEEL +0 -0
  99. {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,9 @@
1
1
  """Data analysis utilities for experimental and synthetic neural data."""
2
2
 
3
+ from . import asa, cell_classification
3
4
  from .asa import * # noqa: F401,F403
5
+ from .cell_classification import * # noqa: F401,F403
4
6
 
5
- __all__ = list(locals().get("__all__", []))
7
+ __all__ = ["asa", "cell_classification"]
8
+ __all__ += list(getattr(asa, "__all__", []))
9
+ __all__ += list(getattr(cell_classification, "__all__", []))
@@ -2,10 +2,14 @@ from __future__ import annotations
2
2
 
3
3
  # Coho-space analysis + visualization
4
4
  from .cohospace import (
5
- compute_cohoscore,
6
- plot_cohospace_neuron,
7
- plot_cohospace_population,
8
- plot_cohospace_trajectory,
5
+ compute_cohoscore_1d,
6
+ compute_cohoscore_2d,
7
+ plot_cohospace_neuron_1d,
8
+ plot_cohospace_neuron_2d,
9
+ plot_cohospace_population_1d,
10
+ plot_cohospace_population_2d,
11
+ plot_cohospace_trajectory_1d,
12
+ plot_cohospace_trajectory_2d,
9
13
  )
10
14
  from .config import (
11
15
  CANN2DError,
@@ -33,7 +37,11 @@ from .fr import (
33
37
  )
34
38
 
35
39
  # Path utilities
36
- from .path import align_coords_to_position, apply_angle_scale
40
+ from .path import (
41
+ align_coords_to_position_1d,
42
+ align_coords_to_position_2d,
43
+ apply_angle_scale,
44
+ )
37
45
 
38
46
  # Higher-level plotting helpers
39
47
  from .plotting import (
@@ -41,7 +49,8 @@ from .plotting import (
41
49
  plot_3d_bump_on_torus,
42
50
  plot_cohomap,
43
51
  plot_cohomap_multi,
44
- plot_path_compare,
52
+ plot_path_compare_1d,
53
+ plot_path_compare_2d,
45
54
  plot_projection,
46
55
  )
47
56
 
@@ -61,7 +70,8 @@ __all__ = [
61
70
  "decode_circular_coordinates",
62
71
  "decode_circular_coordinates_multi",
63
72
  "plot_projection",
64
- "plot_path_compare",
73
+ "plot_path_compare_1d",
74
+ "plot_path_compare_2d",
65
75
  "plot_cohomap",
66
76
  "plot_cohomap_multi",
67
77
  "plot_3d_bump_on_torus",
@@ -75,10 +85,15 @@ __all__ = [
75
85
  "FRMResult",
76
86
  "compute_frm",
77
87
  "plot_frm",
78
- "plot_cohospace_trajectory",
79
- "plot_cohospace_neuron",
80
- "plot_cohospace_population",
81
- "compute_cohoscore",
82
- "align_coords_to_position",
88
+ "plot_cohospace_trajectory_1d",
89
+ "plot_cohospace_trajectory_2d",
90
+ "plot_cohospace_neuron_1d",
91
+ "plot_cohospace_neuron_2d",
92
+ "plot_cohospace_population_1d",
93
+ "plot_cohospace_population_2d",
94
+ "compute_cohoscore_1d",
95
+ "compute_cohoscore_2d",
96
+ "align_coords_to_position_1d",
97
+ "align_coords_to_position_2d",
83
98
  "apply_angle_scale",
84
99
  ]
@@ -91,7 +91,7 @@ def _align_activity_to_coords(
91
91
  return activity
92
92
 
93
93
 
94
- def plot_cohospace_trajectory(
94
+ def plot_cohospace_trajectory_2d(
95
95
  coords: np.ndarray,
96
96
  times: np.ndarray | None = None,
97
97
  subsample: int = 1,
@@ -129,7 +129,7 @@ def plot_cohospace_trajectory(
129
129
 
130
130
  Examples
131
131
  --------
132
- >>> fig = plot_cohospace_trajectory(coords, subsample=2, show=False) # doctest: +SKIP
132
+ >>> fig = plot_cohospace_trajectory_2d(coords, subsample=2, show=False) # doctest: +SKIP
133
133
  """
134
134
 
135
135
  try:
@@ -194,7 +194,106 @@ def plot_cohospace_trajectory(
194
194
  return ax
195
195
 
196
196
 
197
- def plot_cohospace_neuron(
197
+ def plot_cohospace_trajectory_1d(
198
+ coords: np.ndarray,
199
+ times: np.ndarray | None = None,
200
+ subsample: int = 1,
201
+ figsize: tuple[int, int] = (6, 6),
202
+ cmap: str = "viridis",
203
+ save_path: str | None = None,
204
+ show: bool = False,
205
+ config: PlotConfig | None = None,
206
+ ) -> plt.Axes:
207
+ """
208
+ Plot a 1D cohomology trajectory on the unit circle.
209
+
210
+ Parameters
211
+ ----------
212
+ coords : ndarray, shape (T,) or (T, 1)
213
+ Decoded cohomology angles (theta). Values may be in radians or in [0, 1] "unit circle"
214
+ convention depending on upstream decoding; this function will plot on the unit circle.
215
+ times : ndarray, optional, shape (T,)
216
+ Optional time array used to color points. If None, uses arange(T).
217
+ subsample : int
218
+ Downsampling step (>1 reduces the number of plotted points).
219
+ figsize : tuple
220
+ Matplotlib figure size.
221
+ cmap : str
222
+ Matplotlib colormap name.
223
+ save_path : str, optional
224
+ If provided, saves the figure to this path.
225
+ show : bool
226
+ If True, calls plt.show(). If False, closes the figure and returns the Axes.
227
+ """
228
+ try:
229
+ subsample_i = int(subsample)
230
+ except Exception:
231
+ subsample_i = 1
232
+ if subsample_i < 1:
233
+ subsample_i = 1
234
+
235
+ coords = np.asarray(coords)
236
+ if coords.ndim == 2 and coords.shape[1] == 1:
237
+ coords = coords[:, 0]
238
+ if coords.ndim != 1:
239
+ raise ValueError(f"`coords` must have shape (T,) or (T, 1). Got {coords.shape}.")
240
+
241
+ if times is None:
242
+ times_vis = np.arange(coords.shape[0])
243
+ else:
244
+ times_vis = np.asarray(times)
245
+ if times_vis.shape[0] != coords.shape[0]:
246
+ raise ValueError(
247
+ f"`times` length must match coords length. Got times={times_vis.shape[0]}, coords={coords.shape[0]}."
248
+ )
249
+
250
+ if subsample_i > 1:
251
+ coords = coords[::subsample_i]
252
+ times_vis = times_vis[::subsample_i]
253
+
254
+ theta = coords % (2 * np.pi)
255
+ x = np.cos(theta)
256
+ y = np.sin(theta)
257
+
258
+ config = _ensure_plot_config(
259
+ config,
260
+ PlotConfig.for_static_plot,
261
+ title="CohoSpace trajectory (1D)",
262
+ xlabel="cos(theta)",
263
+ ylabel="sin(theta)",
264
+ figsize=figsize,
265
+ save_path=save_path,
266
+ show=show,
267
+ )
268
+
269
+ fig, ax = plt.subplots(figsize=config.figsize)
270
+ circle = np.linspace(0, 2 * np.pi, 200)
271
+ ax.plot(np.cos(circle), np.sin(circle), color="0.85", lw=1.0, zorder=0)
272
+ sc = ax.scatter(
273
+ x,
274
+ y,
275
+ c=times_vis,
276
+ cmap=cmap,
277
+ s=5,
278
+ alpha=0.8,
279
+ )
280
+ cbar = plt.colorbar(sc, ax=ax)
281
+ cbar.set_label("Time")
282
+
283
+ ax.set_xlim(-1.2, 1.2)
284
+ ax.set_ylim(-1.2, 1.2)
285
+ ax.set_xlabel(config.xlabel)
286
+ ax.set_ylabel(config.ylabel)
287
+ ax.set_title(config.title)
288
+ ax.set_aspect("equal", adjustable="box")
289
+ ax.grid(True, alpha=0.2)
290
+
291
+ _ensure_parent_dir(config.save_path)
292
+ finalize_figure(fig, config)
293
+ return ax
294
+
295
+
296
+ def plot_cohospace_neuron_2d(
198
297
  coords: np.ndarray,
199
298
  activity: np.ndarray,
200
299
  neuron_id: int,
@@ -230,7 +329,7 @@ def plot_cohospace_neuron(
230
329
  mode : {"fr", "spike"}
231
330
  top_percent : float
232
331
  Used only when mode="fr". For example, 5.0 means "top 5%%" time points.
233
- figsize, cmap, save_path, show : see `plot_cohospace_trajectory`.
332
+ figsize, cmap, save_path, show : see `plot_cohospace_trajectory_2d`.
234
333
 
235
334
  Returns
236
335
  -------
@@ -238,7 +337,7 @@ def plot_cohospace_neuron(
238
337
 
239
338
  Examples
240
339
  --------
241
- >>> plot_cohospace_neuron(coords, spikes, neuron_id=0, show=False) # doctest: +SKIP
340
+ >>> plot_cohospace_neuron_2d(coords, spikes, neuron_id=0, show=False) # doctest: +SKIP
242
341
  """
243
342
  coords = np.asarray(coords)
244
343
  activity = _align_activity_to_coords(
@@ -300,7 +399,94 @@ def plot_cohospace_neuron(
300
399
  return fig
301
400
 
302
401
 
303
- def plot_cohospace_population(
402
+ def plot_cohospace_neuron_1d(
403
+ coords: np.ndarray,
404
+ activity: np.ndarray,
405
+ neuron_id: int,
406
+ mode: str = "fr",
407
+ top_percent: float = 5.0,
408
+ times: np.ndarray | None = None,
409
+ auto_filter: bool = True,
410
+ figsize: tuple = (6, 6),
411
+ cmap: str = "hot",
412
+ save_path: str | None = None,
413
+ show: bool = True,
414
+ config: PlotConfig | None = None,
415
+ ) -> plt.Figure:
416
+ """
417
+ Overlay a single neuron's activity on the 1D cohomology trajectory (unit circle).
418
+ """
419
+ coords = np.asarray(coords)
420
+ if coords.ndim == 2 and coords.shape[1] == 1:
421
+ coords = coords[:, 0]
422
+ if coords.ndim != 1:
423
+ raise ValueError(f"coords must have shape (T,) or (T, 1), got {coords.shape}")
424
+
425
+ activity = _align_activity_to_coords(
426
+ coords[:, None], activity, times, label="activity", auto_filter=auto_filter
427
+ )
428
+
429
+ signal = activity[:, neuron_id]
430
+
431
+ if mode == "fr":
432
+ threshold = np.percentile(signal, 100 - top_percent)
433
+ idx = signal >= threshold
434
+ color = signal[idx]
435
+ title = f"Neuron {neuron_id} FR top {top_percent:.1f}% on coho-space (1D)"
436
+ use_cmap = cmap
437
+ elif mode == "spike":
438
+ idx = signal > 0
439
+ color = None
440
+ title = f"Neuron {neuron_id} spikes on coho-space (1D)"
441
+ use_cmap = None
442
+ else:
443
+ raise ValueError("mode must be 'fr' or 'spike'")
444
+
445
+ theta = coords % (2 * np.pi)
446
+ x = np.cos(theta)
447
+ y = np.sin(theta)
448
+
449
+ config = _ensure_plot_config(
450
+ config,
451
+ PlotConfig.for_static_plot,
452
+ title=title,
453
+ xlabel="cos(theta)",
454
+ ylabel="sin(theta)",
455
+ figsize=figsize,
456
+ save_path=save_path,
457
+ show=show,
458
+ )
459
+
460
+ fig, ax = plt.subplots(figsize=config.figsize)
461
+ circle = np.linspace(0, 2 * np.pi, 200)
462
+ ax.plot(np.cos(circle), np.sin(circle), color="0.85", lw=1.0, zorder=0)
463
+ sc = ax.scatter(
464
+ x[idx],
465
+ y[idx],
466
+ c=color if mode == "fr" else "red",
467
+ cmap=use_cmap,
468
+ s=8,
469
+ alpha=0.9,
470
+ )
471
+
472
+ if mode == "fr":
473
+ cbar = plt.colorbar(sc, ax=ax)
474
+ cbar.set_label("Firing rate")
475
+
476
+ ax.set_xlim(-1.2, 1.2)
477
+ ax.set_ylim(-1.2, 1.2)
478
+ ax.set_xlabel(config.xlabel)
479
+ ax.set_ylabel(config.ylabel)
480
+ ax.set_title(config.title)
481
+ ax.set_aspect("equal", adjustable="box")
482
+
483
+ _ensure_parent_dir(config.save_path)
484
+ finalize_figure(fig, config)
485
+
486
+ return fig
487
+
488
+
489
+ def plot_cohospace_population_2d(
304
490
  coords: np.ndarray,
305
491
  activity: np.ndarray,
306
492
  neuron_ids: list[int] | np.ndarray,
@@ -338,7 +524,7 @@ def plot_cohospace_population(
338
524
  mode : {"fr", "spike"}
339
525
  top_percent : float
340
526
  Used only when mode="fr".
341
- figsize, cmap, save_path, show : see `plot_cohospace_trajectory`.
527
+ figsize, cmap, save_path, show : see `plot_cohospace_trajectory_2d`.
342
528
 
343
529
  Returns
344
530
  -------
@@ -346,7 +532,7 @@ def plot_cohospace_population(
346
532
 
347
533
  Examples
348
534
  --------
349
- >>> plot_cohospace_population(coords, spikes, neuron_ids=[0, 1, 2], show=False) # doctest: +SKIP
535
+ >>> plot_cohospace_population_2d(coords, spikes, neuron_ids=[0, 1, 2], show=False) # doctest: +SKIP
350
536
  """
351
537
  coords = np.asarray(coords)
352
538
  activity = _align_activity_to_coords(
@@ -411,7 +597,97 @@ def plot_cohospace_population(
411
597
  return fig
412
598
 
413
599
 
414
- def compute_cohoscore(
600
+ def plot_cohospace_population_1d(
601
+ coords: np.ndarray,
602
+ activity: np.ndarray,
603
+ neuron_ids: list[int] | np.ndarray,
604
+ mode: str = "fr",
605
+ top_percent: float = 5.0,
606
+ times: np.ndarray | None = None,
607
+ auto_filter: bool = True,
608
+ figsize: tuple = (6, 6),
609
+ cmap: str = "hot",
610
+ save_path: str | None = None,
611
+ show: bool = True,
612
+ config: PlotConfig | None = None,
613
+ ) -> plt.Figure:
614
+ """
615
+ Plot aggregated activity from multiple neurons on the 1D cohomology trajectory.
616
+ """
617
+ coords = np.asarray(coords)
618
+ if coords.ndim == 2 and coords.shape[1] == 1:
619
+ coords = coords[:, 0]
620
+ if coords.ndim != 1:
621
+ raise ValueError(f"coords must have shape (T,) or (T, 1), got {coords.shape}")
622
+
623
+ activity = _align_activity_to_coords(
624
+ coords[:, None], activity, times, label="activity", auto_filter=auto_filter
625
+ )
626
+ neuron_ids = np.asarray(neuron_ids, dtype=int)
627
+
628
+ T = activity.shape[0]
629
+ mask = np.zeros(T, dtype=bool)
630
+ agg_color = np.zeros(T, dtype=float)
631
+
632
+ for n in neuron_ids:
633
+ signal = activity[:, n]
634
+
635
+ if mode == "fr":
636
+ threshold = np.percentile(signal, 100 - top_percent)
637
+ idx = signal >= threshold
638
+ agg_color[idx] += signal[idx]
639
+ mask |= idx
640
+ elif mode == "spike":
641
+ idx = signal > 0
642
+ agg_color[idx] += 1.0
643
+ mask |= idx
644
+ else:
645
+ raise ValueError("mode must be 'fr' or 'spike'")
646
+
647
+ theta = coords % (2 * np.pi)
648
+ x = np.cos(theta)
649
+ y = np.sin(theta)
650
+
651
+ config = _ensure_plot_config(
652
+ config,
653
+ PlotConfig.for_static_plot,
654
+ title=f"{len(neuron_ids)} neurons on coho-space (1D)",
655
+ xlabel="cos(theta)",
656
+ ylabel="sin(theta)",
657
+ figsize=figsize,
658
+ save_path=save_path,
659
+ show=show,
660
+ )
661
+
662
+ fig, ax = plt.subplots(figsize=config.figsize)
663
+ circle = np.linspace(0, 2 * np.pi, 200)
664
+ ax.plot(np.cos(circle), np.sin(circle), color="0.85", lw=1.0, zorder=0)
665
+ sc = ax.scatter(
666
+ x[mask],
667
+ y[mask],
668
+ c=agg_color[mask],
669
+ cmap=cmap,
670
+ s=6,
671
+ alpha=0.9,
672
+ )
673
+ cbar = plt.colorbar(sc, ax=ax)
674
+ label = "Aggregate FR" if mode == "fr" else "Spike count"
675
+ cbar.set_label(label)
676
+
677
+ ax.set_xlim(-1.2, 1.2)
678
+ ax.set_ylim(-1.2, 1.2)
679
+ ax.set_xlabel(config.xlabel)
680
+ ax.set_ylabel(config.ylabel)
681
+ ax.set_title(config.title)
682
+ ax.set_aspect("equal", adjustable="box")
683
+
684
+ _ensure_parent_dir(config.save_path)
685
+ finalize_figure(fig, config)
686
+
687
+ return fig
688
+
689
+
690
+ def compute_cohoscore_2d(
415
691
  coords: np.ndarray,
416
692
  activity: np.ndarray,
417
693
  top_percent: float = 2.0,
@@ -451,7 +727,7 @@ def compute_cohoscore(
451
727
 
452
728
  Examples
453
729
  --------
454
- >>> scores = compute_cohoscore(coords, spikes) # doctest: +SKIP
730
+ >>> scores = compute_cohoscore_2d(coords, spikes) # doctest: +SKIP
455
731
  >>> scores.shape[0] # doctest: +SKIP
456
732
  """
457
733
  coords = np.asarray(coords)
@@ -487,6 +763,56 @@ def compute_cohoscore(
487
763
  return scores
488
764
 
489
765
 
766
+ def compute_cohoscore_1d(
767
+ coords: np.ndarray,
768
+ activity: np.ndarray,
769
+ top_percent: float = 2.0,
770
+ times: np.ndarray | None = None,
771
+ auto_filter: bool = True,
772
+ ) -> np.ndarray:
773
+ """
774
+ Compute 1D cohomology-space selectivity score (CohoScore) for each neuron.
775
+
776
+ For each neuron:
777
+ - Select "active" time points:
778
+ - If top_percent is None: all time points with activity > 0
779
+ - Else: top `top_percent`%% time points by activity value
780
+ - Compute circular variance for theta on the selected points.
781
+ - CohoScore = var(theta)
782
+ """
783
+ coords = np.asarray(coords)
784
+ if coords.ndim == 2 and coords.shape[1] == 1:
785
+ coords = coords[:, 0]
786
+ if coords.ndim != 1:
787
+ raise ValueError(f"coords must have shape (T,) or (T, 1), got {coords.shape}")
788
+
789
+ activity = _align_activity_to_coords(
790
+ coords[:, None], activity, times, label="activity", auto_filter=auto_filter
791
+ )
792
+ _, n_neurons = activity.shape
793
+
794
+ theta = coords % (2 * np.pi)
795
+ scores = np.zeros(n_neurons, dtype=float)
796
+
797
+ for n in range(n_neurons):
798
+ signal = activity[:, n]
799
+
800
+ if top_percent is None:
801
+ idx = signal > 0
802
+ else:
803
+ threshold = np.percentile(signal, 100 - top_percent)
804
+ idx = signal >= threshold
805
+
806
+ if np.sum(idx) < 5:
807
+ scores[n] = np.nan
808
+ continue
809
+
810
+ var1 = circvar(theta[idx], high=2 * np.pi, low=0)
811
+ scores[n] = var1
812
+
813
+ return scores
814
+
815
+
490
816
  def skew_transform_torus(coords):
491
817
  """
492
818
  Convert torus angles (theta1, theta2) into coordinates in a skewed parallelogram fundamental domain.
@@ -72,6 +72,8 @@ class TDAConfig:
72
72
  Number of shuffles for null distribution.
73
73
  progress_bar : bool
74
74
  Whether to show progress bars.
75
+ standardize : bool
76
+ Whether to standardize data before PCA (z-score).
75
77
 
76
78
  Examples
77
79
  --------
@@ -94,6 +96,7 @@ class TDAConfig:
94
96
  do_shuffle: bool = False
95
97
  num_shuffles: int = 1000
96
98
  progress_bar: bool = True
99
+ standardize: bool = True
97
100
 
98
101
 
99
102
  @dataclass
@@ -59,11 +59,11 @@ def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None,
59
59
  # Step 1: Extract and filter spike data
60
60
  spikes_filtered = _extract_spike_data(spike_trains, config)
61
61
 
62
- # Step 2: Create time bins
63
- time_bins = _create_time_bins(spike_trains["t"], config)
62
+ # Step 2: Create time bins metadata
63
+ min_time, max_time, n_bins = _create_time_bins(spike_trains["t"], config)
64
64
 
65
65
  # Step 3: Bin spike data
66
- spikes_bin = _bin_spike_data(spikes_filtered, time_bins, config)
66
+ spikes_bin = _bin_spike_data(spikes_filtered, min_time, max_time, n_bins, config)
67
67
 
68
68
  # Step 4: Apply temporal smoothing if requested
69
69
  if config.smooth:
@@ -73,7 +73,7 @@ def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None,
73
73
  if config.speed_filter:
74
74
  return _apply_speed_filtering(spikes_bin, spike_trains, config)
75
75
 
76
- return spikes_bin, None, None, None
76
+ return spikes_bin, spike_trains["x"], spike_trains["y"], spike_trains["t"]
77
77
 
78
78
  except Exception as e:
79
79
  raise ProcessingError(f"Failed to embed spike trains: {e}") from e
@@ -132,36 +132,44 @@ def _extract_spike_data(
132
132
  raise ProcessingError(f"Error extracting spike data: {e}") from e
133
133
 
134
134
 
135
- def _create_time_bins(t: np.ndarray, config: SpikeEmbeddingConfig) -> np.ndarray:
136
- """Create time bins for spike discretization."""
135
+ def _create_time_bins(t: np.ndarray, config: SpikeEmbeddingConfig) -> tuple[int, int, int]:
136
+ """Create time-bin metadata for spike discretization."""
137
137
  min_time0 = np.min(t)
138
138
  max_time0 = np.max(t)
139
139
 
140
- min_time = min_time0 * config.res
141
- max_time = max_time0 * config.res
140
+ min_time = int(np.floor(min_time0 * config.res))
141
+ max_time = int(np.ceil(max_time0 * config.res)) + 1
142
+ n_bins = max(1, int(np.ceil((max_time - min_time) / config.dt)))
143
+ last_time = min_time + config.dt * (n_bins - 1)
142
144
 
143
- return np.arange(np.floor(min_time), np.ceil(max_time) + 1, config.dt)
145
+ return min_time, last_time, n_bins
144
146
 
145
147
 
146
148
  def _bin_spike_data(
147
- spikes: dict[int, np.ndarray], time_bins: np.ndarray, config: SpikeEmbeddingConfig
149
+ spikes: dict[int, np.ndarray],
150
+ min_time: int,
151
+ max_time: int,
152
+ n_bins: int,
153
+ config: SpikeEmbeddingConfig,
148
154
  ) -> np.ndarray:
149
155
  """Convert spike times to binned spike matrix."""
150
- min_time = time_bins[0]
151
- max_time = time_bins[-1]
152
-
153
- spikes_bin = np.zeros((len(time_bins), len(spikes)), dtype=int)
156
+ spikes_bin = np.zeros((n_bins, len(spikes)), dtype=np.int32)
157
+ max_time_offset = max_time - min_time
154
158
 
155
159
  for n in spikes:
156
- spike_times = np.array(spikes[n] * config.res - min_time, dtype=int)
160
+ spike_times = np.asarray(spikes[n])
161
+ if spike_times.size == 0:
162
+ continue
163
+ spike_times = (spike_times * config.res - min_time).astype(np.int64, copy=False)
157
164
  # Filter valid spike times
158
- spike_times = spike_times[(spike_times < (max_time - min_time)) & (spike_times > 0)]
159
- spike_times = np.array(spike_times / config.dt, int)
165
+ valid = (spike_times < max_time_offset) & (spike_times > 0)
166
+ if not np.any(valid):
167
+ continue
168
+ spike_times = spike_times[valid]
169
+ spike_bins = np.floor_divide(spike_times, config.dt).astype(np.int64, copy=False)
160
170
 
161
- # Bin spikes
162
- for j in spike_times:
163
- if j < len(time_bins):
164
- spikes_bin[j, n] += 1
171
+ # Bin spikes (vectorized)
172
+ np.add.at(spikes_bin[:, n], spike_bins, 1)
165
173
 
166
174
  return spikes_bin
167
175
 
@@ -171,21 +179,22 @@ def _apply_temporal_smoothing(spikes_bin: np.ndarray, config: SpikeEmbeddingConf
171
179
  # Calculate smoothing parameters (legacy implementation used custom kernel)
172
180
  # Current implementation uses scipy's gaussian_filter1d for better performance
173
181
 
174
- # Apply smoothing (simplified version - could be further optimized)
175
- smoothed = np.zeros((spikes_bin.shape[0], spikes_bin.shape[1]))
182
+ # Convert to float once to avoid holding both int and float arrays.
183
+ spikes_bin = spikes_bin.astype(np.float32, copy=False)
176
184
 
177
185
  # Use scipy's gaussian_filter1d for better performance
178
186
 
179
187
  sigma_bins = config.sigma / config.dt
180
188
 
181
189
  for n in range(spikes_bin.shape[1]):
182
- smoothed[:, n] = gaussian_filter1d(
183
- spikes_bin[:, n].astype(float), sigma=sigma_bins, mode="constant"
190
+ gaussian_filter1d(
191
+ spikes_bin[:, n], sigma=sigma_bins, mode="constant", output=spikes_bin[:, n]
184
192
  )
185
193
 
186
194
  # Normalize
187
195
  normalization_factor = 1 / np.sqrt(2 * np.pi * (config.sigma / config.res) ** 2)
188
- return smoothed * normalization_factor
196
+ spikes_bin *= normalization_factor
197
+ return spikes_bin
189
198
 
190
199
 
191
200
  def _apply_speed_filtering(
@@ -240,25 +249,19 @@ def _load_pos(t, x, y, res=100000, dt=1000):
240
249
 
241
250
  tt = np.arange(np.floor(min_time), np.ceil(max_time) + 1, dt) / res
242
251
 
243
- idt = np.concatenate(([0], np.digitize(t[1:-1], tt[:]) - 1, [len(tt) + 1]))
244
- idtt = np.digitize(np.arange(len(tt)), idt) - 1
245
-
246
- idx = np.concatenate((np.unique(idtt), [np.max(idtt) + 1]))
247
- divisor = np.bincount(idtt)
248
- steps = 1.0 / divisor[divisor > 0]
249
- N = np.max(divisor)
250
- ranges = np.multiply(np.arange(N)[np.newaxis, :], steps[:, np.newaxis])
251
- ranges[ranges >= 1] = np.nan
252
-
253
- rangesx = x[idx[:-1], np.newaxis] + np.multiply(
254
- ranges, (x[idx[1:]] - x[idx[:-1]])[:, np.newaxis]
255
- )
256
- xx = rangesx[~np.isnan(ranges)]
257
-
258
- rangesy = y[idx[:-1], np.newaxis] + np.multiply(
259
- ranges, (y[idx[1:]] - y[idx[:-1]])[:, np.newaxis]
260
- )
261
- yy = rangesy[~np.isnan(ranges)]
252
+ if t.size == 0:
253
+ return np.array([]), np.array([]), tt, np.array([])
254
+
255
+ # Ensure monotonically increasing time for interpolation.
256
+ if t.size > 1 and np.any(np.diff(t) < 0):
257
+ order = np.argsort(t)
258
+ t = t[order]
259
+ x = x[order]
260
+ y = y[order]
261
+
262
+ # Interpolate positions onto the spike time bins.
263
+ xx = np.interp(tt, t, x)
264
+ yy = np.interp(tt, t, y)
262
265
 
263
266
  xxs = _gaussian_filter1d(xx - np.min(xx), sigma=100)
264
267
  yys = _gaussian_filter1d(yy - np.min(yy), sigma=100)