canns 0.13.0__py3-none-any.whl → 0.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,2565 +0,0 @@
1
- import logging
2
- import multiprocessing as mp
3
- import numbers
4
- import os
5
- from dataclasses import dataclass
6
- from typing import Any
7
-
8
- import matplotlib.pyplot as plt
9
- import numpy as np
10
- from canns_lib.ripser import ripser
11
- from matplotlib import animation, cm, gridspec
12
- from numpy.exceptions import AxisError
13
-
14
- # from ripser import ripser
15
- from scipy import signal
16
- from scipy.ndimage import (
17
- _nd_image,
18
- _ni_support,
19
- binary_closing,
20
- gaussian_filter,
21
- gaussian_filter1d,
22
- )
23
- from scipy.ndimage._filters import _invalid_origin
24
- from scipy.sparse import coo_matrix
25
- from scipy.sparse.linalg import lsmr
26
- from scipy.spatial.distance import pdist, squareform
27
- from scipy.stats import binned_statistic_2d, multivariate_normal
28
- from sklearn import preprocessing
29
- from tqdm import tqdm
30
-
31
- # Import PlotConfig for unified plotting
32
- from ...visualization import PlotConfig
33
- from ...visualization.core.jupyter_utils import (
34
- display_animation_in_jupyter,
35
- is_jupyter_environment,
36
- )
37
-
38
-
39
- # ==================== Configuration Classes ====================
40
- @dataclass
41
- class SpikeEmbeddingConfig:
42
- """Configuration for spike train embedding."""
43
-
44
- res: int = 100000
45
- dt: int = 1000
46
- sigma: int = 5000
47
- smooth: bool = True
48
- speed_filter: bool = True
49
- min_speed: float = 2.5
50
-
51
-
52
- @dataclass
53
- class TDAConfig:
54
- """Configuration for Topological Data Analysis."""
55
-
56
- dim: int = 6
57
- num_times: int = 5
58
- active_times: int = 15000
59
- k: int = 1000
60
- n_points: int = 1200
61
- metric: str = "cosine"
62
- nbs: int = 800
63
- maxdim: int = 1
64
- coeff: int = 47
65
- show: bool = True
66
- do_shuffle: bool = False
67
- num_shuffles: int = 1000
68
- progress_bar: bool = True
69
-
70
-
71
- @dataclass
72
- class CANN2DPlotConfig(PlotConfig):
73
- """Specialized PlotConfig for CANN2D visualizations."""
74
-
75
- # 3D projection specific parameters
76
- zlabel: str = "Component 3"
77
- dpi: int = 300
78
-
79
- # Torus animation specific parameters
80
- numangsint: int = 51
81
- r1: float = 1.5 # Major radius
82
- r2: float = 1.0 # Minor radius
83
- window_size: int = 300
84
- frame_step: int = 5
85
- n_frames: int = 20
86
-
87
- @classmethod
88
- def for_projection_3d(cls, **kwargs) -> "CANN2DPlotConfig":
89
- """Create configuration for 3D projection plots."""
90
- defaults = {
91
- "title": "3D Data Projection",
92
- "xlabel": "Component 1",
93
- "ylabel": "Component 2",
94
- "zlabel": "Component 3",
95
- "figsize": (10, 8),
96
- "dpi": 300,
97
- }
98
- defaults.update(kwargs)
99
- return cls.for_static_plot(**defaults)
100
-
101
- @classmethod
102
- def for_torus_animation(cls, **kwargs) -> "CANN2DPlotConfig":
103
- """Create configuration for 3D torus bump animations."""
104
- defaults = {
105
- "title": "3D Bump on Torus",
106
- "figsize": (8, 8),
107
- "fps": 5,
108
- "repeat": True,
109
- "show_progress_bar": True,
110
- "numangsint": 51,
111
- "r1": 1.5,
112
- "r2": 1.0,
113
- "window_size": 300,
114
- "frame_step": 5,
115
- "n_frames": 20,
116
- }
117
- defaults.update(kwargs)
118
- time_steps = kwargs.get("time_steps_per_second", 1000)
119
- config = cls.for_animation(time_steps, **defaults)
120
- # Add torus-specific attributes
121
- config.numangsint = defaults["numangsint"]
122
- config.r1 = defaults["r1"]
123
- config.r2 = defaults["r2"]
124
- config.window_size = defaults["window_size"]
125
- config.frame_step = defaults["frame_step"]
126
- config.n_frames = defaults["n_frames"]
127
- return config
128
-
129
-
130
- # ==================== Constants ====================
131
- class Constants:
132
- """Constants used throughout CANN2D analysis."""
133
-
134
- DEFAULT_FIGSIZE = (10, 8)
135
- DEFAULT_DPI = 300
136
- GAUSSIAN_SIGMA_FACTOR = 100
137
- SPEED_CONVERSION_FACTOR = 100
138
- TIME_CONVERSION_FACTOR = 0.01
139
- MULTIPROCESSING_CORES = 4
140
-
141
-
142
- # ==================== Custom Exceptions ====================
143
- class CANN2DError(Exception):
144
- """Base exception for CANN2D analysis errors."""
145
-
146
- pass
147
-
148
-
149
- class DataLoadError(CANN2DError):
150
- """Raised when data loading fails."""
151
-
152
- pass
153
-
154
-
155
- class ProcessingError(CANN2DError):
156
- """Raised when data processing fails."""
157
-
158
- pass
159
-
160
-
161
- try:
162
- from numba import jit, njit, prange
163
-
164
- HAS_NUMBA = True
165
- except ImportError:
166
- HAS_NUMBA = False
167
- print(
168
- "Using numba for FAST CANN2D analysis, now using pure numpy implementation.",
169
- "Try numba by `pip install numba` to speed up the process.",
170
- )
171
-
172
- # Create dummy decorators if numba is not available
173
- def jit(*args, **kwargs):
174
- def decorator(func):
175
- return func
176
-
177
- return decorator
178
-
179
- def njit(*args, **kwargs):
180
- def decorator(func):
181
- return func
182
-
183
- return decorator
184
-
185
- def prange(x):
186
- return range(x)
187
-
188
-
189
- def embed_spike_trains(spike_trains, config: SpikeEmbeddingConfig | None = None, **kwargs):
190
- """
191
- Load and preprocess spike train data from npz file.
192
-
193
- This function converts raw spike times into a time-binned spike matrix,
194
- optionally applying Gaussian smoothing and filtering based on animal movement speed.
195
-
196
- Parameters:
197
- spike_trains : dict containing 'spike', 't', and optionally 'x', 'y'.
198
- config : SpikeEmbeddingConfig, optional configuration object
199
- **kwargs : backward compatibility parameters
200
-
201
- Returns:
202
- spikes_bin (ndarray): Binned and optionally smoothed spike matrix of shape (T, N).
203
- xx (ndarray, optional): X coordinates (if speed_filter=True).
204
- yy (ndarray, optional): Y coordinates (if speed_filter=True).
205
- tt (ndarray, optional): Time points (if speed_filter=True).
206
- """
207
- # Handle backward compatibility and configuration
208
- if config is None:
209
- config = SpikeEmbeddingConfig(
210
- res=kwargs.get("res", 100000),
211
- dt=kwargs.get("dt", 1000),
212
- sigma=kwargs.get("sigma", 5000),
213
- smooth=kwargs.get("smooth0", True),
214
- speed_filter=kwargs.get("speed0", True),
215
- min_speed=kwargs.get("min_speed", 2.5),
216
- )
217
-
218
- try:
219
- # Step 1: Extract and filter spike data
220
- spikes_filtered = _extract_spike_data(spike_trains, config)
221
-
222
- # Step 2: Create time bins
223
- time_bins = _create_time_bins(spike_trains["t"], config)
224
-
225
- # Step 3: Bin spike data
226
- spikes_bin = _bin_spike_data(spikes_filtered, time_bins, config)
227
-
228
- # Step 4: Apply temporal smoothing if requested
229
- if config.smooth:
230
- spikes_bin = _apply_temporal_smoothing(spikes_bin, config)
231
-
232
- # Step 5: Apply speed filtering if requested
233
- if config.speed_filter:
234
- return _apply_speed_filtering(spikes_bin, spike_trains, config)
235
-
236
- return spikes_bin
237
-
238
- except Exception as e:
239
- raise ProcessingError(f"Failed to embed spike trains: {e}") from e
240
-
241
-
242
- def _extract_spike_data(
243
- spike_trains: dict[str, Any], config: SpikeEmbeddingConfig
244
- ) -> dict[int, np.ndarray]:
245
- """Extract and filter spike data within time window."""
246
- try:
247
- # Handle different spike data formats
248
- spike_data = spike_trains["spike"]
249
- if hasattr(spike_data, "item") and callable(spike_data.item):
250
- # numpy array with .item() method (from npz file)
251
- spikes_all = spike_data[()]
252
- elif isinstance(spike_data, dict):
253
- # Already a dictionary
254
- spikes_all = spike_data
255
- elif isinstance(spike_data, list | np.ndarray):
256
- # List or array format
257
- spikes_all = spike_data
258
- else:
259
- # Try direct access
260
- spikes_all = spike_data
261
-
262
- t = spike_trains["t"]
263
-
264
- min_time0 = np.min(t)
265
- max_time0 = np.max(t)
266
-
267
- # Extract spike intervals for each cell
268
- if isinstance(spikes_all, dict):
269
- # Dictionary format
270
- spikes = {}
271
- for i, key in enumerate(spikes_all.keys()):
272
- s = np.array(spikes_all[key])
273
- spikes[i] = s[(s >= min_time0) & (s < max_time0)]
274
- else:
275
- # List/array format
276
- cell_inds = np.arange(len(spikes_all))
277
- spikes = {}
278
-
279
- for i, m in enumerate(cell_inds):
280
- s = np.array(spikes_all[m]) if len(spikes_all[m]) > 0 else np.array([])
281
- # Filter spikes within time window
282
- if len(s) > 0:
283
- spikes[i] = s[(s >= min_time0) & (s < max_time0)]
284
- else:
285
- spikes[i] = np.array([])
286
-
287
- return spikes
288
-
289
- except KeyError as e:
290
- raise DataLoadError(f"Missing required data key: {e}") from e
291
- except Exception as e:
292
- raise ProcessingError(f"Error extracting spike data: {e}") from e
293
-
294
-
295
- def _create_time_bins(t: np.ndarray, config: SpikeEmbeddingConfig) -> np.ndarray:
296
- """Create time bins for spike discretization."""
297
- min_time0 = np.min(t)
298
- max_time0 = np.max(t)
299
-
300
- min_time = min_time0 * config.res
301
- max_time = max_time0 * config.res
302
-
303
- return np.arange(np.floor(min_time), np.ceil(max_time) + 1, config.dt)
304
-
305
-
306
- def _bin_spike_data(
307
- spikes: dict[int, np.ndarray], time_bins: np.ndarray, config: SpikeEmbeddingConfig
308
- ) -> np.ndarray:
309
- """Convert spike times to binned spike matrix."""
310
- min_time = time_bins[0]
311
- max_time = time_bins[-1]
312
-
313
- spikes_bin = np.zeros((len(time_bins), len(spikes)), dtype=int)
314
-
315
- for n in spikes:
316
- spike_times = np.array(spikes[n] * config.res - min_time, dtype=int)
317
- # Filter valid spike times
318
- spike_times = spike_times[(spike_times < (max_time - min_time)) & (spike_times > 0)]
319
- spike_times = np.array(spike_times / config.dt, int)
320
-
321
- # Bin spikes
322
- for j in spike_times:
323
- if j < len(time_bins):
324
- spikes_bin[j, n] += 1
325
-
326
- return spikes_bin
327
-
328
-
329
- def _apply_temporal_smoothing(spikes_bin: np.ndarray, config: SpikeEmbeddingConfig) -> np.ndarray:
330
- """Apply Gaussian temporal smoothing to spike matrix."""
331
- # Calculate smoothing parameters (legacy implementation used custom kernel)
332
- # Current implementation uses scipy's gaussian_filter1d for better performance
333
-
334
- # Apply smoothing (simplified version - could be further optimized)
335
- smoothed = np.zeros((spikes_bin.shape[0], spikes_bin.shape[1]))
336
-
337
- # Use scipy's gaussian_filter1d for better performance
338
-
339
- sigma_bins = config.sigma / config.dt
340
-
341
- for n in range(spikes_bin.shape[1]):
342
- smoothed[:, n] = gaussian_filter1d(
343
- spikes_bin[:, n].astype(float), sigma=sigma_bins, mode="constant"
344
- )
345
-
346
- # Normalize
347
- normalization_factor = 1 / np.sqrt(2 * np.pi * (config.sigma / config.res) ** 2)
348
- return smoothed * normalization_factor
349
-
350
-
351
- def _apply_speed_filtering(
352
- spikes_bin: np.ndarray, spike_trains: dict[str, Any], config: SpikeEmbeddingConfig
353
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
354
- """Apply speed-based filtering to spike data."""
355
- try:
356
- xx, yy, tt_pos, speed = _load_pos(
357
- spike_trains["t"], spike_trains["x"], spike_trains["y"], res=config.res, dt=config.dt
358
- )
359
-
360
- valid = speed > config.min_speed
361
-
362
- return (spikes_bin[valid, :], xx[valid], yy[valid], tt_pos[valid])
363
-
364
- except KeyError as e:
365
- raise DataLoadError(f"Missing position data for speed filtering: {e}") from e
366
- except Exception as e:
367
- raise ProcessingError(f"Error in speed filtering: {e}") from e
368
-
369
-
370
- def plot_projection(
371
- reduce_func,
372
- embed_data,
373
- config: CANN2DPlotConfig | None = None,
374
- title="Projection (3D)",
375
- xlabel="Component 1",
376
- ylabel="Component 2",
377
- zlabel="Component 3",
378
- save_path=None,
379
- show=True,
380
- dpi=300,
381
- figsize=(10, 8),
382
- **kwargs,
383
- ):
384
- """
385
- Plot a 3D projection of the embedded data.
386
-
387
- Parameters:
388
- reduce_func (callable): Function to reduce the dimensionality of the data.
389
- embed_data (ndarray): Data to be projected.
390
- config (PlotConfig, optional): Configuration object for unified plotting parameters
391
- **kwargs: backward compatibility parameters
392
- title (str): Title of the plot.
393
- xlabel (str): Label for the x-axis.
394
- ylabel (str): Label for the y-axis.
395
- zlabel (str): Label for the z-axis.
396
- save_path (str, optional): Path to save the plot. If None, plot will not be saved.
397
- show (bool): Whether to display the plot.
398
- dpi (int): Dots per inch for saving the figure.
399
- figsize (tuple): Size of the figure.
400
-
401
- Returns:
402
- fig: The created figure object.
403
- """
404
-
405
- # Handle backward compatibility and configuration
406
- if config is None:
407
- config = CANN2DPlotConfig.for_projection_3d(
408
- title=title,
409
- xlabel=xlabel,
410
- ylabel=ylabel,
411
- zlabel=zlabel,
412
- save_path=save_path,
413
- show=show,
414
- figsize=figsize,
415
- dpi=dpi,
416
- **kwargs,
417
- )
418
-
419
- reduced_data = reduce_func(embed_data[::5])
420
-
421
- fig = plt.figure(figsize=config.figsize)
422
- ax = fig.add_subplot(111, projection="3d")
423
- ax.scatter(reduced_data[:, 0], reduced_data[:, 1], reduced_data[:, 2], s=1, alpha=0.5)
424
-
425
- ax.set_title(config.title)
426
- ax.set_xlabel(config.xlabel)
427
- ax.set_ylabel(config.ylabel)
428
- ax.set_zlabel(config.zlabel)
429
-
430
- if config.save_path is None and config.show is None:
431
- raise ValueError("Either save path or show must be provided.")
432
- if config.save_path:
433
- plt.savefig(config.save_path, dpi=config.dpi)
434
- if config.show:
435
- plt.show()
436
-
437
- plt.close(fig)
438
-
439
- return fig
440
-
441
-
442
- def tda_vis(embed_data: np.ndarray, config: TDAConfig | None = None, **kwargs) -> dict[str, Any]:
443
- """
444
- Topological Data Analysis visualization with optional shuffle testing.
445
-
446
- Parameters:
447
- embed_data : ndarray
448
- Embedded spike train data.
449
- config : TDAConfig, optional
450
- Configuration object with all TDA parameters
451
- **kwargs : backward compatibility parameters
452
-
453
- Returns:
454
- dict : Dictionary containing:
455
- - persistence: persistence diagrams from real data
456
- - indstemp: indices of sampled points
457
- - movetimes: selected time points
458
- - n_points: number of sampled points
459
- - shuffle_max: shuffle analysis results (if do_shuffle=True, otherwise None)
460
- """
461
- # Handle backward compatibility and configuration
462
- if config is None:
463
- config = TDAConfig(
464
- dim=kwargs.get("dim", 6),
465
- num_times=kwargs.get("num_times", 5),
466
- active_times=kwargs.get("active_times", 15000),
467
- k=kwargs.get("k", 1000),
468
- n_points=kwargs.get("n_points", 1200),
469
- metric=kwargs.get("metric", "cosine"),
470
- nbs=kwargs.get("nbs", 800),
471
- maxdim=kwargs.get("maxdim", 1),
472
- coeff=kwargs.get("coeff", 47),
473
- show=kwargs.get("show", True),
474
- do_shuffle=kwargs.get("do_shuffle", False),
475
- num_shuffles=kwargs.get("num_shuffles", 1000),
476
- progress_bar=kwargs.get("progress_bar", True),
477
- )
478
-
479
- try:
480
- # Compute persistent homology for real data
481
- print("Computing persistent homology for real data...")
482
- real_persistence = _compute_real_persistence(embed_data, config)
483
-
484
- # Perform shuffle analysis if requested
485
- shuffle_max = None
486
- if config.do_shuffle:
487
- shuffle_max = _perform_shuffle_analysis(embed_data, config)
488
-
489
- # Visualization
490
- _handle_visualization(real_persistence["persistence"], shuffle_max, config)
491
-
492
- # Return results as dictionary
493
- return {
494
- "persistence": real_persistence["persistence"],
495
- "indstemp": real_persistence["indstemp"],
496
- "movetimes": real_persistence["movetimes"],
497
- "n_points": real_persistence["n_points"],
498
- "shuffle_max": shuffle_max,
499
- }
500
-
501
- except Exception as e:
502
- raise ProcessingError(f"TDA analysis failed: {e}") from e
503
-
504
-
505
- def _compute_real_persistence(embed_data: np.ndarray, config: TDAConfig) -> dict[str, Any]:
506
- """Compute persistent homology for real data with progress tracking."""
507
-
508
- logging.info("Processing real data - Starting TDA analysis (5 steps)")
509
-
510
- # Step 1: Time point downsampling
511
- logging.info("Step 1/5: Time point downsampling")
512
- times_cube = _downsample_timepoints(embed_data, config.num_times)
513
-
514
- # Step 2: Select most active time points
515
- logging.info("Step 2/5: Selecting active time points")
516
- movetimes = _select_active_timepoints(embed_data, times_cube, config.active_times)
517
-
518
- # Step 3: PCA dimensionality reduction
519
- logging.info("Step 3/5: PCA dimensionality reduction")
520
- dimred = _apply_pca_reduction(embed_data, movetimes, config.dim)
521
-
522
- # Step 4: Point cloud sampling (denoising)
523
- logging.info("Step 4/5: Point cloud denoising")
524
- indstemp = _apply_denoising(dimred, config)
525
-
526
- # Step 5: Compute persistent homology
527
- logging.info("Step 5/5: Computing persistent homology")
528
- persistence = _compute_persistence_homology(dimred, indstemp, config)
529
-
530
- logging.info("TDA analysis completed successfully")
531
-
532
- # Return all necessary data in dictionary format
533
- return {
534
- "persistence": persistence,
535
- "indstemp": indstemp,
536
- "movetimes": movetimes,
537
- "n_points": config.n_points,
538
- }
539
-
540
-
541
- def _downsample_timepoints(embed_data: np.ndarray, num_times: int) -> np.ndarray:
542
- """Downsample timepoints for computational efficiency."""
543
- return np.arange(0, embed_data.shape[0], num_times)
544
-
545
-
546
- def _select_active_timepoints(
547
- embed_data: np.ndarray, times_cube: np.ndarray, active_times: int
548
- ) -> np.ndarray:
549
- """Select most active timepoints based on total activity."""
550
- activity_scores = np.sum(embed_data[times_cube, :], 1)
551
- # Match external TDAvis: sort indices first, then map to times_cube
552
- movetimes = np.sort(np.argsort(activity_scores)[-active_times:])
553
- return times_cube[movetimes]
554
-
555
-
556
- def _apply_pca_reduction(embed_data: np.ndarray, movetimes: np.ndarray, dim: int) -> np.ndarray:
557
- """Apply PCA dimensionality reduction."""
558
- scaled_data = preprocessing.scale(embed_data[movetimes, :])
559
- dimred, *_ = _pca(scaled_data, dim=dim)
560
- return dimred
561
-
562
-
563
- def _apply_denoising(dimred: np.ndarray, config: TDAConfig) -> np.ndarray:
564
- """Apply point cloud denoising."""
565
- indstemp, *_ = _sample_denoising(
566
- dimred,
567
- k=config.k,
568
- num_sample=config.n_points,
569
- omega=1, # Match external TDAvis: uses 1, not default 0.2
570
- metric=config.metric,
571
- )
572
- return indstemp
573
-
574
-
575
- def _compute_persistence_homology(
576
- dimred: np.ndarray, indstemp: np.ndarray, config: TDAConfig
577
- ) -> dict[str, Any]:
578
- """Compute persistent homology using ripser."""
579
- d = _second_build(dimred, indstemp, metric=config.metric, nbs=config.nbs)
580
- np.fill_diagonal(d, 0)
581
-
582
- return ripser(
583
- d,
584
- maxdim=config.maxdim,
585
- coeff=config.coeff,
586
- do_cocycles=True,
587
- distance_matrix=True,
588
- progress_bar=config.progress_bar,
589
- )
590
-
591
-
592
- def _perform_shuffle_analysis(embed_data: np.ndarray, config: TDAConfig) -> dict[int, Any]:
593
- """Perform shuffle analysis with progress tracking."""
594
- print(f"\nStarting shuffle analysis with {config.num_shuffles} iterations...")
595
-
596
- # Create parameters dict for shuffle analysis
597
- shuffle_params = {
598
- "dim": config.dim,
599
- "num_times": config.num_times,
600
- "active_times": config.active_times,
601
- "k": config.k,
602
- "n_points": config.n_points,
603
- "metric": config.metric,
604
- "nbs": config.nbs,
605
- "maxdim": config.maxdim,
606
- "coeff": config.coeff,
607
- }
608
-
609
- shuffle_max = _run_shuffle_analysis(
610
- embed_data,
611
- num_shuffles=config.num_shuffles,
612
- num_cores=Constants.MULTIPROCESSING_CORES,
613
- progress_bar=config.progress_bar,
614
- **shuffle_params,
615
- )
616
-
617
- # Print shuffle analysis summary
618
- _print_shuffle_summary(shuffle_max)
619
-
620
- return shuffle_max
621
-
622
-
623
- def _print_shuffle_summary(shuffle_max: dict[int, Any]) -> None:
624
- """Print summary of shuffle analysis results."""
625
- print("\nSummary of shuffle-based analysis:")
626
- for dim_idx in [0, 1, 2]:
627
- if shuffle_max and dim_idx in shuffle_max and shuffle_max[dim_idx]:
628
- values = shuffle_max[dim_idx]
629
- print(
630
- f"H{dim_idx}: {len(values)} valid iterations | "
631
- f"Mean maximum persistence: {np.mean(values):.4f} | "
632
- f"99.9th percentile: {np.percentile(values, 99.9):.4f}"
633
- )
634
-
635
-
636
- def _handle_visualization(
637
- real_persistence: dict[str, Any], shuffle_max: dict[int, Any] | None, config: TDAConfig
638
- ) -> None:
639
- """Handle visualization based on configuration."""
640
- if config.show:
641
- if config.do_shuffle and shuffle_max is not None:
642
- _plot_barcode_with_shuffle(real_persistence, shuffle_max)
643
- else:
644
- _plot_barcode(real_persistence)
645
- plt.show()
646
- else:
647
- plt.close()
648
-
649
-
650
- def _load_pos(t, x, y, res=100000, dt=1000):
651
- """
652
- Compute animal position and speed from spike data file.
653
-
654
- Interpolates animal positions to match spike time bins and computes smoothed velocity vectors and speed.
655
-
656
- Parameters:
657
- t (ndarray): Time points of the spikes (in seconds).
658
- x (ndarray): X coordinates of the animal's position.
659
- y (ndarray): Y coordinates of the animal's position.
660
- res (int): Time scaling factor to align with spike resolution.
661
- dt (int): Temporal bin size in microseconds.
662
-
663
- Returns:
664
- xx (ndarray): Interpolated x positions.
665
- yy (ndarray): Interpolated y positions.
666
- tt (ndarray): Corresponding time points (in seconds).
667
- speed (ndarray): Speed at each time point (in cm/s).
668
- """
669
-
670
- min_time0 = np.min(t)
671
- max_time0 = np.max(t)
672
-
673
- times = np.where((t >= min_time0) & (t < max_time0))
674
- x = x[times]
675
- y = y[times]
676
- t = t[times]
677
-
678
- min_time = min_time0 * res
679
- max_time = max_time0 * res
680
-
681
- tt = np.arange(np.floor(min_time), np.ceil(max_time) + 1, dt) / res
682
-
683
- idt = np.concatenate(([0], np.digitize(t[1:-1], tt[:]) - 1, [len(tt) + 1]))
684
- idtt = np.digitize(np.arange(len(tt)), idt) - 1
685
-
686
- idx = np.concatenate((np.unique(idtt), [np.max(idtt) + 1]))
687
- divisor = np.bincount(idtt)
688
- steps = 1.0 / divisor[divisor > 0]
689
- N = np.max(divisor)
690
- ranges = np.multiply(np.arange(N)[np.newaxis, :], steps[:, np.newaxis])
691
- ranges[ranges >= 1] = np.nan
692
-
693
- rangesx = x[idx[:-1], np.newaxis] + np.multiply(
694
- ranges, (x[idx[1:]] - x[idx[:-1]])[:, np.newaxis]
695
- )
696
- xx = rangesx[~np.isnan(ranges)]
697
-
698
- rangesy = y[idx[:-1], np.newaxis] + np.multiply(
699
- ranges, (y[idx[1:]] - y[idx[:-1]])[:, np.newaxis]
700
- )
701
- yy = rangesy[~np.isnan(ranges)]
702
-
703
- xxs = _gaussian_filter1d(xx - np.min(xx), sigma=100)
704
- yys = _gaussian_filter1d(yy - np.min(yy), sigma=100)
705
- dx = (xxs[1:] - xxs[:-1]) * 100
706
- dy = (yys[1:] - yys[:-1]) * 100
707
- speed = np.sqrt(dx**2 + dy**2) / 0.01
708
- speed = np.concatenate(([speed[0]], speed))
709
- return xx, yy, tt, speed
710
-
711
-
712
- def _gaussian_filter1d(
713
- input,
714
- sigma,
715
- axis=-1,
716
- order=0,
717
- output=None,
718
- mode="reflect",
719
- cval=0.0,
720
- truncate=4.0,
721
- *,
722
- radius=None,
723
- ):
724
- """1-D Gaussian filter.
725
-
726
- Parameters
727
- ----------
728
- %(input)s
729
- sigma : scalar
730
- standard deviation for Gaussian kernel
731
- %(axis)s
732
- order : int, optional
733
- An order of 0 corresponds to convolution with a Gaussian
734
- kernel. A positive order corresponds to convolution with
735
- that derivative of a Gaussian.
736
- %(output)s
737
- %(mode_reflect)s
738
- %(cval)s
739
- truncate : float, optional
740
- Truncate the filter at this many standard deviations.
741
- Default is 4.0.
742
- radius : None or int, optional
743
- Radius of the Gaussian kernel. If specified, the size of
744
- the kernel will be ``2*radius + 1``, and `truncate` is ignored.
745
- Default is None.
746
-
747
- Returns
748
- -------
749
- gaussian_filter1d : ndarray
750
-
751
- Notes
752
- -----
753
- The Gaussian kernel will have size ``2*radius + 1`` along each axis. If
754
- `radius` is None, a default ``radius = round(truncate * sigma)`` will be
755
- used.
756
-
757
- Examples
758
- --------
759
- >>> from scipy.ndimage import gaussian_filter1d
760
- >>> import numpy as np
761
- >>> gaussian_filter1d([1.0, 2.0, 3.0, 4.0, 5.0], 1)
762
- array([ 1.42704095, 2.06782203, 3. , 3.93217797, 4.57295905])
763
- >>> _gaussian_filter1d([1.0, 2.0, 3.0, 4.0, 5.0], 4)
764
- array([ 2.91948343, 2.95023502, 3. , 3.04976498, 3.08051657])
765
- >>> import matplotlib.pyplot as plt
766
- >>> rng = np.random.default_rng()
767
- >>> x = rng.standard_normal(101).cumsum()
768
- >>> y3 = _gaussian_filter1d(x, 3)
769
- >>> y6 = _gaussian_filter1d(x, 6)
770
- >>> plt.plot(x, 'k', label='original data')
771
- >>> plt.plot(y3, '--', label='filtered, sigma=3')
772
- >>> plt.plot(y6, ':', label='filtered, sigma=6')
773
- >>> plt.legend()
774
- >>> plt.grid()
775
- >>> plt.show()
776
-
777
- """
778
- sd = float(sigma)
779
- # make the radius of the filter equal to truncate standard deviations
780
- lw = int(truncate * sd + 0.5)
781
- if radius is not None:
782
- lw = radius
783
- if not isinstance(lw, numbers.Integral) or lw < 0:
784
- raise ValueError("Radius must be a nonnegative integer.")
785
- # Since we are calling correlate, not convolve, revert the kernel
786
- weights = _gaussian_kernel1d(sigma, order, lw)[::-1]
787
- return _correlate1d(input, weights, axis, output, mode, cval, 0)
788
-
789
-
790
- def _gaussian_kernel1d(sigma, order, radius):
791
- """
792
- Computes a 1-D Gaussian convolution kernel.
793
- """
794
- if order < 0:
795
- raise ValueError("order must be non-negative")
796
- exponent_range = np.arange(order + 1)
797
- sigma2 = sigma * sigma
798
- x = np.arange(-radius, radius + 1)
799
- phi_x = np.exp(-0.5 / sigma2 * x**2)
800
- phi_x = phi_x / phi_x.sum()
801
-
802
- if order == 0:
803
- return phi_x
804
- else:
805
- # f(x) = q(x) * phi(x) = q(x) * exp(p(x))
806
- # f'(x) = (q'(x) + q(x) * p'(x)) * phi(x)
807
- # p'(x) = -1 / sigma ** 2
808
- # Implement q'(x) + q(x) * p'(x) as a matrix operator and apply to the
809
- # coefficients of q(x)
810
- q = np.zeros(order + 1)
811
- q[0] = 1
812
- D = np.diag(exponent_range[1:], 1) # D @ q(x) = q'(x)
813
- P = np.diag(np.ones(order) / -sigma2, -1) # P @ q(x) = q(x) * p'(x)
814
- Q_deriv = D + P
815
- for _ in range(order):
816
- q = Q_deriv.dot(q)
817
- q = (x[:, None] ** exponent_range).dot(q)
818
- return q * phi_x
819
-
820
-
821
- def _correlate1d(input, weights, axis=-1, output=None, mode="reflect", cval=0.0, origin=0):
822
- """Calculate a 1-D correlation along the given axis.
823
-
824
- The lines of the array along the given axis are correlated with the
825
- given weights.
826
-
827
- Parameters
828
- ----------
829
- %(input)s
830
- weights : array
831
- 1-D sequence of numbers.
832
- %(axis)s
833
- %(output)s
834
- %(mode_reflect)s
835
- %(cval)s
836
- %(origin)s
837
-
838
- Returns
839
- -------
840
- result : ndarray
841
- Correlation result. Has the same shape as `input`.
842
-
843
- Examples
844
- --------
845
- >>> from scipy.ndimage import correlate1d
846
- >>> correlate1d([2, 8, 0, 4, 1, 9, 9, 0], weights=[1, 3])
847
- array([ 8, 26, 8, 12, 7, 28, 36, 9])
848
- """
849
- input = np.asarray(input)
850
- weights = np.asarray(weights)
851
- complex_input = input.dtype.kind == "c"
852
- complex_weights = weights.dtype.kind == "c"
853
- if complex_input or complex_weights:
854
- if complex_weights:
855
- weights = weights.conj()
856
- weights = weights.astype(np.complex128, copy=False)
857
- kwargs = dict(axis=axis, mode=mode, origin=origin)
858
- output = _ni_support._get_output(output, input, complex_output=True)
859
- return _complex_via_real_components(_correlate1d, input, weights, output, cval, **kwargs)
860
-
861
- output = _ni_support._get_output(output, input)
862
- weights = np.asarray(weights, dtype=np.float64)
863
- if weights.ndim != 1 or weights.shape[0] < 1:
864
- raise RuntimeError("no filter weights given")
865
- if not weights.flags.contiguous:
866
- weights = weights.copy()
867
- axis = _normalize_axis_index(axis, input.ndim)
868
- if _invalid_origin(origin, len(weights)):
869
- raise ValueError(
870
- "Invalid origin; origin must satisfy "
871
- "-(len(weights) // 2) <= origin <= "
872
- "(len(weights)-1) // 2"
873
- )
874
- mode = _ni_support._extend_mode_to_code(mode)
875
- _nd_image.correlate1d(input, weights, axis, output, mode, cval, origin)
876
- return output
877
-
878
-
879
- def _complex_via_real_components(func, input, weights, output, cval, **kwargs):
880
- """Complex convolution via a linear combination of real convolutions."""
881
- complex_input = input.dtype.kind == "c"
882
- complex_weights = weights.dtype.kind == "c"
883
- if complex_input and complex_weights:
884
- # real component of the output
885
- func(input.real, weights.real, output=output.real, cval=np.real(cval), **kwargs)
886
- output.real -= func(input.imag, weights.imag, output=None, cval=np.imag(cval), **kwargs)
887
- # imaginary component of the output
888
- func(input.real, weights.imag, output=output.imag, cval=np.real(cval), **kwargs)
889
- output.imag += func(input.imag, weights.real, output=None, cval=np.imag(cval), **kwargs)
890
- elif complex_input:
891
- func(input.real, weights, output=output.real, cval=np.real(cval), **kwargs)
892
- func(input.imag, weights, output=output.imag, cval=np.imag(cval), **kwargs)
893
- else:
894
- if np.iscomplexobj(cval):
895
- raise ValueError("Cannot provide a complex-valued cval when the input is real.")
896
- func(input, weights.real, output=output.real, cval=cval, **kwargs)
897
- func(input, weights.imag, output=output.imag, cval=cval, **kwargs)
898
- return output
899
-
900
-
901
- def _normalize_axis_index(axis, ndim):
902
- # Check if `axis` is in the correct range and normalize it
903
- if axis < -ndim or axis >= ndim:
904
- msg = f"axis {axis} is out of bounds for array of dimension {ndim}"
905
- raise AxisError(msg)
906
-
907
- if axis < 0:
908
- axis = axis + ndim
909
- return axis
910
-
911
-
912
- def _compute_persistence(
913
- sspikes,
914
- dim=6,
915
- num_times=5,
916
- active_times=15000,
917
- k=1000,
918
- n_points=1200,
919
- metric="cosine",
920
- nbs=800,
921
- maxdim=1,
922
- coeff=47,
923
- progress_bar=True,
924
- ):
925
- # Time point downsampling
926
- times_cube = np.arange(0, sspikes.shape[0], num_times)
927
-
928
- # Select most active time points
929
- movetimes = np.sort(np.argsort(np.sum(sspikes[times_cube, :], 1))[-active_times:])
930
- movetimes = times_cube[movetimes]
931
-
932
- # PCA dimensionality reduction
933
- scaled_data = preprocessing.scale(sspikes[movetimes, :])
934
- dimred, *_ = _pca(scaled_data, dim=dim)
935
-
936
- # Point cloud sampling (denoising)
937
- indstemp, *_ = _sample_denoising(dimred, k, n_points, 1, metric)
938
-
939
- # Build distance matrix
940
- d = _second_build(dimred, indstemp, metric=metric, nbs=nbs)
941
- np.fill_diagonal(d, 0)
942
-
943
- # Compute persistent homology
944
- persistence = ripser(
945
- d,
946
- maxdim=maxdim,
947
- coeff=coeff,
948
- do_cocycles=True,
949
- distance_matrix=True,
950
- progress_bar=progress_bar,
951
- )
952
-
953
- return persistence
954
-
955
-
956
- def _pca(data, dim=2):
957
- """
958
- Perform PCA (Principal Component Analysis) for dimensionality reduction.
959
-
960
- Parameters:
961
- data (ndarray): Input data matrix of shape (N_samples, N_features).
962
- dim (int): Target dimension for PCA projection.
963
-
964
- Returns:
965
- components (ndarray): Projected data of shape (N_samples, dim).
966
- var_exp (list): Variance explained by each principal component.
967
- evals (ndarray): Eigenvalues corresponding to the selected components.
968
- """
969
- if dim < 2:
970
- return data, [0]
971
- m, n = data.shape
972
- # mean center the data
973
- # data -= data.mean(axis=0)
974
- # calculate the covariance matrix
975
- R = np.cov(data, rowvar=False)
976
- # calculate eigenvectors & eigenvalues of the covariance matrix
977
- # use 'eigh' rather than 'eig' since R is symmetric,
978
- # the performance gain is substantial
979
- evals, evecs = np.linalg.eig(R)
980
- # sort eigenvalue in decreasing order
981
- idx = np.argsort(evals)[::-1]
982
- evecs = evecs[:, idx]
983
- # sort eigenvectors according to same index
984
- evals = evals[idx]
985
- # select the first n eigenvectors (n is desired dimension
986
- # of rescaled data array, or dims_rescaled_data)
987
- evecs = evecs[:, :dim]
988
- # carry out the transformation on the data using eigenvectors
989
- # and return the re-scaled data, eigenvalues, and eigenvectors
990
-
991
- tot = np.sum(evals)
992
- var_exp = [(i / tot) * 100 for i in sorted(evals[:dim], reverse=True)]
993
- components = np.dot(evecs.T, data.T).T
994
- return components, var_exp, evals[:dim]
995
-
996
-
997
- def _sample_denoising(data, k=10, num_sample=500, omega=0.2, metric="euclidean"):
998
- """
999
- Perform denoising and greedy sampling based on mutual k-NN graph.
1000
-
1001
- Parameters:
1002
- data (ndarray): High-dimensional point cloud data.
1003
- k (int): Number of neighbors for local density estimation.
1004
- num_sample (int): Number of samples to retain.
1005
- omega (float): Suppression factor during greedy sampling.
1006
- metric (str): Distance metric used for kNN ('euclidean', 'cosine', etc).
1007
-
1008
- Returns:
1009
- inds (ndarray): Indices of sampled points.
1010
- d (ndarray): Pairwise similarity matrix of sampled points.
1011
- Fs (ndarray): Sampling scores at each step.
1012
- """
1013
- if HAS_NUMBA:
1014
- return _sample_denoising_numba(data, k, num_sample, omega, metric)
1015
- else:
1016
- return _sample_denoising_numpy(data, k, num_sample, omega, metric)
1017
-
1018
-
1019
- def _sample_denoising_numpy(data, k=10, num_sample=500, omega=0.2, metric="euclidean"):
1020
- """Original numpy implementation for fallback."""
1021
- n = data.shape[0]
1022
- X = squareform(pdist(data, metric))
1023
- knn_indices = np.argsort(X)[:, :k]
1024
- knn_dists = X[np.arange(X.shape[0])[:, None], knn_indices].copy()
1025
-
1026
- sigmas, rhos = _smooth_knn_dist(knn_dists, k, local_connectivity=0)
1027
- rows, cols, vals = _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos)
1028
- result = coo_matrix((vals, (rows, cols)), shape=(n, n))
1029
- result.eliminate_zeros()
1030
- transpose = result.transpose()
1031
- prod_matrix = result.multiply(transpose)
1032
- result = result + transpose - prod_matrix
1033
- result.eliminate_zeros()
1034
- X = result.toarray()
1035
- F = np.sum(X, 1)
1036
- Fs = np.zeros(num_sample)
1037
- Fs[0] = np.max(F)
1038
- i = np.argmax(F)
1039
- inds_all = np.arange(n)
1040
- inds_left = inds_all > -1
1041
- inds_left[i] = False
1042
- inds = np.zeros(num_sample, dtype=int)
1043
- inds[0] = i
1044
- for j in np.arange(1, num_sample):
1045
- F -= omega * X[i, :]
1046
- Fmax = np.argmax(F[inds_left])
1047
- # Exactly match external TDAvis implementation (including the indexing logic)
1048
- Fs[j] = F[Fmax]
1049
- i = inds_all[inds_left][Fmax]
1050
-
1051
- inds_left[i] = False
1052
- inds[j] = i
1053
- d = np.zeros((num_sample, num_sample))
1054
-
1055
- for j, i in enumerate(inds):
1056
- d[j, :] = X[i, inds]
1057
- return inds, d, Fs
1058
-
1059
-
1060
- def _sample_denoising_numba(data, k=10, num_sample=500, omega=0.2, metric="euclidean"):
1061
- """Optimized numba implementation."""
1062
- n = data.shape[0]
1063
- X = squareform(pdist(data, metric))
1064
- knn_indices = np.argsort(X)[:, :k]
1065
- knn_dists = X[np.arange(X.shape[0])[:, None], knn_indices].copy()
1066
-
1067
- sigmas, rhos = _smooth_knn_dist(knn_dists, k, local_connectivity=0)
1068
- rows, cols, vals = _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos)
1069
-
1070
- # Build symmetric adjacency matrix using optimized function
1071
- X_adj = _build_adjacency_matrix_numba(rows, cols, vals, n)
1072
-
1073
- # Greedy sampling using optimized function
1074
- inds, Fs = _greedy_sampling_numba(X_adj, num_sample, omega)
1075
-
1076
- # Build final distance matrix
1077
- d = _build_distance_matrix_numba(X_adj, inds)
1078
-
1079
- return inds, d, Fs
1080
-
1081
-
1082
- @njit(fastmath=True)
1083
- def _build_adjacency_matrix_numba(rows, cols, vals, n):
1084
- """Build symmetric adjacency matrix efficiently with numba.
1085
-
1086
- This matches the scipy sparse matrix operations:
1087
- result = result + transpose - prod_matrix
1088
- where prod_matrix = result.multiply(transpose)
1089
- """
1090
- # Initialize matrices
1091
- X = np.zeros((n, n), dtype=np.float64)
1092
- X_T = np.zeros((n, n), dtype=np.float64)
1093
-
1094
- # Build adjacency matrix and its transpose simultaneously
1095
- for i in range(len(rows)):
1096
- X[rows[i], cols[i]] = vals[i]
1097
- X_T[cols[i], rows[i]] = vals[i] # Transpose
1098
-
1099
- # Apply the symmetrization formula: A = A + A^T - A ⊙ A^T (vectorized)
1100
- # This matches scipy's: result + transpose - prod_matrix
1101
- X[:, :] = X + X_T - X * X_T
1102
-
1103
- return X
1104
-
1105
-
1106
- @njit(fastmath=True)
1107
- def _greedy_sampling_numba(X, num_sample, omega):
1108
- """Optimized greedy sampling with numba."""
1109
- n = X.shape[0]
1110
- F = np.sum(X, axis=1)
1111
- Fs = np.zeros(num_sample)
1112
- inds = np.zeros(num_sample, dtype=np.int64)
1113
- inds_left = np.ones(n, dtype=np.bool_)
1114
-
1115
- # Initialize with maximum F
1116
- i = np.argmax(F)
1117
- Fs[0] = F[i]
1118
- inds[0] = i
1119
- inds_left[i] = False
1120
-
1121
- # Greedy sampling loop
1122
- for j in range(1, num_sample):
1123
- # Update F values
1124
- for k in range(n):
1125
- F[k] -= omega * X[i, k]
1126
-
1127
- # Find maximum among remaining points (matching numpy logic exactly)
1128
- max_val = -np.inf
1129
- max_idx = -1
1130
- for k in range(n):
1131
- if inds_left[k] and F[k] > max_val:
1132
- max_val = F[k]
1133
- max_idx = k
1134
-
1135
- # Record the F value using the selected index (matching external TDAvis)
1136
- i = max_idx
1137
- Fs[j] = F[i]
1138
- inds[j] = i
1139
- inds_left[i] = False
1140
-
1141
- return inds, Fs
1142
-
1143
-
1144
- @njit(fastmath=True)
1145
- def _build_distance_matrix_numba(X, inds):
1146
- """Build final distance matrix efficiently with numba."""
1147
- num_sample = len(inds)
1148
- d = np.zeros((num_sample, num_sample))
1149
-
1150
- for j in range(num_sample):
1151
- for k in range(num_sample):
1152
- d[j, k] = X[inds[j], inds[k]]
1153
-
1154
- return d
1155
-
1156
-
1157
- @njit(fastmath=True)
1158
- def _smooth_knn_dist(distances, k, n_iter=64, local_connectivity=0.0, bandwidth=1.0):
1159
- """
1160
- Compute smoothed local distances for kNN graph with entropy balancing.
1161
-
1162
- Parameters:
1163
- distances (ndarray): kNN distance matrix.
1164
- k (int): Number of neighbors.
1165
- n_iter (int): Number of binary search iterations.
1166
- local_connectivity (float): Minimum local connectivity.
1167
- bandwidth (float): Bandwidth parameter.
1168
-
1169
- Returns:
1170
- sigmas (ndarray): Smoothed sigma values for each point.
1171
- rhos (ndarray): Minimum distances (connectivity cutoff) for each point.
1172
- """
1173
- target = np.log2(k) * bandwidth
1174
- # target = np.log(k) * bandwidth
1175
- # target = k
1176
-
1177
- rho = np.zeros(distances.shape[0])
1178
- result = np.zeros(distances.shape[0])
1179
-
1180
- mean_distances = np.mean(distances)
1181
-
1182
- for i in range(distances.shape[0]):
1183
- lo = 0.0
1184
- hi = np.inf
1185
- mid = 1.0
1186
-
1187
- # Vectorized computation of non-zero distances
1188
- ith_distances = distances[i]
1189
- non_zero_dists = ith_distances[ith_distances > 0.0]
1190
- if non_zero_dists.shape[0] >= local_connectivity:
1191
- index = int(np.floor(local_connectivity))
1192
- interpolation = local_connectivity - index
1193
- if index > 0:
1194
- rho[i] = non_zero_dists[index - 1]
1195
- if interpolation > 1e-5:
1196
- rho[i] += interpolation * (non_zero_dists[index] - non_zero_dists[index - 1])
1197
- else:
1198
- rho[i] = interpolation * non_zero_dists[0]
1199
- elif non_zero_dists.shape[0] > 0:
1200
- rho[i] = np.max(non_zero_dists)
1201
-
1202
- # Vectorized binary search loop - compute all at once instead of loop
1203
- for _ in range(n_iter):
1204
- # Vectorized computation: compute all distances at once
1205
- d_array = distances[i, 1:] - rho[i]
1206
- # Vectorized conditional: use np.where for conditional computation
1207
- psum = np.sum(np.where(d_array > 0, np.exp(-(d_array / mid)), 1.0))
1208
-
1209
- if np.fabs(psum - target) < 1e-5:
1210
- break
1211
-
1212
- if psum > target:
1213
- hi = mid
1214
- mid = (lo + hi) / 2.0
1215
- else:
1216
- lo = mid
1217
- if hi == np.inf:
1218
- mid *= 2
1219
- else:
1220
- mid = (lo + hi) / 2.0
1221
- result[i] = mid
1222
- # Optimized mean computation - reuse ith_distances
1223
- if rho[i] > 0.0:
1224
- mean_ith_distances = np.mean(ith_distances)
1225
- if result[i] < 1e-3 * mean_ith_distances:
1226
- result[i] = 1e-3 * mean_ith_distances
1227
- else:
1228
- if result[i] < 1e-3 * mean_distances:
1229
- result[i] = 1e-3 * mean_distances
1230
-
1231
- return result, rho
1232
-
1233
-
1234
- @njit(parallel=True, fastmath=True)
1235
- def _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos):
1236
- """
1237
- Compute membership strength matrix from smoothed kNN graph.
1238
-
1239
- Parameters:
1240
- knn_indices (ndarray): Indices of k-nearest neighbors.
1241
- knn_dists (ndarray): Corresponding distances.
1242
- sigmas (ndarray): Local bandwidths.
1243
- rhos (ndarray): Minimum distance thresholds.
1244
-
1245
- Returns:
1246
- rows (ndarray): Row indices for sparse matrix.
1247
- cols (ndarray): Column indices for sparse matrix.
1248
- vals (ndarray): Weight values for sparse matrix.
1249
- """
1250
- n_samples = knn_indices.shape[0]
1251
- n_neighbors = knn_indices.shape[1]
1252
- rows = np.zeros((n_samples * n_neighbors), dtype=np.int64)
1253
- cols = np.zeros((n_samples * n_neighbors), dtype=np.int64)
1254
- vals = np.zeros((n_samples * n_neighbors), dtype=np.float64)
1255
- for i in range(n_samples):
1256
- for j in range(n_neighbors):
1257
- if knn_indices[i, j] == -1:
1258
- continue # We didn't get the full knn for i
1259
- if knn_indices[i, j] == i:
1260
- val = 0.0
1261
- elif knn_dists[i, j] - rhos[i] <= 0.0:
1262
- val = 1.0
1263
- else:
1264
- val = np.exp(-((knn_dists[i, j] - rhos[i]) / (sigmas[i])))
1265
- # val = ((knn_dists[i, j] - rhos[i]) / (sigmas[i]))
1266
-
1267
- rows[i * n_neighbors + j] = i
1268
- cols[i * n_neighbors + j] = knn_indices[i, j]
1269
- vals[i * n_neighbors + j] = val
1270
-
1271
- return rows, cols, vals
1272
-
1273
-
1274
- def _second_build(data, indstemp, nbs=800, metric="cosine"):
1275
- """
1276
- Reconstruct distance matrix after denoising for persistent homology.
1277
-
1278
- Parameters:
1279
- data (ndarray): PCA-reduced data matrix.
1280
- indstemp (ndarray): Indices of sampled points.
1281
- nbs (int): Number of neighbors in reconstructed graph.
1282
- metric (str): Distance metric ('cosine', 'euclidean', etc).
1283
-
1284
- Returns:
1285
- d (ndarray): Symmetric distance matrix used for persistent homology.
1286
- """
1287
- # Filter the data using the sampled point indices
1288
- data = data[indstemp, :]
1289
-
1290
- # Compute the pairwise distance matrix
1291
- X = squareform(pdist(data, metric))
1292
- knn_indices = np.argsort(X)[:, :nbs]
1293
- knn_dists = X[np.arange(X.shape[0])[:, None], knn_indices].copy()
1294
-
1295
- # Compute smoothed kernel widths
1296
- sigmas, rhos = _smooth_knn_dist(knn_dists, nbs, local_connectivity=0)
1297
- rows, cols, vals = _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos)
1298
-
1299
- # Construct a sparse graph
1300
- result = coo_matrix((vals, (rows, cols)), shape=(X.shape[0], X.shape[0]))
1301
- result.eliminate_zeros()
1302
- transpose = result.transpose()
1303
- prod_matrix = result.multiply(transpose)
1304
- result = result + transpose - prod_matrix
1305
- result.eliminate_zeros()
1306
-
1307
- # Build the final distance matrix
1308
- d = result.toarray()
1309
- # Match external TDAvis: direct negative log without epsilon handling
1310
- # Temporarily suppress divide by zero warning to match external behavior
1311
- with np.errstate(divide="ignore", invalid="ignore"):
1312
- d = -np.log(d)
1313
- np.fill_diagonal(d, 0)
1314
-
1315
- return d
1316
-
1317
-
1318
- def _run_shuffle_analysis(sspikes, num_shuffles=1000, num_cores=4, progress_bar=True, **kwargs):
1319
- """Perform shuffle analysis with optimized computation."""
1320
- return _run_shuffle_analysis_multiprocessing(
1321
- sspikes, num_shuffles, num_cores, progress_bar, **kwargs
1322
- )
1323
-
1324
-
1325
- def _run_shuffle_analysis_multiprocessing(
1326
- sspikes, num_shuffles=1000, num_cores=4, progress_bar=True, **kwargs
1327
- ):
1328
- """Original multiprocessing implementation for fallback."""
1329
- # Use numpy arrays with NaN for failed results (more efficient than None filtering)
1330
- max_lifetimes = {
1331
- 0: np.full(num_shuffles, np.nan),
1332
- 1: np.full(num_shuffles, np.nan),
1333
- 2: np.full(num_shuffles, np.nan),
1334
- }
1335
-
1336
- # Estimate runtime with a test iteration
1337
- logging.info("Running test iteration to estimate runtime...")
1338
-
1339
- _ = _process_single_shuffle((0, sspikes, kwargs))
1340
-
1341
- # Prepare task list
1342
- tasks = [(i, sspikes, kwargs) for i in range(num_shuffles)]
1343
- logging.info(
1344
- f"Starting shuffle analysis with {num_shuffles} iterations using {num_cores} cores..."
1345
- )
1346
-
1347
- # Use multiprocessing pool for parallel processing
1348
- with mp.Pool(processes=num_cores) as pool:
1349
- results = list(pool.imap(_process_single_shuffle, tasks))
1350
- logging.info("Shuffle analysis completed")
1351
-
1352
- # Collect results - use indexing instead of append for better performance
1353
- for idx, res in enumerate(results):
1354
- for dim, lifetime in res.items():
1355
- max_lifetimes[dim][idx] = lifetime
1356
-
1357
- # Filter out NaN values (failed results) - convert to list for consistency
1358
- for dim in max_lifetimes:
1359
- max_lifetimes[dim] = max_lifetimes[dim][~np.isnan(max_lifetimes[dim])].tolist()
1360
-
1361
- return max_lifetimes
1362
-
1363
-
1364
- @njit(fastmath=True)
1365
- def _fast_pca_transform(data, components):
1366
- """Fast PCA transformation using numba."""
1367
- return np.dot(data, components.T)
1368
-
1369
-
1370
- def _process_single_shuffle(args):
1371
- """Process a single shuffle task."""
1372
- i, sspikes, kwargs = args
1373
- try:
1374
- shuffled_data = _shuffle_spike_trains(sspikes)
1375
- persistence = _compute_persistence(shuffled_data, **kwargs)
1376
-
1377
- dim_max_lifetimes = {}
1378
- for dim in [0, 1, 2]:
1379
- if dim < len(persistence["dgms"]):
1380
- # Filter out infinite values
1381
- valid_bars = [bar for bar in persistence["dgms"][dim] if not np.isinf(bar[1])]
1382
- if valid_bars:
1383
- lifetimes = [bar[1] - bar[0] for bar in valid_bars]
1384
- if lifetimes:
1385
- dim_max_lifetimes[dim] = max(lifetimes)
1386
- return dim_max_lifetimes
1387
- except Exception as e:
1388
- print(f"Shuffle {i} failed: {str(e)}")
1389
- return {}
1390
-
1391
-
1392
- def _shuffle_spike_trains(sspikes):
1393
- """Perform random circular shift on spike trains."""
1394
- shuffled = sspikes.copy()
1395
- num_neurons = shuffled.shape[1]
1396
-
1397
- # Independent shift for each neuron
1398
- for n in range(num_neurons):
1399
- shift = np.random.randint(0, int(shuffled.shape[0] * 0.1))
1400
- shuffled[:, n] = np.roll(shuffled[:, n], shift)
1401
-
1402
- return shuffled
1403
-
1404
-
1405
- def _plot_barcode(persistence):
1406
- """
1407
- Plot barcode diagram from persistent homology result.
1408
-
1409
- Parameters:
1410
- persistence (dict): Persistent homology result with 'dgms' key.
1411
- """
1412
- cs = np.repeat([[0, 0.55, 0.2]], 3).reshape(3, 3).T # RGB color for each dimension
1413
- alpha = 1
1414
- inf_delta = 0.1
1415
- colormap = cs
1416
- dgms = persistence["dgms"]
1417
- maxdim = len(dgms) - 1
1418
- dims = np.arange(maxdim + 1)
1419
- labels = ["$H_0$", "$H_1$", "$H_2$"]
1420
-
1421
- # Determine axis range
1422
- min_birth, max_death = 0, 0
1423
- for dim in dims:
1424
- persistence_dim = dgms[dim][~np.isinf(dgms[dim][:, 1]), :]
1425
- if persistence_dim.size > 0:
1426
- min_birth = min(min_birth, np.min(persistence_dim))
1427
- max_death = max(max_death, np.max(persistence_dim))
1428
-
1429
- delta = (max_death - min_birth) * inf_delta
1430
- infinity = max_death + delta
1431
- axis_start = min_birth - delta
1432
-
1433
- # Create plot
1434
- fig = plt.figure(figsize=(10, 6))
1435
- gs = gridspec.GridSpec(len(dims), 1)
1436
-
1437
- for dim in dims:
1438
- axes = plt.subplot(gs[dim])
1439
- axes.axis("on")
1440
- axes.set_yticks([])
1441
- axes.set_ylabel(labels[dim], rotation=0, labelpad=20, fontsize=12)
1442
-
1443
- d = np.copy(dgms[dim])
1444
- d[np.isinf(d[:, 1]), 1] = infinity
1445
- dlife = d[:, 1] - d[:, 0]
1446
-
1447
- # Select top 30 bars by lifetime
1448
- dinds = np.argsort(dlife)[-30:]
1449
- if dim > 0:
1450
- dinds = dinds[np.flip(np.argsort(d[dinds, 0]))]
1451
-
1452
- axes.barh(
1453
- 0.5 + np.arange(len(dinds)),
1454
- dlife[dinds],
1455
- height=0.8,
1456
- left=d[dinds, 0],
1457
- alpha=alpha,
1458
- color=colormap[dim],
1459
- linewidth=0,
1460
- )
1461
-
1462
- axes.plot([0, 0], [0, len(dinds)], c="k", linestyle="-", lw=1)
1463
- axes.plot([0, len(dinds)], [0, 0], c="k", linestyle="-", lw=1)
1464
- axes.set_xlim([axis_start, infinity])
1465
-
1466
- plt.tight_layout()
1467
- return fig
1468
-
1469
-
1470
- def _plot_barcode_with_shuffle(persistence, shuffle_max):
1471
- """
1472
- Plot barcode with shuffle region markers.
1473
- """
1474
- # Handle case where shuffle_max is None
1475
- if shuffle_max is None:
1476
- shuffle_max = {}
1477
-
1478
- cs = np.repeat([[0, 0.55, 0.2]], 3).reshape(3, 3).T
1479
- alpha = 1
1480
- inf_delta = 0.1
1481
- colormap = cs
1482
- maxdim = len(persistence["dgms"]) - 1
1483
- dims = np.arange(maxdim + 1)
1484
-
1485
- min_birth, max_death = 0, 0
1486
- for dim in dims:
1487
- # Filter out infinite values
1488
- valid_bars = [bar for bar in persistence["dgms"][dim] if not np.isinf(bar[1])]
1489
- if valid_bars:
1490
- min_birth = min(min_birth, np.min(valid_bars))
1491
- max_death = max(max_death, np.max(valid_bars))
1492
-
1493
- # Handle case with no valid bars
1494
- if max_death == 0 and min_birth == 0:
1495
- min_birth = 0
1496
- max_death = 1
1497
-
1498
- delta = (max_death - min_birth) * inf_delta
1499
- infinity = max_death + delta
1500
-
1501
- # Create figure
1502
- fig = plt.figure(figsize=(10, 8))
1503
- gs = gridspec.GridSpec(len(dims), 1)
1504
-
1505
- # Get shuffle thresholds (99.9th percentile for each dimension)
1506
- thresholds = {}
1507
- for dim in dims:
1508
- if dim in shuffle_max and shuffle_max[dim]:
1509
- thresholds[dim] = np.percentile(shuffle_max[dim], 99.9)
1510
- else:
1511
- thresholds[dim] = 0
1512
-
1513
- for _, dim in enumerate(dims):
1514
- axes = plt.subplot(gs[dim])
1515
- axes.axis("off")
1516
-
1517
- # Add gray background to represent shuffle region
1518
- if dim in thresholds:
1519
- axes.axvspan(0, thresholds[dim], alpha=0.2, color="gray", zorder=-3)
1520
- axes.axvline(x=thresholds[dim], color="gray", linestyle="--", alpha=0.7)
1521
-
1522
- # Do not pre-filter out infinite bars; copy the full diagram instead
1523
- d = np.copy(persistence["dgms"][dim])
1524
- if d.size == 0:
1525
- d = np.zeros((0, 2))
1526
-
1527
- # Map infinite death values to a finite upper bound for visualization
1528
- d[np.isinf(d[:, 1]), 1] = infinity
1529
- dlife = d[:, 1] - d[:, 0]
1530
-
1531
- # Select top 30 longest-lived bars
1532
- if len(dlife) > 0:
1533
- dinds = np.argsort(dlife)[-30:]
1534
- if dim > 0:
1535
- dinds = dinds[np.flip(np.argsort(d[dinds, 0]))]
1536
-
1537
- # Mark significant bars
1538
- significant_bars = []
1539
- for idx in dinds:
1540
- if dlife[idx] > thresholds.get(dim, 0):
1541
- significant_bars.append(idx)
1542
-
1543
- # Draw bars
1544
- for i, idx in enumerate(dinds):
1545
- color = "red" if idx in significant_bars else colormap[dim]
1546
- axes.barh(
1547
- 0.5 + i,
1548
- dlife[idx],
1549
- height=0.8,
1550
- left=d[idx, 0],
1551
- alpha=alpha,
1552
- color=color,
1553
- linewidth=0,
1554
- )
1555
-
1556
- indsall = len(dinds)
1557
- else:
1558
- indsall = 0
1559
-
1560
- axes.plot([0, 0], [0, indsall], c="k", linestyle="-", lw=1)
1561
- axes.plot([0, indsall], [0, 0], c="k", linestyle="-", lw=1)
1562
- axes.set_xlim([0, infinity])
1563
- axes.set_title(f"$H_{dim}$", loc="left")
1564
-
1565
- plt.tight_layout()
1566
- return fig
1567
-
1568
-
1569
- def decode_circular_coordinates(
1570
- persistence_result: dict[str, Any],
1571
- spike_data: dict[str, Any],
1572
- real_ground: bool = True,
1573
- real_of: bool = True,
1574
- save_path: str | None = None,
1575
- ) -> dict[str, Any]:
1576
- """
1577
- Decode circular coordinates (bump positions) from cohomology.
1578
-
1579
- Parameters:
1580
- persistence_result : dict containing persistence analysis results with keys:
1581
- - 'persistence': persistent homology result
1582
- - 'indstemp': indices of sampled points
1583
- - 'movetimes': selected time points
1584
- - 'n_points': number of sampled points
1585
- spike_data : dict, optional
1586
- Spike data dictionary containing 'spike', 't', and optionally 'x', 'y'
1587
- real_ground : bool
1588
- Whether x, y, t ground truth exists
1589
- real_of : bool
1590
- Whether experiment was performed in open field
1591
- save_path : str, optional
1592
- Path to save decoding results. If None, saves to 'Results/spikes_decoding.npz'
1593
-
1594
- Returns:
1595
- dict : Dictionary containing decoding results with keys:
1596
- - 'coords': decoded coordinates for all timepoints
1597
- - 'coordsbox': decoded coordinates for box timepoints
1598
- - 'times': time indices for coords
1599
- - 'times_box': time indices for coordsbox
1600
- - 'centcosall': cosine centroids
1601
- - 'centsinall': sine centroids
1602
- """
1603
- ph_classes = [0, 1] # Decode the ith most persistent cohomology class
1604
- num_circ = len(ph_classes)
1605
- dec_tresh = 0.99
1606
- coeff = 47
1607
-
1608
- # Extract persistence analysis results
1609
- persistence = persistence_result["persistence"]
1610
- indstemp = persistence_result["indstemp"]
1611
- movetimes = persistence_result["movetimes"]
1612
- n_points = persistence_result["n_points"]
1613
-
1614
- diagrams = persistence["dgms"] # the multiset describing the lives of the persistence classes
1615
- cocycles = persistence["cocycles"][1] # the cocycle representatives for the 1-dim classes
1616
- dists_land = persistence["dperm2all"] # the pairwise distance between the points
1617
- births1 = diagrams[1][:, 0] # the time of birth for the 1-dim classes
1618
- deaths1 = diagrams[1][:, 1] # the time of death for the 1-dim classes
1619
- deaths1[np.isinf(deaths1)] = 0
1620
- lives1 = deaths1 - births1 # the lifetime for the 1-dim classes
1621
- iMax = np.argsort(lives1)
1622
- coords1 = np.zeros((num_circ, len(indstemp)))
1623
- threshold = births1[iMax[-2]] + (deaths1[iMax[-2]] - births1[iMax[-2]]) * dec_tresh
1624
-
1625
- for c in ph_classes:
1626
- cocycle = cocycles[iMax[-(c + 1)]]
1627
- coords1[c, :], inds = _get_coords(cocycle, threshold, len(indstemp), dists_land, coeff)
1628
-
1629
- if real_ground: # 用户所提供的数据是否有真实的xyt
1630
- sspikes, xx, yy, tt = embed_spike_trains(
1631
- spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=True)
1632
- )
1633
- else:
1634
- sspikes = embed_spike_trains(
1635
- spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=False)
1636
- )
1637
-
1638
- num_neurons = sspikes.shape[1]
1639
- centcosall = np.zeros((num_neurons, 2, n_points))
1640
- centsinall = np.zeros((num_neurons, 2, n_points))
1641
- dspk = preprocessing.scale(sspikes[movetimes[indstemp], :])
1642
-
1643
- for neurid in range(num_neurons):
1644
- spktemp = dspk[:, neurid].copy()
1645
- centcosall[neurid, :, :] = np.multiply(np.cos(coords1[:, :] * 2 * np.pi), spktemp)
1646
- centsinall[neurid, :, :] = np.multiply(np.sin(coords1[:, :] * 2 * np.pi), spktemp)
1647
-
1648
- if real_ground: # 用户所提供的数据是否有真实的xyt
1649
- sspikes, xx, yy, tt = embed_spike_trains(
1650
- spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=True)
1651
- )
1652
- spikes, __, __, __ = embed_spike_trains(
1653
- spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
1654
- )
1655
- else:
1656
- sspikes = embed_spike_trains(
1657
- spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=False)
1658
- )
1659
- spikes = embed_spike_trains(
1660
- spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=False)
1661
- )
1662
-
1663
- times = np.where(np.sum(spikes > 0, 1) >= 1)[0]
1664
- dspk = preprocessing.scale(sspikes)
1665
- sspikes = sspikes[times, :]
1666
- dspk = dspk[times, :]
1667
-
1668
- a = np.zeros((len(sspikes[:, 0]), 2, num_neurons))
1669
- for n in range(num_neurons):
1670
- a[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centcosall[n, :, :], 1))
1671
-
1672
- c = np.zeros((len(sspikes[:, 0]), 2, num_neurons))
1673
- for n in range(num_neurons):
1674
- c[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centsinall[n, :, :], 1))
1675
-
1676
- mtot2 = np.sum(c, 2)
1677
- mtot1 = np.sum(a, 2)
1678
- coords = np.arctan2(mtot2, mtot1) % (2 * np.pi)
1679
-
1680
- if real_of: # 用户的数据是否是来自真实的OF场地
1681
- coordsbox = coords.copy()
1682
- times_box = times.copy()
1683
- else:
1684
- sspikes, xx, yy, tt = embed_spike_trains(
1685
- spike_data, config=SpikeEmbeddingConfig(smooth=True, speed_filter=True)
1686
- )
1687
- spikes, __, __, __ = embed_spike_trains(
1688
- spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
1689
- )
1690
- dspk = preprocessing.scale(sspikes)
1691
- times_box = np.where(np.sum(spikes > 0, 1) >= 1)[0]
1692
- dspk = dspk[times_box, :]
1693
-
1694
- a = np.zeros((len(times_box), 2, num_neurons))
1695
- for n in range(num_neurons):
1696
- a[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centcosall[n, :, :], 1))
1697
-
1698
- c = np.zeros((len(times_box), 2, num_neurons))
1699
- for n in range(num_neurons):
1700
- c[:, :, n] = np.multiply(dspk[:, n : n + 1], np.sum(centsinall[n, :, :], 1))
1701
-
1702
- mtot2 = np.sum(c, 2)
1703
- mtot1 = np.sum(a, 2)
1704
- coordsbox = np.arctan2(mtot2, mtot1) % (2 * np.pi)
1705
-
1706
- # Prepare results dictionary
1707
- results = {
1708
- "coords": coords,
1709
- "coordsbox": coordsbox,
1710
- "times": times,
1711
- "times_box": times_box,
1712
- "centcosall": centcosall,
1713
- "centsinall": centsinall,
1714
- }
1715
-
1716
- # Save results
1717
- if save_path is None:
1718
- os.makedirs("Results", exist_ok=True)
1719
- save_path = "Results/spikes_decoding.npz"
1720
-
1721
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
1722
- np.savez_compressed(save_path, **results)
1723
-
1724
- return results
1725
-
1726
-
1727
- def plot_cohomap(
1728
- decoding_result: dict[str, Any],
1729
- position_data: dict[str, Any],
1730
- save_path: str | None = None,
1731
- show: bool = False,
1732
- figsize: tuple[int, int] = (10, 4),
1733
- dpi: int = 300,
1734
- subsample: int = 10,
1735
- ) -> plt.Figure:
1736
- """
1737
- Visualize CohoMap 1.0: decoded circular coordinates mapped onto spatial trajectory.
1738
-
1739
- Creates a two-panel visualization showing how the two decoded circular coordinates
1740
- vary across the animal's spatial trajectory. Each panel displays the spatial path
1741
- colored by the cosine of one circular coordinate dimension.
1742
-
1743
- Parameters:
1744
- decoding_result : dict
1745
- Dictionary from decode_circular_coordinates() containing:
1746
- - 'coordsbox': decoded coordinates for box timepoints (n_times x n_dims)
1747
- - 'times_box': time indices for coordsbox
1748
- position_data : dict
1749
- Position data containing 'x' and 'y' arrays for spatial coordinates
1750
- save_path : str, optional
1751
- Path to save the visualization. If None, no save performed
1752
- show : bool, default=False
1753
- Whether to display the visualization
1754
- figsize : tuple[int, int], default=(10, 4)
1755
- Figure size (width, height) in inches
1756
- dpi : int, default=300
1757
- Resolution for saved figure
1758
- subsample : int, default=10
1759
- Subsampling interval for plotting (plot every Nth timepoint)
1760
-
1761
- Returns:
1762
- plt.Figure : The matplotlib figure object
1763
-
1764
- Raises:
1765
- KeyError : If required keys are missing from input dictionaries
1766
- ValueError : If data dimensions are inconsistent
1767
- IndexError : If time indices are out of bounds
1768
-
1769
- Examples:
1770
- >>> # Decode coordinates
1771
- >>> decoding = decode_circular_coordinates(persistence_result, spike_data)
1772
- >>> # Visualize with trajectory data
1773
- >>> fig = plot_cohomap(
1774
- ... decoding,
1775
- ... position_data={'x': xx, 'y': yy},
1776
- ... save_path='cohomap.png',
1777
- ... show=True
1778
- ... )
1779
- """
1780
- try:
1781
- # Extract data
1782
- coordsbox = decoding_result["coordsbox"]
1783
- times_box = decoding_result["times_box"]
1784
- xx = position_data["x"]
1785
- yy = position_data["y"]
1786
-
1787
- # Subsample time indices for plotting
1788
- plot_times = np.arange(0, len(coordsbox), subsample)
1789
-
1790
- # Create a two-panel figure (one per cohomology dimension)
1791
- plt.set_cmap("viridis")
1792
- fig, ax = plt.subplots(1, 2, figsize=figsize)
1793
-
1794
- # Plot for the first circular coordinate
1795
- ax[0].axis("off")
1796
- ax[0].set_aspect("equal", "box")
1797
- im0 = ax[0].scatter(
1798
- xx[times_box][plot_times],
1799
- yy[times_box][plot_times],
1800
- c=np.cos(coordsbox[plot_times, 0]),
1801
- s=8,
1802
- cmap="viridis",
1803
- )
1804
- plt.colorbar(im0, ax=ax[0], label="cos(coord)")
1805
- ax[0].set_title("CohoMap Dim 1", fontsize=10)
1806
-
1807
- # Plot for the second circular coordinate
1808
- ax[1].axis("off")
1809
- ax[1].set_aspect("equal", "box")
1810
- im1 = ax[1].scatter(
1811
- xx[times_box][plot_times],
1812
- yy[times_box][plot_times],
1813
- c=np.cos(coordsbox[plot_times, 1]),
1814
- s=8,
1815
- cmap="viridis",
1816
- )
1817
- plt.colorbar(im1, ax=ax[1], label="cos(coord)")
1818
- ax[1].set_title("CohoMap Dim 2", fontsize=10)
1819
-
1820
- plt.tight_layout()
1821
-
1822
- # Save if path provided
1823
- if save_path:
1824
- try:
1825
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
1826
- plt.savefig(save_path, dpi=dpi)
1827
- print(f"CohoMap visualization saved to {save_path}")
1828
- except Exception as e:
1829
- print(f"Error saving CohoMap visualization: {e}")
1830
-
1831
- # Show if requested
1832
- if show:
1833
- plt.show()
1834
- else:
1835
- plt.close(fig)
1836
-
1837
- return fig
1838
-
1839
- except (KeyError, ValueError, IndexError) as e:
1840
- print(f"CohoMap visualization failed: {e}")
1841
- raise
1842
- except Exception as e:
1843
- print(f"Unexpected error in CohoMap visualization: {e}")
1844
- raise
1845
-
1846
-
1847
- def plot_3d_bump_on_torus(
1848
- decoding_result: dict[str, Any] | str,
1849
- spike_data: dict[str, Any],
1850
- config: CANN2DPlotConfig | None = None,
1851
- save_path: str | None = None,
1852
- numangsint: int = 51,
1853
- r1: float = 1.5,
1854
- r2: float = 1.0,
1855
- window_size: int = 300,
1856
- frame_step: int = 5,
1857
- n_frames: int = 20,
1858
- fps: int = 5,
1859
- show_progress: bool = True,
1860
- show: bool = True,
1861
- figsize: tuple[int, int] = (8, 8),
1862
- **kwargs,
1863
- ) -> animation.FuncAnimation:
1864
- """
1865
- Visualize the movement of the neural activity bump on a torus using matplotlib animation.
1866
-
1867
- This function follows the canns.analyzer.plotting patterns for animation generation
1868
- with progress tracking and proper resource cleanup.
1869
-
1870
- Parameters:
1871
- decoding_result : dict or str
1872
- Dictionary containing decoding results with 'coordsbox' and 'times_box' keys,
1873
- or path to .npz file containing these results
1874
- spike_data : dict, optional
1875
- Spike data dictionary containing spike information
1876
- config : PlotConfig, optional
1877
- Configuration object for unified plotting parameters
1878
- **kwargs : backward compatibility parameters
1879
- save_path : str, optional
1880
- Path to save the animation (e.g., 'animation.gif' or 'animation.mp4')
1881
- numangsint : int
1882
- Grid resolution for the torus surface
1883
- r1 : float
1884
- Major radius of the torus
1885
- r2 : float
1886
- Minor radius of the torus
1887
- window_size : int
1888
- Time window (in number of time points) for each frame
1889
- frame_step : int
1890
- Step size to slide the time window between frames
1891
- n_frames : int
1892
- Total number of frames in the animation
1893
- fps : int
1894
- Frames per second for the output animation
1895
- show_progress : bool
1896
- Whether to show progress bar during generation
1897
- show : bool
1898
- Whether to display the animation
1899
- figsize : tuple[int, int]
1900
- Figure size for the animation
1901
-
1902
- Returns:
1903
- matplotlib.animation.FuncAnimation : The animation object
1904
- """
1905
- # Handle backward compatibility and configuration
1906
- if config is None:
1907
- config = CANN2DPlotConfig.for_torus_animation(**kwargs)
1908
-
1909
- # Override config with any explicitly passed parameters
1910
- for key, value in kwargs.items():
1911
- if hasattr(config, key):
1912
- setattr(config, key, value)
1913
-
1914
- # Extract configuration values
1915
- save_path = config.save_path if config.save_path else save_path
1916
- show = config.show
1917
- figsize = config.figsize
1918
- fps = config.fps
1919
- show_progress = config.show_progress_bar
1920
- numangsint = config.numangsint
1921
- r1 = config.r1
1922
- r2 = config.r2
1923
- window_size = config.window_size
1924
- frame_step = config.frame_step
1925
- n_frames = config.n_frames
1926
-
1927
- # Load decoding results if path is provided
1928
- if isinstance(decoding_result, str):
1929
- f = np.load(decoding_result, allow_pickle=True)
1930
- coords = f["coordsbox"]
1931
- times = f["times_box"]
1932
- f.close()
1933
- else:
1934
- coords = decoding_result["coordsbox"]
1935
- times = decoding_result["times_box"]
1936
-
1937
- spk, *_ = embed_spike_trains(
1938
- spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
1939
- )
1940
-
1941
- # Pre-compute torus geometry (constant across frames - optimization)
1942
- # Create grid for torus surface
1943
- x_edge = np.linspace(0, 2 * np.pi, numangsint)
1944
- y_edge = np.linspace(0, 2 * np.pi, numangsint)
1945
- X_grid, Y_grid = np.meshgrid(x_edge, y_edge)
1946
- X_transformed = (X_grid + np.pi / 5) % (2 * np.pi)
1947
-
1948
- # Pre-compute torus geometry (only done once!)
1949
- torus_x = (r1 + r2 * np.cos(X_transformed)) * np.cos(Y_grid)
1950
- torus_y = (r1 + r2 * np.cos(X_transformed)) * np.sin(Y_grid)
1951
- torus_z = -r2 * np.sin(X_transformed) # Flip torus surface orientation
1952
-
1953
- # Prepare animation data (now only stores colors, not geometry)
1954
- frame_data = []
1955
- prev_m = None
1956
-
1957
- for frame_idx in tqdm(range(n_frames), desc="Processing frames"):
1958
- start_idx = frame_idx * frame_step
1959
- end_idx = start_idx + window_size
1960
- if end_idx > np.max(times):
1961
- break
1962
-
1963
- mask = (times >= start_idx) & (times < end_idx)
1964
- coords_window = coords[mask]
1965
- if len(coords_window) == 0:
1966
- continue
1967
-
1968
- spk_window = spk[times[mask], :]
1969
- activity = np.sum(spk_window, axis=1)
1970
-
1971
- m, _, _, _ = binned_statistic_2d(
1972
- coords_window[:, 0],
1973
- coords_window[:, 1],
1974
- activity,
1975
- statistic="sum",
1976
- bins=np.linspace(0, 2 * np.pi, numangsint - 1),
1977
- )
1978
- m = np.nan_to_num(m)
1979
- m = _smooth_tuning_map(m, numangsint - 1, sig=4.0, bClose=True)
1980
- m = gaussian_filter(m, sigma=1.0)
1981
-
1982
- if prev_m is not None:
1983
- m = 0.7 * prev_m + 0.3 * m
1984
- prev_m = m
1985
-
1986
- # Store only activity map (m) and metadata, reuse geometry
1987
- frame_data.append({"m": m, "time": start_idx * frame_step})
1988
-
1989
- if not frame_data:
1990
- raise ProcessingError("No valid frames generated for animation")
1991
-
1992
- # Create figure and animation with optimized geometry reuse
1993
- fig = plt.figure(figsize=figsize)
1994
-
1995
- try:
1996
- ax = fig.add_subplot(111, projection="3d")
1997
- # Batch set axis properties (reduces overhead)
1998
- ax.set_zlim(-2, 2)
1999
- ax.view_init(-125, 135)
2000
- ax.axis("off")
2001
-
2002
- # Initialize with first frame
2003
- first_frame = frame_data[0]
2004
- surface = ax.plot_surface(
2005
- torus_x, # Pre-computed geometry
2006
- torus_y, # Pre-computed geometry
2007
- torus_z, # Pre-computed geometry
2008
- facecolors=cm.viridis(first_frame["m"] / (np.max(first_frame["m"]) + 1e-9)),
2009
- alpha=1,
2010
- linewidth=0.1,
2011
- antialiased=True,
2012
- rstride=1,
2013
- cstride=1,
2014
- shade=False,
2015
- )
2016
-
2017
- def animate(frame_idx):
2018
- """Optimized animation update - reuses pre-computed geometry."""
2019
- if frame_idx >= len(frame_data):
2020
- return (surface,)
2021
-
2022
- frame = frame_data[frame_idx]
2023
-
2024
- # 3D surfaces require clear (no blitting support), but minimize overhead
2025
- ax.clear()
2026
-
2027
- # Batch axis settings together (reduces function call overhead)
2028
- ax.set_zlim(-2, 2)
2029
- ax.view_init(-125, 135)
2030
- ax.axis("off")
2031
-
2032
- # Reuse pre-computed geometry, only update colors
2033
- new_surface = ax.plot_surface(
2034
- torus_x, # Pre-computed, not recalculated!
2035
- torus_y, # Pre-computed, not recalculated!
2036
- torus_z, # Pre-computed, not recalculated!
2037
- facecolors=cm.viridis(frame["m"] / (np.max(frame["m"]) + 1e-9)),
2038
- alpha=1,
2039
- linewidth=0.1,
2040
- antialiased=True,
2041
- rstride=1,
2042
- cstride=1,
2043
- shade=False,
2044
- )
2045
-
2046
- # Update time text
2047
- time_text = ax.text2D(
2048
- 0.05,
2049
- 0.95,
2050
- f"Frame: {frame_idx + 1}/{len(frame_data)}",
2051
- transform=ax.transAxes,
2052
- fontsize=12,
2053
- bbox=dict(facecolor="white", alpha=0.7),
2054
- )
2055
-
2056
- return new_surface, time_text
2057
-
2058
- # Create animation (blit=False due to 3D limitation)
2059
- interval_ms = 1000 / fps
2060
- ani = animation.FuncAnimation(
2061
- fig,
2062
- animate,
2063
- frames=len(frame_data),
2064
- interval=interval_ms,
2065
- blit=False,
2066
- repeat=True, # 3D does not support blitting
2067
- )
2068
-
2069
- # Save animation if path provided
2070
- if save_path:
2071
- # Warn if both saving and showing (causes double rendering)
2072
- if show and len(frame_data) > 50:
2073
- from ...visualization.core import warn_double_rendering
2074
-
2075
- warn_double_rendering(len(frame_data), save_path, stacklevel=2)
2076
-
2077
- if show_progress:
2078
- pbar = tqdm(total=len(frame_data), desc=f"Saving animation to {save_path}")
2079
-
2080
- def progress_callback(current_frame, total_frames):
2081
- pbar.update(1)
2082
-
2083
- try:
2084
- writer = animation.PillowWriter(fps=fps)
2085
- ani.save(save_path, writer=writer, progress_callback=progress_callback)
2086
- pbar.close()
2087
- print(f"\nAnimation saved to: {save_path}")
2088
- except Exception as e:
2089
- pbar.close()
2090
- print(f"\nError saving animation: {e}")
2091
- else:
2092
- try:
2093
- writer = animation.PillowWriter(fps=fps)
2094
- ani.save(save_path, writer=writer)
2095
- print(f"Animation saved to: {save_path}")
2096
- except Exception as e:
2097
- print(f"Error saving animation: {e}")
2098
-
2099
- if show:
2100
- # Automatically detect Jupyter and display as HTML/JS
2101
- if is_jupyter_environment():
2102
- display_animation_in_jupyter(ani)
2103
- plt.close(fig) # Close after HTML conversion to prevent auto-display
2104
- else:
2105
- plt.show()
2106
- else:
2107
- plt.close(fig) # Close if not showing
2108
-
2109
- # Return None in Jupyter when showing to avoid double display
2110
- if show and is_jupyter_environment():
2111
- return None
2112
- return ani
2113
-
2114
- except Exception as e:
2115
- plt.close(fig)
2116
- raise ProcessingError(f"Failed to create torus animation: {e}") from e
2117
-
2118
-
2119
- def plot_2d_bump_on_manifold(
2120
- decoding_result: dict[str, Any] | str,
2121
- spike_data: dict[str, Any],
2122
- save_path: str | None = None,
2123
- fps: int = 20,
2124
- show: bool = True,
2125
- mode: str = "fast",
2126
- window_size: int = 10,
2127
- frame_step: int = 5,
2128
- numangsint: int = 20,
2129
- figsize: tuple[int, int] = (8, 6),
2130
- show_progress: bool = False,
2131
- ):
2132
- """
2133
- Create 2D projection animation of CANN2D bump activity with full blitting support.
2134
-
2135
- This function provides a fast 2D heatmap visualization as an alternative to the
2136
- 3D torus animation. It achieves 10-20x speedup using matplotlib blitting
2137
- optimization, making it ideal for rapid prototyping and daily analysis.
2138
-
2139
- Args:
2140
- decoding_result: Decoding results containing coords and times (dict or file path)
2141
- spike_data: Dictionary containing spike train data
2142
- save_path: Path to save animation (None to skip saving)
2143
- fps: Frames per second
2144
- show: Whether to display the animation
2145
- mode: Visualization mode - 'fast' for 2D heatmap (default), '3d' falls back to 3D
2146
- window_size: Time window for activity aggregation
2147
- frame_step: Time step between frames
2148
- numangsint: Number of angular bins for spatial discretization
2149
- figsize: Figure size (width, height) in inches
2150
- show_progress: Show progress bar during processing
2151
-
2152
- Returns:
2153
- FuncAnimation object (or None in Jupyter when showing)
2154
-
2155
- Raises:
2156
- ProcessingError: If mode is invalid or animation generation fails
2157
-
2158
- Example:
2159
- >>> # Fast 2D visualization (recommended for daily use)
2160
- >>> ani = plot_2d_bump_on_manifold(
2161
- ... decoding_result, spike_data,
2162
- ... save_path='bump_2d.mp4', mode='fast'
2163
- ... )
2164
- >>> # For publication-ready 3D visualization, use mode='3d'
2165
- >>> ani = plot_2d_bump_on_manifold(
2166
- ... decoding_result, spike_data, mode='3d'
2167
- ... )
2168
- """
2169
- import matplotlib.animation as animation
2170
-
2171
- from ...visualization.core.jupyter_utils import (
2172
- display_animation_in_jupyter,
2173
- is_jupyter_environment,
2174
- )
2175
-
2176
- # Validate inputs
2177
- if mode == "3d":
2178
- # Fall back to 3D visualization
2179
- return plot_3d_bump_on_torus(
2180
- decoding_result=decoding_result,
2181
- spike_data=spike_data,
2182
- save_path=save_path,
2183
- fps=fps,
2184
- show=show,
2185
- window_size=window_size,
2186
- frame_step=frame_step,
2187
- numangsint=numangsint,
2188
- figsize=figsize,
2189
- show_progress=show_progress,
2190
- )
2191
-
2192
- if mode != "fast":
2193
- raise ProcessingError(f"Invalid mode '{mode}'. Must be 'fast' or '3d'.")
2194
-
2195
- # Load decoding results
2196
- if isinstance(decoding_result, str):
2197
- f = np.load(decoding_result, allow_pickle=True)
2198
- coords = f["coordsbox"]
2199
- times = f["times_box"]
2200
- f.close()
2201
- else:
2202
- coords = decoding_result["coordsbox"]
2203
- times = decoding_result["times_box"]
2204
-
2205
- # Process spike data for 2D projection
2206
- spk, *_ = embed_spike_trains(
2207
- spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
2208
- )
2209
-
2210
- # Process frames
2211
- n_frames = (np.max(times) - window_size) // frame_step
2212
- frame_activity_maps = []
2213
- prev_m = None
2214
-
2215
- for frame_idx in tqdm(range(n_frames), desc="Processing frames", disable=not show_progress):
2216
- start_idx = frame_idx * frame_step
2217
- end_idx = start_idx + window_size
2218
- if end_idx > np.max(times):
2219
- break
2220
-
2221
- mask = (times >= start_idx) & (times < end_idx)
2222
- coords_window = coords[mask]
2223
- if len(coords_window) == 0:
2224
- continue
2225
-
2226
- spk_window = spk[times[mask], :]
2227
- activity = np.sum(spk_window, axis=1)
2228
-
2229
- m, _, _, _ = binned_statistic_2d(
2230
- coords_window[:, 0],
2231
- coords_window[:, 1],
2232
- activity,
2233
- statistic="sum",
2234
- bins=np.linspace(0, 2 * np.pi, numangsint - 1),
2235
- )
2236
- m = np.nan_to_num(m)
2237
- m = _smooth_tuning_map(m, numangsint - 1, sig=4.0, bClose=True)
2238
- m = gaussian_filter(m, sigma=1.0)
2239
-
2240
- if prev_m is not None:
2241
- m = 0.7 * prev_m + 0.3 * m
2242
- prev_m = m
2243
-
2244
- frame_activity_maps.append(m)
2245
-
2246
- if not frame_activity_maps:
2247
- raise ProcessingError("No valid frames generated for animation")
2248
-
2249
- # Create 2D visualization with blitting
2250
- fig, ax = plt.subplots(figsize=figsize)
2251
- ax.set_xlabel("Manifold Dimension 1 (rad)", fontsize=12)
2252
- ax.set_ylabel("Manifold Dimension 2 (rad)", fontsize=12)
2253
- ax.set_title("CANN2D Bump Activity (2D Projection)", fontsize=14, fontweight="bold")
2254
-
2255
- # Pre-create artists for blitting
2256
- # Heatmap
2257
- im = ax.imshow(
2258
- frame_activity_maps[0].T, # Transpose for correct orientation
2259
- extent=[0, 2 * np.pi, 0, 2 * np.pi],
2260
- origin="lower",
2261
- cmap="viridis",
2262
- animated=True,
2263
- aspect="auto",
2264
- )
2265
- # Colorbar (static)
2266
- cbar = plt.colorbar(im, ax=ax)
2267
- cbar.set_label("Activity", fontsize=11)
2268
-
2269
- # Time text
2270
- time_text = ax.text(
2271
- 0.02,
2272
- 0.98,
2273
- "",
2274
- transform=ax.transAxes,
2275
- fontsize=11,
2276
- verticalalignment="top",
2277
- bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
2278
- animated=True,
2279
- )
2280
-
2281
- def init():
2282
- """Initialize animation"""
2283
- im.set_array(frame_activity_maps[0].T)
2284
- time_text.set_text("")
2285
- return im, time_text
2286
-
2287
- def update(frame_idx):
2288
- """Update function - only modify data using blitting"""
2289
- if frame_idx >= len(frame_activity_maps):
2290
- return im, time_text
2291
-
2292
- # Update heatmap data
2293
- im.set_array(frame_activity_maps[frame_idx].T)
2294
-
2295
- # Update time text
2296
- time_text.set_text(f"Frame: {frame_idx + 1}/{len(frame_activity_maps)}")
2297
-
2298
- return im, time_text
2299
-
2300
- # Check blitting support
2301
- use_blitting = True
2302
- try:
2303
- if not fig.canvas.supports_blit:
2304
- use_blitting = False
2305
- print("Warning: Backend does not support blitting. Using slower mode.")
2306
- except AttributeError:
2307
- use_blitting = False
2308
-
2309
- # Create animation with blitting enabled for 10-20x speedup
2310
- interval_ms = 1000 / fps
2311
- ani = animation.FuncAnimation(
2312
- fig,
2313
- update,
2314
- frames=len(frame_activity_maps),
2315
- init_func=init,
2316
- interval=interval_ms,
2317
- blit=use_blitting,
2318
- repeat=True,
2319
- )
2320
-
2321
- # Save animation if path provided
2322
- if save_path:
2323
- # Warn if both saving and showing (causes double rendering)
2324
- if show and len(frame_activity_maps) > 50:
2325
- from ...visualization.core import warn_double_rendering
2326
-
2327
- warn_double_rendering(len(frame_activity_maps), save_path, stacklevel=2)
2328
-
2329
- if show_progress:
2330
- pbar = tqdm(total=len(frame_activity_maps), desc=f"Saving animation to {save_path}")
2331
-
2332
- def progress_callback(current_frame, total_frames):
2333
- pbar.update(1)
2334
-
2335
- try:
2336
- if save_path.endswith(".mp4"):
2337
- from matplotlib.animation import FFMpegWriter
2338
-
2339
- writer = FFMpegWriter(
2340
- fps=fps, codec="libx264", extra_args=["-pix_fmt", "yuv420p"]
2341
- )
2342
- else:
2343
- from matplotlib.animation import PillowWriter
2344
-
2345
- writer = PillowWriter(fps=fps)
2346
-
2347
- ani.save(save_path, writer=writer, progress_callback=progress_callback)
2348
- pbar.close()
2349
- print(f"\nAnimation saved to: {save_path}")
2350
- except Exception as e:
2351
- pbar.close()
2352
- print(f"\nError saving animation: {e}")
2353
- raise
2354
- else:
2355
- try:
2356
- if save_path.endswith(".mp4"):
2357
- from matplotlib.animation import FFMpegWriter
2358
-
2359
- writer = FFMpegWriter(
2360
- fps=fps, codec="libx264", extra_args=["-pix_fmt", "yuv420p"]
2361
- )
2362
- else:
2363
- from matplotlib.animation import PillowWriter
2364
-
2365
- writer = PillowWriter(fps=fps)
2366
-
2367
- ani.save(save_path, writer=writer)
2368
- print(f"Animation saved to: {save_path}")
2369
- except Exception as e:
2370
- print(f"Error saving animation: {e}")
2371
- raise
2372
-
2373
- if show:
2374
- # Automatically detect Jupyter and display as HTML/JS
2375
- if is_jupyter_environment():
2376
- display_animation_in_jupyter(ani)
2377
- plt.close(fig) # Close after HTML conversion to prevent auto-display
2378
- else:
2379
- plt.show()
2380
- else:
2381
- plt.close(fig) # Close if not showing
2382
-
2383
- # Return None in Jupyter when showing to avoid double display
2384
- if show and is_jupyter_environment():
2385
- return None
2386
- return ani
2387
-
2388
-
2389
- def _get_coords(cocycle, threshold, num_sampled, dists, coeff):
2390
- """
2391
- Reconstruct circular coordinates from cocycle information.
2392
-
2393
- Parameters:
2394
- cocycle (ndarray): Persistent cocycle representative.
2395
- threshold (float): Maximum allowable edge distance.
2396
- num_sampled (int): Number of sampled points.
2397
- dists (ndarray): Pairwise distance matrix.
2398
- coeff (int): Finite field modulus for cohomology.
2399
-
2400
- Returns:
2401
- f (ndarray): Circular coordinate values (in [0,1]).
2402
- verts (ndarray): Indices of used vertices.
2403
- """
2404
- zint = np.where(coeff - cocycle[:, 2] < cocycle[:, 2])
2405
- cocycle[zint, 2] = cocycle[zint, 2] - coeff
2406
- d = np.zeros((num_sampled, num_sampled))
2407
- d[np.tril_indices(num_sampled)] = np.nan
2408
- d[cocycle[:, 1], cocycle[:, 0]] = cocycle[:, 2]
2409
- d[dists > threshold] = np.nan
2410
- d[dists == 0] = np.nan
2411
- edges = np.where(~np.isnan(d))
2412
- verts = np.array(np.unique(edges))
2413
- num_edges = np.shape(edges)[1]
2414
- num_verts = np.size(verts)
2415
- values = d[edges]
2416
- A = np.zeros((num_edges, num_verts), dtype=int)
2417
- v1 = np.zeros((num_edges, 2), dtype=int)
2418
- v2 = np.zeros((num_edges, 2), dtype=int)
2419
- for i in range(num_edges):
2420
- # Extract scalar indices from np.where results
2421
- idx1 = np.where(verts == edges[0][i])[0]
2422
- idx2 = np.where(verts == edges[1][i])[0]
2423
-
2424
- # Handle case where np.where returns multiple matches (shouldn't happen in valid data)
2425
- if len(idx1) > 0:
2426
- v1[i, :] = [i, idx1[0]]
2427
- else:
2428
- raise ValueError(f"No vertex found for edge {edges[0][i]}")
2429
-
2430
- if len(idx2) > 0:
2431
- v2[i, :] = [i, idx2[0]]
2432
- else:
2433
- raise ValueError(f"No vertex found for edge {edges[1][i]}")
2434
-
2435
- A[v1[:, 0], v1[:, 1]] = -1
2436
- A[v2[:, 0], v2[:, 1]] = 1
2437
-
2438
- L = np.ones((num_edges,))
2439
- Aw = A * np.sqrt(L[:, np.newaxis])
2440
- Bw = values * np.sqrt(L)
2441
- f = lsmr(Aw, Bw)[0] % 1
2442
- return f, verts
2443
-
2444
-
2445
- def _smooth_tuning_map(mtot, numangsint, sig, bClose=True):
2446
- """
2447
- Smooth activity map over circular topology (e.g., torus).
2448
-
2449
- Parameters:
2450
- mtot (ndarray): Raw activity map matrix.
2451
- numangsint (int): Grid resolution.
2452
- sig (float): Smoothing kernel standard deviation.
2453
- bClose (bool): Whether to assume circular boundary conditions.
2454
-
2455
- Returns:
2456
- mtot_out (ndarray): Smoothed map matrix.
2457
- """
2458
- numangsint_1 = numangsint - 1
2459
- mid = int((numangsint_1) / 2)
2460
- indstemp1 = np.zeros((numangsint_1, numangsint_1), dtype=int)
2461
- indstemp1[indstemp1 == 0] = np.arange((numangsint_1) ** 2)
2462
- mid = int((numangsint_1) / 2)
2463
- mtemp1_3 = mtot.copy()
2464
- for i in range(numangsint_1):
2465
- mtemp1_3[i, :] = np.roll(mtemp1_3[i, :], int(i / 2))
2466
- mtot_out = np.zeros_like(mtot)
2467
- mtemp1_4 = np.concatenate((mtemp1_3, mtemp1_3, mtemp1_3), 1)
2468
- mtemp1_5 = np.zeros_like(mtemp1_4)
2469
- mtemp1_5[:, :mid] = mtemp1_4[:, (numangsint_1) * 3 - mid :]
2470
- mtemp1_5[:, mid:] = mtemp1_4[:, : (numangsint_1) * 3 - mid]
2471
- if bClose:
2472
- mtemp1_6 = _smooth_image(np.concatenate((mtemp1_5, mtemp1_4, mtemp1_5)), sigma=sig)
2473
- else:
2474
- mtemp1_6 = gaussian_filter(np.concatenate((mtemp1_5, mtemp1_4, mtemp1_5)), sigma=sig)
2475
- for i in range(numangsint_1):
2476
- mtot_out[i, :] = mtemp1_6[
2477
- (numangsint_1) + i,
2478
- (numangsint_1) + (int(i / 2) + 1) : (numangsint_1) * 2 + (int(i / 2) + 1),
2479
- ]
2480
- return mtot_out
2481
-
2482
-
2483
- def _smooth_image(img, sigma):
2484
- """
2485
- Smooth image using multivariate Gaussian kernel, handling missing (NaN) values.
2486
-
2487
- Parameters:
2488
- img (ndarray): Input image matrix.
2489
- sigma (float): Standard deviation of smoothing kernel.
2490
-
2491
- Returns:
2492
- imgC (ndarray): Smoothed image with inpainting around NaNs.
2493
- """
2494
- filterSize = max(np.shape(img))
2495
- grid = np.arange(-filterSize + 1, filterSize, 1)
2496
- xx, yy = np.meshgrid(grid, grid)
2497
-
2498
- pos = np.dstack((xx, yy))
2499
-
2500
- var = multivariate_normal(mean=[0, 0], cov=[[sigma**2, 0], [0, sigma**2]])
2501
- k = var.pdf(pos)
2502
- k = k / np.sum(k)
2503
-
2504
- nans = np.isnan(img)
2505
- imgA = img.copy()
2506
- imgA[nans] = 0
2507
- imgA = signal.convolve2d(imgA, k, mode="valid")
2508
- imgD = img.copy()
2509
- imgD[nans] = 0
2510
- imgD[~nans] = 1
2511
- radius = 1
2512
- L = np.arange(-radius, radius + 1)
2513
- X, Y = np.meshgrid(L, L)
2514
- dk = np.array((X**2 + Y**2) <= radius**2, dtype=bool)
2515
- imgE = np.zeros((filterSize + 2, filterSize + 2))
2516
- imgE[1:-1, 1:-1] = imgD
2517
- imgE = binary_closing(imgE, iterations=1, structure=dk)
2518
- imgD = imgE[1:-1, 1:-1]
2519
-
2520
- imgB = np.divide(
2521
- signal.convolve2d(imgD, k, mode="valid"),
2522
- signal.convolve2d(np.ones(np.shape(imgD)), k, mode="valid"),
2523
- )
2524
- imgC = np.divide(imgA, imgB)
2525
- imgC[imgD == 0] = -np.inf
2526
- return imgC
2527
-
2528
-
2529
- if __name__ == "__main__":
2530
- from canns.data.loaders import load_grid_data
2531
-
2532
- data = load_grid_data()
2533
-
2534
- spikes, xx, yy, tt = embed_spike_trains(data)
2535
-
2536
- # import umap
2537
- #
2538
- # reducer = umap.UMAP(
2539
- # n_neighbors=15,
2540
- # min_dist=0.1,
2541
- # n_components=3,
2542
- # metric='euclidean',
2543
- # random_state=42
2544
- # )
2545
- #
2546
- # reduce_func = reducer.fit_transform
2547
- #
2548
- # plot_projection(reduce_func=reduce_func, embed_data=spikes, show=True)
2549
- results = tda_vis(embed_data=spikes, maxdim=1, do_shuffle=False, show=True)
2550
- decoding = decode_circular_coordinates(
2551
- persistence_result=results,
2552
- spike_data=data,
2553
- real_ground=True,
2554
- real_of=True,
2555
- )
2556
-
2557
- # Visualize CohoMap
2558
- plot_cohomap(
2559
- decoding_result=decoding,
2560
- position_data={"x": xx, "y": yy},
2561
- save_path="Results/cohomap.png",
2562
- show=True,
2563
- )
2564
-
2565
- # results = tda_vis(embed_data=spikes, maxdim=1, do_shuffle=True, num_shuffles=10, show=True)