canns 0.12.7__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. canns/analyzer/data/__init__.py +3 -11
  2. canns/analyzer/data/asa/__init__.py +74 -0
  3. canns/analyzer/data/asa/cohospace.py +905 -0
  4. canns/analyzer/data/asa/config.py +246 -0
  5. canns/analyzer/data/asa/decode.py +448 -0
  6. canns/analyzer/data/asa/embedding.py +269 -0
  7. canns/analyzer/data/asa/filters.py +208 -0
  8. canns/analyzer/data/asa/fr.py +439 -0
  9. canns/analyzer/data/asa/path.py +389 -0
  10. canns/analyzer/data/asa/plotting.py +1276 -0
  11. canns/analyzer/data/asa/tda.py +901 -0
  12. canns/analyzer/data/legacy/__init__.py +6 -0
  13. canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
  14. canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
  15. canns/analyzer/visualization/core/backend.py +1 -1
  16. canns/analyzer/visualization/core/config.py +77 -0
  17. canns/analyzer/visualization/core/rendering.py +10 -6
  18. canns/analyzer/visualization/energy_plots.py +22 -8
  19. canns/analyzer/visualization/spatial_plots.py +31 -11
  20. canns/analyzer/visualization/theta_sweep_plots.py +15 -6
  21. canns/pipeline/__init__.py +4 -8
  22. canns/pipeline/asa/__init__.py +21 -0
  23. canns/pipeline/asa/__main__.py +11 -0
  24. canns/pipeline/asa/app.py +1000 -0
  25. canns/pipeline/asa/runner.py +1095 -0
  26. canns/pipeline/asa/screens.py +215 -0
  27. canns/pipeline/asa/state.py +248 -0
  28. canns/pipeline/asa/styles.tcss +221 -0
  29. canns/pipeline/asa/widgets.py +233 -0
  30. canns/pipeline/gallery/__init__.py +7 -0
  31. canns/task/open_loop_navigation.py +3 -1
  32. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
  33. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/RECORD +36 -17
  34. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
  35. canns/pipeline/theta_sweep.py +0 -573
  36. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
  37. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,246 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from ...visualization import PlotConfig
6
+
7
+
8
+ @dataclass
9
+ class SpikeEmbeddingConfig:
10
+ """Configuration for spike train embedding.
11
+
12
+ Attributes
13
+ ----------
14
+ res : int
15
+ Time scaling factor that converts seconds to integer bins.
16
+ dt : int
17
+ Bin width in the same scaled units as ``res``.
18
+ sigma : int
19
+ Gaussian smoothing width (scaled units).
20
+ smooth : bool
21
+ Whether to apply temporal smoothing to spike counts.
22
+ speed_filter : bool
23
+ Whether to filter by animal speed (requires x/y/t in the input).
24
+ min_speed : float
25
+ Minimum speed threshold for ``speed_filter`` (cm/s by convention).
26
+
27
+ Examples
28
+ --------
29
+ >>> from canns.analyzer.data import SpikeEmbeddingConfig
30
+ >>> cfg = SpikeEmbeddingConfig(smooth=False, speed_filter=False)
31
+ >>> cfg.min_speed
32
+ 2.5
33
+ """
34
+
35
+ res: int = 100000
36
+ dt: int = 1000
37
+ sigma: int = 5000
38
+ smooth: bool = True
39
+ speed_filter: bool = True
40
+ min_speed: float = 2.5
41
+
42
+
43
+ @dataclass
44
+ class TDAConfig:
45
+ """Configuration for Topological Data Analysis (TDA).
46
+
47
+ Attributes
48
+ ----------
49
+ dim : int
50
+ Target PCA dimension before TDA.
51
+ num_times : int
52
+ Downsampling stride in time.
53
+ active_times : int
54
+ Number of most active time points to keep.
55
+ k : int
56
+ Number of neighbors used in denoising.
57
+ n_points : int
58
+ Number of points sampled for persistent homology.
59
+ metric : str
60
+ Distance metric for point cloud (e.g., "cosine").
61
+ nbs : int
62
+ Number of neighbors for distance matrix construction.
63
+ maxdim : int
64
+ Maximum homology dimension for persistence.
65
+ coeff : int
66
+ Field coefficient for persistent homology.
67
+ show : bool
68
+ Whether to show barcode plots.
69
+ do_shuffle : bool
70
+ Whether to run shuffle analysis.
71
+ num_shuffles : int
72
+ Number of shuffles for null distribution.
73
+ progress_bar : bool
74
+ Whether to show progress bars.
75
+
76
+ Examples
77
+ --------
78
+ >>> from canns.analyzer.data import TDAConfig
79
+ >>> cfg = TDAConfig(maxdim=1, do_shuffle=False, show=False)
80
+ >>> cfg.maxdim
81
+ 1
82
+ """
83
+
84
+ dim: int = 6
85
+ num_times: int = 5
86
+ active_times: int = 15000
87
+ k: int = 1000
88
+ n_points: int = 1200
89
+ metric: str = "cosine"
90
+ nbs: int = 800
91
+ maxdim: int = 1
92
+ coeff: int = 47
93
+ show: bool = True
94
+ do_shuffle: bool = False
95
+ num_shuffles: int = 1000
96
+ progress_bar: bool = True
97
+
98
+
99
+ @dataclass
100
+ class CANN2DPlotConfig(PlotConfig):
101
+ """Specialized PlotConfig for CANN2D visualizations.
102
+
103
+ Extends :class:`canns.analyzer.visualization.PlotConfig` with fields that
104
+ control 3D projection and torus animation parameters.
105
+
106
+ Examples
107
+ --------
108
+ >>> from canns.analyzer.data import CANN2DPlotConfig
109
+ >>> cfg = CANN2DPlotConfig.for_projection_3d(title="Projection")
110
+ >>> cfg.zlabel
111
+ 'Component 3'
112
+ """
113
+
114
+ # 3D projection specific parameters
115
+ zlabel: str = "Component 3"
116
+ dpi: int = 300
117
+
118
+ # Torus animation specific parameters
119
+ numangsint: int = 51
120
+ r1: float = 1.5 # Major radius
121
+ r2: float = 1.0 # Minor radius
122
+ window_size: int = 300
123
+ frame_step: int = 5
124
+ n_frames: int = 20
125
+
126
+ @classmethod
127
+ def for_projection_3d(cls, **kwargs) -> CANN2DPlotConfig:
128
+ """Create configuration for 3D projection plots.
129
+
130
+ Examples
131
+ --------
132
+ >>> cfg = CANN2DPlotConfig.for_projection_3d(figsize=(6, 5))
133
+ >>> cfg.figsize
134
+ (6, 5)
135
+ """
136
+ defaults = {
137
+ "title": "3D Data Projection",
138
+ "xlabel": "Component 1",
139
+ "ylabel": "Component 2",
140
+ "zlabel": "Component 3",
141
+ "figsize": (10, 8),
142
+ "dpi": 300,
143
+ }
144
+ defaults.update(kwargs)
145
+ return cls.for_static_plot(**defaults)
146
+
147
+ @classmethod
148
+ def for_torus_animation(cls, **kwargs) -> CANN2DPlotConfig:
149
+ """Create configuration for 3D torus bump animations.
150
+
151
+ Examples
152
+ --------
153
+ >>> cfg = CANN2DPlotConfig.for_torus_animation(fps=10, n_frames=50)
154
+ >>> cfg.fps, cfg.n_frames
155
+ (10, 50)
156
+ """
157
+ defaults = {
158
+ "title": "3D Bump on Torus",
159
+ "figsize": (8, 8),
160
+ "fps": 5,
161
+ "repeat": True,
162
+ "show_progress_bar": True,
163
+ "numangsint": 51,
164
+ "r1": 1.5,
165
+ "r2": 1.0,
166
+ "window_size": 300,
167
+ "frame_step": 5,
168
+ "n_frames": 20,
169
+ }
170
+ defaults.update(kwargs)
171
+ time_steps = kwargs.get("time_steps_per_second", 1000)
172
+ config = cls.for_animation(time_steps, **defaults)
173
+ # Add torus-specific attributes
174
+ config.numangsint = defaults["numangsint"]
175
+ config.r1 = defaults["r1"]
176
+ config.r2 = defaults["r2"]
177
+ config.window_size = defaults["window_size"]
178
+ config.frame_step = defaults["frame_step"]
179
+ config.n_frames = defaults["n_frames"]
180
+ return config
181
+
182
+
183
+ # ==================== Constants ====================
184
+
185
+
186
+ class Constants:
187
+ """Constants used throughout CANN2D analysis.
188
+
189
+ Examples
190
+ --------
191
+ >>> from canns.analyzer.data import Constants
192
+ >>> Constants.DEFAULT_DPI
193
+ 300
194
+ """
195
+
196
+ DEFAULT_FIGSIZE = (10, 8)
197
+ DEFAULT_DPI = 300
198
+ GAUSSIAN_SIGMA_FACTOR = 100
199
+ SPEED_CONVERSION_FACTOR = 100
200
+ TIME_CONVERSION_FACTOR = 0.01
201
+ MULTIPROCESSING_CORES = 4
202
+
203
+
204
+ # ==================== Custom Exceptions ====================
205
+
206
+
207
+ class CANN2DError(Exception):
208
+ """Base exception for CANN2D analysis errors.
209
+
210
+ Examples
211
+ --------
212
+ >>> try: # doctest: +SKIP
213
+ ... raise CANN2DError("boom")
214
+ ... except CANN2DError:
215
+ ... pass
216
+ """
217
+
218
+ pass
219
+
220
+
221
+ class DataLoadError(CANN2DError):
222
+ """Raised when data loading fails.
223
+
224
+ Examples
225
+ --------
226
+ >>> try: # doctest: +SKIP
227
+ ... raise DataLoadError("missing data")
228
+ ... except DataLoadError:
229
+ ... pass
230
+ """
231
+
232
+ pass
233
+
234
+
235
+ class ProcessingError(CANN2DError):
236
+ """Raised when data processing fails.
237
+
238
+ Examples
239
+ --------
240
+ >>> try: # doctest: +SKIP
241
+ ... raise ProcessingError("processing failed")
242
+ ... except ProcessingError:
243
+ ... pass
244
+ """
245
+
246
+ pass
@@ -0,0 +1,448 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ from scipy.sparse.linalg import lsmr
8
+ from sklearn import preprocessing
9
+
10
+ from .config import SpikeEmbeddingConfig
11
+ from .embedding import embed_spike_trains
12
+
13
+
14
+ def decode_circular_coordinates(
15
+ persistence_result: dict[str, Any],
16
+ spike_data: dict[str, Any],
17
+ real_ground: bool = True,
18
+ real_of: bool = True,
19
+ save_path: str | None = None,
20
+ ) -> dict[str, Any]:
21
+ """
22
+ Decode circular coordinates (bump positions) from cohomology.
23
+
24
+ Parameters
25
+ ----------
26
+ persistence_result : dict
27
+ Output from :func:`canns.analyzer.data.tda_vis`, containing keys:
28
+ ``persistence``, ``indstemp``, ``movetimes``, ``n_points``.
29
+ spike_data : dict
30
+ Spike data dictionary containing ``'spike'``, ``'t'`` and optionally ``'x'``/``'y'``.
31
+ real_ground : bool
32
+ Whether x/y/t ground-truth exists (controls whether speed filtering is applied).
33
+ real_of : bool
34
+ Whether the experiment is open-field (controls box coordinate handling).
35
+ save_path : str, optional
36
+ Path to save decoding results. Defaults to ``Results/spikes_decoding.npz``.
37
+
38
+ Returns
39
+ -------
40
+ dict
41
+ Dictionary containing:
42
+ - ``coords``: decoded coordinates for all timepoints.
43
+ - ``coordsbox``: decoded coordinates for box timepoints.
44
+ - ``times``: time indices for ``coords``.
45
+ - ``times_box``: time indices for ``coordsbox``.
46
+ - ``centcosall`` / ``centsinall``: cosine/sine centroids.
47
+
48
+ Examples
49
+ --------
50
+ >>> from canns.analyzer.data import tda_vis, decode_circular_coordinates
51
+ >>> persistence = tda_vis(embed_spikes, config=tda_cfg) # doctest: +SKIP
52
+ >>> decoding = decode_circular_coordinates(persistence, spike_data) # doctest: +SKIP
53
+ >>> decoding["coords"].shape # doctest: +SKIP
54
+ """
55
+ ph_classes = [0, 1] # Decode the ith most persistent cohomology class
56
+ num_circ = len(ph_classes)
57
+ dec_tresh = 0.99
58
+ coeff = 47
59
+
60
+ # Extract persistence analysis results
61
+ persistence = persistence_result["persistence"]
62
+ indstemp = persistence_result["indstemp"]
63
+ movetimes = persistence_result["movetimes"]
64
+ n_points = persistence_result["n_points"]
65
+
66
+ diagrams = persistence["dgms"] # the multiset describing the lives of the persistence classes
67
+ cocycles = persistence["cocycles"][1] # the cocycle representatives for the 1-dim classes
68
+ dists_land = persistence["dperm2all"] # the pairwise distance between the points
69
+ births1 = diagrams[1][:, 0] # the time of birth for the 1-dim classes
70
+ deaths1 = diagrams[1][:, 1] # the time of death for the 1-dim classes
71
+ deaths1[np.isinf(deaths1)] = 0
72
+ lives1 = deaths1 - births1 # the lifetime for the 1-dim classes
73
+ iMax = np.argsort(lives1)
74
+ coords1 = np.zeros((num_circ, len(indstemp)))
75
+ threshold = births1[iMax[-2]] + (deaths1[iMax[-2]] - births1[iMax[-2]]) * dec_tresh
76
+
77
+ for c in ph_classes:
78
+ cocycle = cocycles[iMax[-(c + 1)]]
79
+ f, verts = _get_coords(cocycle, threshold, len(indstemp), dists_land, coeff)
80
+ if len(verts) != len(indstemp):
81
+ raise ValueError(
82
+ "Circular coordinate reconstruction returned fewer vertices than sampled points. "
83
+ "Increase n_points/active_times or use denser data."
84
+ )
85
+ coords1[c, :] = f
86
+
87
+ # Whether the user-provided dataset has ground-truth x/y/t.
88
+ if real_ground:
89
+ sspikes, _, _, _ = embed_spike_trains(
90
+ spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=True)
91
+ )
92
+ else:
93
+ sspikes, _, _, _ = embed_spike_trains(
94
+ spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=False)
95
+ )
96
+
97
+ num_neurons = sspikes.shape[1]
98
+ centcosall = np.zeros((num_neurons, 2, n_points))
99
+ centsinall = np.zeros((num_neurons, 2, n_points))
100
+ dspk = preprocessing.scale(sspikes[movetimes[indstemp], :])
101
+
102
+ for neurid in range(num_neurons):
103
+ spktemp = dspk[:, neurid].copy()
104
+ centcosall[neurid, :, :] = np.multiply(np.cos(coords1[:, :] * 2 * np.pi), spktemp)
105
+ centsinall[neurid, :, :] = np.multiply(np.sin(coords1[:, :] * 2 * np.pi), spktemp)
106
+
107
+ # Whether the user-provided dataset has ground-truth x/y/t.
108
+ if real_ground:
109
+ sspikes, _, _, _ = embed_spike_trains(
110
+ spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=True)
111
+ )
112
+ spikes, __, __, __ = embed_spike_trains(
113
+ spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
114
+ )
115
+ else:
116
+ sspikes, _, _, _ = embed_spike_trains(
117
+ spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=False)
118
+ )
119
+ spikes, _, _, _ = embed_spike_trains(
120
+ spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=False)
121
+ )
122
+
123
+ times = np.where(np.sum(spikes > 0, 1) >= 1)[0]
124
+ dspk = preprocessing.scale(sspikes)
125
+ sspikes = sspikes[times, :]
126
+ dspk = dspk[times, :]
127
+
128
+ a = np.zeros((len(sspikes[:, 0]), 2, num_neurons))
129
+ for n in range(num_neurons):
130
+ a[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centcosall[n, :, :], 1))
131
+
132
+ c = np.zeros((len(sspikes[:, 0]), 2, num_neurons))
133
+ for n in range(num_neurons):
134
+ c[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centsinall[n, :, :], 1))
135
+
136
+ mtot2 = np.sum(c, 2)
137
+ mtot1 = np.sum(a, 2)
138
+ coords = np.arctan2(mtot2, mtot1) % (2 * np.pi)
139
+
140
+ # Whether the dataset comes from a real open-field (OF) environment.
141
+ if real_of:
142
+ coordsbox = coords.copy()
143
+ times_box = times.copy()
144
+ else:
145
+ sspikes, _, _, _ = embed_spike_trains(
146
+ spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=True)
147
+ )
148
+ spikes, __, __, __ = embed_spike_trains(
149
+ spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
150
+ )
151
+ dspk = preprocessing.scale(sspikes)
152
+ times_box = np.where(np.sum(spikes > 0, 1) >= 1)[0]
153
+ dspk = dspk[times_box, :]
154
+
155
+ a = np.zeros((len(times_box), 2, num_neurons))
156
+ for n in range(num_neurons):
157
+ a[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centcosall[n, :, :], 1))
158
+
159
+ c = np.zeros((len(times_box), 2, num_neurons))
160
+ for n in range(num_neurons):
161
+ c[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centsinall[n, :, :], 1))
162
+
163
+ mtot2 = np.sum(c, 2)
164
+ mtot1 = np.sum(a, 2)
165
+ coordsbox = np.arctan2(mtot2, mtot1) % (2 * np.pi)
166
+
167
+ # Prepare results dictionary
168
+ results = {
169
+ "coords": coords,
170
+ "coordsbox": coordsbox,
171
+ "times": times,
172
+ "times_box": times_box,
173
+ "centcosall": centcosall,
174
+ "centsinall": centsinall,
175
+ }
176
+
177
+ # Save results
178
+ if save_path is None:
179
+ os.makedirs("Results", exist_ok=True)
180
+ save_path = "Results/spikes_decoding.npz"
181
+
182
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
183
+ np.savez_compressed(save_path, **results)
184
+
185
+ return results
186
+
187
+
188
+ def decode_circular_coordinates1(
189
+ persistence_result: dict[str, Any],
190
+ spike_data: dict[str, Any],
191
+ save_path: str | None = None,
192
+ ) -> dict[str, Any]:
193
+ """Legacy helper kept for backward compatibility."""
194
+ ph_classes = [0, 1] # Decode the ith most persistent cohomology class
195
+ num_circ = len(ph_classes)
196
+ dec_tresh = 0.99
197
+ coeff = 47
198
+
199
+ # Extract persistence analysis results
200
+ persistence = persistence_result["persistence"]
201
+ indstemp = persistence_result["indstemp"]
202
+ movetimes = persistence_result["movetimes"]
203
+ n_points = persistence_result["n_points"]
204
+
205
+ diagrams = persistence["dgms"] # the multiset describing the lives of the persistence classes
206
+ cocycles = persistence["cocycles"][1] # the cocycle representatives for the 1-dim classes
207
+ dists_land = persistence["dperm2all"] # the pairwise distance between the points
208
+ births1 = diagrams[1][:, 0] # the time of birth for the 1-dim classes
209
+ deaths1 = diagrams[1][:, 1] # the time of death for the 1-dim classes
210
+ deaths1[np.isinf(deaths1)] = 0
211
+ lives1 = deaths1 - births1 # the lifetime for the 1-dim classes
212
+ iMax = np.argsort(lives1)
213
+ coords1 = np.zeros((num_circ, len(indstemp)))
214
+ threshold = births1[iMax[-2]] + (deaths1[iMax[-2]] - births1[iMax[-2]]) * dec_tresh
215
+
216
+ for c in ph_classes:
217
+ cocycle = cocycles[iMax[-(c + 1)]]
218
+ f, verts = _get_coords(cocycle, threshold, len(indstemp), dists_land, coeff)
219
+ if len(verts) != len(indstemp):
220
+ raise ValueError(
221
+ "Circular coordinate reconstruction returned fewer vertices than sampled points. "
222
+ "Increase n_points/active_times or use denser data."
223
+ )
224
+ coords1[c, :] = f
225
+
226
+ sspikes = spike_data["spike"]
227
+ num_neurons = sspikes.shape[1]
228
+ centcosall = np.zeros((num_neurons, 2, n_points))
229
+ centsinall = np.zeros((num_neurons, 2, n_points))
230
+ dspk = preprocessing.scale(sspikes[movetimes[indstemp], :])
231
+
232
+ for neurid in range(num_neurons):
233
+ spktemp = dspk[:, neurid].copy()
234
+ centcosall[neurid, :, :] = np.multiply(np.cos(coords1[:, :] * 2 * np.pi), spktemp)
235
+ centsinall[neurid, :, :] = np.multiply(np.sin(coords1[:, :] * 2 * np.pi), spktemp)
236
+
237
+ times = np.where(np.sum(sspikes > 0, 1) >= 1)[0]
238
+ dspk = preprocessing.scale(sspikes)
239
+ sspikes = sspikes[times, :]
240
+ dspk = dspk[times, :]
241
+
242
+ a = np.zeros((len(sspikes[:, 0]), 2, num_neurons))
243
+ for n in range(num_neurons):
244
+ a[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centcosall[n, :, :], 1))
245
+
246
+ c = np.zeros((len(sspikes[:, 0]), 2, num_neurons))
247
+ for n in range(num_neurons):
248
+ c[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centsinall[n, :, :], 1))
249
+
250
+ mtot2 = np.sum(c, 2)
251
+ mtot1 = np.sum(a, 2)
252
+ coords = np.arctan2(mtot2, mtot1) % (2 * np.pi)
253
+
254
+ coordsbox = coords.copy()
255
+ times_box = times.copy()
256
+
257
+ # Prepare results dictionary
258
+ results = {
259
+ "coords": coords,
260
+ "coordsbox": coordsbox,
261
+ "times": times,
262
+ "times_box": times_box,
263
+ "centcosall": centcosall,
264
+ "centsinall": centsinall,
265
+ }
266
+
267
+ # Save results
268
+ if save_path is None:
269
+ os.makedirs("Results", exist_ok=True)
270
+ save_path = "Results/spikes_decoding.npz"
271
+
272
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
273
+ np.savez_compressed(save_path, **results)
274
+
275
+ return results
276
+
277
+
278
+ def decode_circular_coordinates_multi(
279
+ persistence_result: dict,
280
+ spike_data: dict,
281
+ save_path: str | None = None,
282
+ num_circ: int = 2, # Number of H1 cocycles/circular coordinates to decode
283
+ ) -> dict:
284
+ """Decode multiple circular coordinates from TDA persistence.
285
+
286
+ Parameters
287
+ ----------
288
+ persistence_result : dict
289
+ Output from :func:`canns.analyzer.data.tda_vis`, containing keys:
290
+ ``persistence``, ``indstemp``, ``movetimes``, ``n_points``.
291
+ spike_data : dict
292
+ Spike data dictionary containing ``'spike'``, ``'t'`` and optionally ``'x'``/``'y'``.
293
+ save_path : str, optional
294
+ Path to save decoding results. Defaults to ``Results/spikes_decoding.npz``.
295
+ num_circ : int
296
+ Number of H1 cocycles/circular coordinates to decode.
297
+
298
+ Returns
299
+ -------
300
+ dict
301
+ Dictionary with ``coords``, ``coordsbox``, ``times``, ``times_box`` and centroid terms.
302
+
303
+ Examples
304
+ --------
305
+ >>> decoding = decode_circular_coordinates_multi(persistence, spike_data, num_circ=2) # doctest: +SKIP
306
+ >>> decoding["coords"].shape # doctest: +SKIP
307
+ """
308
+ from sklearn import preprocessing
309
+
310
+ dec_tresh = 0.99
311
+ coeff = 47
312
+
313
+ persistence = persistence_result["persistence"]
314
+ indstemp = persistence_result["indstemp"]
315
+ movetimes = persistence_result["movetimes"]
316
+ n_points = persistence_result["n_points"]
317
+
318
+ diagrams = persistence["dgms"]
319
+ cocycles = persistence["cocycles"][1]
320
+ dists_land = persistence["dperm2all"]
321
+
322
+ births1 = diagrams[1][:, 0]
323
+ deaths1 = diagrams[1][:, 1]
324
+ deaths1[np.isinf(deaths1)] = 0
325
+ lives1 = deaths1 - births1
326
+
327
+ if lives1.size < num_circ or len(cocycles) < num_circ:
328
+ raise ValueError(
329
+ f"Requested num_circ={num_circ}, but only {lives1.size} H1 feature(s) are available. "
330
+ "This usually means the chosen time window is too short, the data are too sparse, "
331
+ "or the embedding parameters are not appropriate."
332
+ )
333
+
334
+ iMax = np.argsort(lives1)
335
+ coords1 = np.zeros((num_circ, len(indstemp)))
336
+
337
+ for i in range(num_circ):
338
+ idx = iMax[-(i + 1)]
339
+ threshold = births1[idx] + (deaths1[idx] - births1[idx]) * dec_tresh
340
+ cocycle = cocycles[idx]
341
+ f, verts = _get_coords(cocycle, threshold, len(indstemp), dists_land, coeff)
342
+ if len(verts) != len(indstemp):
343
+ raise ValueError(
344
+ "Circular coordinate reconstruction returned fewer vertices than sampled points. "
345
+ "Increase n_points/active_times or use denser data."
346
+ )
347
+ coords1[i, :] = f
348
+
349
+ sspikes = spike_data["spike"]
350
+ num_neurons = sspikes.shape[1]
351
+
352
+ centcosall = np.zeros((num_neurons, num_circ, n_points))
353
+ centsinall = np.zeros((num_neurons, num_circ, n_points))
354
+ dspk = preprocessing.scale(sspikes[movetimes[indstemp], :])
355
+
356
+ for n in range(num_neurons):
357
+ spktemp = dspk[:, n].copy()
358
+ centcosall[n, :, :] = np.multiply(np.cos(coords1 * 2 * np.pi), spktemp)
359
+ centsinall[n, :, :] = np.multiply(np.sin(coords1 * 2 * np.pi), spktemp)
360
+
361
+ times = np.where(np.sum(sspikes > 0, 1) >= 1)[0]
362
+ dspk = preprocessing.scale(sspikes)
363
+ sspikes = sspikes[times, :]
364
+ dspk = dspk[times, :]
365
+
366
+ a = np.zeros((len(sspikes), num_circ, num_neurons))
367
+ c = np.zeros((len(sspikes), num_circ, num_neurons))
368
+
369
+ for n in range(num_neurons):
370
+ a[:, :, n] = dspk[:, n : n + 1] * np.sum(centcosall[n, :, :], axis=1)
371
+ c[:, :, n] = dspk[:, n : n + 1] * np.sum(centsinall[n, :, :], axis=1)
372
+
373
+ mtot1 = np.sum(a, 2)
374
+ mtot2 = np.sum(c, 2)
375
+ coords = np.arctan2(mtot2, mtot1) % (2 * np.pi)
376
+
377
+ results = {
378
+ "coords": coords,
379
+ "coordsbox": coords.copy(),
380
+ "times": times,
381
+ "times_box": times.copy(),
382
+ "centcosall": centcosall,
383
+ "centsinall": centsinall,
384
+ }
385
+
386
+ if save_path is None:
387
+ os.makedirs("Results", exist_ok=True)
388
+ save_path = "Results/spikes_decoding.npz"
389
+
390
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
391
+ np.savez_compressed(save_path, **results)
392
+ return results
393
+
394
+
395
+ def _get_coords(cocycle, threshold, num_sampled, dists, coeff):
396
+ """
397
+ Reconstruct circular coordinates from cocycle information.
398
+
399
+ Parameters:
400
+ cocycle (ndarray): Persistent cocycle representative.
401
+ threshold (float): Maximum allowable edge distance.
402
+ num_sampled (int): Number of sampled points.
403
+ dists (ndarray): Pairwise distance matrix.
404
+ coeff (int): Finite field modulus for cohomology.
405
+
406
+ Returns:
407
+ f (ndarray): Circular coordinate values (in [0,1]).
408
+ verts (ndarray): Indices of used vertices.
409
+ """
410
+ zint = np.where(coeff - cocycle[:, 2] < cocycle[:, 2])
411
+ cocycle[zint, 2] = cocycle[zint, 2] - coeff
412
+ d = np.zeros((num_sampled, num_sampled))
413
+ d[np.tril_indices(num_sampled)] = np.nan
414
+ d[cocycle[:, 1], cocycle[:, 0]] = cocycle[:, 2]
415
+ d[dists > threshold] = np.nan
416
+ d[dists == 0] = np.nan
417
+ edges = np.where(~np.isnan(d))
418
+ verts = np.array(np.unique(edges))
419
+ num_edges = np.shape(edges)[1]
420
+ num_verts = np.size(verts)
421
+ values = d[edges]
422
+ A = np.zeros((num_edges, num_verts), dtype=int)
423
+ v1 = np.zeros((num_edges, 2), dtype=int)
424
+ v2 = np.zeros((num_edges, 2), dtype=int)
425
+ for i in range(num_edges):
426
+ # Extract scalar indices from np.where results
427
+ idx1 = np.where(verts == edges[0][i])[0]
428
+ idx2 = np.where(verts == edges[1][i])[0]
429
+
430
+ # Handle case where np.where returns multiple matches (shouldn't happen in valid data)
431
+ if len(idx1) > 0:
432
+ v1[i, :] = [i, idx1[0]]
433
+ else:
434
+ raise ValueError(f"No vertex found for edge {edges[0][i]}")
435
+
436
+ if len(idx2) > 0:
437
+ v2[i, :] = [i, idx2[0]]
438
+ else:
439
+ raise ValueError(f"No vertex found for edge {edges[1][i]}")
440
+
441
+ A[v1[:, 0], v1[:, 1]] = -1
442
+ A[v2[:, 0], v2[:, 1]] = 1
443
+
444
+ L = np.ones((num_edges,))
445
+ Aw = A * np.sqrt(L[:, np.newaxis])
446
+ Bw = values * np.sqrt(L)
447
+ f = lsmr(Aw, Bw)[0] % 1
448
+ return f, verts