canns 0.13.1__py3-none-any.whl → 0.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/__init__.py +5 -1
- canns/analyzer/data/asa/__init__.py +27 -12
- canns/analyzer/data/asa/cohospace.py +336 -10
- canns/analyzer/data/asa/config.py +3 -0
- canns/analyzer/data/asa/embedding.py +48 -45
- canns/analyzer/data/asa/path.py +104 -2
- canns/analyzer/data/asa/plotting.py +88 -19
- canns/analyzer/data/asa/tda.py +11 -4
- canns/analyzer/data/cell_classification/__init__.py +97 -0
- canns/analyzer/data/cell_classification/core/__init__.py +26 -0
- canns/analyzer/data/cell_classification/core/grid_cells.py +633 -0
- canns/analyzer/data/cell_classification/core/grid_modules_leiden.py +288 -0
- canns/analyzer/data/cell_classification/core/head_direction.py +347 -0
- canns/analyzer/data/cell_classification/core/spatial_analysis.py +431 -0
- canns/analyzer/data/cell_classification/io/__init__.py +5 -0
- canns/analyzer/data/cell_classification/io/matlab_loader.py +417 -0
- canns/analyzer/data/cell_classification/utils/__init__.py +39 -0
- canns/analyzer/data/cell_classification/utils/circular_stats.py +383 -0
- canns/analyzer/data/cell_classification/utils/correlation.py +318 -0
- canns/analyzer/data/cell_classification/utils/geometry.py +442 -0
- canns/analyzer/data/cell_classification/utils/image_processing.py +416 -0
- canns/analyzer/data/cell_classification/visualization/__init__.py +19 -0
- canns/analyzer/data/cell_classification/visualization/grid_plots.py +292 -0
- canns/analyzer/data/cell_classification/visualization/hd_plots.py +200 -0
- canns/analyzer/metrics/__init__.py +2 -1
- canns/analyzer/visualization/core/config.py +46 -4
- canns/data/__init__.py +6 -1
- canns/data/datasets.py +154 -1
- canns/data/loaders.py +37 -0
- canns/pipeline/__init__.py +13 -9
- canns/pipeline/__main__.py +6 -0
- canns/pipeline/asa/runner.py +105 -41
- canns/pipeline/asa_gui/__init__.py +68 -0
- canns/pipeline/asa_gui/__main__.py +6 -0
- canns/pipeline/asa_gui/analysis_modes/__init__.py +42 -0
- canns/pipeline/asa_gui/analysis_modes/base.py +39 -0
- canns/pipeline/asa_gui/analysis_modes/batch_mode.py +21 -0
- canns/pipeline/asa_gui/analysis_modes/cohomap_mode.py +56 -0
- canns/pipeline/asa_gui/analysis_modes/cohospace_mode.py +194 -0
- canns/pipeline/asa_gui/analysis_modes/decode_mode.py +52 -0
- canns/pipeline/asa_gui/analysis_modes/fr_mode.py +81 -0
- canns/pipeline/asa_gui/analysis_modes/frm_mode.py +92 -0
- canns/pipeline/asa_gui/analysis_modes/gridscore_mode.py +123 -0
- canns/pipeline/asa_gui/analysis_modes/pathcompare_mode.py +199 -0
- canns/pipeline/asa_gui/analysis_modes/tda_mode.py +112 -0
- canns/pipeline/asa_gui/app.py +29 -0
- canns/pipeline/asa_gui/controllers/__init__.py +6 -0
- canns/pipeline/asa_gui/controllers/analysis_controller.py +59 -0
- canns/pipeline/asa_gui/controllers/preprocess_controller.py +89 -0
- canns/pipeline/asa_gui/core/__init__.py +15 -0
- canns/pipeline/asa_gui/core/cache.py +14 -0
- canns/pipeline/asa_gui/core/runner.py +1936 -0
- canns/pipeline/asa_gui/core/state.py +324 -0
- canns/pipeline/asa_gui/core/worker.py +260 -0
- canns/pipeline/asa_gui/main_window.py +184 -0
- canns/pipeline/asa_gui/models/__init__.py +7 -0
- canns/pipeline/asa_gui/models/config.py +14 -0
- canns/pipeline/asa_gui/models/job.py +31 -0
- canns/pipeline/asa_gui/models/presets.py +21 -0
- canns/pipeline/asa_gui/resources/__init__.py +16 -0
- canns/pipeline/asa_gui/resources/dark.qss +167 -0
- canns/pipeline/asa_gui/resources/light.qss +163 -0
- canns/pipeline/asa_gui/resources/styles.qss +130 -0
- canns/pipeline/asa_gui/utils/__init__.py +1 -0
- canns/pipeline/asa_gui/utils/formatters.py +15 -0
- canns/pipeline/asa_gui/utils/io_adapters.py +40 -0
- canns/pipeline/asa_gui/utils/validators.py +41 -0
- canns/pipeline/asa_gui/views/__init__.py +1 -0
- canns/pipeline/asa_gui/views/help_content.py +171 -0
- canns/pipeline/asa_gui/views/pages/__init__.py +6 -0
- canns/pipeline/asa_gui/views/pages/analysis_page.py +565 -0
- canns/pipeline/asa_gui/views/pages/preprocess_page.py +492 -0
- canns/pipeline/asa_gui/views/panels/__init__.py +1 -0
- canns/pipeline/asa_gui/views/widgets/__init__.py +21 -0
- canns/pipeline/asa_gui/views/widgets/artifacts_tab.py +44 -0
- canns/pipeline/asa_gui/views/widgets/drop_zone.py +80 -0
- canns/pipeline/asa_gui/views/widgets/file_list.py +27 -0
- canns/pipeline/asa_gui/views/widgets/gridscore_tab.py +308 -0
- canns/pipeline/asa_gui/views/widgets/help_dialog.py +27 -0
- canns/pipeline/asa_gui/views/widgets/image_tab.py +50 -0
- canns/pipeline/asa_gui/views/widgets/image_viewer.py +97 -0
- canns/pipeline/asa_gui/views/widgets/log_box.py +16 -0
- canns/pipeline/asa_gui/views/widgets/pathcompare_tab.py +200 -0
- canns/pipeline/asa_gui/views/widgets/popup_combo.py +25 -0
- canns/pipeline/gallery/__init__.py +15 -5
- canns/pipeline/gallery/__main__.py +11 -0
- canns/pipeline/gallery/app.py +705 -0
- canns/pipeline/gallery/runner.py +790 -0
- canns/pipeline/gallery/state.py +51 -0
- canns/pipeline/gallery/styles.tcss +123 -0
- canns/pipeline/launcher.py +81 -0
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/METADATA +11 -1
- canns-0.14.0.dist-info/RECORD +163 -0
- canns-0.14.0.dist-info/entry_points.txt +5 -0
- canns/pipeline/_base.py +0 -50
- canns-0.13.1.dist-info/RECORD +0 -89
- canns-0.13.1.dist-info/entry_points.txt +0 -3
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/WHEEL +0 -0
- {canns-0.13.1.dist-info → canns-0.14.0.dist-info}/licenses/LICENSE +0 -0
canns/analyzer/data/__init__.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
"""Data analysis utilities for experimental and synthetic neural data."""
|
|
2
2
|
|
|
3
|
+
from . import asa, cell_classification
|
|
3
4
|
from .asa import * # noqa: F401,F403
|
|
5
|
+
from .cell_classification import * # noqa: F401,F403
|
|
4
6
|
|
|
5
|
-
__all__ =
|
|
7
|
+
__all__ = ["asa", "cell_classification"]
|
|
8
|
+
__all__ += list(getattr(asa, "__all__", []))
|
|
9
|
+
__all__ += list(getattr(cell_classification, "__all__", []))
|
|
@@ -2,10 +2,14 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
# Coho-space analysis + visualization
|
|
4
4
|
from .cohospace import (
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
5
|
+
compute_cohoscore_1d,
|
|
6
|
+
compute_cohoscore_2d,
|
|
7
|
+
plot_cohospace_neuron_1d,
|
|
8
|
+
plot_cohospace_neuron_2d,
|
|
9
|
+
plot_cohospace_population_1d,
|
|
10
|
+
plot_cohospace_population_2d,
|
|
11
|
+
plot_cohospace_trajectory_1d,
|
|
12
|
+
plot_cohospace_trajectory_2d,
|
|
9
13
|
)
|
|
10
14
|
from .config import (
|
|
11
15
|
CANN2DError,
|
|
@@ -33,7 +37,11 @@ from .fr import (
|
|
|
33
37
|
)
|
|
34
38
|
|
|
35
39
|
# Path utilities
|
|
36
|
-
from .path import
|
|
40
|
+
from .path import (
|
|
41
|
+
align_coords_to_position_1d,
|
|
42
|
+
align_coords_to_position_2d,
|
|
43
|
+
apply_angle_scale,
|
|
44
|
+
)
|
|
37
45
|
|
|
38
46
|
# Higher-level plotting helpers
|
|
39
47
|
from .plotting import (
|
|
@@ -41,7 +49,8 @@ from .plotting import (
|
|
|
41
49
|
plot_3d_bump_on_torus,
|
|
42
50
|
plot_cohomap,
|
|
43
51
|
plot_cohomap_multi,
|
|
44
|
-
|
|
52
|
+
plot_path_compare_1d,
|
|
53
|
+
plot_path_compare_2d,
|
|
45
54
|
plot_projection,
|
|
46
55
|
)
|
|
47
56
|
|
|
@@ -61,7 +70,8 @@ __all__ = [
|
|
|
61
70
|
"decode_circular_coordinates",
|
|
62
71
|
"decode_circular_coordinates_multi",
|
|
63
72
|
"plot_projection",
|
|
64
|
-
"
|
|
73
|
+
"plot_path_compare_1d",
|
|
74
|
+
"plot_path_compare_2d",
|
|
65
75
|
"plot_cohomap",
|
|
66
76
|
"plot_cohomap_multi",
|
|
67
77
|
"plot_3d_bump_on_torus",
|
|
@@ -75,10 +85,15 @@ __all__ = [
|
|
|
75
85
|
"FRMResult",
|
|
76
86
|
"compute_frm",
|
|
77
87
|
"plot_frm",
|
|
78
|
-
"
|
|
79
|
-
"
|
|
80
|
-
"
|
|
81
|
-
"
|
|
82
|
-
"
|
|
88
|
+
"plot_cohospace_trajectory_1d",
|
|
89
|
+
"plot_cohospace_trajectory_2d",
|
|
90
|
+
"plot_cohospace_neuron_1d",
|
|
91
|
+
"plot_cohospace_neuron_2d",
|
|
92
|
+
"plot_cohospace_population_1d",
|
|
93
|
+
"plot_cohospace_population_2d",
|
|
94
|
+
"compute_cohoscore_1d",
|
|
95
|
+
"compute_cohoscore_2d",
|
|
96
|
+
"align_coords_to_position_1d",
|
|
97
|
+
"align_coords_to_position_2d",
|
|
83
98
|
"apply_angle_scale",
|
|
84
99
|
]
|
|
@@ -91,7 +91,7 @@ def _align_activity_to_coords(
|
|
|
91
91
|
return activity
|
|
92
92
|
|
|
93
93
|
|
|
94
|
-
def
|
|
94
|
+
def plot_cohospace_trajectory_2d(
|
|
95
95
|
coords: np.ndarray,
|
|
96
96
|
times: np.ndarray | None = None,
|
|
97
97
|
subsample: int = 1,
|
|
@@ -129,7 +129,7 @@ def plot_cohospace_trajectory(
|
|
|
129
129
|
|
|
130
130
|
Examples
|
|
131
131
|
--------
|
|
132
|
-
>>> fig =
|
|
132
|
+
>>> fig = plot_cohospace_trajectory_2d(coords, subsample=2, show=False) # doctest: +SKIP
|
|
133
133
|
"""
|
|
134
134
|
|
|
135
135
|
try:
|
|
@@ -194,7 +194,106 @@ def plot_cohospace_trajectory(
|
|
|
194
194
|
return ax
|
|
195
195
|
|
|
196
196
|
|
|
197
|
-
def
|
|
197
|
+
def plot_cohospace_trajectory_1d(
|
|
198
|
+
coords: np.ndarray,
|
|
199
|
+
times: np.ndarray | None = None,
|
|
200
|
+
subsample: int = 1,
|
|
201
|
+
figsize: tuple[int, int] = (6, 6),
|
|
202
|
+
cmap: str = "viridis",
|
|
203
|
+
save_path: str | None = None,
|
|
204
|
+
show: bool = False,
|
|
205
|
+
config: PlotConfig | None = None,
|
|
206
|
+
) -> plt.Axes:
|
|
207
|
+
"""
|
|
208
|
+
Plot a 1D cohomology trajectory on the unit circle.
|
|
209
|
+
|
|
210
|
+
Parameters
|
|
211
|
+
----------
|
|
212
|
+
coords : ndarray, shape (T,) or (T, 1)
|
|
213
|
+
Decoded cohomology angles (theta). Values may be in radians or in [0, 1] "unit circle"
|
|
214
|
+
convention depending on upstream decoding; this function will plot on the unit circle.
|
|
215
|
+
times : ndarray, optional, shape (T,)
|
|
216
|
+
Optional time array used to color points. If None, uses arange(T).
|
|
217
|
+
subsample : int
|
|
218
|
+
Downsampling step (>1 reduces the number of plotted points).
|
|
219
|
+
figsize : tuple
|
|
220
|
+
Matplotlib figure size.
|
|
221
|
+
cmap : str
|
|
222
|
+
Matplotlib colormap name.
|
|
223
|
+
save_path : str, optional
|
|
224
|
+
If provided, saves the figure to this path.
|
|
225
|
+
show : bool
|
|
226
|
+
If True, calls plt.show(). If False, closes the figure and returns the Axes.
|
|
227
|
+
"""
|
|
228
|
+
try:
|
|
229
|
+
subsample_i = int(subsample)
|
|
230
|
+
except Exception:
|
|
231
|
+
subsample_i = 1
|
|
232
|
+
if subsample_i < 1:
|
|
233
|
+
subsample_i = 1
|
|
234
|
+
|
|
235
|
+
coords = np.asarray(coords)
|
|
236
|
+
if coords.ndim == 2 and coords.shape[1] == 1:
|
|
237
|
+
coords = coords[:, 0]
|
|
238
|
+
if coords.ndim != 1:
|
|
239
|
+
raise ValueError(f"`coords` must have shape (T,) or (T, 1). Got {coords.shape}.")
|
|
240
|
+
|
|
241
|
+
if times is None:
|
|
242
|
+
times_vis = np.arange(coords.shape[0])
|
|
243
|
+
else:
|
|
244
|
+
times_vis = np.asarray(times)
|
|
245
|
+
if times_vis.shape[0] != coords.shape[0]:
|
|
246
|
+
raise ValueError(
|
|
247
|
+
f"`times` length must match coords length. Got times={times_vis.shape[0]}, coords={coords.shape[0]}."
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
if subsample_i > 1:
|
|
251
|
+
coords = coords[::subsample_i]
|
|
252
|
+
times_vis = times_vis[::subsample_i]
|
|
253
|
+
|
|
254
|
+
theta = coords % (2 * np.pi)
|
|
255
|
+
x = np.cos(theta)
|
|
256
|
+
y = np.sin(theta)
|
|
257
|
+
|
|
258
|
+
config = _ensure_plot_config(
|
|
259
|
+
config,
|
|
260
|
+
PlotConfig.for_static_plot,
|
|
261
|
+
title="CohoSpace trajectory (1D)",
|
|
262
|
+
xlabel="cos(theta)",
|
|
263
|
+
ylabel="sin(theta)",
|
|
264
|
+
figsize=figsize,
|
|
265
|
+
save_path=save_path,
|
|
266
|
+
show=show,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
fig, ax = plt.subplots(figsize=config.figsize)
|
|
270
|
+
circle = np.linspace(0, 2 * np.pi, 200)
|
|
271
|
+
ax.plot(np.cos(circle), np.sin(circle), color="0.85", lw=1.0, zorder=0)
|
|
272
|
+
sc = ax.scatter(
|
|
273
|
+
x,
|
|
274
|
+
y,
|
|
275
|
+
c=times_vis,
|
|
276
|
+
cmap=cmap,
|
|
277
|
+
s=5,
|
|
278
|
+
alpha=0.8,
|
|
279
|
+
)
|
|
280
|
+
cbar = plt.colorbar(sc, ax=ax)
|
|
281
|
+
cbar.set_label("Time")
|
|
282
|
+
|
|
283
|
+
ax.set_xlim(-1.2, 1.2)
|
|
284
|
+
ax.set_ylim(-1.2, 1.2)
|
|
285
|
+
ax.set_xlabel(config.xlabel)
|
|
286
|
+
ax.set_ylabel(config.ylabel)
|
|
287
|
+
ax.set_title(config.title)
|
|
288
|
+
ax.set_aspect("equal", adjustable="box")
|
|
289
|
+
ax.grid(True, alpha=0.2)
|
|
290
|
+
|
|
291
|
+
_ensure_parent_dir(config.save_path)
|
|
292
|
+
finalize_figure(fig, config)
|
|
293
|
+
return ax
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def plot_cohospace_neuron_2d(
|
|
198
297
|
coords: np.ndarray,
|
|
199
298
|
activity: np.ndarray,
|
|
200
299
|
neuron_id: int,
|
|
@@ -230,7 +329,7 @@ def plot_cohospace_neuron(
|
|
|
230
329
|
mode : {"fr", "spike"}
|
|
231
330
|
top_percent : float
|
|
232
331
|
Used only when mode="fr". For example, 5.0 means "top 5%%" time points.
|
|
233
|
-
figsize, cmap, save_path, show : see `
|
|
332
|
+
figsize, cmap, save_path, show : see `plot_cohospace_trajectory_2d`.
|
|
234
333
|
|
|
235
334
|
Returns
|
|
236
335
|
-------
|
|
@@ -238,7 +337,7 @@ def plot_cohospace_neuron(
|
|
|
238
337
|
|
|
239
338
|
Examples
|
|
240
339
|
--------
|
|
241
|
-
>>>
|
|
340
|
+
>>> plot_cohospace_neuron_2d(coords, spikes, neuron_id=0, show=False) # doctest: +SKIP
|
|
242
341
|
"""
|
|
243
342
|
coords = np.asarray(coords)
|
|
244
343
|
activity = _align_activity_to_coords(
|
|
@@ -300,7 +399,94 @@ def plot_cohospace_neuron(
|
|
|
300
399
|
return fig
|
|
301
400
|
|
|
302
401
|
|
|
303
|
-
def
|
|
402
|
+
def plot_cohospace_neuron_1d(
|
|
403
|
+
coords: np.ndarray,
|
|
404
|
+
activity: np.ndarray,
|
|
405
|
+
neuron_id: int,
|
|
406
|
+
mode: str = "fr",
|
|
407
|
+
top_percent: float = 5.0,
|
|
408
|
+
times: np.ndarray | None = None,
|
|
409
|
+
auto_filter: bool = True,
|
|
410
|
+
figsize: tuple = (6, 6),
|
|
411
|
+
cmap: str = "hot",
|
|
412
|
+
save_path: str | None = None,
|
|
413
|
+
show: bool = True,
|
|
414
|
+
config: PlotConfig | None = None,
|
|
415
|
+
) -> plt.Figure:
|
|
416
|
+
"""
|
|
417
|
+
Overlay a single neuron's activity on the 1D cohomology trajectory (unit circle).
|
|
418
|
+
"""
|
|
419
|
+
coords = np.asarray(coords)
|
|
420
|
+
if coords.ndim == 2 and coords.shape[1] == 1:
|
|
421
|
+
coords = coords[:, 0]
|
|
422
|
+
if coords.ndim != 1:
|
|
423
|
+
raise ValueError(f"coords must have shape (T,) or (T, 1), got {coords.shape}")
|
|
424
|
+
|
|
425
|
+
activity = _align_activity_to_coords(
|
|
426
|
+
coords[:, None], activity, times, label="activity", auto_filter=auto_filter
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
signal = activity[:, neuron_id]
|
|
430
|
+
|
|
431
|
+
if mode == "fr":
|
|
432
|
+
threshold = np.percentile(signal, 100 - top_percent)
|
|
433
|
+
idx = signal >= threshold
|
|
434
|
+
color = signal[idx]
|
|
435
|
+
title = f"Neuron {neuron_id} FR top {top_percent:.1f}% on coho-space (1D)"
|
|
436
|
+
use_cmap = cmap
|
|
437
|
+
elif mode == "spike":
|
|
438
|
+
idx = signal > 0
|
|
439
|
+
color = None
|
|
440
|
+
title = f"Neuron {neuron_id} spikes on coho-space (1D)"
|
|
441
|
+
use_cmap = None
|
|
442
|
+
else:
|
|
443
|
+
raise ValueError("mode must be 'fr' or 'spike'")
|
|
444
|
+
|
|
445
|
+
theta = coords % (2 * np.pi)
|
|
446
|
+
x = np.cos(theta)
|
|
447
|
+
y = np.sin(theta)
|
|
448
|
+
|
|
449
|
+
config = _ensure_plot_config(
|
|
450
|
+
config,
|
|
451
|
+
PlotConfig.for_static_plot,
|
|
452
|
+
title=title,
|
|
453
|
+
xlabel="cos(theta)",
|
|
454
|
+
ylabel="sin(theta)",
|
|
455
|
+
figsize=figsize,
|
|
456
|
+
save_path=save_path,
|
|
457
|
+
show=show,
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
fig, ax = plt.subplots(figsize=config.figsize)
|
|
461
|
+
circle = np.linspace(0, 2 * np.pi, 200)
|
|
462
|
+
ax.plot(np.cos(circle), np.sin(circle), color="0.85", lw=1.0, zorder=0)
|
|
463
|
+
sc = ax.scatter(
|
|
464
|
+
x[idx],
|
|
465
|
+
y[idx],
|
|
466
|
+
c=color if mode == "fr" else "red",
|
|
467
|
+
cmap=use_cmap,
|
|
468
|
+
s=8,
|
|
469
|
+
alpha=0.9,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
if mode == "fr":
|
|
473
|
+
cbar = plt.colorbar(sc, ax=ax)
|
|
474
|
+
cbar.set_label("Firing rate")
|
|
475
|
+
|
|
476
|
+
ax.set_xlim(-1.2, 1.2)
|
|
477
|
+
ax.set_ylim(-1.2, 1.2)
|
|
478
|
+
ax.set_xlabel(config.xlabel)
|
|
479
|
+
ax.set_ylabel(config.ylabel)
|
|
480
|
+
ax.set_title(config.title)
|
|
481
|
+
ax.set_aspect("equal", adjustable="box")
|
|
482
|
+
|
|
483
|
+
_ensure_parent_dir(config.save_path)
|
|
484
|
+
finalize_figure(fig, config)
|
|
485
|
+
|
|
486
|
+
return fig
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def plot_cohospace_population_2d(
|
|
304
490
|
coords: np.ndarray,
|
|
305
491
|
activity: np.ndarray,
|
|
306
492
|
neuron_ids: list[int] | np.ndarray,
|
|
@@ -338,7 +524,7 @@ def plot_cohospace_population(
|
|
|
338
524
|
mode : {"fr", "spike"}
|
|
339
525
|
top_percent : float
|
|
340
526
|
Used only when mode="fr".
|
|
341
|
-
figsize, cmap, save_path, show : see `
|
|
527
|
+
figsize, cmap, save_path, show : see `plot_cohospace_trajectory_2d`.
|
|
342
528
|
|
|
343
529
|
Returns
|
|
344
530
|
-------
|
|
@@ -346,7 +532,7 @@ def plot_cohospace_population(
|
|
|
346
532
|
|
|
347
533
|
Examples
|
|
348
534
|
--------
|
|
349
|
-
>>>
|
|
535
|
+
>>> plot_cohospace_population_2d(coords, spikes, neuron_ids=[0, 1, 2], show=False) # doctest: +SKIP
|
|
350
536
|
"""
|
|
351
537
|
coords = np.asarray(coords)
|
|
352
538
|
activity = _align_activity_to_coords(
|
|
@@ -411,7 +597,97 @@ def plot_cohospace_population(
|
|
|
411
597
|
return fig
|
|
412
598
|
|
|
413
599
|
|
|
414
|
-
def
|
|
600
|
+
def plot_cohospace_population_1d(
|
|
601
|
+
coords: np.ndarray,
|
|
602
|
+
activity: np.ndarray,
|
|
603
|
+
neuron_ids: list[int] | np.ndarray,
|
|
604
|
+
mode: str = "fr",
|
|
605
|
+
top_percent: float = 5.0,
|
|
606
|
+
times: np.ndarray | None = None,
|
|
607
|
+
auto_filter: bool = True,
|
|
608
|
+
figsize: tuple = (6, 6),
|
|
609
|
+
cmap: str = "hot",
|
|
610
|
+
save_path: str | None = None,
|
|
611
|
+
show: bool = True,
|
|
612
|
+
config: PlotConfig | None = None,
|
|
613
|
+
) -> plt.Figure:
|
|
614
|
+
"""
|
|
615
|
+
Plot aggregated activity from multiple neurons on the 1D cohomology trajectory.
|
|
616
|
+
"""
|
|
617
|
+
coords = np.asarray(coords)
|
|
618
|
+
if coords.ndim == 2 and coords.shape[1] == 1:
|
|
619
|
+
coords = coords[:, 0]
|
|
620
|
+
if coords.ndim != 1:
|
|
621
|
+
raise ValueError(f"coords must have shape (T,) or (T, 1), got {coords.shape}")
|
|
622
|
+
|
|
623
|
+
activity = _align_activity_to_coords(
|
|
624
|
+
coords[:, None], activity, times, label="activity", auto_filter=auto_filter
|
|
625
|
+
)
|
|
626
|
+
neuron_ids = np.asarray(neuron_ids, dtype=int)
|
|
627
|
+
|
|
628
|
+
T = activity.shape[0]
|
|
629
|
+
mask = np.zeros(T, dtype=bool)
|
|
630
|
+
agg_color = np.zeros(T, dtype=float)
|
|
631
|
+
|
|
632
|
+
for n in neuron_ids:
|
|
633
|
+
signal = activity[:, n]
|
|
634
|
+
|
|
635
|
+
if mode == "fr":
|
|
636
|
+
threshold = np.percentile(signal, 100 - top_percent)
|
|
637
|
+
idx = signal >= threshold
|
|
638
|
+
agg_color[idx] += signal[idx]
|
|
639
|
+
mask |= idx
|
|
640
|
+
elif mode == "spike":
|
|
641
|
+
idx = signal > 0
|
|
642
|
+
agg_color[idx] += 1.0
|
|
643
|
+
mask |= idx
|
|
644
|
+
else:
|
|
645
|
+
raise ValueError("mode must be 'fr' or 'spike'")
|
|
646
|
+
|
|
647
|
+
theta = coords % (2 * np.pi)
|
|
648
|
+
x = np.cos(theta)
|
|
649
|
+
y = np.sin(theta)
|
|
650
|
+
|
|
651
|
+
config = _ensure_plot_config(
|
|
652
|
+
config,
|
|
653
|
+
PlotConfig.for_static_plot,
|
|
654
|
+
title=f"{len(neuron_ids)} neurons on coho-space (1D)",
|
|
655
|
+
xlabel="cos(theta)",
|
|
656
|
+
ylabel="sin(theta)",
|
|
657
|
+
figsize=figsize,
|
|
658
|
+
save_path=save_path,
|
|
659
|
+
show=show,
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
fig, ax = plt.subplots(figsize=config.figsize)
|
|
663
|
+
circle = np.linspace(0, 2 * np.pi, 200)
|
|
664
|
+
ax.plot(np.cos(circle), np.sin(circle), color="0.85", lw=1.0, zorder=0)
|
|
665
|
+
sc = ax.scatter(
|
|
666
|
+
x[mask],
|
|
667
|
+
y[mask],
|
|
668
|
+
c=agg_color[mask],
|
|
669
|
+
cmap=cmap,
|
|
670
|
+
s=6,
|
|
671
|
+
alpha=0.9,
|
|
672
|
+
)
|
|
673
|
+
cbar = plt.colorbar(sc, ax=ax)
|
|
674
|
+
label = "Aggregate FR" if mode == "fr" else "Spike count"
|
|
675
|
+
cbar.set_label(label)
|
|
676
|
+
|
|
677
|
+
ax.set_xlim(-1.2, 1.2)
|
|
678
|
+
ax.set_ylim(-1.2, 1.2)
|
|
679
|
+
ax.set_xlabel(config.xlabel)
|
|
680
|
+
ax.set_ylabel(config.ylabel)
|
|
681
|
+
ax.set_title(config.title)
|
|
682
|
+
ax.set_aspect("equal", adjustable="box")
|
|
683
|
+
|
|
684
|
+
_ensure_parent_dir(config.save_path)
|
|
685
|
+
finalize_figure(fig, config)
|
|
686
|
+
|
|
687
|
+
return fig
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
def compute_cohoscore_2d(
|
|
415
691
|
coords: np.ndarray,
|
|
416
692
|
activity: np.ndarray,
|
|
417
693
|
top_percent: float = 2.0,
|
|
@@ -451,7 +727,7 @@ def compute_cohoscore(
|
|
|
451
727
|
|
|
452
728
|
Examples
|
|
453
729
|
--------
|
|
454
|
-
>>> scores =
|
|
730
|
+
>>> scores = compute_cohoscore_2d(coords, spikes) # doctest: +SKIP
|
|
455
731
|
>>> scores.shape[0] # doctest: +SKIP
|
|
456
732
|
"""
|
|
457
733
|
coords = np.asarray(coords)
|
|
@@ -487,6 +763,56 @@ def compute_cohoscore(
|
|
|
487
763
|
return scores
|
|
488
764
|
|
|
489
765
|
|
|
766
|
+
def compute_cohoscore_1d(
|
|
767
|
+
coords: np.ndarray,
|
|
768
|
+
activity: np.ndarray,
|
|
769
|
+
top_percent: float = 2.0,
|
|
770
|
+
times: np.ndarray | None = None,
|
|
771
|
+
auto_filter: bool = True,
|
|
772
|
+
) -> np.ndarray:
|
|
773
|
+
"""
|
|
774
|
+
Compute 1D cohomology-space selectivity score (CohoScore) for each neuron.
|
|
775
|
+
|
|
776
|
+
For each neuron:
|
|
777
|
+
- Select "active" time points:
|
|
778
|
+
- If top_percent is None: all time points with activity > 0
|
|
779
|
+
- Else: top `top_percent`%% time points by activity value
|
|
780
|
+
- Compute circular variance for theta on the selected points.
|
|
781
|
+
- CohoScore = var(theta)
|
|
782
|
+
"""
|
|
783
|
+
coords = np.asarray(coords)
|
|
784
|
+
if coords.ndim == 2 and coords.shape[1] == 1:
|
|
785
|
+
coords = coords[:, 0]
|
|
786
|
+
if coords.ndim != 1:
|
|
787
|
+
raise ValueError(f"coords must have shape (T,) or (T, 1), got {coords.shape}")
|
|
788
|
+
|
|
789
|
+
activity = _align_activity_to_coords(
|
|
790
|
+
coords[:, None], activity, times, label="activity", auto_filter=auto_filter
|
|
791
|
+
)
|
|
792
|
+
_, n_neurons = activity.shape
|
|
793
|
+
|
|
794
|
+
theta = coords % (2 * np.pi)
|
|
795
|
+
scores = np.zeros(n_neurons, dtype=float)
|
|
796
|
+
|
|
797
|
+
for n in range(n_neurons):
|
|
798
|
+
signal = activity[:, n]
|
|
799
|
+
|
|
800
|
+
if top_percent is None:
|
|
801
|
+
idx = signal > 0
|
|
802
|
+
else:
|
|
803
|
+
threshold = np.percentile(signal, 100 - top_percent)
|
|
804
|
+
idx = signal >= threshold
|
|
805
|
+
|
|
806
|
+
if np.sum(idx) < 5:
|
|
807
|
+
scores[n] = np.nan
|
|
808
|
+
continue
|
|
809
|
+
|
|
810
|
+
var1 = circvar(theta[idx], high=2 * np.pi, low=0)
|
|
811
|
+
scores[n] = var1
|
|
812
|
+
|
|
813
|
+
return scores
|
|
814
|
+
|
|
815
|
+
|
|
490
816
|
def skew_transform_torus(coords):
|
|
491
817
|
"""
|
|
492
818
|
Convert torus angles (theta1, theta2) into coordinates in a skewed parallelogram fundamental domain.
|
|
@@ -72,6 +72,8 @@ class TDAConfig:
|
|
|
72
72
|
Number of shuffles for null distribution.
|
|
73
73
|
progress_bar : bool
|
|
74
74
|
Whether to show progress bars.
|
|
75
|
+
standardize : bool
|
|
76
|
+
Whether to standardize data before PCA (z-score).
|
|
75
77
|
|
|
76
78
|
Examples
|
|
77
79
|
--------
|
|
@@ -94,6 +96,7 @@ class TDAConfig:
|
|
|
94
96
|
do_shuffle: bool = False
|
|
95
97
|
num_shuffles: int = 1000
|
|
96
98
|
progress_bar: bool = True
|
|
99
|
+
standardize: bool = True
|
|
97
100
|
|
|
98
101
|
|
|
99
102
|
@dataclass
|
|
@@ -59,11 +59,11 @@ def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None,
|
|
|
59
59
|
# Step 1: Extract and filter spike data
|
|
60
60
|
spikes_filtered = _extract_spike_data(spike_trains, config)
|
|
61
61
|
|
|
62
|
-
# Step 2: Create time bins
|
|
63
|
-
|
|
62
|
+
# Step 2: Create time bins metadata
|
|
63
|
+
min_time, max_time, n_bins = _create_time_bins(spike_trains["t"], config)
|
|
64
64
|
|
|
65
65
|
# Step 3: Bin spike data
|
|
66
|
-
spikes_bin = _bin_spike_data(spikes_filtered,
|
|
66
|
+
spikes_bin = _bin_spike_data(spikes_filtered, min_time, max_time, n_bins, config)
|
|
67
67
|
|
|
68
68
|
# Step 4: Apply temporal smoothing if requested
|
|
69
69
|
if config.smooth:
|
|
@@ -73,7 +73,7 @@ def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None,
|
|
|
73
73
|
if config.speed_filter:
|
|
74
74
|
return _apply_speed_filtering(spikes_bin, spike_trains, config)
|
|
75
75
|
|
|
76
|
-
return spikes_bin,
|
|
76
|
+
return spikes_bin, spike_trains["x"], spike_trains["y"], spike_trains["t"]
|
|
77
77
|
|
|
78
78
|
except Exception as e:
|
|
79
79
|
raise ProcessingError(f"Failed to embed spike trains: {e}") from e
|
|
@@ -132,36 +132,44 @@ def _extract_spike_data(
|
|
|
132
132
|
raise ProcessingError(f"Error extracting spike data: {e}") from e
|
|
133
133
|
|
|
134
134
|
|
|
135
|
-
def _create_time_bins(t: np.ndarray, config: SpikeEmbeddingConfig) ->
|
|
136
|
-
"""Create time
|
|
135
|
+
def _create_time_bins(t: np.ndarray, config: SpikeEmbeddingConfig) -> tuple[int, int, int]:
|
|
136
|
+
"""Create time-bin metadata for spike discretization."""
|
|
137
137
|
min_time0 = np.min(t)
|
|
138
138
|
max_time0 = np.max(t)
|
|
139
139
|
|
|
140
|
-
min_time = min_time0 * config.res
|
|
141
|
-
max_time = max_time0 * config.res
|
|
140
|
+
min_time = int(np.floor(min_time0 * config.res))
|
|
141
|
+
max_time = int(np.ceil(max_time0 * config.res)) + 1
|
|
142
|
+
n_bins = max(1, int(np.ceil((max_time - min_time) / config.dt)))
|
|
143
|
+
last_time = min_time + config.dt * (n_bins - 1)
|
|
142
144
|
|
|
143
|
-
return
|
|
145
|
+
return min_time, last_time, n_bins
|
|
144
146
|
|
|
145
147
|
|
|
146
148
|
def _bin_spike_data(
|
|
147
|
-
spikes: dict[int, np.ndarray],
|
|
149
|
+
spikes: dict[int, np.ndarray],
|
|
150
|
+
min_time: int,
|
|
151
|
+
max_time: int,
|
|
152
|
+
n_bins: int,
|
|
153
|
+
config: SpikeEmbeddingConfig,
|
|
148
154
|
) -> np.ndarray:
|
|
149
155
|
"""Convert spike times to binned spike matrix."""
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
spikes_bin = np.zeros((len(time_bins), len(spikes)), dtype=int)
|
|
156
|
+
spikes_bin = np.zeros((n_bins, len(spikes)), dtype=np.int32)
|
|
157
|
+
max_time_offset = max_time - min_time
|
|
154
158
|
|
|
155
159
|
for n in spikes:
|
|
156
|
-
spike_times = np.
|
|
160
|
+
spike_times = np.asarray(spikes[n])
|
|
161
|
+
if spike_times.size == 0:
|
|
162
|
+
continue
|
|
163
|
+
spike_times = (spike_times * config.res - min_time).astype(np.int64, copy=False)
|
|
157
164
|
# Filter valid spike times
|
|
158
|
-
|
|
159
|
-
|
|
165
|
+
valid = (spike_times < max_time_offset) & (spike_times > 0)
|
|
166
|
+
if not np.any(valid):
|
|
167
|
+
continue
|
|
168
|
+
spike_times = spike_times[valid]
|
|
169
|
+
spike_bins = np.floor_divide(spike_times, config.dt).astype(np.int64, copy=False)
|
|
160
170
|
|
|
161
|
-
# Bin spikes
|
|
162
|
-
|
|
163
|
-
if j < len(time_bins):
|
|
164
|
-
spikes_bin[j, n] += 1
|
|
171
|
+
# Bin spikes (vectorized)
|
|
172
|
+
np.add.at(spikes_bin[:, n], spike_bins, 1)
|
|
165
173
|
|
|
166
174
|
return spikes_bin
|
|
167
175
|
|
|
@@ -171,21 +179,22 @@ def _apply_temporal_smoothing(spikes_bin: np.ndarray, config: SpikeEmbeddingConf
|
|
|
171
179
|
# Calculate smoothing parameters (legacy implementation used custom kernel)
|
|
172
180
|
# Current implementation uses scipy's gaussian_filter1d for better performance
|
|
173
181
|
|
|
174
|
-
#
|
|
175
|
-
|
|
182
|
+
# Convert to float once to avoid holding both int and float arrays.
|
|
183
|
+
spikes_bin = spikes_bin.astype(np.float32, copy=False)
|
|
176
184
|
|
|
177
185
|
# Use scipy's gaussian_filter1d for better performance
|
|
178
186
|
|
|
179
187
|
sigma_bins = config.sigma / config.dt
|
|
180
188
|
|
|
181
189
|
for n in range(spikes_bin.shape[1]):
|
|
182
|
-
|
|
183
|
-
spikes_bin[:, n]
|
|
190
|
+
gaussian_filter1d(
|
|
191
|
+
spikes_bin[:, n], sigma=sigma_bins, mode="constant", output=spikes_bin[:, n]
|
|
184
192
|
)
|
|
185
193
|
|
|
186
194
|
# Normalize
|
|
187
195
|
normalization_factor = 1 / np.sqrt(2 * np.pi * (config.sigma / config.res) ** 2)
|
|
188
|
-
|
|
196
|
+
spikes_bin *= normalization_factor
|
|
197
|
+
return spikes_bin
|
|
189
198
|
|
|
190
199
|
|
|
191
200
|
def _apply_speed_filtering(
|
|
@@ -240,25 +249,19 @@ def _load_pos(t, x, y, res=100000, dt=1000):
|
|
|
240
249
|
|
|
241
250
|
tt = np.arange(np.floor(min_time), np.ceil(max_time) + 1, dt) / res
|
|
242
251
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
)
|
|
256
|
-
xx = rangesx[~np.isnan(ranges)]
|
|
257
|
-
|
|
258
|
-
rangesy = y[idx[:-1], np.newaxis] + np.multiply(
|
|
259
|
-
ranges, (y[idx[1:]] - y[idx[:-1]])[:, np.newaxis]
|
|
260
|
-
)
|
|
261
|
-
yy = rangesy[~np.isnan(ranges)]
|
|
252
|
+
if t.size == 0:
|
|
253
|
+
return np.array([]), np.array([]), tt, np.array([])
|
|
254
|
+
|
|
255
|
+
# Ensure monotonically increasing time for interpolation.
|
|
256
|
+
if t.size > 1 and np.any(np.diff(t) < 0):
|
|
257
|
+
order = np.argsort(t)
|
|
258
|
+
t = t[order]
|
|
259
|
+
x = x[order]
|
|
260
|
+
y = y[order]
|
|
261
|
+
|
|
262
|
+
# Interpolate positions onto the spike time bins.
|
|
263
|
+
xx = np.interp(tt, t, x)
|
|
264
|
+
yy = np.interp(tt, t, y)
|
|
262
265
|
|
|
263
266
|
xxs = _gaussian_filter1d(xx - np.min(xx), sigma=100)
|
|
264
267
|
yys = _gaussian_filter1d(yy - np.min(yy), sigma=100)
|