canns 0.13.0__py3-none-any.whl → 0.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/asa/__init__.py +10 -0
- canns/analyzer/data/asa/decode.py +18 -21
- canns/analyzer/data/{legacy/cann1d.py → asa/fly_roi.py} +96 -43
- canns/analyzer/data/asa/fr.py +4 -12
- canns/analyzer/data/asa/plotting.py +12 -1
- canns/pipeline/asa/widgets.py +1 -1
- {canns-0.13.0.dist-info → canns-0.13.1.dist-info}/METADATA +1 -1
- {canns-0.13.0.dist-info → canns-0.13.1.dist-info}/RECORD +11 -13
- canns/analyzer/data/legacy/__init__.py +0 -6
- canns/analyzer/data/legacy/cann2d.py +0 -2565
- {canns-0.13.0.dist-info → canns-0.13.1.dist-info}/WHEEL +0 -0
- {canns-0.13.0.dist-info → canns-0.13.1.dist-info}/entry_points.txt +0 -0
- {canns-0.13.0.dist-info → canns-0.13.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,2565 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import multiprocessing as mp
|
|
3
|
-
import numbers
|
|
4
|
-
import os
|
|
5
|
-
from dataclasses import dataclass
|
|
6
|
-
from typing import Any
|
|
7
|
-
|
|
8
|
-
import matplotlib.pyplot as plt
|
|
9
|
-
import numpy as np
|
|
10
|
-
from canns_lib.ripser import ripser
|
|
11
|
-
from matplotlib import animation, cm, gridspec
|
|
12
|
-
from numpy.exceptions import AxisError
|
|
13
|
-
|
|
14
|
-
# from ripser import ripser
|
|
15
|
-
from scipy import signal
|
|
16
|
-
from scipy.ndimage import (
|
|
17
|
-
_nd_image,
|
|
18
|
-
_ni_support,
|
|
19
|
-
binary_closing,
|
|
20
|
-
gaussian_filter,
|
|
21
|
-
gaussian_filter1d,
|
|
22
|
-
)
|
|
23
|
-
from scipy.ndimage._filters import _invalid_origin
|
|
24
|
-
from scipy.sparse import coo_matrix
|
|
25
|
-
from scipy.sparse.linalg import lsmr
|
|
26
|
-
from scipy.spatial.distance import pdist, squareform
|
|
27
|
-
from scipy.stats import binned_statistic_2d, multivariate_normal
|
|
28
|
-
from sklearn import preprocessing
|
|
29
|
-
from tqdm import tqdm
|
|
30
|
-
|
|
31
|
-
# Import PlotConfig for unified plotting
|
|
32
|
-
from ...visualization import PlotConfig
|
|
33
|
-
from ...visualization.core.jupyter_utils import (
|
|
34
|
-
display_animation_in_jupyter,
|
|
35
|
-
is_jupyter_environment,
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
# ==================== Configuration Classes ====================
|
|
40
|
-
@dataclass
|
|
41
|
-
class SpikeEmbeddingConfig:
|
|
42
|
-
"""Configuration for spike train embedding."""
|
|
43
|
-
|
|
44
|
-
res: int = 100000
|
|
45
|
-
dt: int = 1000
|
|
46
|
-
sigma: int = 5000
|
|
47
|
-
smooth: bool = True
|
|
48
|
-
speed_filter: bool = True
|
|
49
|
-
min_speed: float = 2.5
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
@dataclass
|
|
53
|
-
class TDAConfig:
|
|
54
|
-
"""Configuration for Topological Data Analysis."""
|
|
55
|
-
|
|
56
|
-
dim: int = 6
|
|
57
|
-
num_times: int = 5
|
|
58
|
-
active_times: int = 15000
|
|
59
|
-
k: int = 1000
|
|
60
|
-
n_points: int = 1200
|
|
61
|
-
metric: str = "cosine"
|
|
62
|
-
nbs: int = 800
|
|
63
|
-
maxdim: int = 1
|
|
64
|
-
coeff: int = 47
|
|
65
|
-
show: bool = True
|
|
66
|
-
do_shuffle: bool = False
|
|
67
|
-
num_shuffles: int = 1000
|
|
68
|
-
progress_bar: bool = True
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
@dataclass
|
|
72
|
-
class CANN2DPlotConfig(PlotConfig):
|
|
73
|
-
"""Specialized PlotConfig for CANN2D visualizations."""
|
|
74
|
-
|
|
75
|
-
# 3D projection specific parameters
|
|
76
|
-
zlabel: str = "Component 3"
|
|
77
|
-
dpi: int = 300
|
|
78
|
-
|
|
79
|
-
# Torus animation specific parameters
|
|
80
|
-
numangsint: int = 51
|
|
81
|
-
r1: float = 1.5 # Major radius
|
|
82
|
-
r2: float = 1.0 # Minor radius
|
|
83
|
-
window_size: int = 300
|
|
84
|
-
frame_step: int = 5
|
|
85
|
-
n_frames: int = 20
|
|
86
|
-
|
|
87
|
-
@classmethod
|
|
88
|
-
def for_projection_3d(cls, **kwargs) -> "CANN2DPlotConfig":
|
|
89
|
-
"""Create configuration for 3D projection plots."""
|
|
90
|
-
defaults = {
|
|
91
|
-
"title": "3D Data Projection",
|
|
92
|
-
"xlabel": "Component 1",
|
|
93
|
-
"ylabel": "Component 2",
|
|
94
|
-
"zlabel": "Component 3",
|
|
95
|
-
"figsize": (10, 8),
|
|
96
|
-
"dpi": 300,
|
|
97
|
-
}
|
|
98
|
-
defaults.update(kwargs)
|
|
99
|
-
return cls.for_static_plot(**defaults)
|
|
100
|
-
|
|
101
|
-
@classmethod
|
|
102
|
-
def for_torus_animation(cls, **kwargs) -> "CANN2DPlotConfig":
|
|
103
|
-
"""Create configuration for 3D torus bump animations."""
|
|
104
|
-
defaults = {
|
|
105
|
-
"title": "3D Bump on Torus",
|
|
106
|
-
"figsize": (8, 8),
|
|
107
|
-
"fps": 5,
|
|
108
|
-
"repeat": True,
|
|
109
|
-
"show_progress_bar": True,
|
|
110
|
-
"numangsint": 51,
|
|
111
|
-
"r1": 1.5,
|
|
112
|
-
"r2": 1.0,
|
|
113
|
-
"window_size": 300,
|
|
114
|
-
"frame_step": 5,
|
|
115
|
-
"n_frames": 20,
|
|
116
|
-
}
|
|
117
|
-
defaults.update(kwargs)
|
|
118
|
-
time_steps = kwargs.get("time_steps_per_second", 1000)
|
|
119
|
-
config = cls.for_animation(time_steps, **defaults)
|
|
120
|
-
# Add torus-specific attributes
|
|
121
|
-
config.numangsint = defaults["numangsint"]
|
|
122
|
-
config.r1 = defaults["r1"]
|
|
123
|
-
config.r2 = defaults["r2"]
|
|
124
|
-
config.window_size = defaults["window_size"]
|
|
125
|
-
config.frame_step = defaults["frame_step"]
|
|
126
|
-
config.n_frames = defaults["n_frames"]
|
|
127
|
-
return config
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
# ==================== Constants ====================
|
|
131
|
-
class Constants:
|
|
132
|
-
"""Constants used throughout CANN2D analysis."""
|
|
133
|
-
|
|
134
|
-
DEFAULT_FIGSIZE = (10, 8)
|
|
135
|
-
DEFAULT_DPI = 300
|
|
136
|
-
GAUSSIAN_SIGMA_FACTOR = 100
|
|
137
|
-
SPEED_CONVERSION_FACTOR = 100
|
|
138
|
-
TIME_CONVERSION_FACTOR = 0.01
|
|
139
|
-
MULTIPROCESSING_CORES = 4
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
# ==================== Custom Exceptions ====================
|
|
143
|
-
class CANN2DError(Exception):
|
|
144
|
-
"""Base exception for CANN2D analysis errors."""
|
|
145
|
-
|
|
146
|
-
pass
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
class DataLoadError(CANN2DError):
|
|
150
|
-
"""Raised when data loading fails."""
|
|
151
|
-
|
|
152
|
-
pass
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
class ProcessingError(CANN2DError):
|
|
156
|
-
"""Raised when data processing fails."""
|
|
157
|
-
|
|
158
|
-
pass
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
try:
|
|
162
|
-
from numba import jit, njit, prange
|
|
163
|
-
|
|
164
|
-
HAS_NUMBA = True
|
|
165
|
-
except ImportError:
|
|
166
|
-
HAS_NUMBA = False
|
|
167
|
-
print(
|
|
168
|
-
"Using numba for FAST CANN2D analysis, now using pure numpy implementation.",
|
|
169
|
-
"Try numba by `pip install numba` to speed up the process.",
|
|
170
|
-
)
|
|
171
|
-
|
|
172
|
-
# Create dummy decorators if numba is not available
|
|
173
|
-
def jit(*args, **kwargs):
|
|
174
|
-
def decorator(func):
|
|
175
|
-
return func
|
|
176
|
-
|
|
177
|
-
return decorator
|
|
178
|
-
|
|
179
|
-
def njit(*args, **kwargs):
|
|
180
|
-
def decorator(func):
|
|
181
|
-
return func
|
|
182
|
-
|
|
183
|
-
return decorator
|
|
184
|
-
|
|
185
|
-
def prange(x):
|
|
186
|
-
return range(x)
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None, **kwargs):
|
|
190
|
-
"""
|
|
191
|
-
Load and preprocess spike train data from npz file.
|
|
192
|
-
|
|
193
|
-
This function converts raw spike times into a time-binned spike matrix,
|
|
194
|
-
optionally applying Gaussian smoothing and filtering based on animal movement speed.
|
|
195
|
-
|
|
196
|
-
Parameters:
|
|
197
|
-
spike_trains : dict containing 'spike', 't', and optionally 'x', 'y'.
|
|
198
|
-
config : SpikeEmbeddingConfig, optional configuration object
|
|
199
|
-
**kwargs : backward compatibility parameters
|
|
200
|
-
|
|
201
|
-
Returns:
|
|
202
|
-
spikes_bin (ndarray): Binned and optionally smoothed spike matrix of shape (T, N).
|
|
203
|
-
xx (ndarray, optional): X coordinates (if speed_filter=True).
|
|
204
|
-
yy (ndarray, optional): Y coordinates (if speed_filter=True).
|
|
205
|
-
tt (ndarray, optional): Time points (if speed_filter=True).
|
|
206
|
-
"""
|
|
207
|
-
# Handle backward compatibility and configuration
|
|
208
|
-
if config is None:
|
|
209
|
-
config = SpikeEmbeddingConfig(
|
|
210
|
-
res=kwargs.get("res", 100000),
|
|
211
|
-
dt=kwargs.get("dt", 1000),
|
|
212
|
-
sigma=kwargs.get("sigma", 5000),
|
|
213
|
-
smooth=kwargs.get("smooth0", True),
|
|
214
|
-
speed_filter=kwargs.get("speed0", True),
|
|
215
|
-
min_speed=kwargs.get("min_speed", 2.5),
|
|
216
|
-
)
|
|
217
|
-
|
|
218
|
-
try:
|
|
219
|
-
# Step 1: Extract and filter spike data
|
|
220
|
-
spikes_filtered = _extract_spike_data(spike_trains, config)
|
|
221
|
-
|
|
222
|
-
# Step 2: Create time bins
|
|
223
|
-
time_bins = _create_time_bins(spike_trains["t"], config)
|
|
224
|
-
|
|
225
|
-
# Step 3: Bin spike data
|
|
226
|
-
spikes_bin = _bin_spike_data(spikes_filtered, time_bins, config)
|
|
227
|
-
|
|
228
|
-
# Step 4: Apply temporal smoothing if requested
|
|
229
|
-
if config.smooth:
|
|
230
|
-
spikes_bin = _apply_temporal_smoothing(spikes_bin, config)
|
|
231
|
-
|
|
232
|
-
# Step 5: Apply speed filtering if requested
|
|
233
|
-
if config.speed_filter:
|
|
234
|
-
return _apply_speed_filtering(spikes_bin, spike_trains, config)
|
|
235
|
-
|
|
236
|
-
return spikes_bin
|
|
237
|
-
|
|
238
|
-
except Exception as e:
|
|
239
|
-
raise ProcessingError(f"Failed to embed spike trains: {e}") from e
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
def _extract_spike_data(
|
|
243
|
-
spike_trains: dict[str, Any], config: SpikeEmbeddingConfig
|
|
244
|
-
) -> dict[int, np.ndarray]:
|
|
245
|
-
"""Extract and filter spike data within time window."""
|
|
246
|
-
try:
|
|
247
|
-
# Handle different spike data formats
|
|
248
|
-
spike_data = spike_trains["spike"]
|
|
249
|
-
if hasattr(spike_data, "item") and callable(spike_data.item):
|
|
250
|
-
# numpy array with .item() method (from npz file)
|
|
251
|
-
spikes_all = spike_data[()]
|
|
252
|
-
elif isinstance(spike_data, dict):
|
|
253
|
-
# Already a dictionary
|
|
254
|
-
spikes_all = spike_data
|
|
255
|
-
elif isinstance(spike_data, list | np.ndarray):
|
|
256
|
-
# List or array format
|
|
257
|
-
spikes_all = spike_data
|
|
258
|
-
else:
|
|
259
|
-
# Try direct access
|
|
260
|
-
spikes_all = spike_data
|
|
261
|
-
|
|
262
|
-
t = spike_trains["t"]
|
|
263
|
-
|
|
264
|
-
min_time0 = np.min(t)
|
|
265
|
-
max_time0 = np.max(t)
|
|
266
|
-
|
|
267
|
-
# Extract spike intervals for each cell
|
|
268
|
-
if isinstance(spikes_all, dict):
|
|
269
|
-
# Dictionary format
|
|
270
|
-
spikes = {}
|
|
271
|
-
for i, key in enumerate(spikes_all.keys()):
|
|
272
|
-
s = np.array(spikes_all[key])
|
|
273
|
-
spikes[i] = s[(s >= min_time0) & (s < max_time0)]
|
|
274
|
-
else:
|
|
275
|
-
# List/array format
|
|
276
|
-
cell_inds = np.arange(len(spikes_all))
|
|
277
|
-
spikes = {}
|
|
278
|
-
|
|
279
|
-
for i, m in enumerate(cell_inds):
|
|
280
|
-
s = np.array(spikes_all[m]) if len(spikes_all[m]) > 0 else np.array([])
|
|
281
|
-
# Filter spikes within time window
|
|
282
|
-
if len(s) > 0:
|
|
283
|
-
spikes[i] = s[(s >= min_time0) & (s < max_time0)]
|
|
284
|
-
else:
|
|
285
|
-
spikes[i] = np.array([])
|
|
286
|
-
|
|
287
|
-
return spikes
|
|
288
|
-
|
|
289
|
-
except KeyError as e:
|
|
290
|
-
raise DataLoadError(f"Missing required data key: {e}") from e
|
|
291
|
-
except Exception as e:
|
|
292
|
-
raise ProcessingError(f"Error extracting spike data: {e}") from e
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
def _create_time_bins(t: np.ndarray, config: SpikeEmbeddingConfig) -> np.ndarray:
|
|
296
|
-
"""Create time bins for spike discretization."""
|
|
297
|
-
min_time0 = np.min(t)
|
|
298
|
-
max_time0 = np.max(t)
|
|
299
|
-
|
|
300
|
-
min_time = min_time0 * config.res
|
|
301
|
-
max_time = max_time0 * config.res
|
|
302
|
-
|
|
303
|
-
return np.arange(np.floor(min_time), np.ceil(max_time) + 1, config.dt)
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
def _bin_spike_data(
|
|
307
|
-
spikes: dict[int, np.ndarray], time_bins: np.ndarray, config: SpikeEmbeddingConfig
|
|
308
|
-
) -> np.ndarray:
|
|
309
|
-
"""Convert spike times to binned spike matrix."""
|
|
310
|
-
min_time = time_bins[0]
|
|
311
|
-
max_time = time_bins[-1]
|
|
312
|
-
|
|
313
|
-
spikes_bin = np.zeros((len(time_bins), len(spikes)), dtype=int)
|
|
314
|
-
|
|
315
|
-
for n in spikes:
|
|
316
|
-
spike_times = np.array(spikes[n] * config.res - min_time, dtype=int)
|
|
317
|
-
# Filter valid spike times
|
|
318
|
-
spike_times = spike_times[(spike_times < (max_time - min_time)) & (spike_times > 0)]
|
|
319
|
-
spike_times = np.array(spike_times / config.dt, int)
|
|
320
|
-
|
|
321
|
-
# Bin spikes
|
|
322
|
-
for j in spike_times:
|
|
323
|
-
if j < len(time_bins):
|
|
324
|
-
spikes_bin[j, n] += 1
|
|
325
|
-
|
|
326
|
-
return spikes_bin
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
def _apply_temporal_smoothing(spikes_bin: np.ndarray, config: SpikeEmbeddingConfig) -> np.ndarray:
|
|
330
|
-
"""Apply Gaussian temporal smoothing to spike matrix."""
|
|
331
|
-
# Calculate smoothing parameters (legacy implementation used custom kernel)
|
|
332
|
-
# Current implementation uses scipy's gaussian_filter1d for better performance
|
|
333
|
-
|
|
334
|
-
# Apply smoothing (simplified version - could be further optimized)
|
|
335
|
-
smoothed = np.zeros((spikes_bin.shape[0], spikes_bin.shape[1]))
|
|
336
|
-
|
|
337
|
-
# Use scipy's gaussian_filter1d for better performance
|
|
338
|
-
|
|
339
|
-
sigma_bins = config.sigma / config.dt
|
|
340
|
-
|
|
341
|
-
for n in range(spikes_bin.shape[1]):
|
|
342
|
-
smoothed[:, n] = gaussian_filter1d(
|
|
343
|
-
spikes_bin[:, n].astype(float), sigma=sigma_bins, mode="constant"
|
|
344
|
-
)
|
|
345
|
-
|
|
346
|
-
# Normalize
|
|
347
|
-
normalization_factor = 1 / np.sqrt(2 * np.pi * (config.sigma / config.res) ** 2)
|
|
348
|
-
return smoothed * normalization_factor
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
def _apply_speed_filtering(
|
|
352
|
-
spikes_bin: np.ndarray, spike_trains: dict[str, Any], config: SpikeEmbeddingConfig
|
|
353
|
-
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
354
|
-
"""Apply speed-based filtering to spike data."""
|
|
355
|
-
try:
|
|
356
|
-
xx, yy, tt_pos, speed = _load_pos(
|
|
357
|
-
spike_trains["t"], spike_trains["x"], spike_trains["y"], res=config.res, dt=config.dt
|
|
358
|
-
)
|
|
359
|
-
|
|
360
|
-
valid = speed > config.min_speed
|
|
361
|
-
|
|
362
|
-
return (spikes_bin[valid, :], xx[valid], yy[valid], tt_pos[valid])
|
|
363
|
-
|
|
364
|
-
except KeyError as e:
|
|
365
|
-
raise DataLoadError(f"Missing position data for speed filtering: {e}") from e
|
|
366
|
-
except Exception as e:
|
|
367
|
-
raise ProcessingError(f"Error in speed filtering: {e}") from e
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
def plot_projection(
|
|
371
|
-
reduce_func,
|
|
372
|
-
embed_data,
|
|
373
|
-
config: CANN2DPlotConfig | None = None,
|
|
374
|
-
title="Projection (3D)",
|
|
375
|
-
xlabel="Component 1",
|
|
376
|
-
ylabel="Component 2",
|
|
377
|
-
zlabel="Component 3",
|
|
378
|
-
save_path=None,
|
|
379
|
-
show=True,
|
|
380
|
-
dpi=300,
|
|
381
|
-
figsize=(10, 8),
|
|
382
|
-
**kwargs,
|
|
383
|
-
):
|
|
384
|
-
"""
|
|
385
|
-
Plot a 3D projection of the embedded data.
|
|
386
|
-
|
|
387
|
-
Parameters:
|
|
388
|
-
reduce_func (callable): Function to reduce the dimensionality of the data.
|
|
389
|
-
embed_data (ndarray): Data to be projected.
|
|
390
|
-
config (PlotConfig, optional): Configuration object for unified plotting parameters
|
|
391
|
-
**kwargs: backward compatibility parameters
|
|
392
|
-
title (str): Title of the plot.
|
|
393
|
-
xlabel (str): Label for the x-axis.
|
|
394
|
-
ylabel (str): Label for the y-axis.
|
|
395
|
-
zlabel (str): Label for the z-axis.
|
|
396
|
-
save_path (str, optional): Path to save the plot. If None, plot will not be saved.
|
|
397
|
-
show (bool): Whether to display the plot.
|
|
398
|
-
dpi (int): Dots per inch for saving the figure.
|
|
399
|
-
figsize (tuple): Size of the figure.
|
|
400
|
-
|
|
401
|
-
Returns:
|
|
402
|
-
fig: The created figure object.
|
|
403
|
-
"""
|
|
404
|
-
|
|
405
|
-
# Handle backward compatibility and configuration
|
|
406
|
-
if config is None:
|
|
407
|
-
config = CANN2DPlotConfig.for_projection_3d(
|
|
408
|
-
title=title,
|
|
409
|
-
xlabel=xlabel,
|
|
410
|
-
ylabel=ylabel,
|
|
411
|
-
zlabel=zlabel,
|
|
412
|
-
save_path=save_path,
|
|
413
|
-
show=show,
|
|
414
|
-
figsize=figsize,
|
|
415
|
-
dpi=dpi,
|
|
416
|
-
**kwargs,
|
|
417
|
-
)
|
|
418
|
-
|
|
419
|
-
reduced_data = reduce_func(embed_data[::5])
|
|
420
|
-
|
|
421
|
-
fig = plt.figure(figsize=config.figsize)
|
|
422
|
-
ax = fig.add_subplot(111, projection="3d")
|
|
423
|
-
ax.scatter(reduced_data[:, 0], reduced_data[:, 1], reduced_data[:, 2], s=1, alpha=0.5)
|
|
424
|
-
|
|
425
|
-
ax.set_title(config.title)
|
|
426
|
-
ax.set_xlabel(config.xlabel)
|
|
427
|
-
ax.set_ylabel(config.ylabel)
|
|
428
|
-
ax.set_zlabel(config.zlabel)
|
|
429
|
-
|
|
430
|
-
if config.save_path is None and config.show is None:
|
|
431
|
-
raise ValueError("Either save path or show must be provided.")
|
|
432
|
-
if config.save_path:
|
|
433
|
-
plt.savefig(config.save_path, dpi=config.dpi)
|
|
434
|
-
if config.show:
|
|
435
|
-
plt.show()
|
|
436
|
-
|
|
437
|
-
plt.close(fig)
|
|
438
|
-
|
|
439
|
-
return fig
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
def tda_vis(embed_data: np.ndarray, config: TDAConfig | None = None, **kwargs) -> dict[str, Any]:
|
|
443
|
-
"""
|
|
444
|
-
Topological Data Analysis visualization with optional shuffle testing.
|
|
445
|
-
|
|
446
|
-
Parameters:
|
|
447
|
-
embed_data : ndarray
|
|
448
|
-
Embedded spike train data.
|
|
449
|
-
config : TDAConfig, optional
|
|
450
|
-
Configuration object with all TDA parameters
|
|
451
|
-
**kwargs : backward compatibility parameters
|
|
452
|
-
|
|
453
|
-
Returns:
|
|
454
|
-
dict : Dictionary containing:
|
|
455
|
-
- persistence: persistence diagrams from real data
|
|
456
|
-
- indstemp: indices of sampled points
|
|
457
|
-
- movetimes: selected time points
|
|
458
|
-
- n_points: number of sampled points
|
|
459
|
-
- shuffle_max: shuffle analysis results (if do_shuffle=True, otherwise None)
|
|
460
|
-
"""
|
|
461
|
-
# Handle backward compatibility and configuration
|
|
462
|
-
if config is None:
|
|
463
|
-
config = TDAConfig(
|
|
464
|
-
dim=kwargs.get("dim", 6),
|
|
465
|
-
num_times=kwargs.get("num_times", 5),
|
|
466
|
-
active_times=kwargs.get("active_times", 15000),
|
|
467
|
-
k=kwargs.get("k", 1000),
|
|
468
|
-
n_points=kwargs.get("n_points", 1200),
|
|
469
|
-
metric=kwargs.get("metric", "cosine"),
|
|
470
|
-
nbs=kwargs.get("nbs", 800),
|
|
471
|
-
maxdim=kwargs.get("maxdim", 1),
|
|
472
|
-
coeff=kwargs.get("coeff", 47),
|
|
473
|
-
show=kwargs.get("show", True),
|
|
474
|
-
do_shuffle=kwargs.get("do_shuffle", False),
|
|
475
|
-
num_shuffles=kwargs.get("num_shuffles", 1000),
|
|
476
|
-
progress_bar=kwargs.get("progress_bar", True),
|
|
477
|
-
)
|
|
478
|
-
|
|
479
|
-
try:
|
|
480
|
-
# Compute persistent homology for real data
|
|
481
|
-
print("Computing persistent homology for real data...")
|
|
482
|
-
real_persistence = _compute_real_persistence(embed_data, config)
|
|
483
|
-
|
|
484
|
-
# Perform shuffle analysis if requested
|
|
485
|
-
shuffle_max = None
|
|
486
|
-
if config.do_shuffle:
|
|
487
|
-
shuffle_max = _perform_shuffle_analysis(embed_data, config)
|
|
488
|
-
|
|
489
|
-
# Visualization
|
|
490
|
-
_handle_visualization(real_persistence["persistence"], shuffle_max, config)
|
|
491
|
-
|
|
492
|
-
# Return results as dictionary
|
|
493
|
-
return {
|
|
494
|
-
"persistence": real_persistence["persistence"],
|
|
495
|
-
"indstemp": real_persistence["indstemp"],
|
|
496
|
-
"movetimes": real_persistence["movetimes"],
|
|
497
|
-
"n_points": real_persistence["n_points"],
|
|
498
|
-
"shuffle_max": shuffle_max,
|
|
499
|
-
}
|
|
500
|
-
|
|
501
|
-
except Exception as e:
|
|
502
|
-
raise ProcessingError(f"TDA analysis failed: {e}") from e
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
def _compute_real_persistence(embed_data: np.ndarray, config: TDAConfig) -> dict[str, Any]:
|
|
506
|
-
"""Compute persistent homology for real data with progress tracking."""
|
|
507
|
-
|
|
508
|
-
logging.info("Processing real data - Starting TDA analysis (5 steps)")
|
|
509
|
-
|
|
510
|
-
# Step 1: Time point downsampling
|
|
511
|
-
logging.info("Step 1/5: Time point downsampling")
|
|
512
|
-
times_cube = _downsample_timepoints(embed_data, config.num_times)
|
|
513
|
-
|
|
514
|
-
# Step 2: Select most active time points
|
|
515
|
-
logging.info("Step 2/5: Selecting active time points")
|
|
516
|
-
movetimes = _select_active_timepoints(embed_data, times_cube, config.active_times)
|
|
517
|
-
|
|
518
|
-
# Step 3: PCA dimensionality reduction
|
|
519
|
-
logging.info("Step 3/5: PCA dimensionality reduction")
|
|
520
|
-
dimred = _apply_pca_reduction(embed_data, movetimes, config.dim)
|
|
521
|
-
|
|
522
|
-
# Step 4: Point cloud sampling (denoising)
|
|
523
|
-
logging.info("Step 4/5: Point cloud denoising")
|
|
524
|
-
indstemp = _apply_denoising(dimred, config)
|
|
525
|
-
|
|
526
|
-
# Step 5: Compute persistent homology
|
|
527
|
-
logging.info("Step 5/5: Computing persistent homology")
|
|
528
|
-
persistence = _compute_persistence_homology(dimred, indstemp, config)
|
|
529
|
-
|
|
530
|
-
logging.info("TDA analysis completed successfully")
|
|
531
|
-
|
|
532
|
-
# Return all necessary data in dictionary format
|
|
533
|
-
return {
|
|
534
|
-
"persistence": persistence,
|
|
535
|
-
"indstemp": indstemp,
|
|
536
|
-
"movetimes": movetimes,
|
|
537
|
-
"n_points": config.n_points,
|
|
538
|
-
}
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
def _downsample_timepoints(embed_data: np.ndarray, num_times: int) -> np.ndarray:
|
|
542
|
-
"""Downsample timepoints for computational efficiency."""
|
|
543
|
-
return np.arange(0, embed_data.shape[0], num_times)
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
def _select_active_timepoints(
|
|
547
|
-
embed_data: np.ndarray, times_cube: np.ndarray, active_times: int
|
|
548
|
-
) -> np.ndarray:
|
|
549
|
-
"""Select most active timepoints based on total activity."""
|
|
550
|
-
activity_scores = np.sum(embed_data[times_cube, :], 1)
|
|
551
|
-
# Match external TDAvis: sort indices first, then map to times_cube
|
|
552
|
-
movetimes = np.sort(np.argsort(activity_scores)[-active_times:])
|
|
553
|
-
return times_cube[movetimes]
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
def _apply_pca_reduction(embed_data: np.ndarray, movetimes: np.ndarray, dim: int) -> np.ndarray:
|
|
557
|
-
"""Apply PCA dimensionality reduction."""
|
|
558
|
-
scaled_data = preprocessing.scale(embed_data[movetimes, :])
|
|
559
|
-
dimred, *_ = _pca(scaled_data, dim=dim)
|
|
560
|
-
return dimred
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
def _apply_denoising(dimred: np.ndarray, config: TDAConfig) -> np.ndarray:
|
|
564
|
-
"""Apply point cloud denoising."""
|
|
565
|
-
indstemp, *_ = _sample_denoising(
|
|
566
|
-
dimred,
|
|
567
|
-
k=config.k,
|
|
568
|
-
num_sample=config.n_points,
|
|
569
|
-
omega=1, # Match external TDAvis: uses 1, not default 0.2
|
|
570
|
-
metric=config.metric,
|
|
571
|
-
)
|
|
572
|
-
return indstemp
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
def _compute_persistence_homology(
|
|
576
|
-
dimred: np.ndarray, indstemp: np.ndarray, config: TDAConfig
|
|
577
|
-
) -> dict[str, Any]:
|
|
578
|
-
"""Compute persistent homology using ripser."""
|
|
579
|
-
d = _second_build(dimred, indstemp, metric=config.metric, nbs=config.nbs)
|
|
580
|
-
np.fill_diagonal(d, 0)
|
|
581
|
-
|
|
582
|
-
return ripser(
|
|
583
|
-
d,
|
|
584
|
-
maxdim=config.maxdim,
|
|
585
|
-
coeff=config.coeff,
|
|
586
|
-
do_cocycles=True,
|
|
587
|
-
distance_matrix=True,
|
|
588
|
-
progress_bar=config.progress_bar,
|
|
589
|
-
)
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
def _perform_shuffle_analysis(embed_data: np.ndarray, config: TDAConfig) -> dict[int, Any]:
|
|
593
|
-
"""Perform shuffle analysis with progress tracking."""
|
|
594
|
-
print(f"\nStarting shuffle analysis with {config.num_shuffles} iterations...")
|
|
595
|
-
|
|
596
|
-
# Create parameters dict for shuffle analysis
|
|
597
|
-
shuffle_params = {
|
|
598
|
-
"dim": config.dim,
|
|
599
|
-
"num_times": config.num_times,
|
|
600
|
-
"active_times": config.active_times,
|
|
601
|
-
"k": config.k,
|
|
602
|
-
"n_points": config.n_points,
|
|
603
|
-
"metric": config.metric,
|
|
604
|
-
"nbs": config.nbs,
|
|
605
|
-
"maxdim": config.maxdim,
|
|
606
|
-
"coeff": config.coeff,
|
|
607
|
-
}
|
|
608
|
-
|
|
609
|
-
shuffle_max = _run_shuffle_analysis(
|
|
610
|
-
embed_data,
|
|
611
|
-
num_shuffles=config.num_shuffles,
|
|
612
|
-
num_cores=Constants.MULTIPROCESSING_CORES,
|
|
613
|
-
progress_bar=config.progress_bar,
|
|
614
|
-
**shuffle_params,
|
|
615
|
-
)
|
|
616
|
-
|
|
617
|
-
# Print shuffle analysis summary
|
|
618
|
-
_print_shuffle_summary(shuffle_max)
|
|
619
|
-
|
|
620
|
-
return shuffle_max
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
def _print_shuffle_summary(shuffle_max: dict[int, Any]) -> None:
|
|
624
|
-
"""Print summary of shuffle analysis results."""
|
|
625
|
-
print("\nSummary of shuffle-based analysis:")
|
|
626
|
-
for dim_idx in [0, 1, 2]:
|
|
627
|
-
if shuffle_max and dim_idx in shuffle_max and shuffle_max[dim_idx]:
|
|
628
|
-
values = shuffle_max[dim_idx]
|
|
629
|
-
print(
|
|
630
|
-
f"H{dim_idx}: {len(values)} valid iterations | "
|
|
631
|
-
f"Mean maximum persistence: {np.mean(values):.4f} | "
|
|
632
|
-
f"99.9th percentile: {np.percentile(values, 99.9):.4f}"
|
|
633
|
-
)
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
def _handle_visualization(
|
|
637
|
-
real_persistence: dict[str, Any], shuffle_max: dict[int, Any] | None, config: TDAConfig
|
|
638
|
-
) -> None:
|
|
639
|
-
"""Handle visualization based on configuration."""
|
|
640
|
-
if config.show:
|
|
641
|
-
if config.do_shuffle and shuffle_max is not None:
|
|
642
|
-
_plot_barcode_with_shuffle(real_persistence, shuffle_max)
|
|
643
|
-
else:
|
|
644
|
-
_plot_barcode(real_persistence)
|
|
645
|
-
plt.show()
|
|
646
|
-
else:
|
|
647
|
-
plt.close()
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
def _load_pos(t, x, y, res=100000, dt=1000):
|
|
651
|
-
"""
|
|
652
|
-
Compute animal position and speed from spike data file.
|
|
653
|
-
|
|
654
|
-
Interpolates animal positions to match spike time bins and computes smoothed velocity vectors and speed.
|
|
655
|
-
|
|
656
|
-
Parameters:
|
|
657
|
-
t (ndarray): Time points of the spikes (in seconds).
|
|
658
|
-
x (ndarray): X coordinates of the animal's position.
|
|
659
|
-
y (ndarray): Y coordinates of the animal's position.
|
|
660
|
-
res (int): Time scaling factor to align with spike resolution.
|
|
661
|
-
dt (int): Temporal bin size in microseconds.
|
|
662
|
-
|
|
663
|
-
Returns:
|
|
664
|
-
xx (ndarray): Interpolated x positions.
|
|
665
|
-
yy (ndarray): Interpolated y positions.
|
|
666
|
-
tt (ndarray): Corresponding time points (in seconds).
|
|
667
|
-
speed (ndarray): Speed at each time point (in cm/s).
|
|
668
|
-
"""
|
|
669
|
-
|
|
670
|
-
min_time0 = np.min(t)
|
|
671
|
-
max_time0 = np.max(t)
|
|
672
|
-
|
|
673
|
-
times = np.where((t >= min_time0) & (t < max_time0))
|
|
674
|
-
x = x[times]
|
|
675
|
-
y = y[times]
|
|
676
|
-
t = t[times]
|
|
677
|
-
|
|
678
|
-
min_time = min_time0 * res
|
|
679
|
-
max_time = max_time0 * res
|
|
680
|
-
|
|
681
|
-
tt = np.arange(np.floor(min_time), np.ceil(max_time) + 1, dt) / res
|
|
682
|
-
|
|
683
|
-
idt = np.concatenate(([0], np.digitize(t[1:-1], tt[:]) - 1, [len(tt) + 1]))
|
|
684
|
-
idtt = np.digitize(np.arange(len(tt)), idt) - 1
|
|
685
|
-
|
|
686
|
-
idx = np.concatenate((np.unique(idtt), [np.max(idtt) + 1]))
|
|
687
|
-
divisor = np.bincount(idtt)
|
|
688
|
-
steps = 1.0 / divisor[divisor > 0]
|
|
689
|
-
N = np.max(divisor)
|
|
690
|
-
ranges = np.multiply(np.arange(N)[np.newaxis, :], steps[:, np.newaxis])
|
|
691
|
-
ranges[ranges >= 1] = np.nan
|
|
692
|
-
|
|
693
|
-
rangesx = x[idx[:-1], np.newaxis] + np.multiply(
|
|
694
|
-
ranges, (x[idx[1:]] - x[idx[:-1]])[:, np.newaxis]
|
|
695
|
-
)
|
|
696
|
-
xx = rangesx[~np.isnan(ranges)]
|
|
697
|
-
|
|
698
|
-
rangesy = y[idx[:-1], np.newaxis] + np.multiply(
|
|
699
|
-
ranges, (y[idx[1:]] - y[idx[:-1]])[:, np.newaxis]
|
|
700
|
-
)
|
|
701
|
-
yy = rangesy[~np.isnan(ranges)]
|
|
702
|
-
|
|
703
|
-
xxs = _gaussian_filter1d(xx - np.min(xx), sigma=100)
|
|
704
|
-
yys = _gaussian_filter1d(yy - np.min(yy), sigma=100)
|
|
705
|
-
dx = (xxs[1:] - xxs[:-1]) * 100
|
|
706
|
-
dy = (yys[1:] - yys[:-1]) * 100
|
|
707
|
-
speed = np.sqrt(dx**2 + dy**2) / 0.01
|
|
708
|
-
speed = np.concatenate(([speed[0]], speed))
|
|
709
|
-
return xx, yy, tt, speed
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
def _gaussian_filter1d(
|
|
713
|
-
input,
|
|
714
|
-
sigma,
|
|
715
|
-
axis=-1,
|
|
716
|
-
order=0,
|
|
717
|
-
output=None,
|
|
718
|
-
mode="reflect",
|
|
719
|
-
cval=0.0,
|
|
720
|
-
truncate=4.0,
|
|
721
|
-
*,
|
|
722
|
-
radius=None,
|
|
723
|
-
):
|
|
724
|
-
"""1-D Gaussian filter.
|
|
725
|
-
|
|
726
|
-
Parameters
|
|
727
|
-
----------
|
|
728
|
-
%(input)s
|
|
729
|
-
sigma : scalar
|
|
730
|
-
standard deviation for Gaussian kernel
|
|
731
|
-
%(axis)s
|
|
732
|
-
order : int, optional
|
|
733
|
-
An order of 0 corresponds to convolution with a Gaussian
|
|
734
|
-
kernel. A positive order corresponds to convolution with
|
|
735
|
-
that derivative of a Gaussian.
|
|
736
|
-
%(output)s
|
|
737
|
-
%(mode_reflect)s
|
|
738
|
-
%(cval)s
|
|
739
|
-
truncate : float, optional
|
|
740
|
-
Truncate the filter at this many standard deviations.
|
|
741
|
-
Default is 4.0.
|
|
742
|
-
radius : None or int, optional
|
|
743
|
-
Radius of the Gaussian kernel. If specified, the size of
|
|
744
|
-
the kernel will be ``2*radius + 1``, and `truncate` is ignored.
|
|
745
|
-
Default is None.
|
|
746
|
-
|
|
747
|
-
Returns
|
|
748
|
-
-------
|
|
749
|
-
gaussian_filter1d : ndarray
|
|
750
|
-
|
|
751
|
-
Notes
|
|
752
|
-
-----
|
|
753
|
-
The Gaussian kernel will have size ``2*radius + 1`` along each axis. If
|
|
754
|
-
`radius` is None, a default ``radius = round(truncate * sigma)`` will be
|
|
755
|
-
used.
|
|
756
|
-
|
|
757
|
-
Examples
|
|
758
|
-
--------
|
|
759
|
-
>>> from scipy.ndimage import gaussian_filter1d
|
|
760
|
-
>>> import numpy as np
|
|
761
|
-
>>> gaussian_filter1d([1.0, 2.0, 3.0, 4.0, 5.0], 1)
|
|
762
|
-
array([ 1.42704095, 2.06782203, 3. , 3.93217797, 4.57295905])
|
|
763
|
-
>>> _gaussian_filter1d([1.0, 2.0, 3.0, 4.0, 5.0], 4)
|
|
764
|
-
array([ 2.91948343, 2.95023502, 3. , 3.04976498, 3.08051657])
|
|
765
|
-
>>> import matplotlib.pyplot as plt
|
|
766
|
-
>>> rng = np.random.default_rng()
|
|
767
|
-
>>> x = rng.standard_normal(101).cumsum()
|
|
768
|
-
>>> y3 = _gaussian_filter1d(x, 3)
|
|
769
|
-
>>> y6 = _gaussian_filter1d(x, 6)
|
|
770
|
-
>>> plt.plot(x, 'k', label='original data')
|
|
771
|
-
>>> plt.plot(y3, '--', label='filtered, sigma=3')
|
|
772
|
-
>>> plt.plot(y6, ':', label='filtered, sigma=6')
|
|
773
|
-
>>> plt.legend()
|
|
774
|
-
>>> plt.grid()
|
|
775
|
-
>>> plt.show()
|
|
776
|
-
|
|
777
|
-
"""
|
|
778
|
-
sd = float(sigma)
|
|
779
|
-
# make the radius of the filter equal to truncate standard deviations
|
|
780
|
-
lw = int(truncate * sd + 0.5)
|
|
781
|
-
if radius is not None:
|
|
782
|
-
lw = radius
|
|
783
|
-
if not isinstance(lw, numbers.Integral) or lw < 0:
|
|
784
|
-
raise ValueError("Radius must be a nonnegative integer.")
|
|
785
|
-
# Since we are calling correlate, not convolve, revert the kernel
|
|
786
|
-
weights = _gaussian_kernel1d(sigma, order, lw)[::-1]
|
|
787
|
-
return _correlate1d(input, weights, axis, output, mode, cval, 0)
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
def _gaussian_kernel1d(sigma, order, radius):
|
|
791
|
-
"""
|
|
792
|
-
Computes a 1-D Gaussian convolution kernel.
|
|
793
|
-
"""
|
|
794
|
-
if order < 0:
|
|
795
|
-
raise ValueError("order must be non-negative")
|
|
796
|
-
exponent_range = np.arange(order + 1)
|
|
797
|
-
sigma2 = sigma * sigma
|
|
798
|
-
x = np.arange(-radius, radius + 1)
|
|
799
|
-
phi_x = np.exp(-0.5 / sigma2 * x**2)
|
|
800
|
-
phi_x = phi_x / phi_x.sum()
|
|
801
|
-
|
|
802
|
-
if order == 0:
|
|
803
|
-
return phi_x
|
|
804
|
-
else:
|
|
805
|
-
# f(x) = q(x) * phi(x) = q(x) * exp(p(x))
|
|
806
|
-
# f'(x) = (q'(x) + q(x) * p'(x)) * phi(x)
|
|
807
|
-
# p'(x) = -1 / sigma ** 2
|
|
808
|
-
# Implement q'(x) + q(x) * p'(x) as a matrix operator and apply to the
|
|
809
|
-
# coefficients of q(x)
|
|
810
|
-
q = np.zeros(order + 1)
|
|
811
|
-
q[0] = 1
|
|
812
|
-
D = np.diag(exponent_range[1:], 1) # D @ q(x) = q'(x)
|
|
813
|
-
P = np.diag(np.ones(order) / -sigma2, -1) # P @ q(x) = q(x) * p'(x)
|
|
814
|
-
Q_deriv = D + P
|
|
815
|
-
for _ in range(order):
|
|
816
|
-
q = Q_deriv.dot(q)
|
|
817
|
-
q = (x[:, None] ** exponent_range).dot(q)
|
|
818
|
-
return q * phi_x
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
def _correlate1d(input, weights, axis=-1, output=None, mode="reflect", cval=0.0, origin=0):
|
|
822
|
-
"""Calculate a 1-D correlation along the given axis.
|
|
823
|
-
|
|
824
|
-
The lines of the array along the given axis are correlated with the
|
|
825
|
-
given weights.
|
|
826
|
-
|
|
827
|
-
Parameters
|
|
828
|
-
----------
|
|
829
|
-
%(input)s
|
|
830
|
-
weights : array
|
|
831
|
-
1-D sequence of numbers.
|
|
832
|
-
%(axis)s
|
|
833
|
-
%(output)s
|
|
834
|
-
%(mode_reflect)s
|
|
835
|
-
%(cval)s
|
|
836
|
-
%(origin)s
|
|
837
|
-
|
|
838
|
-
Returns
|
|
839
|
-
-------
|
|
840
|
-
result : ndarray
|
|
841
|
-
Correlation result. Has the same shape as `input`.
|
|
842
|
-
|
|
843
|
-
Examples
|
|
844
|
-
--------
|
|
845
|
-
>>> from scipy.ndimage import correlate1d
|
|
846
|
-
>>> correlate1d([2, 8, 0, 4, 1, 9, 9, 0], weights=[1, 3])
|
|
847
|
-
array([ 8, 26, 8, 12, 7, 28, 36, 9])
|
|
848
|
-
"""
|
|
849
|
-
input = np.asarray(input)
|
|
850
|
-
weights = np.asarray(weights)
|
|
851
|
-
complex_input = input.dtype.kind == "c"
|
|
852
|
-
complex_weights = weights.dtype.kind == "c"
|
|
853
|
-
if complex_input or complex_weights:
|
|
854
|
-
if complex_weights:
|
|
855
|
-
weights = weights.conj()
|
|
856
|
-
weights = weights.astype(np.complex128, copy=False)
|
|
857
|
-
kwargs = dict(axis=axis, mode=mode, origin=origin)
|
|
858
|
-
output = _ni_support._get_output(output, input, complex_output=True)
|
|
859
|
-
return _complex_via_real_components(_correlate1d, input, weights, output, cval, **kwargs)
|
|
860
|
-
|
|
861
|
-
output = _ni_support._get_output(output, input)
|
|
862
|
-
weights = np.asarray(weights, dtype=np.float64)
|
|
863
|
-
if weights.ndim != 1 or weights.shape[0] < 1:
|
|
864
|
-
raise RuntimeError("no filter weights given")
|
|
865
|
-
if not weights.flags.contiguous:
|
|
866
|
-
weights = weights.copy()
|
|
867
|
-
axis = _normalize_axis_index(axis, input.ndim)
|
|
868
|
-
if _invalid_origin(origin, len(weights)):
|
|
869
|
-
raise ValueError(
|
|
870
|
-
"Invalid origin; origin must satisfy "
|
|
871
|
-
"-(len(weights) // 2) <= origin <= "
|
|
872
|
-
"(len(weights)-1) // 2"
|
|
873
|
-
)
|
|
874
|
-
mode = _ni_support._extend_mode_to_code(mode)
|
|
875
|
-
_nd_image.correlate1d(input, weights, axis, output, mode, cval, origin)
|
|
876
|
-
return output
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
def _complex_via_real_components(func, input, weights, output, cval, **kwargs):
|
|
880
|
-
"""Complex convolution via a linear combination of real convolutions."""
|
|
881
|
-
complex_input = input.dtype.kind == "c"
|
|
882
|
-
complex_weights = weights.dtype.kind == "c"
|
|
883
|
-
if complex_input and complex_weights:
|
|
884
|
-
# real component of the output
|
|
885
|
-
func(input.real, weights.real, output=output.real, cval=np.real(cval), **kwargs)
|
|
886
|
-
output.real -= func(input.imag, weights.imag, output=None, cval=np.imag(cval), **kwargs)
|
|
887
|
-
# imaginary component of the output
|
|
888
|
-
func(input.real, weights.imag, output=output.imag, cval=np.real(cval), **kwargs)
|
|
889
|
-
output.imag += func(input.imag, weights.real, output=None, cval=np.imag(cval), **kwargs)
|
|
890
|
-
elif complex_input:
|
|
891
|
-
func(input.real, weights, output=output.real, cval=np.real(cval), **kwargs)
|
|
892
|
-
func(input.imag, weights, output=output.imag, cval=np.imag(cval), **kwargs)
|
|
893
|
-
else:
|
|
894
|
-
if np.iscomplexobj(cval):
|
|
895
|
-
raise ValueError("Cannot provide a complex-valued cval when the input is real.")
|
|
896
|
-
func(input, weights.real, output=output.real, cval=cval, **kwargs)
|
|
897
|
-
func(input, weights.imag, output=output.imag, cval=cval, **kwargs)
|
|
898
|
-
return output
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
def _normalize_axis_index(axis, ndim):
|
|
902
|
-
# Check if `axis` is in the correct range and normalize it
|
|
903
|
-
if axis < -ndim or axis >= ndim:
|
|
904
|
-
msg = f"axis {axis} is out of bounds for array of dimension {ndim}"
|
|
905
|
-
raise AxisError(msg)
|
|
906
|
-
|
|
907
|
-
if axis < 0:
|
|
908
|
-
axis = axis + ndim
|
|
909
|
-
return axis
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
def _compute_persistence(
|
|
913
|
-
sspikes,
|
|
914
|
-
dim=6,
|
|
915
|
-
num_times=5,
|
|
916
|
-
active_times=15000,
|
|
917
|
-
k=1000,
|
|
918
|
-
n_points=1200,
|
|
919
|
-
metric="cosine",
|
|
920
|
-
nbs=800,
|
|
921
|
-
maxdim=1,
|
|
922
|
-
coeff=47,
|
|
923
|
-
progress_bar=True,
|
|
924
|
-
):
|
|
925
|
-
# Time point downsampling
|
|
926
|
-
times_cube = np.arange(0, sspikes.shape[0], num_times)
|
|
927
|
-
|
|
928
|
-
# Select most active time points
|
|
929
|
-
movetimes = np.sort(np.argsort(np.sum(sspikes[times_cube, :], 1))[-active_times:])
|
|
930
|
-
movetimes = times_cube[movetimes]
|
|
931
|
-
|
|
932
|
-
# PCA dimensionality reduction
|
|
933
|
-
scaled_data = preprocessing.scale(sspikes[movetimes, :])
|
|
934
|
-
dimred, *_ = _pca(scaled_data, dim=dim)
|
|
935
|
-
|
|
936
|
-
# Point cloud sampling (denoising)
|
|
937
|
-
indstemp, *_ = _sample_denoising(dimred, k, n_points, 1, metric)
|
|
938
|
-
|
|
939
|
-
# Build distance matrix
|
|
940
|
-
d = _second_build(dimred, indstemp, metric=metric, nbs=nbs)
|
|
941
|
-
np.fill_diagonal(d, 0)
|
|
942
|
-
|
|
943
|
-
# Compute persistent homology
|
|
944
|
-
persistence = ripser(
|
|
945
|
-
d,
|
|
946
|
-
maxdim=maxdim,
|
|
947
|
-
coeff=coeff,
|
|
948
|
-
do_cocycles=True,
|
|
949
|
-
distance_matrix=True,
|
|
950
|
-
progress_bar=progress_bar,
|
|
951
|
-
)
|
|
952
|
-
|
|
953
|
-
return persistence
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
def _pca(data, dim=2):
|
|
957
|
-
"""
|
|
958
|
-
Perform PCA (Principal Component Analysis) for dimensionality reduction.
|
|
959
|
-
|
|
960
|
-
Parameters:
|
|
961
|
-
data (ndarray): Input data matrix of shape (N_samples, N_features).
|
|
962
|
-
dim (int): Target dimension for PCA projection.
|
|
963
|
-
|
|
964
|
-
Returns:
|
|
965
|
-
components (ndarray): Projected data of shape (N_samples, dim).
|
|
966
|
-
var_exp (list): Variance explained by each principal component.
|
|
967
|
-
evals (ndarray): Eigenvalues corresponding to the selected components.
|
|
968
|
-
"""
|
|
969
|
-
if dim < 2:
|
|
970
|
-
return data, [0]
|
|
971
|
-
m, n = data.shape
|
|
972
|
-
# mean center the data
|
|
973
|
-
# data -= data.mean(axis=0)
|
|
974
|
-
# calculate the covariance matrix
|
|
975
|
-
R = np.cov(data, rowvar=False)
|
|
976
|
-
# calculate eigenvectors & eigenvalues of the covariance matrix
|
|
977
|
-
# use 'eigh' rather than 'eig' since R is symmetric,
|
|
978
|
-
# the performance gain is substantial
|
|
979
|
-
evals, evecs = np.linalg.eig(R)
|
|
980
|
-
# sort eigenvalue in decreasing order
|
|
981
|
-
idx = np.argsort(evals)[::-1]
|
|
982
|
-
evecs = evecs[:, idx]
|
|
983
|
-
# sort eigenvectors according to same index
|
|
984
|
-
evals = evals[idx]
|
|
985
|
-
# select the first n eigenvectors (n is desired dimension
|
|
986
|
-
# of rescaled data array, or dims_rescaled_data)
|
|
987
|
-
evecs = evecs[:, :dim]
|
|
988
|
-
# carry out the transformation on the data using eigenvectors
|
|
989
|
-
# and return the re-scaled data, eigenvalues, and eigenvectors
|
|
990
|
-
|
|
991
|
-
tot = np.sum(evals)
|
|
992
|
-
var_exp = [(i / tot) * 100 for i in sorted(evals[:dim], reverse=True)]
|
|
993
|
-
components = np.dot(evecs.T, data.T).T
|
|
994
|
-
return components, var_exp, evals[:dim]
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
def _sample_denoising(data, k=10, num_sample=500, omega=0.2, metric="euclidean"):
|
|
998
|
-
"""
|
|
999
|
-
Perform denoising and greedy sampling based on mutual k-NN graph.
|
|
1000
|
-
|
|
1001
|
-
Parameters:
|
|
1002
|
-
data (ndarray): High-dimensional point cloud data.
|
|
1003
|
-
k (int): Number of neighbors for local density estimation.
|
|
1004
|
-
num_sample (int): Number of samples to retain.
|
|
1005
|
-
omega (float): Suppression factor during greedy sampling.
|
|
1006
|
-
metric (str): Distance metric used for kNN ('euclidean', 'cosine', etc).
|
|
1007
|
-
|
|
1008
|
-
Returns:
|
|
1009
|
-
inds (ndarray): Indices of sampled points.
|
|
1010
|
-
d (ndarray): Pairwise similarity matrix of sampled points.
|
|
1011
|
-
Fs (ndarray): Sampling scores at each step.
|
|
1012
|
-
"""
|
|
1013
|
-
if HAS_NUMBA:
|
|
1014
|
-
return _sample_denoising_numba(data, k, num_sample, omega, metric)
|
|
1015
|
-
else:
|
|
1016
|
-
return _sample_denoising_numpy(data, k, num_sample, omega, metric)
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
def _sample_denoising_numpy(data, k=10, num_sample=500, omega=0.2, metric="euclidean"):
|
|
1020
|
-
"""Original numpy implementation for fallback."""
|
|
1021
|
-
n = data.shape[0]
|
|
1022
|
-
X = squareform(pdist(data, metric))
|
|
1023
|
-
knn_indices = np.argsort(X)[:, :k]
|
|
1024
|
-
knn_dists = X[np.arange(X.shape[0])[:, None], knn_indices].copy()
|
|
1025
|
-
|
|
1026
|
-
sigmas, rhos = _smooth_knn_dist(knn_dists, k, local_connectivity=0)
|
|
1027
|
-
rows, cols, vals = _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos)
|
|
1028
|
-
result = coo_matrix((vals, (rows, cols)), shape=(n, n))
|
|
1029
|
-
result.eliminate_zeros()
|
|
1030
|
-
transpose = result.transpose()
|
|
1031
|
-
prod_matrix = result.multiply(transpose)
|
|
1032
|
-
result = result + transpose - prod_matrix
|
|
1033
|
-
result.eliminate_zeros()
|
|
1034
|
-
X = result.toarray()
|
|
1035
|
-
F = np.sum(X, 1)
|
|
1036
|
-
Fs = np.zeros(num_sample)
|
|
1037
|
-
Fs[0] = np.max(F)
|
|
1038
|
-
i = np.argmax(F)
|
|
1039
|
-
inds_all = np.arange(n)
|
|
1040
|
-
inds_left = inds_all > -1
|
|
1041
|
-
inds_left[i] = False
|
|
1042
|
-
inds = np.zeros(num_sample, dtype=int)
|
|
1043
|
-
inds[0] = i
|
|
1044
|
-
for j in np.arange(1, num_sample):
|
|
1045
|
-
F -= omega * X[i, :]
|
|
1046
|
-
Fmax = np.argmax(F[inds_left])
|
|
1047
|
-
# Exactly match external TDAvis implementation (including the indexing logic)
|
|
1048
|
-
Fs[j] = F[Fmax]
|
|
1049
|
-
i = inds_all[inds_left][Fmax]
|
|
1050
|
-
|
|
1051
|
-
inds_left[i] = False
|
|
1052
|
-
inds[j] = i
|
|
1053
|
-
d = np.zeros((num_sample, num_sample))
|
|
1054
|
-
|
|
1055
|
-
for j, i in enumerate(inds):
|
|
1056
|
-
d[j, :] = X[i, inds]
|
|
1057
|
-
return inds, d, Fs
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
def _sample_denoising_numba(data, k=10, num_sample=500, omega=0.2, metric="euclidean"):
|
|
1061
|
-
"""Optimized numba implementation."""
|
|
1062
|
-
n = data.shape[0]
|
|
1063
|
-
X = squareform(pdist(data, metric))
|
|
1064
|
-
knn_indices = np.argsort(X)[:, :k]
|
|
1065
|
-
knn_dists = X[np.arange(X.shape[0])[:, None], knn_indices].copy()
|
|
1066
|
-
|
|
1067
|
-
sigmas, rhos = _smooth_knn_dist(knn_dists, k, local_connectivity=0)
|
|
1068
|
-
rows, cols, vals = _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos)
|
|
1069
|
-
|
|
1070
|
-
# Build symmetric adjacency matrix using optimized function
|
|
1071
|
-
X_adj = _build_adjacency_matrix_numba(rows, cols, vals, n)
|
|
1072
|
-
|
|
1073
|
-
# Greedy sampling using optimized function
|
|
1074
|
-
inds, Fs = _greedy_sampling_numba(X_adj, num_sample, omega)
|
|
1075
|
-
|
|
1076
|
-
# Build final distance matrix
|
|
1077
|
-
d = _build_distance_matrix_numba(X_adj, inds)
|
|
1078
|
-
|
|
1079
|
-
return inds, d, Fs
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
@njit(fastmath=True)
|
|
1083
|
-
def _build_adjacency_matrix_numba(rows, cols, vals, n):
|
|
1084
|
-
"""Build symmetric adjacency matrix efficiently with numba.
|
|
1085
|
-
|
|
1086
|
-
This matches the scipy sparse matrix operations:
|
|
1087
|
-
result = result + transpose - prod_matrix
|
|
1088
|
-
where prod_matrix = result.multiply(transpose)
|
|
1089
|
-
"""
|
|
1090
|
-
# Initialize matrices
|
|
1091
|
-
X = np.zeros((n, n), dtype=np.float64)
|
|
1092
|
-
X_T = np.zeros((n, n), dtype=np.float64)
|
|
1093
|
-
|
|
1094
|
-
# Build adjacency matrix and its transpose simultaneously
|
|
1095
|
-
for i in range(len(rows)):
|
|
1096
|
-
X[rows[i], cols[i]] = vals[i]
|
|
1097
|
-
X_T[cols[i], rows[i]] = vals[i] # Transpose
|
|
1098
|
-
|
|
1099
|
-
# Apply the symmetrization formula: A = A + A^T - A ⊙ A^T (vectorized)
|
|
1100
|
-
# This matches scipy's: result + transpose - prod_matrix
|
|
1101
|
-
X[:, :] = X + X_T - X * X_T
|
|
1102
|
-
|
|
1103
|
-
return X
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
@njit(fastmath=True)
|
|
1107
|
-
def _greedy_sampling_numba(X, num_sample, omega):
|
|
1108
|
-
"""Optimized greedy sampling with numba."""
|
|
1109
|
-
n = X.shape[0]
|
|
1110
|
-
F = np.sum(X, axis=1)
|
|
1111
|
-
Fs = np.zeros(num_sample)
|
|
1112
|
-
inds = np.zeros(num_sample, dtype=np.int64)
|
|
1113
|
-
inds_left = np.ones(n, dtype=np.bool_)
|
|
1114
|
-
|
|
1115
|
-
# Initialize with maximum F
|
|
1116
|
-
i = np.argmax(F)
|
|
1117
|
-
Fs[0] = F[i]
|
|
1118
|
-
inds[0] = i
|
|
1119
|
-
inds_left[i] = False
|
|
1120
|
-
|
|
1121
|
-
# Greedy sampling loop
|
|
1122
|
-
for j in range(1, num_sample):
|
|
1123
|
-
# Update F values
|
|
1124
|
-
for k in range(n):
|
|
1125
|
-
F[k] -= omega * X[i, k]
|
|
1126
|
-
|
|
1127
|
-
# Find maximum among remaining points (matching numpy logic exactly)
|
|
1128
|
-
max_val = -np.inf
|
|
1129
|
-
max_idx = -1
|
|
1130
|
-
for k in range(n):
|
|
1131
|
-
if inds_left[k] and F[k] > max_val:
|
|
1132
|
-
max_val = F[k]
|
|
1133
|
-
max_idx = k
|
|
1134
|
-
|
|
1135
|
-
# Record the F value using the selected index (matching external TDAvis)
|
|
1136
|
-
i = max_idx
|
|
1137
|
-
Fs[j] = F[i]
|
|
1138
|
-
inds[j] = i
|
|
1139
|
-
inds_left[i] = False
|
|
1140
|
-
|
|
1141
|
-
return inds, Fs
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
@njit(fastmath=True)
|
|
1145
|
-
def _build_distance_matrix_numba(X, inds):
|
|
1146
|
-
"""Build final distance matrix efficiently with numba."""
|
|
1147
|
-
num_sample = len(inds)
|
|
1148
|
-
d = np.zeros((num_sample, num_sample))
|
|
1149
|
-
|
|
1150
|
-
for j in range(num_sample):
|
|
1151
|
-
for k in range(num_sample):
|
|
1152
|
-
d[j, k] = X[inds[j], inds[k]]
|
|
1153
|
-
|
|
1154
|
-
return d
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
@njit(fastmath=True)
|
|
1158
|
-
def _smooth_knn_dist(distances, k, n_iter=64, local_connectivity=0.0, bandwidth=1.0):
|
|
1159
|
-
"""
|
|
1160
|
-
Compute smoothed local distances for kNN graph with entropy balancing.
|
|
1161
|
-
|
|
1162
|
-
Parameters:
|
|
1163
|
-
distances (ndarray): kNN distance matrix.
|
|
1164
|
-
k (int): Number of neighbors.
|
|
1165
|
-
n_iter (int): Number of binary search iterations.
|
|
1166
|
-
local_connectivity (float): Minimum local connectivity.
|
|
1167
|
-
bandwidth (float): Bandwidth parameter.
|
|
1168
|
-
|
|
1169
|
-
Returns:
|
|
1170
|
-
sigmas (ndarray): Smoothed sigma values for each point.
|
|
1171
|
-
rhos (ndarray): Minimum distances (connectivity cutoff) for each point.
|
|
1172
|
-
"""
|
|
1173
|
-
target = np.log2(k) * bandwidth
|
|
1174
|
-
# target = np.log(k) * bandwidth
|
|
1175
|
-
# target = k
|
|
1176
|
-
|
|
1177
|
-
rho = np.zeros(distances.shape[0])
|
|
1178
|
-
result = np.zeros(distances.shape[0])
|
|
1179
|
-
|
|
1180
|
-
mean_distances = np.mean(distances)
|
|
1181
|
-
|
|
1182
|
-
for i in range(distances.shape[0]):
|
|
1183
|
-
lo = 0.0
|
|
1184
|
-
hi = np.inf
|
|
1185
|
-
mid = 1.0
|
|
1186
|
-
|
|
1187
|
-
# Vectorized computation of non-zero distances
|
|
1188
|
-
ith_distances = distances[i]
|
|
1189
|
-
non_zero_dists = ith_distances[ith_distances > 0.0]
|
|
1190
|
-
if non_zero_dists.shape[0] >= local_connectivity:
|
|
1191
|
-
index = int(np.floor(local_connectivity))
|
|
1192
|
-
interpolation = local_connectivity - index
|
|
1193
|
-
if index > 0:
|
|
1194
|
-
rho[i] = non_zero_dists[index - 1]
|
|
1195
|
-
if interpolation > 1e-5:
|
|
1196
|
-
rho[i] += interpolation * (non_zero_dists[index] - non_zero_dists[index - 1])
|
|
1197
|
-
else:
|
|
1198
|
-
rho[i] = interpolation * non_zero_dists[0]
|
|
1199
|
-
elif non_zero_dists.shape[0] > 0:
|
|
1200
|
-
rho[i] = np.max(non_zero_dists)
|
|
1201
|
-
|
|
1202
|
-
# Vectorized binary search loop - compute all at once instead of loop
|
|
1203
|
-
for _ in range(n_iter):
|
|
1204
|
-
# Vectorized computation: compute all distances at once
|
|
1205
|
-
d_array = distances[i, 1:] - rho[i]
|
|
1206
|
-
# Vectorized conditional: use np.where for conditional computation
|
|
1207
|
-
psum = np.sum(np.where(d_array > 0, np.exp(-(d_array / mid)), 1.0))
|
|
1208
|
-
|
|
1209
|
-
if np.fabs(psum - target) < 1e-5:
|
|
1210
|
-
break
|
|
1211
|
-
|
|
1212
|
-
if psum > target:
|
|
1213
|
-
hi = mid
|
|
1214
|
-
mid = (lo + hi) / 2.0
|
|
1215
|
-
else:
|
|
1216
|
-
lo = mid
|
|
1217
|
-
if hi == np.inf:
|
|
1218
|
-
mid *= 2
|
|
1219
|
-
else:
|
|
1220
|
-
mid = (lo + hi) / 2.0
|
|
1221
|
-
result[i] = mid
|
|
1222
|
-
# Optimized mean computation - reuse ith_distances
|
|
1223
|
-
if rho[i] > 0.0:
|
|
1224
|
-
mean_ith_distances = np.mean(ith_distances)
|
|
1225
|
-
if result[i] < 1e-3 * mean_ith_distances:
|
|
1226
|
-
result[i] = 1e-3 * mean_ith_distances
|
|
1227
|
-
else:
|
|
1228
|
-
if result[i] < 1e-3 * mean_distances:
|
|
1229
|
-
result[i] = 1e-3 * mean_distances
|
|
1230
|
-
|
|
1231
|
-
return result, rho
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
@njit(parallel=True, fastmath=True)
|
|
1235
|
-
def _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos):
|
|
1236
|
-
"""
|
|
1237
|
-
Compute membership strength matrix from smoothed kNN graph.
|
|
1238
|
-
|
|
1239
|
-
Parameters:
|
|
1240
|
-
knn_indices (ndarray): Indices of k-nearest neighbors.
|
|
1241
|
-
knn_dists (ndarray): Corresponding distances.
|
|
1242
|
-
sigmas (ndarray): Local bandwidths.
|
|
1243
|
-
rhos (ndarray): Minimum distance thresholds.
|
|
1244
|
-
|
|
1245
|
-
Returns:
|
|
1246
|
-
rows (ndarray): Row indices for sparse matrix.
|
|
1247
|
-
cols (ndarray): Column indices for sparse matrix.
|
|
1248
|
-
vals (ndarray): Weight values for sparse matrix.
|
|
1249
|
-
"""
|
|
1250
|
-
n_samples = knn_indices.shape[0]
|
|
1251
|
-
n_neighbors = knn_indices.shape[1]
|
|
1252
|
-
rows = np.zeros((n_samples * n_neighbors), dtype=np.int64)
|
|
1253
|
-
cols = np.zeros((n_samples * n_neighbors), dtype=np.int64)
|
|
1254
|
-
vals = np.zeros((n_samples * n_neighbors), dtype=np.float64)
|
|
1255
|
-
for i in range(n_samples):
|
|
1256
|
-
for j in range(n_neighbors):
|
|
1257
|
-
if knn_indices[i, j] == -1:
|
|
1258
|
-
continue # We didn't get the full knn for i
|
|
1259
|
-
if knn_indices[i, j] == i:
|
|
1260
|
-
val = 0.0
|
|
1261
|
-
elif knn_dists[i, j] - rhos[i] <= 0.0:
|
|
1262
|
-
val = 1.0
|
|
1263
|
-
else:
|
|
1264
|
-
val = np.exp(-((knn_dists[i, j] - rhos[i]) / (sigmas[i])))
|
|
1265
|
-
# val = ((knn_dists[i, j] - rhos[i]) / (sigmas[i]))
|
|
1266
|
-
|
|
1267
|
-
rows[i * n_neighbors + j] = i
|
|
1268
|
-
cols[i * n_neighbors + j] = knn_indices[i, j]
|
|
1269
|
-
vals[i * n_neighbors + j] = val
|
|
1270
|
-
|
|
1271
|
-
return rows, cols, vals
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
def _second_build(data, indstemp, nbs=800, metric="cosine"):
|
|
1275
|
-
"""
|
|
1276
|
-
Reconstruct distance matrix after denoising for persistent homology.
|
|
1277
|
-
|
|
1278
|
-
Parameters:
|
|
1279
|
-
data (ndarray): PCA-reduced data matrix.
|
|
1280
|
-
indstemp (ndarray): Indices of sampled points.
|
|
1281
|
-
nbs (int): Number of neighbors in reconstructed graph.
|
|
1282
|
-
metric (str): Distance metric ('cosine', 'euclidean', etc).
|
|
1283
|
-
|
|
1284
|
-
Returns:
|
|
1285
|
-
d (ndarray): Symmetric distance matrix used for persistent homology.
|
|
1286
|
-
"""
|
|
1287
|
-
# Filter the data using the sampled point indices
|
|
1288
|
-
data = data[indstemp, :]
|
|
1289
|
-
|
|
1290
|
-
# Compute the pairwise distance matrix
|
|
1291
|
-
X = squareform(pdist(data, metric))
|
|
1292
|
-
knn_indices = np.argsort(X)[:, :nbs]
|
|
1293
|
-
knn_dists = X[np.arange(X.shape[0])[:, None], knn_indices].copy()
|
|
1294
|
-
|
|
1295
|
-
# Compute smoothed kernel widths
|
|
1296
|
-
sigmas, rhos = _smooth_knn_dist(knn_dists, nbs, local_connectivity=0)
|
|
1297
|
-
rows, cols, vals = _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos)
|
|
1298
|
-
|
|
1299
|
-
# Construct a sparse graph
|
|
1300
|
-
result = coo_matrix((vals, (rows, cols)), shape=(X.shape[0], X.shape[0]))
|
|
1301
|
-
result.eliminate_zeros()
|
|
1302
|
-
transpose = result.transpose()
|
|
1303
|
-
prod_matrix = result.multiply(transpose)
|
|
1304
|
-
result = result + transpose - prod_matrix
|
|
1305
|
-
result.eliminate_zeros()
|
|
1306
|
-
|
|
1307
|
-
# Build the final distance matrix
|
|
1308
|
-
d = result.toarray()
|
|
1309
|
-
# Match external TDAvis: direct negative log without epsilon handling
|
|
1310
|
-
# Temporarily suppress divide by zero warning to match external behavior
|
|
1311
|
-
with np.errstate(divide="ignore", invalid="ignore"):
|
|
1312
|
-
d = -np.log(d)
|
|
1313
|
-
np.fill_diagonal(d, 0)
|
|
1314
|
-
|
|
1315
|
-
return d
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
def _run_shuffle_analysis(sspikes, num_shuffles=1000, num_cores=4, progress_bar=True, **kwargs):
|
|
1319
|
-
"""Perform shuffle analysis with optimized computation."""
|
|
1320
|
-
return _run_shuffle_analysis_multiprocessing(
|
|
1321
|
-
sspikes, num_shuffles, num_cores, progress_bar, **kwargs
|
|
1322
|
-
)
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
def _run_shuffle_analysis_multiprocessing(
|
|
1326
|
-
sspikes, num_shuffles=1000, num_cores=4, progress_bar=True, **kwargs
|
|
1327
|
-
):
|
|
1328
|
-
"""Original multiprocessing implementation for fallback."""
|
|
1329
|
-
# Use numpy arrays with NaN for failed results (more efficient than None filtering)
|
|
1330
|
-
max_lifetimes = {
|
|
1331
|
-
0: np.full(num_shuffles, np.nan),
|
|
1332
|
-
1: np.full(num_shuffles, np.nan),
|
|
1333
|
-
2: np.full(num_shuffles, np.nan),
|
|
1334
|
-
}
|
|
1335
|
-
|
|
1336
|
-
# Estimate runtime with a test iteration
|
|
1337
|
-
logging.info("Running test iteration to estimate runtime...")
|
|
1338
|
-
|
|
1339
|
-
_ = _process_single_shuffle((0, sspikes, kwargs))
|
|
1340
|
-
|
|
1341
|
-
# Prepare task list
|
|
1342
|
-
tasks = [(i, sspikes, kwargs) for i in range(num_shuffles)]
|
|
1343
|
-
logging.info(
|
|
1344
|
-
f"Starting shuffle analysis with {num_shuffles} iterations using {num_cores} cores..."
|
|
1345
|
-
)
|
|
1346
|
-
|
|
1347
|
-
# Use multiprocessing pool for parallel processing
|
|
1348
|
-
with mp.Pool(processes=num_cores) as pool:
|
|
1349
|
-
results = list(pool.imap(_process_single_shuffle, tasks))
|
|
1350
|
-
logging.info("Shuffle analysis completed")
|
|
1351
|
-
|
|
1352
|
-
# Collect results - use indexing instead of append for better performance
|
|
1353
|
-
for idx, res in enumerate(results):
|
|
1354
|
-
for dim, lifetime in res.items():
|
|
1355
|
-
max_lifetimes[dim][idx] = lifetime
|
|
1356
|
-
|
|
1357
|
-
# Filter out NaN values (failed results) - convert to list for consistency
|
|
1358
|
-
for dim in max_lifetimes:
|
|
1359
|
-
max_lifetimes[dim] = max_lifetimes[dim][~np.isnan(max_lifetimes[dim])].tolist()
|
|
1360
|
-
|
|
1361
|
-
return max_lifetimes
|
|
1362
|
-
|
|
1363
|
-
|
|
1364
|
-
@njit(fastmath=True)
|
|
1365
|
-
def _fast_pca_transform(data, components):
|
|
1366
|
-
"""Fast PCA transformation using numba."""
|
|
1367
|
-
return np.dot(data, components.T)
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
def _process_single_shuffle(args):
|
|
1371
|
-
"""Process a single shuffle task."""
|
|
1372
|
-
i, sspikes, kwargs = args
|
|
1373
|
-
try:
|
|
1374
|
-
shuffled_data = _shuffle_spike_trains(sspikes)
|
|
1375
|
-
persistence = _compute_persistence(shuffled_data, **kwargs)
|
|
1376
|
-
|
|
1377
|
-
dim_max_lifetimes = {}
|
|
1378
|
-
for dim in [0, 1, 2]:
|
|
1379
|
-
if dim < len(persistence["dgms"]):
|
|
1380
|
-
# Filter out infinite values
|
|
1381
|
-
valid_bars = [bar for bar in persistence["dgms"][dim] if not np.isinf(bar[1])]
|
|
1382
|
-
if valid_bars:
|
|
1383
|
-
lifetimes = [bar[1] - bar[0] for bar in valid_bars]
|
|
1384
|
-
if lifetimes:
|
|
1385
|
-
dim_max_lifetimes[dim] = max(lifetimes)
|
|
1386
|
-
return dim_max_lifetimes
|
|
1387
|
-
except Exception as e:
|
|
1388
|
-
print(f"Shuffle {i} failed: {str(e)}")
|
|
1389
|
-
return {}
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
def _shuffle_spike_trains(sspikes):
|
|
1393
|
-
"""Perform random circular shift on spike trains."""
|
|
1394
|
-
shuffled = sspikes.copy()
|
|
1395
|
-
num_neurons = shuffled.shape[1]
|
|
1396
|
-
|
|
1397
|
-
# Independent shift for each neuron
|
|
1398
|
-
for n in range(num_neurons):
|
|
1399
|
-
shift = np.random.randint(0, int(shuffled.shape[0] * 0.1))
|
|
1400
|
-
shuffled[:, n] = np.roll(shuffled[:, n], shift)
|
|
1401
|
-
|
|
1402
|
-
return shuffled
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
def _plot_barcode(persistence):
|
|
1406
|
-
"""
|
|
1407
|
-
Plot barcode diagram from persistent homology result.
|
|
1408
|
-
|
|
1409
|
-
Parameters:
|
|
1410
|
-
persistence (dict): Persistent homology result with 'dgms' key.
|
|
1411
|
-
"""
|
|
1412
|
-
cs = np.repeat([[0, 0.55, 0.2]], 3).reshape(3, 3).T # RGB color for each dimension
|
|
1413
|
-
alpha = 1
|
|
1414
|
-
inf_delta = 0.1
|
|
1415
|
-
colormap = cs
|
|
1416
|
-
dgms = persistence["dgms"]
|
|
1417
|
-
maxdim = len(dgms) - 1
|
|
1418
|
-
dims = np.arange(maxdim + 1)
|
|
1419
|
-
labels = ["$H_0$", "$H_1$", "$H_2$"]
|
|
1420
|
-
|
|
1421
|
-
# Determine axis range
|
|
1422
|
-
min_birth, max_death = 0, 0
|
|
1423
|
-
for dim in dims:
|
|
1424
|
-
persistence_dim = dgms[dim][~np.isinf(dgms[dim][:, 1]), :]
|
|
1425
|
-
if persistence_dim.size > 0:
|
|
1426
|
-
min_birth = min(min_birth, np.min(persistence_dim))
|
|
1427
|
-
max_death = max(max_death, np.max(persistence_dim))
|
|
1428
|
-
|
|
1429
|
-
delta = (max_death - min_birth) * inf_delta
|
|
1430
|
-
infinity = max_death + delta
|
|
1431
|
-
axis_start = min_birth - delta
|
|
1432
|
-
|
|
1433
|
-
# Create plot
|
|
1434
|
-
fig = plt.figure(figsize=(10, 6))
|
|
1435
|
-
gs = gridspec.GridSpec(len(dims), 1)
|
|
1436
|
-
|
|
1437
|
-
for dim in dims:
|
|
1438
|
-
axes = plt.subplot(gs[dim])
|
|
1439
|
-
axes.axis("on")
|
|
1440
|
-
axes.set_yticks([])
|
|
1441
|
-
axes.set_ylabel(labels[dim], rotation=0, labelpad=20, fontsize=12)
|
|
1442
|
-
|
|
1443
|
-
d = np.copy(dgms[dim])
|
|
1444
|
-
d[np.isinf(d[:, 1]), 1] = infinity
|
|
1445
|
-
dlife = d[:, 1] - d[:, 0]
|
|
1446
|
-
|
|
1447
|
-
# Select top 30 bars by lifetime
|
|
1448
|
-
dinds = np.argsort(dlife)[-30:]
|
|
1449
|
-
if dim > 0:
|
|
1450
|
-
dinds = dinds[np.flip(np.argsort(d[dinds, 0]))]
|
|
1451
|
-
|
|
1452
|
-
axes.barh(
|
|
1453
|
-
0.5 + np.arange(len(dinds)),
|
|
1454
|
-
dlife[dinds],
|
|
1455
|
-
height=0.8,
|
|
1456
|
-
left=d[dinds, 0],
|
|
1457
|
-
alpha=alpha,
|
|
1458
|
-
color=colormap[dim],
|
|
1459
|
-
linewidth=0,
|
|
1460
|
-
)
|
|
1461
|
-
|
|
1462
|
-
axes.plot([0, 0], [0, len(dinds)], c="k", linestyle="-", lw=1)
|
|
1463
|
-
axes.plot([0, len(dinds)], [0, 0], c="k", linestyle="-", lw=1)
|
|
1464
|
-
axes.set_xlim([axis_start, infinity])
|
|
1465
|
-
|
|
1466
|
-
plt.tight_layout()
|
|
1467
|
-
return fig
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
def _plot_barcode_with_shuffle(persistence, shuffle_max):
|
|
1471
|
-
"""
|
|
1472
|
-
Plot barcode with shuffle region markers.
|
|
1473
|
-
"""
|
|
1474
|
-
# Handle case where shuffle_max is None
|
|
1475
|
-
if shuffle_max is None:
|
|
1476
|
-
shuffle_max = {}
|
|
1477
|
-
|
|
1478
|
-
cs = np.repeat([[0, 0.55, 0.2]], 3).reshape(3, 3).T
|
|
1479
|
-
alpha = 1
|
|
1480
|
-
inf_delta = 0.1
|
|
1481
|
-
colormap = cs
|
|
1482
|
-
maxdim = len(persistence["dgms"]) - 1
|
|
1483
|
-
dims = np.arange(maxdim + 1)
|
|
1484
|
-
|
|
1485
|
-
min_birth, max_death = 0, 0
|
|
1486
|
-
for dim in dims:
|
|
1487
|
-
# Filter out infinite values
|
|
1488
|
-
valid_bars = [bar for bar in persistence["dgms"][dim] if not np.isinf(bar[1])]
|
|
1489
|
-
if valid_bars:
|
|
1490
|
-
min_birth = min(min_birth, np.min(valid_bars))
|
|
1491
|
-
max_death = max(max_death, np.max(valid_bars))
|
|
1492
|
-
|
|
1493
|
-
# Handle case with no valid bars
|
|
1494
|
-
if max_death == 0 and min_birth == 0:
|
|
1495
|
-
min_birth = 0
|
|
1496
|
-
max_death = 1
|
|
1497
|
-
|
|
1498
|
-
delta = (max_death - min_birth) * inf_delta
|
|
1499
|
-
infinity = max_death + delta
|
|
1500
|
-
|
|
1501
|
-
# Create figure
|
|
1502
|
-
fig = plt.figure(figsize=(10, 8))
|
|
1503
|
-
gs = gridspec.GridSpec(len(dims), 1)
|
|
1504
|
-
|
|
1505
|
-
# Get shuffle thresholds (99.9th percentile for each dimension)
|
|
1506
|
-
thresholds = {}
|
|
1507
|
-
for dim in dims:
|
|
1508
|
-
if dim in shuffle_max and shuffle_max[dim]:
|
|
1509
|
-
thresholds[dim] = np.percentile(shuffle_max[dim], 99.9)
|
|
1510
|
-
else:
|
|
1511
|
-
thresholds[dim] = 0
|
|
1512
|
-
|
|
1513
|
-
for _, dim in enumerate(dims):
|
|
1514
|
-
axes = plt.subplot(gs[dim])
|
|
1515
|
-
axes.axis("off")
|
|
1516
|
-
|
|
1517
|
-
# Add gray background to represent shuffle region
|
|
1518
|
-
if dim in thresholds:
|
|
1519
|
-
axes.axvspan(0, thresholds[dim], alpha=0.2, color="gray", zorder=-3)
|
|
1520
|
-
axes.axvline(x=thresholds[dim], color="gray", linestyle="--", alpha=0.7)
|
|
1521
|
-
|
|
1522
|
-
# Do not pre-filter out infinite bars; copy the full diagram instead
|
|
1523
|
-
d = np.copy(persistence["dgms"][dim])
|
|
1524
|
-
if d.size == 0:
|
|
1525
|
-
d = np.zeros((0, 2))
|
|
1526
|
-
|
|
1527
|
-
# Map infinite death values to a finite upper bound for visualization
|
|
1528
|
-
d[np.isinf(d[:, 1]), 1] = infinity
|
|
1529
|
-
dlife = d[:, 1] - d[:, 0]
|
|
1530
|
-
|
|
1531
|
-
# Select top 30 longest-lived bars
|
|
1532
|
-
if len(dlife) > 0:
|
|
1533
|
-
dinds = np.argsort(dlife)[-30:]
|
|
1534
|
-
if dim > 0:
|
|
1535
|
-
dinds = dinds[np.flip(np.argsort(d[dinds, 0]))]
|
|
1536
|
-
|
|
1537
|
-
# Mark significant bars
|
|
1538
|
-
significant_bars = []
|
|
1539
|
-
for idx in dinds:
|
|
1540
|
-
if dlife[idx] > thresholds.get(dim, 0):
|
|
1541
|
-
significant_bars.append(idx)
|
|
1542
|
-
|
|
1543
|
-
# Draw bars
|
|
1544
|
-
for i, idx in enumerate(dinds):
|
|
1545
|
-
color = "red" if idx in significant_bars else colormap[dim]
|
|
1546
|
-
axes.barh(
|
|
1547
|
-
0.5 + i,
|
|
1548
|
-
dlife[idx],
|
|
1549
|
-
height=0.8,
|
|
1550
|
-
left=d[idx, 0],
|
|
1551
|
-
alpha=alpha,
|
|
1552
|
-
color=color,
|
|
1553
|
-
linewidth=0,
|
|
1554
|
-
)
|
|
1555
|
-
|
|
1556
|
-
indsall = len(dinds)
|
|
1557
|
-
else:
|
|
1558
|
-
indsall = 0
|
|
1559
|
-
|
|
1560
|
-
axes.plot([0, 0], [0, indsall], c="k", linestyle="-", lw=1)
|
|
1561
|
-
axes.plot([0, indsall], [0, 0], c="k", linestyle="-", lw=1)
|
|
1562
|
-
axes.set_xlim([0, infinity])
|
|
1563
|
-
axes.set_title(f"$H_{dim}$", loc="left")
|
|
1564
|
-
|
|
1565
|
-
plt.tight_layout()
|
|
1566
|
-
return fig
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
def decode_circular_coordinates(
|
|
1570
|
-
persistence_result: dict[str, Any],
|
|
1571
|
-
spike_data: dict[str, Any],
|
|
1572
|
-
real_ground: bool = True,
|
|
1573
|
-
real_of: bool = True,
|
|
1574
|
-
save_path: str | None = None,
|
|
1575
|
-
) -> dict[str, Any]:
|
|
1576
|
-
"""
|
|
1577
|
-
Decode circular coordinates (bump positions) from cohomology.
|
|
1578
|
-
|
|
1579
|
-
Parameters:
|
|
1580
|
-
persistence_result : dict containing persistence analysis results with keys:
|
|
1581
|
-
- 'persistence': persistent homology result
|
|
1582
|
-
- 'indstemp': indices of sampled points
|
|
1583
|
-
- 'movetimes': selected time points
|
|
1584
|
-
- 'n_points': number of sampled points
|
|
1585
|
-
spike_data : dict, optional
|
|
1586
|
-
Spike data dictionary containing 'spike', 't', and optionally 'x', 'y'
|
|
1587
|
-
real_ground : bool
|
|
1588
|
-
Whether x, y, t ground truth exists
|
|
1589
|
-
real_of : bool
|
|
1590
|
-
Whether experiment was performed in open field
|
|
1591
|
-
save_path : str, optional
|
|
1592
|
-
Path to save decoding results. If None, saves to 'Results/spikes_decoding.npz'
|
|
1593
|
-
|
|
1594
|
-
Returns:
|
|
1595
|
-
dict : Dictionary containing decoding results with keys:
|
|
1596
|
-
- 'coords': decoded coordinates for all timepoints
|
|
1597
|
-
- 'coordsbox': decoded coordinates for box timepoints
|
|
1598
|
-
- 'times': time indices for coords
|
|
1599
|
-
- 'times_box': time indices for coordsbox
|
|
1600
|
-
- 'centcosall': cosine centroids
|
|
1601
|
-
- 'centsinall': sine centroids
|
|
1602
|
-
"""
|
|
1603
|
-
ph_classes = [0, 1] # Decode the ith most persistent cohomology class
|
|
1604
|
-
num_circ = len(ph_classes)
|
|
1605
|
-
dec_tresh = 0.99
|
|
1606
|
-
coeff = 47
|
|
1607
|
-
|
|
1608
|
-
# Extract persistence analysis results
|
|
1609
|
-
persistence = persistence_result["persistence"]
|
|
1610
|
-
indstemp = persistence_result["indstemp"]
|
|
1611
|
-
movetimes = persistence_result["movetimes"]
|
|
1612
|
-
n_points = persistence_result["n_points"]
|
|
1613
|
-
|
|
1614
|
-
diagrams = persistence["dgms"] # the multiset describing the lives of the persistence classes
|
|
1615
|
-
cocycles = persistence["cocycles"][1] # the cocycle representatives for the 1-dim classes
|
|
1616
|
-
dists_land = persistence["dperm2all"] # the pairwise distance between the points
|
|
1617
|
-
births1 = diagrams[1][:, 0] # the time of birth for the 1-dim classes
|
|
1618
|
-
deaths1 = diagrams[1][:, 1] # the time of death for the 1-dim classes
|
|
1619
|
-
deaths1[np.isinf(deaths1)] = 0
|
|
1620
|
-
lives1 = deaths1 - births1 # the lifetime for the 1-dim classes
|
|
1621
|
-
iMax = np.argsort(lives1)
|
|
1622
|
-
coords1 = np.zeros((num_circ, len(indstemp)))
|
|
1623
|
-
threshold = births1[iMax[-2]] + (deaths1[iMax[-2]] - births1[iMax[-2]]) * dec_tresh
|
|
1624
|
-
|
|
1625
|
-
for c in ph_classes:
|
|
1626
|
-
cocycle = cocycles[iMax[-(c + 1)]]
|
|
1627
|
-
coords1[c, :], inds = _get_coords(cocycle, threshold, len(indstemp), dists_land, coeff)
|
|
1628
|
-
|
|
1629
|
-
if real_ground: # 用户所提供的数据是否有真实的xyt
|
|
1630
|
-
sspikes, xx, yy, tt = embed_spike_trains(
|
|
1631
|
-
spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=True)
|
|
1632
|
-
)
|
|
1633
|
-
else:
|
|
1634
|
-
sspikes = embed_spike_trains(
|
|
1635
|
-
spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=False)
|
|
1636
|
-
)
|
|
1637
|
-
|
|
1638
|
-
num_neurons = sspikes.shape[1]
|
|
1639
|
-
centcosall = np.zeros((num_neurons, 2, n_points))
|
|
1640
|
-
centsinall = np.zeros((num_neurons, 2, n_points))
|
|
1641
|
-
dspk = preprocessing.scale(sspikes[movetimes[indstemp], :])
|
|
1642
|
-
|
|
1643
|
-
for neurid in range(num_neurons):
|
|
1644
|
-
spktemp = dspk[:, neurid].copy()
|
|
1645
|
-
centcosall[neurid, :, :] = np.multiply(np.cos(coords1[:, :] * 2 * np.pi), spktemp)
|
|
1646
|
-
centsinall[neurid, :, :] = np.multiply(np.sin(coords1[:, :] * 2 * np.pi), spktemp)
|
|
1647
|
-
|
|
1648
|
-
if real_ground: # 用户所提供的数据是否有真实的xyt
|
|
1649
|
-
sspikes, xx, yy, tt = embed_spike_trains(
|
|
1650
|
-
spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=True)
|
|
1651
|
-
)
|
|
1652
|
-
spikes, __, __, __ = embed_spike_trains(
|
|
1653
|
-
spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
|
|
1654
|
-
)
|
|
1655
|
-
else:
|
|
1656
|
-
sspikes = embed_spike_trains(
|
|
1657
|
-
spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=False)
|
|
1658
|
-
)
|
|
1659
|
-
spikes = embed_spike_trains(
|
|
1660
|
-
spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=False)
|
|
1661
|
-
)
|
|
1662
|
-
|
|
1663
|
-
times = np.where(np.sum(spikes > 0, 1) >= 1)[0]
|
|
1664
|
-
dspk = preprocessing.scale(sspikes)
|
|
1665
|
-
sspikes = sspikes[times, :]
|
|
1666
|
-
dspk = dspk[times, :]
|
|
1667
|
-
|
|
1668
|
-
a = np.zeros((len(sspikes[:, 0]), 2, num_neurons))
|
|
1669
|
-
for n in range(num_neurons):
|
|
1670
|
-
a[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centcosall[n, :, :], 1))
|
|
1671
|
-
|
|
1672
|
-
c = np.zeros((len(sspikes[:, 0]), 2, num_neurons))
|
|
1673
|
-
for n in range(num_neurons):
|
|
1674
|
-
c[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centsinall[n, :, :], 1))
|
|
1675
|
-
|
|
1676
|
-
mtot2 = np.sum(c, 2)
|
|
1677
|
-
mtot1 = np.sum(a, 2)
|
|
1678
|
-
coords = np.arctan2(mtot2, mtot1) % (2 * np.pi)
|
|
1679
|
-
|
|
1680
|
-
if real_of: # 用户的数据是否是来自真实的OF场地
|
|
1681
|
-
coordsbox = coords.copy()
|
|
1682
|
-
times_box = times.copy()
|
|
1683
|
-
else:
|
|
1684
|
-
sspikes, xx, yy, tt = embed_spike_trains(
|
|
1685
|
-
spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=True)
|
|
1686
|
-
)
|
|
1687
|
-
spikes, __, __, __ = embed_spike_trains(
|
|
1688
|
-
spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
|
|
1689
|
-
)
|
|
1690
|
-
dspk = preprocessing.scale(sspikes)
|
|
1691
|
-
times_box = np.where(np.sum(spikes > 0, 1) >= 1)[0]
|
|
1692
|
-
dspk = dspk[times_box, :]
|
|
1693
|
-
|
|
1694
|
-
a = np.zeros((len(times_box), 2, num_neurons))
|
|
1695
|
-
for n in range(num_neurons):
|
|
1696
|
-
a[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centcosall[n, :, :], 1))
|
|
1697
|
-
|
|
1698
|
-
c = np.zeros((len(times_box), 2, num_neurons))
|
|
1699
|
-
for n in range(num_neurons):
|
|
1700
|
-
c[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centsinall[n, :, :], 1))
|
|
1701
|
-
|
|
1702
|
-
mtot2 = np.sum(c, 2)
|
|
1703
|
-
mtot1 = np.sum(a, 2)
|
|
1704
|
-
coordsbox = np.arctan2(mtot2, mtot1) % (2 * np.pi)
|
|
1705
|
-
|
|
1706
|
-
# Prepare results dictionary
|
|
1707
|
-
results = {
|
|
1708
|
-
"coords": coords,
|
|
1709
|
-
"coordsbox": coordsbox,
|
|
1710
|
-
"times": times,
|
|
1711
|
-
"times_box": times_box,
|
|
1712
|
-
"centcosall": centcosall,
|
|
1713
|
-
"centsinall": centsinall,
|
|
1714
|
-
}
|
|
1715
|
-
|
|
1716
|
-
# Save results
|
|
1717
|
-
if save_path is None:
|
|
1718
|
-
os.makedirs("Results", exist_ok=True)
|
|
1719
|
-
save_path = "Results/spikes_decoding.npz"
|
|
1720
|
-
|
|
1721
|
-
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
1722
|
-
np.savez_compressed(save_path, **results)
|
|
1723
|
-
|
|
1724
|
-
return results
|
|
1725
|
-
|
|
1726
|
-
|
|
1727
|
-
def plot_cohomap(
|
|
1728
|
-
decoding_result: dict[str, Any],
|
|
1729
|
-
position_data: dict[str, Any],
|
|
1730
|
-
save_path: str | None = None,
|
|
1731
|
-
show: bool = False,
|
|
1732
|
-
figsize: tuple[int, int] = (10, 4),
|
|
1733
|
-
dpi: int = 300,
|
|
1734
|
-
subsample: int = 10,
|
|
1735
|
-
) -> plt.Figure:
|
|
1736
|
-
"""
|
|
1737
|
-
Visualize CohoMap 1.0: decoded circular coordinates mapped onto spatial trajectory.
|
|
1738
|
-
|
|
1739
|
-
Creates a two-panel visualization showing how the two decoded circular coordinates
|
|
1740
|
-
vary across the animal's spatial trajectory. Each panel displays the spatial path
|
|
1741
|
-
colored by the cosine of one circular coordinate dimension.
|
|
1742
|
-
|
|
1743
|
-
Parameters:
|
|
1744
|
-
decoding_result : dict
|
|
1745
|
-
Dictionary from decode_circular_coordinates() containing:
|
|
1746
|
-
- 'coordsbox': decoded coordinates for box timepoints (n_times x n_dims)
|
|
1747
|
-
- 'times_box': time indices for coordsbox
|
|
1748
|
-
position_data : dict
|
|
1749
|
-
Position data containing 'x' and 'y' arrays for spatial coordinates
|
|
1750
|
-
save_path : str, optional
|
|
1751
|
-
Path to save the visualization. If None, no save performed
|
|
1752
|
-
show : bool, default=False
|
|
1753
|
-
Whether to display the visualization
|
|
1754
|
-
figsize : tuple[int, int], default=(10, 4)
|
|
1755
|
-
Figure size (width, height) in inches
|
|
1756
|
-
dpi : int, default=300
|
|
1757
|
-
Resolution for saved figure
|
|
1758
|
-
subsample : int, default=10
|
|
1759
|
-
Subsampling interval for plotting (plot every Nth timepoint)
|
|
1760
|
-
|
|
1761
|
-
Returns:
|
|
1762
|
-
plt.Figure : The matplotlib figure object
|
|
1763
|
-
|
|
1764
|
-
Raises:
|
|
1765
|
-
KeyError : If required keys are missing from input dictionaries
|
|
1766
|
-
ValueError : If data dimensions are inconsistent
|
|
1767
|
-
IndexError : If time indices are out of bounds
|
|
1768
|
-
|
|
1769
|
-
Examples:
|
|
1770
|
-
>>> # Decode coordinates
|
|
1771
|
-
>>> decoding = decode_circular_coordinates(persistence_result, spike_data)
|
|
1772
|
-
>>> # Visualize with trajectory data
|
|
1773
|
-
>>> fig = plot_cohomap(
|
|
1774
|
-
... decoding,
|
|
1775
|
-
... position_data={'x': xx, 'y': yy},
|
|
1776
|
-
... save_path='cohomap.png',
|
|
1777
|
-
... show=True
|
|
1778
|
-
... )
|
|
1779
|
-
"""
|
|
1780
|
-
try:
|
|
1781
|
-
# Extract data
|
|
1782
|
-
coordsbox = decoding_result["coordsbox"]
|
|
1783
|
-
times_box = decoding_result["times_box"]
|
|
1784
|
-
xx = position_data["x"]
|
|
1785
|
-
yy = position_data["y"]
|
|
1786
|
-
|
|
1787
|
-
# Subsample time indices for plotting
|
|
1788
|
-
plot_times = np.arange(0, len(coordsbox), subsample)
|
|
1789
|
-
|
|
1790
|
-
# Create a two-panel figure (one per cohomology dimension)
|
|
1791
|
-
plt.set_cmap("viridis")
|
|
1792
|
-
fig, ax = plt.subplots(1, 2, figsize=figsize)
|
|
1793
|
-
|
|
1794
|
-
# Plot for the first circular coordinate
|
|
1795
|
-
ax[0].axis("off")
|
|
1796
|
-
ax[0].set_aspect("equal", "box")
|
|
1797
|
-
im0 = ax[0].scatter(
|
|
1798
|
-
xx[times_box][plot_times],
|
|
1799
|
-
yy[times_box][plot_times],
|
|
1800
|
-
c=np.cos(coordsbox[plot_times, 0]),
|
|
1801
|
-
s=8,
|
|
1802
|
-
cmap="viridis",
|
|
1803
|
-
)
|
|
1804
|
-
plt.colorbar(im0, ax=ax[0], label="cos(coord)")
|
|
1805
|
-
ax[0].set_title("CohoMap Dim 1", fontsize=10)
|
|
1806
|
-
|
|
1807
|
-
# Plot for the second circular coordinate
|
|
1808
|
-
ax[1].axis("off")
|
|
1809
|
-
ax[1].set_aspect("equal", "box")
|
|
1810
|
-
im1 = ax[1].scatter(
|
|
1811
|
-
xx[times_box][plot_times],
|
|
1812
|
-
yy[times_box][plot_times],
|
|
1813
|
-
c=np.cos(coordsbox[plot_times, 1]),
|
|
1814
|
-
s=8,
|
|
1815
|
-
cmap="viridis",
|
|
1816
|
-
)
|
|
1817
|
-
plt.colorbar(im1, ax=ax[1], label="cos(coord)")
|
|
1818
|
-
ax[1].set_title("CohoMap Dim 2", fontsize=10)
|
|
1819
|
-
|
|
1820
|
-
plt.tight_layout()
|
|
1821
|
-
|
|
1822
|
-
# Save if path provided
|
|
1823
|
-
if save_path:
|
|
1824
|
-
try:
|
|
1825
|
-
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
1826
|
-
plt.savefig(save_path, dpi=dpi)
|
|
1827
|
-
print(f"CohoMap visualization saved to {save_path}")
|
|
1828
|
-
except Exception as e:
|
|
1829
|
-
print(f"Error saving CohoMap visualization: {e}")
|
|
1830
|
-
|
|
1831
|
-
# Show if requested
|
|
1832
|
-
if show:
|
|
1833
|
-
plt.show()
|
|
1834
|
-
else:
|
|
1835
|
-
plt.close(fig)
|
|
1836
|
-
|
|
1837
|
-
return fig
|
|
1838
|
-
|
|
1839
|
-
except (KeyError, ValueError, IndexError) as e:
|
|
1840
|
-
print(f"CohoMap visualization failed: {e}")
|
|
1841
|
-
raise
|
|
1842
|
-
except Exception as e:
|
|
1843
|
-
print(f"Unexpected error in CohoMap visualization: {e}")
|
|
1844
|
-
raise
|
|
1845
|
-
|
|
1846
|
-
|
|
1847
|
-
def plot_3d_bump_on_torus(
|
|
1848
|
-
decoding_result: dict[str, Any] | str,
|
|
1849
|
-
spike_data: dict[str, Any],
|
|
1850
|
-
config: CANN2DPlotConfig | None = None,
|
|
1851
|
-
save_path: str | None = None,
|
|
1852
|
-
numangsint: int = 51,
|
|
1853
|
-
r1: float = 1.5,
|
|
1854
|
-
r2: float = 1.0,
|
|
1855
|
-
window_size: int = 300,
|
|
1856
|
-
frame_step: int = 5,
|
|
1857
|
-
n_frames: int = 20,
|
|
1858
|
-
fps: int = 5,
|
|
1859
|
-
show_progress: bool = True,
|
|
1860
|
-
show: bool = True,
|
|
1861
|
-
figsize: tuple[int, int] = (8, 8),
|
|
1862
|
-
**kwargs,
|
|
1863
|
-
) -> animation.FuncAnimation:
|
|
1864
|
-
"""
|
|
1865
|
-
Visualize the movement of the neural activity bump on a torus using matplotlib animation.
|
|
1866
|
-
|
|
1867
|
-
This function follows the canns.analyzer.plotting patterns for animation generation
|
|
1868
|
-
with progress tracking and proper resource cleanup.
|
|
1869
|
-
|
|
1870
|
-
Parameters:
|
|
1871
|
-
decoding_result : dict or str
|
|
1872
|
-
Dictionary containing decoding results with 'coordsbox' and 'times_box' keys,
|
|
1873
|
-
or path to .npz file containing these results
|
|
1874
|
-
spike_data : dict, optional
|
|
1875
|
-
Spike data dictionary containing spike information
|
|
1876
|
-
config : PlotConfig, optional
|
|
1877
|
-
Configuration object for unified plotting parameters
|
|
1878
|
-
**kwargs : backward compatibility parameters
|
|
1879
|
-
save_path : str, optional
|
|
1880
|
-
Path to save the animation (e.g., 'animation.gif' or 'animation.mp4')
|
|
1881
|
-
numangsint : int
|
|
1882
|
-
Grid resolution for the torus surface
|
|
1883
|
-
r1 : float
|
|
1884
|
-
Major radius of the torus
|
|
1885
|
-
r2 : float
|
|
1886
|
-
Minor radius of the torus
|
|
1887
|
-
window_size : int
|
|
1888
|
-
Time window (in number of time points) for each frame
|
|
1889
|
-
frame_step : int
|
|
1890
|
-
Step size to slide the time window between frames
|
|
1891
|
-
n_frames : int
|
|
1892
|
-
Total number of frames in the animation
|
|
1893
|
-
fps : int
|
|
1894
|
-
Frames per second for the output animation
|
|
1895
|
-
show_progress : bool
|
|
1896
|
-
Whether to show progress bar during generation
|
|
1897
|
-
show : bool
|
|
1898
|
-
Whether to display the animation
|
|
1899
|
-
figsize : tuple[int, int]
|
|
1900
|
-
Figure size for the animation
|
|
1901
|
-
|
|
1902
|
-
Returns:
|
|
1903
|
-
matplotlib.animation.FuncAnimation : The animation object
|
|
1904
|
-
"""
|
|
1905
|
-
# Handle backward compatibility and configuration
|
|
1906
|
-
if config is None:
|
|
1907
|
-
config = CANN2DPlotConfig.for_torus_animation(**kwargs)
|
|
1908
|
-
|
|
1909
|
-
# Override config with any explicitly passed parameters
|
|
1910
|
-
for key, value in kwargs.items():
|
|
1911
|
-
if hasattr(config, key):
|
|
1912
|
-
setattr(config, key, value)
|
|
1913
|
-
|
|
1914
|
-
# Extract configuration values
|
|
1915
|
-
save_path = config.save_path if config.save_path else save_path
|
|
1916
|
-
show = config.show
|
|
1917
|
-
figsize = config.figsize
|
|
1918
|
-
fps = config.fps
|
|
1919
|
-
show_progress = config.show_progress_bar
|
|
1920
|
-
numangsint = config.numangsint
|
|
1921
|
-
r1 = config.r1
|
|
1922
|
-
r2 = config.r2
|
|
1923
|
-
window_size = config.window_size
|
|
1924
|
-
frame_step = config.frame_step
|
|
1925
|
-
n_frames = config.n_frames
|
|
1926
|
-
|
|
1927
|
-
# Load decoding results if path is provided
|
|
1928
|
-
if isinstance(decoding_result, str):
|
|
1929
|
-
f = np.load(decoding_result, allow_pickle=True)
|
|
1930
|
-
coords = f["coordsbox"]
|
|
1931
|
-
times = f["times_box"]
|
|
1932
|
-
f.close()
|
|
1933
|
-
else:
|
|
1934
|
-
coords = decoding_result["coordsbox"]
|
|
1935
|
-
times = decoding_result["times_box"]
|
|
1936
|
-
|
|
1937
|
-
spk, *_ = embed_spike_trains(
|
|
1938
|
-
spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
|
|
1939
|
-
)
|
|
1940
|
-
|
|
1941
|
-
# Pre-compute torus geometry (constant across frames - optimization)
|
|
1942
|
-
# Create grid for torus surface
|
|
1943
|
-
x_edge = np.linspace(0, 2 * np.pi, numangsint)
|
|
1944
|
-
y_edge = np.linspace(0, 2 * np.pi, numangsint)
|
|
1945
|
-
X_grid, Y_grid = np.meshgrid(x_edge, y_edge)
|
|
1946
|
-
X_transformed = (X_grid + np.pi / 5) % (2 * np.pi)
|
|
1947
|
-
|
|
1948
|
-
# Pre-compute torus geometry (only done once!)
|
|
1949
|
-
torus_x = (r1 + r2 * np.cos(X_transformed)) * np.cos(Y_grid)
|
|
1950
|
-
torus_y = (r1 + r2 * np.cos(X_transformed)) * np.sin(Y_grid)
|
|
1951
|
-
torus_z = -r2 * np.sin(X_transformed) # Flip torus surface orientation
|
|
1952
|
-
|
|
1953
|
-
# Prepare animation data (now only stores colors, not geometry)
|
|
1954
|
-
frame_data = []
|
|
1955
|
-
prev_m = None
|
|
1956
|
-
|
|
1957
|
-
for frame_idx in tqdm(range(n_frames), desc="Processing frames"):
|
|
1958
|
-
start_idx = frame_idx * frame_step
|
|
1959
|
-
end_idx = start_idx + window_size
|
|
1960
|
-
if end_idx > np.max(times):
|
|
1961
|
-
break
|
|
1962
|
-
|
|
1963
|
-
mask = (times >= start_idx) & (times < end_idx)
|
|
1964
|
-
coords_window = coords[mask]
|
|
1965
|
-
if len(coords_window) == 0:
|
|
1966
|
-
continue
|
|
1967
|
-
|
|
1968
|
-
spk_window = spk[times[mask], :]
|
|
1969
|
-
activity = np.sum(spk_window, axis=1)
|
|
1970
|
-
|
|
1971
|
-
m, _, _, _ = binned_statistic_2d(
|
|
1972
|
-
coords_window[:, 0],
|
|
1973
|
-
coords_window[:, 1],
|
|
1974
|
-
activity,
|
|
1975
|
-
statistic="sum",
|
|
1976
|
-
bins=np.linspace(0, 2 * np.pi, numangsint - 1),
|
|
1977
|
-
)
|
|
1978
|
-
m = np.nan_to_num(m)
|
|
1979
|
-
m = _smooth_tuning_map(m, numangsint - 1, sig=4.0, bClose=True)
|
|
1980
|
-
m = gaussian_filter(m, sigma=1.0)
|
|
1981
|
-
|
|
1982
|
-
if prev_m is not None:
|
|
1983
|
-
m = 0.7 * prev_m + 0.3 * m
|
|
1984
|
-
prev_m = m
|
|
1985
|
-
|
|
1986
|
-
# Store only activity map (m) and metadata, reuse geometry
|
|
1987
|
-
frame_data.append({"m": m, "time": start_idx * frame_step})
|
|
1988
|
-
|
|
1989
|
-
if not frame_data:
|
|
1990
|
-
raise ProcessingError("No valid frames generated for animation")
|
|
1991
|
-
|
|
1992
|
-
# Create figure and animation with optimized geometry reuse
|
|
1993
|
-
fig = plt.figure(figsize=figsize)
|
|
1994
|
-
|
|
1995
|
-
try:
|
|
1996
|
-
ax = fig.add_subplot(111, projection="3d")
|
|
1997
|
-
# Batch set axis properties (reduces overhead)
|
|
1998
|
-
ax.set_zlim(-2, 2)
|
|
1999
|
-
ax.view_init(-125, 135)
|
|
2000
|
-
ax.axis("off")
|
|
2001
|
-
|
|
2002
|
-
# Initialize with first frame
|
|
2003
|
-
first_frame = frame_data[0]
|
|
2004
|
-
surface = ax.plot_surface(
|
|
2005
|
-
torus_x, # Pre-computed geometry
|
|
2006
|
-
torus_y, # Pre-computed geometry
|
|
2007
|
-
torus_z, # Pre-computed geometry
|
|
2008
|
-
facecolors=cm.viridis(first_frame["m"] / (np.max(first_frame["m"]) + 1e-9)),
|
|
2009
|
-
alpha=1,
|
|
2010
|
-
linewidth=0.1,
|
|
2011
|
-
antialiased=True,
|
|
2012
|
-
rstride=1,
|
|
2013
|
-
cstride=1,
|
|
2014
|
-
shade=False,
|
|
2015
|
-
)
|
|
2016
|
-
|
|
2017
|
-
def animate(frame_idx):
|
|
2018
|
-
"""Optimized animation update - reuses pre-computed geometry."""
|
|
2019
|
-
if frame_idx >= len(frame_data):
|
|
2020
|
-
return (surface,)
|
|
2021
|
-
|
|
2022
|
-
frame = frame_data[frame_idx]
|
|
2023
|
-
|
|
2024
|
-
# 3D surfaces require clear (no blitting support), but minimize overhead
|
|
2025
|
-
ax.clear()
|
|
2026
|
-
|
|
2027
|
-
# Batch axis settings together (reduces function call overhead)
|
|
2028
|
-
ax.set_zlim(-2, 2)
|
|
2029
|
-
ax.view_init(-125, 135)
|
|
2030
|
-
ax.axis("off")
|
|
2031
|
-
|
|
2032
|
-
# Reuse pre-computed geometry, only update colors
|
|
2033
|
-
new_surface = ax.plot_surface(
|
|
2034
|
-
torus_x, # Pre-computed, not recalculated!
|
|
2035
|
-
torus_y, # Pre-computed, not recalculated!
|
|
2036
|
-
torus_z, # Pre-computed, not recalculated!
|
|
2037
|
-
facecolors=cm.viridis(frame["m"] / (np.max(frame["m"]) + 1e-9)),
|
|
2038
|
-
alpha=1,
|
|
2039
|
-
linewidth=0.1,
|
|
2040
|
-
antialiased=True,
|
|
2041
|
-
rstride=1,
|
|
2042
|
-
cstride=1,
|
|
2043
|
-
shade=False,
|
|
2044
|
-
)
|
|
2045
|
-
|
|
2046
|
-
# Update time text
|
|
2047
|
-
time_text = ax.text2D(
|
|
2048
|
-
0.05,
|
|
2049
|
-
0.95,
|
|
2050
|
-
f"Frame: {frame_idx + 1}/{len(frame_data)}",
|
|
2051
|
-
transform=ax.transAxes,
|
|
2052
|
-
fontsize=12,
|
|
2053
|
-
bbox=dict(facecolor="white", alpha=0.7),
|
|
2054
|
-
)
|
|
2055
|
-
|
|
2056
|
-
return new_surface, time_text
|
|
2057
|
-
|
|
2058
|
-
# Create animation (blit=False due to 3D limitation)
|
|
2059
|
-
interval_ms = 1000 / fps
|
|
2060
|
-
ani = animation.FuncAnimation(
|
|
2061
|
-
fig,
|
|
2062
|
-
animate,
|
|
2063
|
-
frames=len(frame_data),
|
|
2064
|
-
interval=interval_ms,
|
|
2065
|
-
blit=False,
|
|
2066
|
-
repeat=True, # 3D does not support blitting
|
|
2067
|
-
)
|
|
2068
|
-
|
|
2069
|
-
# Save animation if path provided
|
|
2070
|
-
if save_path:
|
|
2071
|
-
# Warn if both saving and showing (causes double rendering)
|
|
2072
|
-
if show and len(frame_data) > 50:
|
|
2073
|
-
from ...visualization.core import warn_double_rendering
|
|
2074
|
-
|
|
2075
|
-
warn_double_rendering(len(frame_data), save_path, stacklevel=2)
|
|
2076
|
-
|
|
2077
|
-
if show_progress:
|
|
2078
|
-
pbar = tqdm(total=len(frame_data), desc=f"Saving animation to {save_path}")
|
|
2079
|
-
|
|
2080
|
-
def progress_callback(current_frame, total_frames):
|
|
2081
|
-
pbar.update(1)
|
|
2082
|
-
|
|
2083
|
-
try:
|
|
2084
|
-
writer = animation.PillowWriter(fps=fps)
|
|
2085
|
-
ani.save(save_path, writer=writer, progress_callback=progress_callback)
|
|
2086
|
-
pbar.close()
|
|
2087
|
-
print(f"\nAnimation saved to: {save_path}")
|
|
2088
|
-
except Exception as e:
|
|
2089
|
-
pbar.close()
|
|
2090
|
-
print(f"\nError saving animation: {e}")
|
|
2091
|
-
else:
|
|
2092
|
-
try:
|
|
2093
|
-
writer = animation.PillowWriter(fps=fps)
|
|
2094
|
-
ani.save(save_path, writer=writer)
|
|
2095
|
-
print(f"Animation saved to: {save_path}")
|
|
2096
|
-
except Exception as e:
|
|
2097
|
-
print(f"Error saving animation: {e}")
|
|
2098
|
-
|
|
2099
|
-
if show:
|
|
2100
|
-
# Automatically detect Jupyter and display as HTML/JS
|
|
2101
|
-
if is_jupyter_environment():
|
|
2102
|
-
display_animation_in_jupyter(ani)
|
|
2103
|
-
plt.close(fig) # Close after HTML conversion to prevent auto-display
|
|
2104
|
-
else:
|
|
2105
|
-
plt.show()
|
|
2106
|
-
else:
|
|
2107
|
-
plt.close(fig) # Close if not showing
|
|
2108
|
-
|
|
2109
|
-
# Return None in Jupyter when showing to avoid double display
|
|
2110
|
-
if show and is_jupyter_environment():
|
|
2111
|
-
return None
|
|
2112
|
-
return ani
|
|
2113
|
-
|
|
2114
|
-
except Exception as e:
|
|
2115
|
-
plt.close(fig)
|
|
2116
|
-
raise ProcessingError(f"Failed to create torus animation: {e}") from e
|
|
2117
|
-
|
|
2118
|
-
|
|
2119
|
-
def plot_2d_bump_on_manifold(
|
|
2120
|
-
decoding_result: dict[str, Any] | str,
|
|
2121
|
-
spike_data: dict[str, Any],
|
|
2122
|
-
save_path: str | None = None,
|
|
2123
|
-
fps: int = 20,
|
|
2124
|
-
show: bool = True,
|
|
2125
|
-
mode: str = "fast",
|
|
2126
|
-
window_size: int = 10,
|
|
2127
|
-
frame_step: int = 5,
|
|
2128
|
-
numangsint: int = 20,
|
|
2129
|
-
figsize: tuple[int, int] = (8, 6),
|
|
2130
|
-
show_progress: bool = False,
|
|
2131
|
-
):
|
|
2132
|
-
"""
|
|
2133
|
-
Create 2D projection animation of CANN2D bump activity with full blitting support.
|
|
2134
|
-
|
|
2135
|
-
This function provides a fast 2D heatmap visualization as an alternative to the
|
|
2136
|
-
3D torus animation. It achieves 10-20x speedup using matplotlib blitting
|
|
2137
|
-
optimization, making it ideal for rapid prototyping and daily analysis.
|
|
2138
|
-
|
|
2139
|
-
Args:
|
|
2140
|
-
decoding_result: Decoding results containing coords and times (dict or file path)
|
|
2141
|
-
spike_data: Dictionary containing spike train data
|
|
2142
|
-
save_path: Path to save animation (None to skip saving)
|
|
2143
|
-
fps: Frames per second
|
|
2144
|
-
show: Whether to display the animation
|
|
2145
|
-
mode: Visualization mode - 'fast' for 2D heatmap (default), '3d' falls back to 3D
|
|
2146
|
-
window_size: Time window for activity aggregation
|
|
2147
|
-
frame_step: Time step between frames
|
|
2148
|
-
numangsint: Number of angular bins for spatial discretization
|
|
2149
|
-
figsize: Figure size (width, height) in inches
|
|
2150
|
-
show_progress: Show progress bar during processing
|
|
2151
|
-
|
|
2152
|
-
Returns:
|
|
2153
|
-
FuncAnimation object (or None in Jupyter when showing)
|
|
2154
|
-
|
|
2155
|
-
Raises:
|
|
2156
|
-
ProcessingError: If mode is invalid or animation generation fails
|
|
2157
|
-
|
|
2158
|
-
Example:
|
|
2159
|
-
>>> # Fast 2D visualization (recommended for daily use)
|
|
2160
|
-
>>> ani = plot_2d_bump_on_manifold(
|
|
2161
|
-
... decoding_result, spike_data,
|
|
2162
|
-
... save_path='bump_2d.mp4', mode='fast'
|
|
2163
|
-
... )
|
|
2164
|
-
>>> # For publication-ready 3D visualization, use mode='3d'
|
|
2165
|
-
>>> ani = plot_2d_bump_on_manifold(
|
|
2166
|
-
... decoding_result, spike_data, mode='3d'
|
|
2167
|
-
... )
|
|
2168
|
-
"""
|
|
2169
|
-
import matplotlib.animation as animation
|
|
2170
|
-
|
|
2171
|
-
from ...visualization.core.jupyter_utils import (
|
|
2172
|
-
display_animation_in_jupyter,
|
|
2173
|
-
is_jupyter_environment,
|
|
2174
|
-
)
|
|
2175
|
-
|
|
2176
|
-
# Validate inputs
|
|
2177
|
-
if mode == "3d":
|
|
2178
|
-
# Fall back to 3D visualization
|
|
2179
|
-
return plot_3d_bump_on_torus(
|
|
2180
|
-
decoding_result=decoding_result,
|
|
2181
|
-
spike_data=spike_data,
|
|
2182
|
-
save_path=save_path,
|
|
2183
|
-
fps=fps,
|
|
2184
|
-
show=show,
|
|
2185
|
-
window_size=window_size,
|
|
2186
|
-
frame_step=frame_step,
|
|
2187
|
-
numangsint=numangsint,
|
|
2188
|
-
figsize=figsize,
|
|
2189
|
-
show_progress=show_progress,
|
|
2190
|
-
)
|
|
2191
|
-
|
|
2192
|
-
if mode != "fast":
|
|
2193
|
-
raise ProcessingError(f"Invalid mode '{mode}'. Must be 'fast' or '3d'.")
|
|
2194
|
-
|
|
2195
|
-
# Load decoding results
|
|
2196
|
-
if isinstance(decoding_result, str):
|
|
2197
|
-
f = np.load(decoding_result, allow_pickle=True)
|
|
2198
|
-
coords = f["coordsbox"]
|
|
2199
|
-
times = f["times_box"]
|
|
2200
|
-
f.close()
|
|
2201
|
-
else:
|
|
2202
|
-
coords = decoding_result["coordsbox"]
|
|
2203
|
-
times = decoding_result["times_box"]
|
|
2204
|
-
|
|
2205
|
-
# Process spike data for 2D projection
|
|
2206
|
-
spk, *_ = embed_spike_trains(
|
|
2207
|
-
spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
|
|
2208
|
-
)
|
|
2209
|
-
|
|
2210
|
-
# Process frames
|
|
2211
|
-
n_frames = (np.max(times) - window_size) // frame_step
|
|
2212
|
-
frame_activity_maps = []
|
|
2213
|
-
prev_m = None
|
|
2214
|
-
|
|
2215
|
-
for frame_idx in tqdm(range(n_frames), desc="Processing frames", disable=not show_progress):
|
|
2216
|
-
start_idx = frame_idx * frame_step
|
|
2217
|
-
end_idx = start_idx + window_size
|
|
2218
|
-
if end_idx > np.max(times):
|
|
2219
|
-
break
|
|
2220
|
-
|
|
2221
|
-
mask = (times >= start_idx) & (times < end_idx)
|
|
2222
|
-
coords_window = coords[mask]
|
|
2223
|
-
if len(coords_window) == 0:
|
|
2224
|
-
continue
|
|
2225
|
-
|
|
2226
|
-
spk_window = spk[times[mask], :]
|
|
2227
|
-
activity = np.sum(spk_window, axis=1)
|
|
2228
|
-
|
|
2229
|
-
m, _, _, _ = binned_statistic_2d(
|
|
2230
|
-
coords_window[:, 0],
|
|
2231
|
-
coords_window[:, 1],
|
|
2232
|
-
activity,
|
|
2233
|
-
statistic="sum",
|
|
2234
|
-
bins=np.linspace(0, 2 * np.pi, numangsint - 1),
|
|
2235
|
-
)
|
|
2236
|
-
m = np.nan_to_num(m)
|
|
2237
|
-
m = _smooth_tuning_map(m, numangsint - 1, sig=4.0, bClose=True)
|
|
2238
|
-
m = gaussian_filter(m, sigma=1.0)
|
|
2239
|
-
|
|
2240
|
-
if prev_m is not None:
|
|
2241
|
-
m = 0.7 * prev_m + 0.3 * m
|
|
2242
|
-
prev_m = m
|
|
2243
|
-
|
|
2244
|
-
frame_activity_maps.append(m)
|
|
2245
|
-
|
|
2246
|
-
if not frame_activity_maps:
|
|
2247
|
-
raise ProcessingError("No valid frames generated for animation")
|
|
2248
|
-
|
|
2249
|
-
# Create 2D visualization with blitting
|
|
2250
|
-
fig, ax = plt.subplots(figsize=figsize)
|
|
2251
|
-
ax.set_xlabel("Manifold Dimension 1 (rad)", fontsize=12)
|
|
2252
|
-
ax.set_ylabel("Manifold Dimension 2 (rad)", fontsize=12)
|
|
2253
|
-
ax.set_title("CANN2D Bump Activity (2D Projection)", fontsize=14, fontweight="bold")
|
|
2254
|
-
|
|
2255
|
-
# Pre-create artists for blitting
|
|
2256
|
-
# Heatmap
|
|
2257
|
-
im = ax.imshow(
|
|
2258
|
-
frame_activity_maps[0].T, # Transpose for correct orientation
|
|
2259
|
-
extent=[0, 2 * np.pi, 0, 2 * np.pi],
|
|
2260
|
-
origin="lower",
|
|
2261
|
-
cmap="viridis",
|
|
2262
|
-
animated=True,
|
|
2263
|
-
aspect="auto",
|
|
2264
|
-
)
|
|
2265
|
-
# Colorbar (static)
|
|
2266
|
-
cbar = plt.colorbar(im, ax=ax)
|
|
2267
|
-
cbar.set_label("Activity", fontsize=11)
|
|
2268
|
-
|
|
2269
|
-
# Time text
|
|
2270
|
-
time_text = ax.text(
|
|
2271
|
-
0.02,
|
|
2272
|
-
0.98,
|
|
2273
|
-
"",
|
|
2274
|
-
transform=ax.transAxes,
|
|
2275
|
-
fontsize=11,
|
|
2276
|
-
verticalalignment="top",
|
|
2277
|
-
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
|
|
2278
|
-
animated=True,
|
|
2279
|
-
)
|
|
2280
|
-
|
|
2281
|
-
def init():
|
|
2282
|
-
"""Initialize animation"""
|
|
2283
|
-
im.set_array(frame_activity_maps[0].T)
|
|
2284
|
-
time_text.set_text("")
|
|
2285
|
-
return im, time_text
|
|
2286
|
-
|
|
2287
|
-
def update(frame_idx):
|
|
2288
|
-
"""Update function - only modify data using blitting"""
|
|
2289
|
-
if frame_idx >= len(frame_activity_maps):
|
|
2290
|
-
return im, time_text
|
|
2291
|
-
|
|
2292
|
-
# Update heatmap data
|
|
2293
|
-
im.set_array(frame_activity_maps[frame_idx].T)
|
|
2294
|
-
|
|
2295
|
-
# Update time text
|
|
2296
|
-
time_text.set_text(f"Frame: {frame_idx + 1}/{len(frame_activity_maps)}")
|
|
2297
|
-
|
|
2298
|
-
return im, time_text
|
|
2299
|
-
|
|
2300
|
-
# Check blitting support
|
|
2301
|
-
use_blitting = True
|
|
2302
|
-
try:
|
|
2303
|
-
if not fig.canvas.supports_blit:
|
|
2304
|
-
use_blitting = False
|
|
2305
|
-
print("Warning: Backend does not support blitting. Using slower mode.")
|
|
2306
|
-
except AttributeError:
|
|
2307
|
-
use_blitting = False
|
|
2308
|
-
|
|
2309
|
-
# Create animation with blitting enabled for 10-20x speedup
|
|
2310
|
-
interval_ms = 1000 / fps
|
|
2311
|
-
ani = animation.FuncAnimation(
|
|
2312
|
-
fig,
|
|
2313
|
-
update,
|
|
2314
|
-
frames=len(frame_activity_maps),
|
|
2315
|
-
init_func=init,
|
|
2316
|
-
interval=interval_ms,
|
|
2317
|
-
blit=use_blitting,
|
|
2318
|
-
repeat=True,
|
|
2319
|
-
)
|
|
2320
|
-
|
|
2321
|
-
# Save animation if path provided
|
|
2322
|
-
if save_path:
|
|
2323
|
-
# Warn if both saving and showing (causes double rendering)
|
|
2324
|
-
if show and len(frame_activity_maps) > 50:
|
|
2325
|
-
from ...visualization.core import warn_double_rendering
|
|
2326
|
-
|
|
2327
|
-
warn_double_rendering(len(frame_activity_maps), save_path, stacklevel=2)
|
|
2328
|
-
|
|
2329
|
-
if show_progress:
|
|
2330
|
-
pbar = tqdm(total=len(frame_activity_maps), desc=f"Saving animation to {save_path}")
|
|
2331
|
-
|
|
2332
|
-
def progress_callback(current_frame, total_frames):
|
|
2333
|
-
pbar.update(1)
|
|
2334
|
-
|
|
2335
|
-
try:
|
|
2336
|
-
if save_path.endswith(".mp4"):
|
|
2337
|
-
from matplotlib.animation import FFMpegWriter
|
|
2338
|
-
|
|
2339
|
-
writer = FFMpegWriter(
|
|
2340
|
-
fps=fps, codec="libx264", extra_args=["-pix_fmt", "yuv420p"]
|
|
2341
|
-
)
|
|
2342
|
-
else:
|
|
2343
|
-
from matplotlib.animation import PillowWriter
|
|
2344
|
-
|
|
2345
|
-
writer = PillowWriter(fps=fps)
|
|
2346
|
-
|
|
2347
|
-
ani.save(save_path, writer=writer, progress_callback=progress_callback)
|
|
2348
|
-
pbar.close()
|
|
2349
|
-
print(f"\nAnimation saved to: {save_path}")
|
|
2350
|
-
except Exception as e:
|
|
2351
|
-
pbar.close()
|
|
2352
|
-
print(f"\nError saving animation: {e}")
|
|
2353
|
-
raise
|
|
2354
|
-
else:
|
|
2355
|
-
try:
|
|
2356
|
-
if save_path.endswith(".mp4"):
|
|
2357
|
-
from matplotlib.animation import FFMpegWriter
|
|
2358
|
-
|
|
2359
|
-
writer = FFMpegWriter(
|
|
2360
|
-
fps=fps, codec="libx264", extra_args=["-pix_fmt", "yuv420p"]
|
|
2361
|
-
)
|
|
2362
|
-
else:
|
|
2363
|
-
from matplotlib.animation import PillowWriter
|
|
2364
|
-
|
|
2365
|
-
writer = PillowWriter(fps=fps)
|
|
2366
|
-
|
|
2367
|
-
ani.save(save_path, writer=writer)
|
|
2368
|
-
print(f"Animation saved to: {save_path}")
|
|
2369
|
-
except Exception as e:
|
|
2370
|
-
print(f"Error saving animation: {e}")
|
|
2371
|
-
raise
|
|
2372
|
-
|
|
2373
|
-
if show:
|
|
2374
|
-
# Automatically detect Jupyter and display as HTML/JS
|
|
2375
|
-
if is_jupyter_environment():
|
|
2376
|
-
display_animation_in_jupyter(ani)
|
|
2377
|
-
plt.close(fig) # Close after HTML conversion to prevent auto-display
|
|
2378
|
-
else:
|
|
2379
|
-
plt.show()
|
|
2380
|
-
else:
|
|
2381
|
-
plt.close(fig) # Close if not showing
|
|
2382
|
-
|
|
2383
|
-
# Return None in Jupyter when showing to avoid double display
|
|
2384
|
-
if show and is_jupyter_environment():
|
|
2385
|
-
return None
|
|
2386
|
-
return ani
|
|
2387
|
-
|
|
2388
|
-
|
|
2389
|
-
def _get_coords(cocycle, threshold, num_sampled, dists, coeff):
|
|
2390
|
-
"""
|
|
2391
|
-
Reconstruct circular coordinates from cocycle information.
|
|
2392
|
-
|
|
2393
|
-
Parameters:
|
|
2394
|
-
cocycle (ndarray): Persistent cocycle representative.
|
|
2395
|
-
threshold (float): Maximum allowable edge distance.
|
|
2396
|
-
num_sampled (int): Number of sampled points.
|
|
2397
|
-
dists (ndarray): Pairwise distance matrix.
|
|
2398
|
-
coeff (int): Finite field modulus for cohomology.
|
|
2399
|
-
|
|
2400
|
-
Returns:
|
|
2401
|
-
f (ndarray): Circular coordinate values (in [0,1]).
|
|
2402
|
-
verts (ndarray): Indices of used vertices.
|
|
2403
|
-
"""
|
|
2404
|
-
zint = np.where(coeff - cocycle[:, 2] < cocycle[:, 2])
|
|
2405
|
-
cocycle[zint, 2] = cocycle[zint, 2] - coeff
|
|
2406
|
-
d = np.zeros((num_sampled, num_sampled))
|
|
2407
|
-
d[np.tril_indices(num_sampled)] = np.nan
|
|
2408
|
-
d[cocycle[:, 1], cocycle[:, 0]] = cocycle[:, 2]
|
|
2409
|
-
d[dists > threshold] = np.nan
|
|
2410
|
-
d[dists == 0] = np.nan
|
|
2411
|
-
edges = np.where(~np.isnan(d))
|
|
2412
|
-
verts = np.array(np.unique(edges))
|
|
2413
|
-
num_edges = np.shape(edges)[1]
|
|
2414
|
-
num_verts = np.size(verts)
|
|
2415
|
-
values = d[edges]
|
|
2416
|
-
A = np.zeros((num_edges, num_verts), dtype=int)
|
|
2417
|
-
v1 = np.zeros((num_edges, 2), dtype=int)
|
|
2418
|
-
v2 = np.zeros((num_edges, 2), dtype=int)
|
|
2419
|
-
for i in range(num_edges):
|
|
2420
|
-
# Extract scalar indices from np.where results
|
|
2421
|
-
idx1 = np.where(verts == edges[0][i])[0]
|
|
2422
|
-
idx2 = np.where(verts == edges[1][i])[0]
|
|
2423
|
-
|
|
2424
|
-
# Handle case where np.where returns multiple matches (shouldn't happen in valid data)
|
|
2425
|
-
if len(idx1) > 0:
|
|
2426
|
-
v1[i, :] = [i, idx1[0]]
|
|
2427
|
-
else:
|
|
2428
|
-
raise ValueError(f"No vertex found for edge {edges[0][i]}")
|
|
2429
|
-
|
|
2430
|
-
if len(idx2) > 0:
|
|
2431
|
-
v2[i, :] = [i, idx2[0]]
|
|
2432
|
-
else:
|
|
2433
|
-
raise ValueError(f"No vertex found for edge {edges[1][i]}")
|
|
2434
|
-
|
|
2435
|
-
A[v1[:, 0], v1[:, 1]] = -1
|
|
2436
|
-
A[v2[:, 0], v2[:, 1]] = 1
|
|
2437
|
-
|
|
2438
|
-
L = np.ones((num_edges,))
|
|
2439
|
-
Aw = A * np.sqrt(L[:, np.newaxis])
|
|
2440
|
-
Bw = values * np.sqrt(L)
|
|
2441
|
-
f = lsmr(Aw, Bw)[0] % 1
|
|
2442
|
-
return f, verts
|
|
2443
|
-
|
|
2444
|
-
|
|
2445
|
-
def _smooth_tuning_map(mtot, numangsint, sig, bClose=True):
|
|
2446
|
-
"""
|
|
2447
|
-
Smooth activity map over circular topology (e.g., torus).
|
|
2448
|
-
|
|
2449
|
-
Parameters:
|
|
2450
|
-
mtot (ndarray): Raw activity map matrix.
|
|
2451
|
-
numangsint (int): Grid resolution.
|
|
2452
|
-
sig (float): Smoothing kernel standard deviation.
|
|
2453
|
-
bClose (bool): Whether to assume circular boundary conditions.
|
|
2454
|
-
|
|
2455
|
-
Returns:
|
|
2456
|
-
mtot_out (ndarray): Smoothed map matrix.
|
|
2457
|
-
"""
|
|
2458
|
-
numangsint_1 = numangsint - 1
|
|
2459
|
-
mid = int((numangsint_1) / 2)
|
|
2460
|
-
indstemp1 = np.zeros((numangsint_1, numangsint_1), dtype=int)
|
|
2461
|
-
indstemp1[indstemp1 == 0] = np.arange((numangsint_1) ** 2)
|
|
2462
|
-
mid = int((numangsint_1) / 2)
|
|
2463
|
-
mtemp1_3 = mtot.copy()
|
|
2464
|
-
for i in range(numangsint_1):
|
|
2465
|
-
mtemp1_3[i, :] = np.roll(mtemp1_3[i, :], int(i / 2))
|
|
2466
|
-
mtot_out = np.zeros_like(mtot)
|
|
2467
|
-
mtemp1_4 = np.concatenate((mtemp1_3, mtemp1_3, mtemp1_3), 1)
|
|
2468
|
-
mtemp1_5 = np.zeros_like(mtemp1_4)
|
|
2469
|
-
mtemp1_5[:, :mid] = mtemp1_4[:, (numangsint_1) * 3 - mid :]
|
|
2470
|
-
mtemp1_5[:, mid:] = mtemp1_4[:, : (numangsint_1) * 3 - mid]
|
|
2471
|
-
if bClose:
|
|
2472
|
-
mtemp1_6 = _smooth_image(np.concatenate((mtemp1_5, mtemp1_4, mtemp1_5)), sigma=sig)
|
|
2473
|
-
else:
|
|
2474
|
-
mtemp1_6 = gaussian_filter(np.concatenate((mtemp1_5, mtemp1_4, mtemp1_5)), sigma=sig)
|
|
2475
|
-
for i in range(numangsint_1):
|
|
2476
|
-
mtot_out[i, :] = mtemp1_6[
|
|
2477
|
-
(numangsint_1) + i,
|
|
2478
|
-
(numangsint_1) + (int(i / 2) + 1) : (numangsint_1) * 2 + (int(i / 2) + 1),
|
|
2479
|
-
]
|
|
2480
|
-
return mtot_out
|
|
2481
|
-
|
|
2482
|
-
|
|
2483
|
-
def _smooth_image(img, sigma):
|
|
2484
|
-
"""
|
|
2485
|
-
Smooth image using multivariate Gaussian kernel, handling missing (NaN) values.
|
|
2486
|
-
|
|
2487
|
-
Parameters:
|
|
2488
|
-
img (ndarray): Input image matrix.
|
|
2489
|
-
sigma (float): Standard deviation of smoothing kernel.
|
|
2490
|
-
|
|
2491
|
-
Returns:
|
|
2492
|
-
imgC (ndarray): Smoothed image with inpainting around NaNs.
|
|
2493
|
-
"""
|
|
2494
|
-
filterSize = max(np.shape(img))
|
|
2495
|
-
grid = np.arange(-filterSize + 1, filterSize, 1)
|
|
2496
|
-
xx, yy = np.meshgrid(grid, grid)
|
|
2497
|
-
|
|
2498
|
-
pos = np.dstack((xx, yy))
|
|
2499
|
-
|
|
2500
|
-
var = multivariate_normal(mean=[0, 0], cov=[[sigma**2, 0], [0, sigma**2]])
|
|
2501
|
-
k = var.pdf(pos)
|
|
2502
|
-
k = k / np.sum(k)
|
|
2503
|
-
|
|
2504
|
-
nans = np.isnan(img)
|
|
2505
|
-
imgA = img.copy()
|
|
2506
|
-
imgA[nans] = 0
|
|
2507
|
-
imgA = signal.convolve2d(imgA, k, mode="valid")
|
|
2508
|
-
imgD = img.copy()
|
|
2509
|
-
imgD[nans] = 0
|
|
2510
|
-
imgD[~nans] = 1
|
|
2511
|
-
radius = 1
|
|
2512
|
-
L = np.arange(-radius, radius + 1)
|
|
2513
|
-
X, Y = np.meshgrid(L, L)
|
|
2514
|
-
dk = np.array((X**2 + Y**2) <= radius**2, dtype=bool)
|
|
2515
|
-
imgE = np.zeros((filterSize + 2, filterSize + 2))
|
|
2516
|
-
imgE[1:-1, 1:-1] = imgD
|
|
2517
|
-
imgE = binary_closing(imgE, iterations=1, structure=dk)
|
|
2518
|
-
imgD = imgE[1:-1, 1:-1]
|
|
2519
|
-
|
|
2520
|
-
imgB = np.divide(
|
|
2521
|
-
signal.convolve2d(imgD, k, mode="valid"),
|
|
2522
|
-
signal.convolve2d(np.ones(np.shape(imgD)), k, mode="valid"),
|
|
2523
|
-
)
|
|
2524
|
-
imgC = np.divide(imgA, imgB)
|
|
2525
|
-
imgC[imgD == 0] = -np.inf
|
|
2526
|
-
return imgC
|
|
2527
|
-
|
|
2528
|
-
|
|
2529
|
-
if __name__ == "__main__":
|
|
2530
|
-
from canns.data.loaders import load_grid_data
|
|
2531
|
-
|
|
2532
|
-
data = load_grid_data()
|
|
2533
|
-
|
|
2534
|
-
spikes, xx, yy, tt = embed_spike_trains(data)
|
|
2535
|
-
|
|
2536
|
-
# import umap
|
|
2537
|
-
#
|
|
2538
|
-
# reducer = umap.UMAP(
|
|
2539
|
-
# n_neighbors=15,
|
|
2540
|
-
# min_dist=0.1,
|
|
2541
|
-
# n_components=3,
|
|
2542
|
-
# metric='euclidean',
|
|
2543
|
-
# random_state=42
|
|
2544
|
-
# )
|
|
2545
|
-
#
|
|
2546
|
-
# reduce_func = reducer.fit_transform
|
|
2547
|
-
#
|
|
2548
|
-
# plot_projection(reduce_func=reduce_func, embed_data=spikes, show=True)
|
|
2549
|
-
results = tda_vis(embed_data=spikes, maxdim=1, do_shuffle=False, show=True)
|
|
2550
|
-
decoding = decode_circular_coordinates(
|
|
2551
|
-
persistence_result=results,
|
|
2552
|
-
spike_data=data,
|
|
2553
|
-
real_ground=True,
|
|
2554
|
-
real_of=True,
|
|
2555
|
-
)
|
|
2556
|
-
|
|
2557
|
-
# Visualize CohoMap
|
|
2558
|
-
plot_cohomap(
|
|
2559
|
-
decoding_result=decoding,
|
|
2560
|
-
position_data={"x": xx, "y": yy},
|
|
2561
|
-
save_path="Results/cohomap.png",
|
|
2562
|
-
show=True,
|
|
2563
|
-
)
|
|
2564
|
-
|
|
2565
|
-
# results = tda_vis(embed_data=spikes, maxdim=1, do_shuffle=True, num_shuffles=10, show=True)
|