canns 0.12.7__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/__init__.py +3 -11
- canns/analyzer/data/asa/__init__.py +74 -0
- canns/analyzer/data/asa/cohospace.py +905 -0
- canns/analyzer/data/asa/config.py +246 -0
- canns/analyzer/data/asa/decode.py +448 -0
- canns/analyzer/data/asa/embedding.py +269 -0
- canns/analyzer/data/asa/filters.py +208 -0
- canns/analyzer/data/asa/fr.py +439 -0
- canns/analyzer/data/asa/path.py +389 -0
- canns/analyzer/data/asa/plotting.py +1276 -0
- canns/analyzer/data/asa/tda.py +901 -0
- canns/analyzer/data/legacy/__init__.py +6 -0
- canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
- canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
- canns/analyzer/visualization/core/backend.py +1 -1
- canns/analyzer/visualization/core/config.py +77 -0
- canns/analyzer/visualization/core/rendering.py +10 -6
- canns/analyzer/visualization/energy_plots.py +22 -8
- canns/analyzer/visualization/spatial_plots.py +31 -11
- canns/analyzer/visualization/theta_sweep_plots.py +15 -6
- canns/pipeline/__init__.py +4 -8
- canns/pipeline/asa/__init__.py +21 -0
- canns/pipeline/asa/__main__.py +11 -0
- canns/pipeline/asa/app.py +1000 -0
- canns/pipeline/asa/runner.py +1095 -0
- canns/pipeline/asa/screens.py +215 -0
- canns/pipeline/asa/state.py +248 -0
- canns/pipeline/asa/styles.tcss +221 -0
- canns/pipeline/asa/widgets.py +233 -0
- canns/pipeline/gallery/__init__.py +7 -0
- canns/task/open_loop_navigation.py +3 -1
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/RECORD +36 -17
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
- canns/pipeline/theta_sweep.py +0 -573
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from ...visualization import PlotConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class SpikeEmbeddingConfig:
|
|
10
|
+
"""Configuration for spike train embedding.
|
|
11
|
+
|
|
12
|
+
Attributes
|
|
13
|
+
----------
|
|
14
|
+
res : int
|
|
15
|
+
Time scaling factor that converts seconds to integer bins.
|
|
16
|
+
dt : int
|
|
17
|
+
Bin width in the same scaled units as ``res``.
|
|
18
|
+
sigma : int
|
|
19
|
+
Gaussian smoothing width (scaled units).
|
|
20
|
+
smooth : bool
|
|
21
|
+
Whether to apply temporal smoothing to spike counts.
|
|
22
|
+
speed_filter : bool
|
|
23
|
+
Whether to filter by animal speed (requires x/y/t in the input).
|
|
24
|
+
min_speed : float
|
|
25
|
+
Minimum speed threshold for ``speed_filter`` (cm/s by convention).
|
|
26
|
+
|
|
27
|
+
Examples
|
|
28
|
+
--------
|
|
29
|
+
>>> from canns.analyzer.data import SpikeEmbeddingConfig
|
|
30
|
+
>>> cfg = SpikeEmbeddingConfig(smooth=False, speed_filter=False)
|
|
31
|
+
>>> cfg.min_speed
|
|
32
|
+
2.5
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
res: int = 100000
|
|
36
|
+
dt: int = 1000
|
|
37
|
+
sigma: int = 5000
|
|
38
|
+
smooth: bool = True
|
|
39
|
+
speed_filter: bool = True
|
|
40
|
+
min_speed: float = 2.5
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class TDAConfig:
|
|
45
|
+
"""Configuration for Topological Data Analysis (TDA).
|
|
46
|
+
|
|
47
|
+
Attributes
|
|
48
|
+
----------
|
|
49
|
+
dim : int
|
|
50
|
+
Target PCA dimension before TDA.
|
|
51
|
+
num_times : int
|
|
52
|
+
Downsampling stride in time.
|
|
53
|
+
active_times : int
|
|
54
|
+
Number of most active time points to keep.
|
|
55
|
+
k : int
|
|
56
|
+
Number of neighbors used in denoising.
|
|
57
|
+
n_points : int
|
|
58
|
+
Number of points sampled for persistent homology.
|
|
59
|
+
metric : str
|
|
60
|
+
Distance metric for point cloud (e.g., "cosine").
|
|
61
|
+
nbs : int
|
|
62
|
+
Number of neighbors for distance matrix construction.
|
|
63
|
+
maxdim : int
|
|
64
|
+
Maximum homology dimension for persistence.
|
|
65
|
+
coeff : int
|
|
66
|
+
Field coefficient for persistent homology.
|
|
67
|
+
show : bool
|
|
68
|
+
Whether to show barcode plots.
|
|
69
|
+
do_shuffle : bool
|
|
70
|
+
Whether to run shuffle analysis.
|
|
71
|
+
num_shuffles : int
|
|
72
|
+
Number of shuffles for null distribution.
|
|
73
|
+
progress_bar : bool
|
|
74
|
+
Whether to show progress bars.
|
|
75
|
+
|
|
76
|
+
Examples
|
|
77
|
+
--------
|
|
78
|
+
>>> from canns.analyzer.data import TDAConfig
|
|
79
|
+
>>> cfg = TDAConfig(maxdim=1, do_shuffle=False, show=False)
|
|
80
|
+
>>> cfg.maxdim
|
|
81
|
+
1
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
dim: int = 6
|
|
85
|
+
num_times: int = 5
|
|
86
|
+
active_times: int = 15000
|
|
87
|
+
k: int = 1000
|
|
88
|
+
n_points: int = 1200
|
|
89
|
+
metric: str = "cosine"
|
|
90
|
+
nbs: int = 800
|
|
91
|
+
maxdim: int = 1
|
|
92
|
+
coeff: int = 47
|
|
93
|
+
show: bool = True
|
|
94
|
+
do_shuffle: bool = False
|
|
95
|
+
num_shuffles: int = 1000
|
|
96
|
+
progress_bar: bool = True
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@dataclass
|
|
100
|
+
class CANN2DPlotConfig(PlotConfig):
|
|
101
|
+
"""Specialized PlotConfig for CANN2D visualizations.
|
|
102
|
+
|
|
103
|
+
Extends :class:`canns.analyzer.visualization.PlotConfig` with fields that
|
|
104
|
+
control 3D projection and torus animation parameters.
|
|
105
|
+
|
|
106
|
+
Examples
|
|
107
|
+
--------
|
|
108
|
+
>>> from canns.analyzer.data import CANN2DPlotConfig
|
|
109
|
+
>>> cfg = CANN2DPlotConfig.for_projection_3d(title="Projection")
|
|
110
|
+
>>> cfg.zlabel
|
|
111
|
+
'Component 3'
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
# 3D projection specific parameters
|
|
115
|
+
zlabel: str = "Component 3"
|
|
116
|
+
dpi: int = 300
|
|
117
|
+
|
|
118
|
+
# Torus animation specific parameters
|
|
119
|
+
numangsint: int = 51
|
|
120
|
+
r1: float = 1.5 # Major radius
|
|
121
|
+
r2: float = 1.0 # Minor radius
|
|
122
|
+
window_size: int = 300
|
|
123
|
+
frame_step: int = 5
|
|
124
|
+
n_frames: int = 20
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def for_projection_3d(cls, **kwargs) -> CANN2DPlotConfig:
|
|
128
|
+
"""Create configuration for 3D projection plots.
|
|
129
|
+
|
|
130
|
+
Examples
|
|
131
|
+
--------
|
|
132
|
+
>>> cfg = CANN2DPlotConfig.for_projection_3d(figsize=(6, 5))
|
|
133
|
+
>>> cfg.figsize
|
|
134
|
+
(6, 5)
|
|
135
|
+
"""
|
|
136
|
+
defaults = {
|
|
137
|
+
"title": "3D Data Projection",
|
|
138
|
+
"xlabel": "Component 1",
|
|
139
|
+
"ylabel": "Component 2",
|
|
140
|
+
"zlabel": "Component 3",
|
|
141
|
+
"figsize": (10, 8),
|
|
142
|
+
"dpi": 300,
|
|
143
|
+
}
|
|
144
|
+
defaults.update(kwargs)
|
|
145
|
+
return cls.for_static_plot(**defaults)
|
|
146
|
+
|
|
147
|
+
@classmethod
|
|
148
|
+
def for_torus_animation(cls, **kwargs) -> CANN2DPlotConfig:
|
|
149
|
+
"""Create configuration for 3D torus bump animations.
|
|
150
|
+
|
|
151
|
+
Examples
|
|
152
|
+
--------
|
|
153
|
+
>>> cfg = CANN2DPlotConfig.for_torus_animation(fps=10, n_frames=50)
|
|
154
|
+
>>> cfg.fps, cfg.n_frames
|
|
155
|
+
(10, 50)
|
|
156
|
+
"""
|
|
157
|
+
defaults = {
|
|
158
|
+
"title": "3D Bump on Torus",
|
|
159
|
+
"figsize": (8, 8),
|
|
160
|
+
"fps": 5,
|
|
161
|
+
"repeat": True,
|
|
162
|
+
"show_progress_bar": True,
|
|
163
|
+
"numangsint": 51,
|
|
164
|
+
"r1": 1.5,
|
|
165
|
+
"r2": 1.0,
|
|
166
|
+
"window_size": 300,
|
|
167
|
+
"frame_step": 5,
|
|
168
|
+
"n_frames": 20,
|
|
169
|
+
}
|
|
170
|
+
defaults.update(kwargs)
|
|
171
|
+
time_steps = kwargs.get("time_steps_per_second", 1000)
|
|
172
|
+
config = cls.for_animation(time_steps, **defaults)
|
|
173
|
+
# Add torus-specific attributes
|
|
174
|
+
config.numangsint = defaults["numangsint"]
|
|
175
|
+
config.r1 = defaults["r1"]
|
|
176
|
+
config.r2 = defaults["r2"]
|
|
177
|
+
config.window_size = defaults["window_size"]
|
|
178
|
+
config.frame_step = defaults["frame_step"]
|
|
179
|
+
config.n_frames = defaults["n_frames"]
|
|
180
|
+
return config
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
# ==================== Constants ====================
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class Constants:
|
|
187
|
+
"""Constants used throughout CANN2D analysis.
|
|
188
|
+
|
|
189
|
+
Examples
|
|
190
|
+
--------
|
|
191
|
+
>>> from canns.analyzer.data import Constants
|
|
192
|
+
>>> Constants.DEFAULT_DPI
|
|
193
|
+
300
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
DEFAULT_FIGSIZE = (10, 8)
|
|
197
|
+
DEFAULT_DPI = 300
|
|
198
|
+
GAUSSIAN_SIGMA_FACTOR = 100
|
|
199
|
+
SPEED_CONVERSION_FACTOR = 100
|
|
200
|
+
TIME_CONVERSION_FACTOR = 0.01
|
|
201
|
+
MULTIPROCESSING_CORES = 4
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
# ==================== Custom Exceptions ====================
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class CANN2DError(Exception):
|
|
208
|
+
"""Base exception for CANN2D analysis errors.
|
|
209
|
+
|
|
210
|
+
Examples
|
|
211
|
+
--------
|
|
212
|
+
>>> try: # doctest: +SKIP
|
|
213
|
+
... raise CANN2DError("boom")
|
|
214
|
+
... except CANN2DError:
|
|
215
|
+
... pass
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
pass
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class DataLoadError(CANN2DError):
|
|
222
|
+
"""Raised when data loading fails.
|
|
223
|
+
|
|
224
|
+
Examples
|
|
225
|
+
--------
|
|
226
|
+
>>> try: # doctest: +SKIP
|
|
227
|
+
... raise DataLoadError("missing data")
|
|
228
|
+
... except DataLoadError:
|
|
229
|
+
... pass
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
pass
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class ProcessingError(CANN2DError):
|
|
236
|
+
"""Raised when data processing fails.
|
|
237
|
+
|
|
238
|
+
Examples
|
|
239
|
+
--------
|
|
240
|
+
>>> try: # doctest: +SKIP
|
|
241
|
+
... raise ProcessingError("processing failed")
|
|
242
|
+
... except ProcessingError:
|
|
243
|
+
... pass
|
|
244
|
+
"""
|
|
245
|
+
|
|
246
|
+
pass
|
|
@@ -0,0 +1,448 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from scipy.sparse.linalg import lsmr
|
|
8
|
+
from sklearn import preprocessing
|
|
9
|
+
|
|
10
|
+
from .config import SpikeEmbeddingConfig
|
|
11
|
+
from .embedding import embed_spike_trains
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def decode_circular_coordinates(
|
|
15
|
+
persistence_result: dict[str, Any],
|
|
16
|
+
spike_data: dict[str, Any],
|
|
17
|
+
real_ground: bool = True,
|
|
18
|
+
real_of: bool = True,
|
|
19
|
+
save_path: str | None = None,
|
|
20
|
+
) -> dict[str, Any]:
|
|
21
|
+
"""
|
|
22
|
+
Decode circular coordinates (bump positions) from cohomology.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
persistence_result : dict
|
|
27
|
+
Output from :func:`canns.analyzer.data.tda_vis`, containing keys:
|
|
28
|
+
``persistence``, ``indstemp``, ``movetimes``, ``n_points``.
|
|
29
|
+
spike_data : dict
|
|
30
|
+
Spike data dictionary containing ``'spike'``, ``'t'`` and optionally ``'x'``/``'y'``.
|
|
31
|
+
real_ground : bool
|
|
32
|
+
Whether x/y/t ground-truth exists (controls whether speed filtering is applied).
|
|
33
|
+
real_of : bool
|
|
34
|
+
Whether the experiment is open-field (controls box coordinate handling).
|
|
35
|
+
save_path : str, optional
|
|
36
|
+
Path to save decoding results. Defaults to ``Results/spikes_decoding.npz``.
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
dict
|
|
41
|
+
Dictionary containing:
|
|
42
|
+
- ``coords``: decoded coordinates for all timepoints.
|
|
43
|
+
- ``coordsbox``: decoded coordinates for box timepoints.
|
|
44
|
+
- ``times``: time indices for ``coords``.
|
|
45
|
+
- ``times_box``: time indices for ``coordsbox``.
|
|
46
|
+
- ``centcosall`` / ``centsinall``: cosine/sine centroids.
|
|
47
|
+
|
|
48
|
+
Examples
|
|
49
|
+
--------
|
|
50
|
+
>>> from canns.analyzer.data import tda_vis, decode_circular_coordinates
|
|
51
|
+
>>> persistence = tda_vis(embed_spikes, config=tda_cfg) # doctest: +SKIP
|
|
52
|
+
>>> decoding = decode_circular_coordinates(persistence, spike_data) # doctest: +SKIP
|
|
53
|
+
>>> decoding["coords"].shape # doctest: +SKIP
|
|
54
|
+
"""
|
|
55
|
+
ph_classes = [0, 1] # Decode the ith most persistent cohomology class
|
|
56
|
+
num_circ = len(ph_classes)
|
|
57
|
+
dec_tresh = 0.99
|
|
58
|
+
coeff = 47
|
|
59
|
+
|
|
60
|
+
# Extract persistence analysis results
|
|
61
|
+
persistence = persistence_result["persistence"]
|
|
62
|
+
indstemp = persistence_result["indstemp"]
|
|
63
|
+
movetimes = persistence_result["movetimes"]
|
|
64
|
+
n_points = persistence_result["n_points"]
|
|
65
|
+
|
|
66
|
+
diagrams = persistence["dgms"] # the multiset describing the lives of the persistence classes
|
|
67
|
+
cocycles = persistence["cocycles"][1] # the cocycle representatives for the 1-dim classes
|
|
68
|
+
dists_land = persistence["dperm2all"] # the pairwise distance between the points
|
|
69
|
+
births1 = diagrams[1][:, 0] # the time of birth for the 1-dim classes
|
|
70
|
+
deaths1 = diagrams[1][:, 1] # the time of death for the 1-dim classes
|
|
71
|
+
deaths1[np.isinf(deaths1)] = 0
|
|
72
|
+
lives1 = deaths1 - births1 # the lifetime for the 1-dim classes
|
|
73
|
+
iMax = np.argsort(lives1)
|
|
74
|
+
coords1 = np.zeros((num_circ, len(indstemp)))
|
|
75
|
+
threshold = births1[iMax[-2]] + (deaths1[iMax[-2]] - births1[iMax[-2]]) * dec_tresh
|
|
76
|
+
|
|
77
|
+
for c in ph_classes:
|
|
78
|
+
cocycle = cocycles[iMax[-(c + 1)]]
|
|
79
|
+
f, verts = _get_coords(cocycle, threshold, len(indstemp), dists_land, coeff)
|
|
80
|
+
if len(verts) != len(indstemp):
|
|
81
|
+
raise ValueError(
|
|
82
|
+
"Circular coordinate reconstruction returned fewer vertices than sampled points. "
|
|
83
|
+
"Increase n_points/active_times or use denser data."
|
|
84
|
+
)
|
|
85
|
+
coords1[c, :] = f
|
|
86
|
+
|
|
87
|
+
# Whether the user-provided dataset has ground-truth x/y/t.
|
|
88
|
+
if real_ground:
|
|
89
|
+
sspikes, _, _, _ = embed_spike_trains(
|
|
90
|
+
spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=True)
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
sspikes, _, _, _ = embed_spike_trains(
|
|
94
|
+
spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=False)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
num_neurons = sspikes.shape[1]
|
|
98
|
+
centcosall = np.zeros((num_neurons, 2, n_points))
|
|
99
|
+
centsinall = np.zeros((num_neurons, 2, n_points))
|
|
100
|
+
dspk = preprocessing.scale(sspikes[movetimes[indstemp], :])
|
|
101
|
+
|
|
102
|
+
for neurid in range(num_neurons):
|
|
103
|
+
spktemp = dspk[:, neurid].copy()
|
|
104
|
+
centcosall[neurid, :, :] = np.multiply(np.cos(coords1[:, :] * 2 * np.pi), spktemp)
|
|
105
|
+
centsinall[neurid, :, :] = np.multiply(np.sin(coords1[:, :] * 2 * np.pi), spktemp)
|
|
106
|
+
|
|
107
|
+
# Whether the user-provided dataset has ground-truth x/y/t.
|
|
108
|
+
if real_ground:
|
|
109
|
+
sspikes, _, _, _ = embed_spike_trains(
|
|
110
|
+
spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=True)
|
|
111
|
+
)
|
|
112
|
+
spikes, __, __, __ = embed_spike_trains(
|
|
113
|
+
spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
sspikes, _, _, _ = embed_spike_trains(
|
|
117
|
+
spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=False)
|
|
118
|
+
)
|
|
119
|
+
spikes, _, _, _ = embed_spike_trains(
|
|
120
|
+
spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=False)
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
times = np.where(np.sum(spikes > 0, 1) >= 1)[0]
|
|
124
|
+
dspk = preprocessing.scale(sspikes)
|
|
125
|
+
sspikes = sspikes[times, :]
|
|
126
|
+
dspk = dspk[times, :]
|
|
127
|
+
|
|
128
|
+
a = np.zeros((len(sspikes[:, 0]), 2, num_neurons))
|
|
129
|
+
for n in range(num_neurons):
|
|
130
|
+
a[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centcosall[n, :, :], 1))
|
|
131
|
+
|
|
132
|
+
c = np.zeros((len(sspikes[:, 0]), 2, num_neurons))
|
|
133
|
+
for n in range(num_neurons):
|
|
134
|
+
c[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centsinall[n, :, :], 1))
|
|
135
|
+
|
|
136
|
+
mtot2 = np.sum(c, 2)
|
|
137
|
+
mtot1 = np.sum(a, 2)
|
|
138
|
+
coords = np.arctan2(mtot2, mtot1) % (2 * np.pi)
|
|
139
|
+
|
|
140
|
+
# Whether the dataset comes from a real open-field (OF) environment.
|
|
141
|
+
if real_of:
|
|
142
|
+
coordsbox = coords.copy()
|
|
143
|
+
times_box = times.copy()
|
|
144
|
+
else:
|
|
145
|
+
sspikes, _, _, _ = embed_spike_trains(
|
|
146
|
+
spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=True)
|
|
147
|
+
)
|
|
148
|
+
spikes, __, __, __ = embed_spike_trains(
|
|
149
|
+
spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
|
|
150
|
+
)
|
|
151
|
+
dspk = preprocessing.scale(sspikes)
|
|
152
|
+
times_box = np.where(np.sum(spikes > 0, 1) >= 1)[0]
|
|
153
|
+
dspk = dspk[times_box, :]
|
|
154
|
+
|
|
155
|
+
a = np.zeros((len(times_box), 2, num_neurons))
|
|
156
|
+
for n in range(num_neurons):
|
|
157
|
+
a[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centcosall[n, :, :], 1))
|
|
158
|
+
|
|
159
|
+
c = np.zeros((len(times_box), 2, num_neurons))
|
|
160
|
+
for n in range(num_neurons):
|
|
161
|
+
c[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centsinall[n, :, :], 1))
|
|
162
|
+
|
|
163
|
+
mtot2 = np.sum(c, 2)
|
|
164
|
+
mtot1 = np.sum(a, 2)
|
|
165
|
+
coordsbox = np.arctan2(mtot2, mtot1) % (2 * np.pi)
|
|
166
|
+
|
|
167
|
+
# Prepare results dictionary
|
|
168
|
+
results = {
|
|
169
|
+
"coords": coords,
|
|
170
|
+
"coordsbox": coordsbox,
|
|
171
|
+
"times": times,
|
|
172
|
+
"times_box": times_box,
|
|
173
|
+
"centcosall": centcosall,
|
|
174
|
+
"centsinall": centsinall,
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
# Save results
|
|
178
|
+
if save_path is None:
|
|
179
|
+
os.makedirs("Results", exist_ok=True)
|
|
180
|
+
save_path = "Results/spikes_decoding.npz"
|
|
181
|
+
|
|
182
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
183
|
+
np.savez_compressed(save_path, **results)
|
|
184
|
+
|
|
185
|
+
return results
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def decode_circular_coordinates1(
|
|
189
|
+
persistence_result: dict[str, Any],
|
|
190
|
+
spike_data: dict[str, Any],
|
|
191
|
+
save_path: str | None = None,
|
|
192
|
+
) -> dict[str, Any]:
|
|
193
|
+
"""Legacy helper kept for backward compatibility."""
|
|
194
|
+
ph_classes = [0, 1] # Decode the ith most persistent cohomology class
|
|
195
|
+
num_circ = len(ph_classes)
|
|
196
|
+
dec_tresh = 0.99
|
|
197
|
+
coeff = 47
|
|
198
|
+
|
|
199
|
+
# Extract persistence analysis results
|
|
200
|
+
persistence = persistence_result["persistence"]
|
|
201
|
+
indstemp = persistence_result["indstemp"]
|
|
202
|
+
movetimes = persistence_result["movetimes"]
|
|
203
|
+
n_points = persistence_result["n_points"]
|
|
204
|
+
|
|
205
|
+
diagrams = persistence["dgms"] # the multiset describing the lives of the persistence classes
|
|
206
|
+
cocycles = persistence["cocycles"][1] # the cocycle representatives for the 1-dim classes
|
|
207
|
+
dists_land = persistence["dperm2all"] # the pairwise distance between the points
|
|
208
|
+
births1 = diagrams[1][:, 0] # the time of birth for the 1-dim classes
|
|
209
|
+
deaths1 = diagrams[1][:, 1] # the time of death for the 1-dim classes
|
|
210
|
+
deaths1[np.isinf(deaths1)] = 0
|
|
211
|
+
lives1 = deaths1 - births1 # the lifetime for the 1-dim classes
|
|
212
|
+
iMax = np.argsort(lives1)
|
|
213
|
+
coords1 = np.zeros((num_circ, len(indstemp)))
|
|
214
|
+
threshold = births1[iMax[-2]] + (deaths1[iMax[-2]] - births1[iMax[-2]]) * dec_tresh
|
|
215
|
+
|
|
216
|
+
for c in ph_classes:
|
|
217
|
+
cocycle = cocycles[iMax[-(c + 1)]]
|
|
218
|
+
f, verts = _get_coords(cocycle, threshold, len(indstemp), dists_land, coeff)
|
|
219
|
+
if len(verts) != len(indstemp):
|
|
220
|
+
raise ValueError(
|
|
221
|
+
"Circular coordinate reconstruction returned fewer vertices than sampled points. "
|
|
222
|
+
"Increase n_points/active_times or use denser data."
|
|
223
|
+
)
|
|
224
|
+
coords1[c, :] = f
|
|
225
|
+
|
|
226
|
+
sspikes = spike_data["spike"]
|
|
227
|
+
num_neurons = sspikes.shape[1]
|
|
228
|
+
centcosall = np.zeros((num_neurons, 2, n_points))
|
|
229
|
+
centsinall = np.zeros((num_neurons, 2, n_points))
|
|
230
|
+
dspk = preprocessing.scale(sspikes[movetimes[indstemp], :])
|
|
231
|
+
|
|
232
|
+
for neurid in range(num_neurons):
|
|
233
|
+
spktemp = dspk[:, neurid].copy()
|
|
234
|
+
centcosall[neurid, :, :] = np.multiply(np.cos(coords1[:, :] * 2 * np.pi), spktemp)
|
|
235
|
+
centsinall[neurid, :, :] = np.multiply(np.sin(coords1[:, :] * 2 * np.pi), spktemp)
|
|
236
|
+
|
|
237
|
+
times = np.where(np.sum(sspikes > 0, 1) >= 1)[0]
|
|
238
|
+
dspk = preprocessing.scale(sspikes)
|
|
239
|
+
sspikes = sspikes[times, :]
|
|
240
|
+
dspk = dspk[times, :]
|
|
241
|
+
|
|
242
|
+
a = np.zeros((len(sspikes[:, 0]), 2, num_neurons))
|
|
243
|
+
for n in range(num_neurons):
|
|
244
|
+
a[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centcosall[n, :, :], 1))
|
|
245
|
+
|
|
246
|
+
c = np.zeros((len(sspikes[:, 0]), 2, num_neurons))
|
|
247
|
+
for n in range(num_neurons):
|
|
248
|
+
c[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centsinall[n, :, :], 1))
|
|
249
|
+
|
|
250
|
+
mtot2 = np.sum(c, 2)
|
|
251
|
+
mtot1 = np.sum(a, 2)
|
|
252
|
+
coords = np.arctan2(mtot2, mtot1) % (2 * np.pi)
|
|
253
|
+
|
|
254
|
+
coordsbox = coords.copy()
|
|
255
|
+
times_box = times.copy()
|
|
256
|
+
|
|
257
|
+
# Prepare results dictionary
|
|
258
|
+
results = {
|
|
259
|
+
"coords": coords,
|
|
260
|
+
"coordsbox": coordsbox,
|
|
261
|
+
"times": times,
|
|
262
|
+
"times_box": times_box,
|
|
263
|
+
"centcosall": centcosall,
|
|
264
|
+
"centsinall": centsinall,
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
# Save results
|
|
268
|
+
if save_path is None:
|
|
269
|
+
os.makedirs("Results", exist_ok=True)
|
|
270
|
+
save_path = "Results/spikes_decoding.npz"
|
|
271
|
+
|
|
272
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
273
|
+
np.savez_compressed(save_path, **results)
|
|
274
|
+
|
|
275
|
+
return results
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def decode_circular_coordinates_multi(
|
|
279
|
+
persistence_result: dict,
|
|
280
|
+
spike_data: dict,
|
|
281
|
+
save_path: str | None = None,
|
|
282
|
+
num_circ: int = 2, # Number of H1 cocycles/circular coordinates to decode
|
|
283
|
+
) -> dict:
|
|
284
|
+
"""Decode multiple circular coordinates from TDA persistence.
|
|
285
|
+
|
|
286
|
+
Parameters
|
|
287
|
+
----------
|
|
288
|
+
persistence_result : dict
|
|
289
|
+
Output from :func:`canns.analyzer.data.tda_vis`, containing keys:
|
|
290
|
+
``persistence``, ``indstemp``, ``movetimes``, ``n_points``.
|
|
291
|
+
spike_data : dict
|
|
292
|
+
Spike data dictionary containing ``'spike'``, ``'t'`` and optionally ``'x'``/``'y'``.
|
|
293
|
+
save_path : str, optional
|
|
294
|
+
Path to save decoding results. Defaults to ``Results/spikes_decoding.npz``.
|
|
295
|
+
num_circ : int
|
|
296
|
+
Number of H1 cocycles/circular coordinates to decode.
|
|
297
|
+
|
|
298
|
+
Returns
|
|
299
|
+
-------
|
|
300
|
+
dict
|
|
301
|
+
Dictionary with ``coords``, ``coordsbox``, ``times``, ``times_box`` and centroid terms.
|
|
302
|
+
|
|
303
|
+
Examples
|
|
304
|
+
--------
|
|
305
|
+
>>> decoding = decode_circular_coordinates_multi(persistence, spike_data, num_circ=2) # doctest: +SKIP
|
|
306
|
+
>>> decoding["coords"].shape # doctest: +SKIP
|
|
307
|
+
"""
|
|
308
|
+
from sklearn import preprocessing
|
|
309
|
+
|
|
310
|
+
dec_tresh = 0.99
|
|
311
|
+
coeff = 47
|
|
312
|
+
|
|
313
|
+
persistence = persistence_result["persistence"]
|
|
314
|
+
indstemp = persistence_result["indstemp"]
|
|
315
|
+
movetimes = persistence_result["movetimes"]
|
|
316
|
+
n_points = persistence_result["n_points"]
|
|
317
|
+
|
|
318
|
+
diagrams = persistence["dgms"]
|
|
319
|
+
cocycles = persistence["cocycles"][1]
|
|
320
|
+
dists_land = persistence["dperm2all"]
|
|
321
|
+
|
|
322
|
+
births1 = diagrams[1][:, 0]
|
|
323
|
+
deaths1 = diagrams[1][:, 1]
|
|
324
|
+
deaths1[np.isinf(deaths1)] = 0
|
|
325
|
+
lives1 = deaths1 - births1
|
|
326
|
+
|
|
327
|
+
if lives1.size < num_circ or len(cocycles) < num_circ:
|
|
328
|
+
raise ValueError(
|
|
329
|
+
f"Requested num_circ={num_circ}, but only {lives1.size} H1 feature(s) are available. "
|
|
330
|
+
"This usually means the chosen time window is too short, the data are too sparse, "
|
|
331
|
+
"or the embedding parameters are not appropriate."
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
iMax = np.argsort(lives1)
|
|
335
|
+
coords1 = np.zeros((num_circ, len(indstemp)))
|
|
336
|
+
|
|
337
|
+
for i in range(num_circ):
|
|
338
|
+
idx = iMax[-(i + 1)]
|
|
339
|
+
threshold = births1[idx] + (deaths1[idx] - births1[idx]) * dec_tresh
|
|
340
|
+
cocycle = cocycles[idx]
|
|
341
|
+
f, verts = _get_coords(cocycle, threshold, len(indstemp), dists_land, coeff)
|
|
342
|
+
if len(verts) != len(indstemp):
|
|
343
|
+
raise ValueError(
|
|
344
|
+
"Circular coordinate reconstruction returned fewer vertices than sampled points. "
|
|
345
|
+
"Increase n_points/active_times or use denser data."
|
|
346
|
+
)
|
|
347
|
+
coords1[i, :] = f
|
|
348
|
+
|
|
349
|
+
sspikes = spike_data["spike"]
|
|
350
|
+
num_neurons = sspikes.shape[1]
|
|
351
|
+
|
|
352
|
+
centcosall = np.zeros((num_neurons, num_circ, n_points))
|
|
353
|
+
centsinall = np.zeros((num_neurons, num_circ, n_points))
|
|
354
|
+
dspk = preprocessing.scale(sspikes[movetimes[indstemp], :])
|
|
355
|
+
|
|
356
|
+
for n in range(num_neurons):
|
|
357
|
+
spktemp = dspk[:, n].copy()
|
|
358
|
+
centcosall[n, :, :] = np.multiply(np.cos(coords1 * 2 * np.pi), spktemp)
|
|
359
|
+
centsinall[n, :, :] = np.multiply(np.sin(coords1 * 2 * np.pi), spktemp)
|
|
360
|
+
|
|
361
|
+
times = np.where(np.sum(sspikes > 0, 1) >= 1)[0]
|
|
362
|
+
dspk = preprocessing.scale(sspikes)
|
|
363
|
+
sspikes = sspikes[times, :]
|
|
364
|
+
dspk = dspk[times, :]
|
|
365
|
+
|
|
366
|
+
a = np.zeros((len(sspikes), num_circ, num_neurons))
|
|
367
|
+
c = np.zeros((len(sspikes), num_circ, num_neurons))
|
|
368
|
+
|
|
369
|
+
for n in range(num_neurons):
|
|
370
|
+
a[:, :, n] = dspk[:, n : n + 1] * np.sum(centcosall[n, :, :], axis=1)
|
|
371
|
+
c[:, :, n] = dspk[:, n : n + 1] * np.sum(centsinall[n, :, :], axis=1)
|
|
372
|
+
|
|
373
|
+
mtot1 = np.sum(a, 2)
|
|
374
|
+
mtot2 = np.sum(c, 2)
|
|
375
|
+
coords = np.arctan2(mtot2, mtot1) % (2 * np.pi)
|
|
376
|
+
|
|
377
|
+
results = {
|
|
378
|
+
"coords": coords,
|
|
379
|
+
"coordsbox": coords.copy(),
|
|
380
|
+
"times": times,
|
|
381
|
+
"times_box": times.copy(),
|
|
382
|
+
"centcosall": centcosall,
|
|
383
|
+
"centsinall": centsinall,
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
if save_path is None:
|
|
387
|
+
os.makedirs("Results", exist_ok=True)
|
|
388
|
+
save_path = "Results/spikes_decoding.npz"
|
|
389
|
+
|
|
390
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
391
|
+
np.savez_compressed(save_path, **results)
|
|
392
|
+
return results
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def _get_coords(cocycle, threshold, num_sampled, dists, coeff):
|
|
396
|
+
"""
|
|
397
|
+
Reconstruct circular coordinates from cocycle information.
|
|
398
|
+
|
|
399
|
+
Parameters:
|
|
400
|
+
cocycle (ndarray): Persistent cocycle representative.
|
|
401
|
+
threshold (float): Maximum allowable edge distance.
|
|
402
|
+
num_sampled (int): Number of sampled points.
|
|
403
|
+
dists (ndarray): Pairwise distance matrix.
|
|
404
|
+
coeff (int): Finite field modulus for cohomology.
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
f (ndarray): Circular coordinate values (in [0,1]).
|
|
408
|
+
verts (ndarray): Indices of used vertices.
|
|
409
|
+
"""
|
|
410
|
+
zint = np.where(coeff - cocycle[:, 2] < cocycle[:, 2])
|
|
411
|
+
cocycle[zint, 2] = cocycle[zint, 2] - coeff
|
|
412
|
+
d = np.zeros((num_sampled, num_sampled))
|
|
413
|
+
d[np.tril_indices(num_sampled)] = np.nan
|
|
414
|
+
d[cocycle[:, 1], cocycle[:, 0]] = cocycle[:, 2]
|
|
415
|
+
d[dists > threshold] = np.nan
|
|
416
|
+
d[dists == 0] = np.nan
|
|
417
|
+
edges = np.where(~np.isnan(d))
|
|
418
|
+
verts = np.array(np.unique(edges))
|
|
419
|
+
num_edges = np.shape(edges)[1]
|
|
420
|
+
num_verts = np.size(verts)
|
|
421
|
+
values = d[edges]
|
|
422
|
+
A = np.zeros((num_edges, num_verts), dtype=int)
|
|
423
|
+
v1 = np.zeros((num_edges, 2), dtype=int)
|
|
424
|
+
v2 = np.zeros((num_edges, 2), dtype=int)
|
|
425
|
+
for i in range(num_edges):
|
|
426
|
+
# Extract scalar indices from np.where results
|
|
427
|
+
idx1 = np.where(verts == edges[0][i])[0]
|
|
428
|
+
idx2 = np.where(verts == edges[1][i])[0]
|
|
429
|
+
|
|
430
|
+
# Handle case where np.where returns multiple matches (shouldn't happen in valid data)
|
|
431
|
+
if len(idx1) > 0:
|
|
432
|
+
v1[i, :] = [i, idx1[0]]
|
|
433
|
+
else:
|
|
434
|
+
raise ValueError(f"No vertex found for edge {edges[0][i]}")
|
|
435
|
+
|
|
436
|
+
if len(idx2) > 0:
|
|
437
|
+
v2[i, :] = [i, idx2[0]]
|
|
438
|
+
else:
|
|
439
|
+
raise ValueError(f"No vertex found for edge {edges[1][i]}")
|
|
440
|
+
|
|
441
|
+
A[v1[:, 0], v1[:, 1]] = -1
|
|
442
|
+
A[v2[:, 0], v2[:, 1]] = 1
|
|
443
|
+
|
|
444
|
+
L = np.ones((num_edges,))
|
|
445
|
+
Aw = A * np.sqrt(L[:, np.newaxis])
|
|
446
|
+
Bw = values * np.sqrt(L)
|
|
447
|
+
f = lsmr(Aw, Bw)[0] % 1
|
|
448
|
+
return f, verts
|