canns 0.12.7__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. canns/analyzer/data/__init__.py +3 -11
  2. canns/analyzer/data/asa/__init__.py +74 -0
  3. canns/analyzer/data/asa/cohospace.py +905 -0
  4. canns/analyzer/data/asa/config.py +246 -0
  5. canns/analyzer/data/asa/decode.py +448 -0
  6. canns/analyzer/data/asa/embedding.py +269 -0
  7. canns/analyzer/data/asa/filters.py +208 -0
  8. canns/analyzer/data/asa/fr.py +439 -0
  9. canns/analyzer/data/asa/path.py +389 -0
  10. canns/analyzer/data/asa/plotting.py +1276 -0
  11. canns/analyzer/data/asa/tda.py +901 -0
  12. canns/analyzer/data/legacy/__init__.py +6 -0
  13. canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
  14. canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
  15. canns/analyzer/visualization/core/backend.py +1 -1
  16. canns/analyzer/visualization/core/config.py +77 -0
  17. canns/analyzer/visualization/core/rendering.py +10 -6
  18. canns/analyzer/visualization/energy_plots.py +22 -8
  19. canns/analyzer/visualization/spatial_plots.py +31 -11
  20. canns/analyzer/visualization/theta_sweep_plots.py +15 -6
  21. canns/pipeline/__init__.py +4 -8
  22. canns/pipeline/asa/__init__.py +21 -0
  23. canns/pipeline/asa/__main__.py +11 -0
  24. canns/pipeline/asa/app.py +1000 -0
  25. canns/pipeline/asa/runner.py +1095 -0
  26. canns/pipeline/asa/screens.py +215 -0
  27. canns/pipeline/asa/state.py +248 -0
  28. canns/pipeline/asa/styles.tcss +221 -0
  29. canns/pipeline/asa/widgets.py +233 -0
  30. canns/pipeline/gallery/__init__.py +7 -0
  31. canns/task/open_loop_navigation.py +3 -1
  32. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
  33. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/RECORD +36 -17
  34. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
  35. canns/pipeline/theta_sweep.py +0 -573
  36. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
  37. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,901 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import multiprocessing as mp
5
+ from typing import Any
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from canns_lib.ripser import ripser
10
+ from matplotlib import gridspec
11
+ from scipy.sparse import coo_matrix
12
+ from scipy.spatial.distance import pdist, squareform
13
+ from sklearn import preprocessing
14
+
15
+ from .config import Constants, ProcessingError, TDAConfig
16
+
17
+ try:
18
+ from numba import njit
19
+
20
+ HAS_NUMBA = True
21
+ except ImportError:
22
+ HAS_NUMBA = False
23
+
24
+ def njit(*args, **kwargs):
25
+ def decorator(func):
26
+ return func
27
+
28
+ return decorator
29
+
30
+
31
+ def tda_vis(embed_data: np.ndarray, config: TDAConfig | None = None, **kwargs) -> dict[str, Any]:
32
+ """
33
+ Topological Data Analysis visualization with optional shuffle testing.
34
+
35
+ Parameters
36
+ ----------
37
+ embed_data : np.ndarray
38
+ Embedded spike train data of shape (T, N).
39
+ config : TDAConfig, optional
40
+ Configuration object with all TDA parameters. If None, legacy kwargs are used.
41
+ **kwargs : Any
42
+ Legacy keyword parameters (``dim``, ``num_times``, ``active_times``, ``k``,
43
+ ``n_points``, ``metric``, ``nbs``, ``maxdim``, ``coeff``, ``show``,
44
+ ``do_shuffle``, ``num_shuffles``, ``progress_bar``).
45
+
46
+ Returns
47
+ -------
48
+ dict
49
+ Dictionary containing:
50
+ - ``persistence``: persistence diagrams from real data.
51
+ - ``indstemp``: indices of sampled points.
52
+ - ``movetimes``: selected time points.
53
+ - ``n_points``: number of sampled points.
54
+ - ``shuffle_max``: shuffle analysis results (if ``do_shuffle=True``), else ``None``.
55
+
56
+ Examples
57
+ --------
58
+ >>> from canns.analyzer.data import TDAConfig, tda_vis
59
+ >>> cfg = TDAConfig(maxdim=1, do_shuffle=False, show=False)
60
+ >>> result = tda_vis(embed_data, config=cfg) # doctest: +SKIP
61
+ >>> sorted(result.keys())
62
+ ['indstemp', 'movetimes', 'n_points', 'persistence', 'shuffle_max']
63
+ """
64
+ # Handle backward compatibility and configuration
65
+ if config is None:
66
+ config = TDAConfig(
67
+ dim=kwargs.get("dim", 6),
68
+ num_times=kwargs.get("num_times", 5),
69
+ active_times=kwargs.get("active_times", 15000),
70
+ k=kwargs.get("k", 1000),
71
+ n_points=kwargs.get("n_points", 1200),
72
+ metric=kwargs.get("metric", "cosine"),
73
+ nbs=kwargs.get("nbs", 800),
74
+ maxdim=kwargs.get("maxdim", 1),
75
+ coeff=kwargs.get("coeff", 47),
76
+ show=kwargs.get("show", True),
77
+ do_shuffle=kwargs.get("do_shuffle", False),
78
+ num_shuffles=kwargs.get("num_shuffles", 1000),
79
+ progress_bar=kwargs.get("progress_bar", True),
80
+ )
81
+
82
+ try:
83
+ # Compute persistent homology for real data
84
+ print("Computing persistent homology for real data...")
85
+ real_persistence = _compute_real_persistence(embed_data, config)
86
+
87
+ # Perform shuffle analysis if requested
88
+ shuffle_max = None
89
+ if config.do_shuffle:
90
+ shuffle_max = _perform_shuffle_analysis(embed_data, config)
91
+
92
+ # Visualization
93
+ _handle_visualization(real_persistence["persistence"], shuffle_max, config)
94
+
95
+ # Return results as dictionary
96
+ return {
97
+ "persistence": real_persistence["persistence"],
98
+ "indstemp": real_persistence["indstemp"],
99
+ "movetimes": real_persistence["movetimes"],
100
+ "n_points": real_persistence["n_points"],
101
+ "shuffle_max": shuffle_max,
102
+ }
103
+
104
+ except Exception as e:
105
+ raise ProcessingError(f"TDA analysis failed: {e}") from e
106
+
107
+
108
+ def _compute_real_persistence(embed_data: np.ndarray, config: TDAConfig) -> dict[str, Any]:
109
+ """Compute persistent homology for real data with progress tracking."""
110
+
111
+ logging.info("Processing real data - Starting TDA analysis (5 steps)")
112
+
113
+ # Step 1: Time point downsampling
114
+ logging.info("Step 1/5: Time point downsampling")
115
+ times_cube = _downsample_timepoints(embed_data, config.num_times)
116
+
117
+ # Step 2: Select most active time points
118
+ logging.info("Step 2/5: Selecting active time points")
119
+ movetimes = _select_active_timepoints(embed_data, times_cube, config.active_times)
120
+
121
+ # Step 3: PCA dimensionality reduction
122
+ logging.info("Step 3/5: PCA dimensionality reduction")
123
+ dimred = _apply_pca_reduction(embed_data, movetimes, config.dim)
124
+
125
+ # Step 4: Point cloud sampling (denoising)
126
+ logging.info("Step 4/5: Point cloud denoising")
127
+ indstemp = _apply_denoising(dimred, config)
128
+
129
+ # Step 5: Compute persistent homology
130
+ logging.info("Step 5/5: Computing persistent homology")
131
+ persistence = _compute_persistence_homology(dimred, indstemp, config)
132
+
133
+ logging.info("TDA analysis completed successfully")
134
+
135
+ # Return all necessary data in dictionary format
136
+ return {
137
+ "persistence": persistence,
138
+ "indstemp": indstemp,
139
+ "movetimes": movetimes,
140
+ "n_points": config.n_points,
141
+ }
142
+
143
+
144
+ def _downsample_timepoints(embed_data: np.ndarray, num_times: int) -> np.ndarray:
145
+ """Downsample timepoints for computational efficiency."""
146
+ return np.arange(0, embed_data.shape[0], num_times)
147
+
148
+
149
+ def _select_active_timepoints(
150
+ embed_data: np.ndarray, times_cube: np.ndarray, active_times: int
151
+ ) -> np.ndarray:
152
+ """Select most active timepoints based on total activity."""
153
+ activity_scores = np.sum(embed_data[times_cube, :], 1)
154
+ # Match external TDAvis: sort indices first, then map to times_cube
155
+ movetimes = np.sort(np.argsort(activity_scores)[-active_times:])
156
+ return times_cube[movetimes]
157
+
158
+
159
+ def _apply_pca_reduction(embed_data: np.ndarray, movetimes: np.ndarray, dim: int) -> np.ndarray:
160
+ """Apply PCA dimensionality reduction."""
161
+ scaled_data = preprocessing.scale(embed_data[movetimes, :])
162
+ dimred, *_ = _pca(scaled_data, dim=dim)
163
+ return dimred
164
+
165
+
166
+ def _apply_denoising(dimred: np.ndarray, config: TDAConfig) -> np.ndarray:
167
+ """Apply point cloud denoising."""
168
+ indstemp, *_ = _sample_denoising(
169
+ dimred,
170
+ k=config.k,
171
+ num_sample=config.n_points,
172
+ omega=1, # Match external TDAvis: uses 1, not default 0.2
173
+ metric=config.metric,
174
+ )
175
+ return indstemp
176
+
177
+
178
+ def _compute_persistence_homology(
179
+ dimred: np.ndarray, indstemp: np.ndarray, config: TDAConfig
180
+ ) -> dict[str, Any]:
181
+ """Compute persistent homology using ripser."""
182
+ d = _second_build(dimred, indstemp, metric=config.metric, nbs=config.nbs)
183
+ np.fill_diagonal(d, 0)
184
+
185
+ return ripser(
186
+ d,
187
+ maxdim=config.maxdim,
188
+ coeff=config.coeff,
189
+ do_cocycles=True,
190
+ distance_matrix=True,
191
+ progress_bar=config.progress_bar,
192
+ )
193
+
194
+
195
+ def _perform_shuffle_analysis(embed_data: np.ndarray, config: TDAConfig) -> dict[int, Any]:
196
+ """Perform shuffle analysis with progress tracking."""
197
+ print(f"\nStarting shuffle analysis with {config.num_shuffles} iterations...")
198
+
199
+ # Create parameters dict for shuffle analysis
200
+ shuffle_params = {
201
+ "dim": config.dim,
202
+ "num_times": config.num_times,
203
+ "active_times": config.active_times,
204
+ "k": config.k,
205
+ "n_points": config.n_points,
206
+ "metric": config.metric,
207
+ "nbs": config.nbs,
208
+ "maxdim": config.maxdim,
209
+ "coeff": config.coeff,
210
+ }
211
+
212
+ shuffle_max = _run_shuffle_analysis(
213
+ embed_data,
214
+ num_shuffles=config.num_shuffles,
215
+ num_cores=Constants.MULTIPROCESSING_CORES,
216
+ progress_bar=config.progress_bar,
217
+ **shuffle_params,
218
+ )
219
+
220
+ # Print shuffle analysis summary
221
+ _print_shuffle_summary(shuffle_max)
222
+
223
+ return shuffle_max
224
+
225
+
226
+ def _print_shuffle_summary(shuffle_max: dict[int, Any]) -> None:
227
+ """Print summary of shuffle analysis results."""
228
+ print("\nSummary of shuffle-based analysis:")
229
+ for dim_idx in [0, 1, 2]:
230
+ if shuffle_max and dim_idx in shuffle_max and shuffle_max[dim_idx]:
231
+ values = shuffle_max[dim_idx]
232
+ print(
233
+ f"H{dim_idx}: {len(values)} valid iterations | "
234
+ f"Mean maximum persistence: {np.mean(values):.4f} | "
235
+ f"99.9th percentile: {np.percentile(values, 99.9):.4f}"
236
+ )
237
+
238
+
239
+ def _handle_visualization(
240
+ real_persistence: dict[str, Any], shuffle_max: dict[int, Any] | None, config: TDAConfig
241
+ ) -> None:
242
+ """Handle visualization based on configuration."""
243
+ if config.show:
244
+ if config.do_shuffle and shuffle_max is not None:
245
+ _plot_barcode_with_shuffle(real_persistence, shuffle_max)
246
+ else:
247
+ _plot_barcode(real_persistence)
248
+ plt.show()
249
+ else:
250
+ plt.close()
251
+
252
+
253
+ def _compute_persistence(
254
+ sspikes,
255
+ dim=6,
256
+ num_times=5,
257
+ active_times=15000,
258
+ k=1000,
259
+ n_points=1200,
260
+ metric="cosine",
261
+ nbs=800,
262
+ maxdim=1,
263
+ coeff=47,
264
+ progress_bar=True,
265
+ ):
266
+ # Time point downsampling
267
+ times_cube = np.arange(0, sspikes.shape[0], num_times)
268
+
269
+ # Select most active time points
270
+ movetimes = np.sort(np.argsort(np.sum(sspikes[times_cube, :], 1))[-active_times:])
271
+ movetimes = times_cube[movetimes]
272
+
273
+ # PCA dimensionality reduction
274
+ scaled_data = preprocessing.scale(sspikes[movetimes, :])
275
+ dimred, *_ = _pca(scaled_data, dim=dim)
276
+
277
+ # Point cloud sampling (denoising)
278
+ indstemp, *_ = _sample_denoising(dimred, k, n_points, 1, metric)
279
+
280
+ # Build distance matrix
281
+ d = _second_build(dimred, indstemp, metric=metric, nbs=nbs)
282
+ np.fill_diagonal(d, 0)
283
+
284
+ # Compute persistent homology
285
+ persistence = ripser(
286
+ d,
287
+ maxdim=maxdim,
288
+ coeff=coeff,
289
+ do_cocycles=True,
290
+ distance_matrix=True,
291
+ progress_bar=progress_bar,
292
+ )
293
+
294
+ return persistence
295
+
296
+
297
+ def _pca(data, dim=2):
298
+ """
299
+ Perform PCA (Principal Component Analysis) for dimensionality reduction.
300
+
301
+ Parameters:
302
+ data (ndarray): Input data matrix of shape (N_samples, N_features).
303
+ dim (int): Target dimension for PCA projection.
304
+
305
+ Returns:
306
+ components (ndarray): Projected data of shape (N_samples, dim).
307
+ var_exp (list): Variance explained by each principal component.
308
+ evals (ndarray): Eigenvalues corresponding to the selected components.
309
+ """
310
+ if dim < 2:
311
+ return data, [0], np.array([])
312
+ _ = data.shape
313
+ # mean center the data
314
+ # data -= data.mean(axis=0)
315
+ # calculate the covariance matrix
316
+ R = np.cov(data, rowvar=False)
317
+ # calculate eigenvectors & eigenvalues of the covariance matrix
318
+ # use 'eigh' rather than 'eig' since R is symmetric,
319
+ # the performance gain is substantial
320
+ evals, evecs = np.linalg.eig(R)
321
+ # sort eigenvalue in decreasing order
322
+ idx = np.argsort(evals)[::-1]
323
+ evecs = evecs[:, idx]
324
+ # sort eigenvectors according to same index
325
+ evals = evals[idx]
326
+ # select the first n eigenvectors (n is desired dimension
327
+ # of rescaled data array, or dims_rescaled_data)
328
+ evecs = evecs[:, :dim]
329
+ # carry out the transformation on the data using eigenvectors
330
+ # and return the re-scaled data, eigenvalues, and eigenvectors
331
+
332
+ tot = np.sum(evals)
333
+ var_exp = [(i / tot) * 100 for i in sorted(evals[:dim], reverse=True)]
334
+ components = np.dot(evecs.T, data.T).T
335
+ return components, var_exp, evals[:dim]
336
+
337
+
338
+ def _sample_denoising(data, k=10, num_sample=500, omega=0.2, metric="euclidean"):
339
+ """
340
+ Perform denoising and greedy sampling based on mutual k-NN graph.
341
+
342
+ Parameters:
343
+ data (ndarray): High-dimensional point cloud data.
344
+ k (int): Number of neighbors for local density estimation.
345
+ num_sample (int): Number of samples to retain.
346
+ omega (float): Suppression factor during greedy sampling.
347
+ metric (str): Distance metric used for kNN ('euclidean', 'cosine', etc).
348
+
349
+ Returns:
350
+ inds (ndarray): Indices of sampled points.
351
+ d (ndarray): Pairwise similarity matrix of sampled points.
352
+ Fs (ndarray): Sampling scores at each step.
353
+ """
354
+ if HAS_NUMBA:
355
+ return _sample_denoising_numba(data, k, num_sample, omega, metric)
356
+ else:
357
+ return _sample_denoising_numpy(data, k, num_sample, omega, metric)
358
+
359
+
360
+ def _sample_denoising_numpy(data, k=10, num_sample=500, omega=0.2, metric="euclidean"):
361
+ """Original numpy implementation for fallback."""
362
+ n = data.shape[0]
363
+ X = squareform(pdist(data, metric))
364
+ knn_indices = np.argsort(X)[:, :k]
365
+ knn_dists = X[np.arange(X.shape[0])[:, None], knn_indices].copy()
366
+
367
+ sigmas, rhos = _smooth_knn_dist(knn_dists, k, local_connectivity=0)
368
+ rows, cols, vals = _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos)
369
+ result = coo_matrix((vals, (rows, cols)), shape=(n, n))
370
+ result.eliminate_zeros()
371
+ transpose = result.transpose()
372
+ prod_matrix = result.multiply(transpose)
373
+ result = result + transpose - prod_matrix
374
+ result.eliminate_zeros()
375
+ X = result.toarray()
376
+ F = np.sum(X, 1)
377
+ Fs = np.zeros(num_sample)
378
+ Fs[0] = np.max(F)
379
+ i = np.argmax(F)
380
+ inds_all = np.arange(n)
381
+ inds_left = inds_all > -1
382
+ inds_left[i] = False
383
+ inds = np.zeros(num_sample, dtype=int)
384
+ inds[0] = i
385
+ for j in np.arange(1, num_sample):
386
+ F -= omega * X[i, :]
387
+ Fmax = np.argmax(F[inds_left])
388
+ # Exactly match external TDAvis implementation (including the indexing logic)
389
+ Fs[j] = F[Fmax]
390
+ i = inds_all[inds_left][Fmax]
391
+
392
+ inds_left[i] = False
393
+ inds[j] = i
394
+ d = np.zeros((num_sample, num_sample))
395
+
396
+ for j, i in enumerate(inds):
397
+ d[j, :] = X[i, inds]
398
+ return inds, d, Fs
399
+
400
+
401
+ def _sample_denoising_numba(data, k=10, num_sample=500, omega=0.2, metric="euclidean"):
402
+ """Optimized numba implementation."""
403
+ n = data.shape[0]
404
+ X = squareform(pdist(data, metric))
405
+ knn_indices = np.argsort(X)[:, :k]
406
+ knn_dists = X[np.arange(X.shape[0])[:, None], knn_indices].copy()
407
+
408
+ sigmas, rhos = _smooth_knn_dist(knn_dists, k, local_connectivity=0)
409
+ rows, cols, vals = _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos)
410
+
411
+ # Build symmetric adjacency matrix using optimized function
412
+ X_adj = _build_adjacency_matrix_numba(rows, cols, vals, n)
413
+
414
+ # Greedy sampling using optimized function
415
+ inds, Fs = _greedy_sampling_numba(X_adj, num_sample, omega)
416
+
417
+ # Build final distance matrix
418
+ d = _build_distance_matrix_numba(X_adj, inds)
419
+
420
+ return inds, d, Fs
421
+
422
+
423
+ @njit(fastmath=True)
424
+ def _build_adjacency_matrix_numba(rows, cols, vals, n):
425
+ """Build symmetric adjacency matrix efficiently with numba.
426
+
427
+ This matches the scipy sparse matrix operations:
428
+ result = result + transpose - prod_matrix
429
+ where prod_matrix = result.multiply(transpose)
430
+ """
431
+ # Initialize matrices
432
+ X = np.zeros((n, n), dtype=np.float64)
433
+ X_T = np.zeros((n, n), dtype=np.float64)
434
+
435
+ # Build adjacency matrix and its transpose simultaneously
436
+ for i in range(len(rows)):
437
+ X[rows[i], cols[i]] = vals[i]
438
+ X_T[cols[i], rows[i]] = vals[i] # Transpose
439
+
440
+ # Apply the symmetrization formula: A = A + A^T - A ⊙ A^T (vectorized)
441
+ # This matches scipy's: result + transpose - prod_matrix
442
+ X[:, :] = X + X_T - X * X_T
443
+
444
+ return X
445
+
446
+
447
+ @njit(fastmath=True)
448
+ def _greedy_sampling_numba(X, num_sample, omega):
449
+ """Optimized greedy sampling with numba."""
450
+ n = X.shape[0]
451
+ F = np.sum(X, axis=1)
452
+ Fs = np.zeros(num_sample)
453
+ inds = np.zeros(num_sample, dtype=np.int64)
454
+ inds_left = np.ones(n, dtype=np.bool_)
455
+
456
+ # Initialize with maximum F
457
+ i = np.argmax(F)
458
+ Fs[0] = F[i]
459
+ inds[0] = i
460
+ inds_left[i] = False
461
+
462
+ # Greedy sampling loop
463
+ for j in range(1, num_sample):
464
+ # Update F values
465
+ for k in range(n):
466
+ F[k] -= omega * X[i, k]
467
+
468
+ # Find maximum among remaining points (matching numpy logic exactly)
469
+ max_val = -np.inf
470
+ max_idx = -1
471
+ for k in range(n):
472
+ if inds_left[k] and F[k] > max_val:
473
+ max_val = F[k]
474
+ max_idx = k
475
+
476
+ # Record the F value using the selected index (matching external TDAvis)
477
+ i = max_idx
478
+ Fs[j] = F[i]
479
+ inds[j] = i
480
+ inds_left[i] = False
481
+
482
+ return inds, Fs
483
+
484
+
485
+ @njit(fastmath=True)
486
+ def _build_distance_matrix_numba(X, inds):
487
+ """Build final distance matrix efficiently with numba."""
488
+ num_sample = len(inds)
489
+ d = np.zeros((num_sample, num_sample))
490
+
491
+ for j in range(num_sample):
492
+ for k in range(num_sample):
493
+ d[j, k] = X[inds[j], inds[k]]
494
+
495
+ return d
496
+
497
+
498
+ @njit(fastmath=True)
499
+ def _smooth_knn_dist(distances, k, n_iter=64, local_connectivity=0.0, bandwidth=1.0):
500
+ """
501
+ Compute smoothed local distances for kNN graph with entropy balancing.
502
+
503
+ Parameters:
504
+ distances (ndarray): kNN distance matrix.
505
+ k (int): Number of neighbors.
506
+ n_iter (int): Number of binary search iterations.
507
+ local_connectivity (float): Minimum local connectivity.
508
+ bandwidth (float): Bandwidth parameter.
509
+
510
+ Returns:
511
+ sigmas (ndarray): Smoothed sigma values for each point.
512
+ rhos (ndarray): Minimum distances (connectivity cutoff) for each point.
513
+ """
514
+ target = np.log2(k) * bandwidth
515
+ # target = np.log(k) * bandwidth
516
+ # target = k
517
+
518
+ rho = np.zeros(distances.shape[0])
519
+ result = np.zeros(distances.shape[0])
520
+
521
+ mean_distances = np.mean(distances)
522
+
523
+ for i in range(distances.shape[0]):
524
+ lo = 0.0
525
+ hi = np.inf
526
+ mid = 1.0
527
+
528
+ # Vectorized computation of non-zero distances
529
+ ith_distances = distances[i]
530
+ non_zero_dists = ith_distances[ith_distances > 0.0]
531
+ if non_zero_dists.shape[0] >= local_connectivity:
532
+ index = int(np.floor(local_connectivity))
533
+ interpolation = local_connectivity - index
534
+ if index > 0:
535
+ rho[i] = non_zero_dists[index - 1]
536
+ if interpolation > 1e-5:
537
+ rho[i] += interpolation * (non_zero_dists[index] - non_zero_dists[index - 1])
538
+ else:
539
+ rho[i] = interpolation * non_zero_dists[0]
540
+ elif non_zero_dists.shape[0] > 0:
541
+ rho[i] = np.max(non_zero_dists)
542
+
543
+ # Vectorized binary search loop - compute all at once instead of loop
544
+ for _ in range(n_iter):
545
+ # Vectorized computation: compute all distances at once
546
+ d_array = distances[i, 1:] - rho[i]
547
+ # Vectorized conditional: use np.where for conditional computation
548
+ psum = np.sum(np.where(d_array > 0, np.exp(-(d_array / mid)), 1.0))
549
+
550
+ if np.fabs(psum - target) < 1e-5:
551
+ break
552
+
553
+ if psum > target:
554
+ hi = mid
555
+ mid = (lo + hi) / 2.0
556
+ else:
557
+ lo = mid
558
+ if hi == np.inf:
559
+ mid *= 2
560
+ else:
561
+ mid = (lo + hi) / 2.0
562
+ result[i] = mid
563
+ # Optimized mean computation - reuse ith_distances
564
+ if rho[i] > 0.0:
565
+ mean_ith_distances = np.mean(ith_distances)
566
+ if result[i] < 1e-3 * mean_ith_distances:
567
+ result[i] = 1e-3 * mean_ith_distances
568
+ else:
569
+ if result[i] < 1e-3 * mean_distances:
570
+ result[i] = 1e-3 * mean_distances
571
+
572
+ return result, rho
573
+
574
+
575
+ @njit(parallel=True, fastmath=True)
576
+ def _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos):
577
+ """
578
+ Compute membership strength matrix from smoothed kNN graph.
579
+
580
+ Parameters:
581
+ knn_indices (ndarray): Indices of k-nearest neighbors.
582
+ knn_dists (ndarray): Corresponding distances.
583
+ sigmas (ndarray): Local bandwidths.
584
+ rhos (ndarray): Minimum distance thresholds.
585
+
586
+ Returns:
587
+ rows (ndarray): Row indices for sparse matrix.
588
+ cols (ndarray): Column indices for sparse matrix.
589
+ vals (ndarray): Weight values for sparse matrix.
590
+ """
591
+ n_samples = knn_indices.shape[0]
592
+ n_neighbors = knn_indices.shape[1]
593
+ rows = np.zeros((n_samples * n_neighbors), dtype=np.int64)
594
+ cols = np.zeros((n_samples * n_neighbors), dtype=np.int64)
595
+ vals = np.zeros((n_samples * n_neighbors), dtype=np.float64)
596
+ for i in range(n_samples):
597
+ for j in range(n_neighbors):
598
+ if knn_indices[i, j] == -1:
599
+ continue # We didn't get the full knn for i
600
+ if knn_indices[i, j] == i:
601
+ val = 0.0
602
+ elif knn_dists[i, j] - rhos[i] <= 0.0:
603
+ val = 1.0
604
+ else:
605
+ val = np.exp(-((knn_dists[i, j] - rhos[i]) / (sigmas[i])))
606
+ # val = ((knn_dists[i, j] - rhos[i]) / (sigmas[i]))
607
+
608
+ rows[i * n_neighbors + j] = i
609
+ cols[i * n_neighbors + j] = knn_indices[i, j]
610
+ vals[i * n_neighbors + j] = val
611
+
612
+ return rows, cols, vals
613
+
614
+
615
+ def _second_build(data, indstemp, nbs=800, metric="cosine"):
616
+ """
617
+ Reconstruct distance matrix after denoising for persistent homology.
618
+
619
+ Parameters:
620
+ data (ndarray): PCA-reduced data matrix.
621
+ indstemp (ndarray): Indices of sampled points.
622
+ nbs (int): Number of neighbors in reconstructed graph.
623
+ metric (str): Distance metric ('cosine', 'euclidean', etc).
624
+
625
+ Returns:
626
+ d (ndarray): Symmetric distance matrix used for persistent homology.
627
+ """
628
+ # Filter the data using the sampled point indices
629
+ data = data[indstemp, :]
630
+
631
+ # Compute the pairwise distance matrix
632
+ X = squareform(pdist(data, metric))
633
+ knn_indices = np.argsort(X)[:, :nbs]
634
+ knn_dists = X[np.arange(X.shape[0])[:, None], knn_indices].copy()
635
+
636
+ # Compute smoothed kernel widths
637
+ sigmas, rhos = _smooth_knn_dist(knn_dists, nbs, local_connectivity=0)
638
+ rows, cols, vals = _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos)
639
+
640
+ # Construct a sparse graph
641
+ result = coo_matrix((vals, (rows, cols)), shape=(X.shape[0], X.shape[0]))
642
+ result.eliminate_zeros()
643
+ transpose = result.transpose()
644
+ prod_matrix = result.multiply(transpose)
645
+ result = result + transpose - prod_matrix
646
+ result.eliminate_zeros()
647
+
648
+ # Build the final distance matrix
649
+ d = result.toarray()
650
+ # Match external TDAvis: direct negative log without epsilon handling
651
+ # Temporarily suppress divide by zero warning to match external behavior
652
+ with np.errstate(divide="ignore", invalid="ignore"):
653
+ d = -np.log(d)
654
+ np.fill_diagonal(d, 0)
655
+
656
+ return d
657
+
658
+
659
+ def _fast_pca_transform(data, components):
660
+ """Fast PCA transformation using numba."""
661
+ return np.dot(data, components.T)
662
+
663
+
664
+ def _run_shuffle_analysis(sspikes, num_shuffles=1000, num_cores=4, progress_bar=True, **kwargs):
665
+ """Perform shuffle analysis with optimized computation."""
666
+ return _run_shuffle_analysis_multiprocessing(
667
+ sspikes, num_shuffles, num_cores, progress_bar, **kwargs
668
+ )
669
+
670
+
671
+ def _run_shuffle_analysis_multiprocessing(
672
+ sspikes, num_shuffles=1000, num_cores=4, progress_bar=True, **kwargs
673
+ ):
674
+ """Original multiprocessing implementation for fallback."""
675
+ # Use numpy arrays with NaN for failed results (more efficient than None filtering)
676
+ max_lifetimes = {
677
+ 0: np.full(num_shuffles, np.nan),
678
+ 1: np.full(num_shuffles, np.nan),
679
+ 2: np.full(num_shuffles, np.nan),
680
+ }
681
+
682
+ # Prepare task list
683
+ tasks = [(i, sspikes, kwargs) for i in range(num_shuffles)]
684
+ logging.info(
685
+ f"Starting shuffle analysis with {num_shuffles} iterations using {num_cores} cores..."
686
+ )
687
+
688
+ # Use multiprocessing pool for parallel processing
689
+ with mp.Pool(processes=num_cores) as pool:
690
+ results = list(pool.imap(_process_single_shuffle, tasks))
691
+ logging.info("Shuffle analysis completed")
692
+
693
+ # Collect results - use indexing instead of append for better performance
694
+ for idx, res in enumerate(results):
695
+ for dim, lifetime in res.items():
696
+ max_lifetimes[dim][idx] = lifetime
697
+
698
+ # Filter out NaN values (failed results) - convert to list for consistency
699
+ for dim in max_lifetimes:
700
+ max_lifetimes[dim] = max_lifetimes[dim][~np.isnan(max_lifetimes[dim])].tolist()
701
+
702
+ return max_lifetimes
703
+
704
+
705
+ def _process_single_shuffle(args):
706
+ """Process a single shuffle task."""
707
+ i, sspikes, kwargs = args
708
+ try:
709
+ shuffled_data = _shuffle_spike_trains(sspikes)
710
+ persistence = _compute_persistence(shuffled_data, **kwargs)
711
+
712
+ dim_max_lifetimes = {}
713
+ for dim in [0, 1, 2]:
714
+ if dim < len(persistence["dgms"]):
715
+ # Filter out infinite values
716
+ valid_bars = [bar for bar in persistence["dgms"][dim] if not np.isinf(bar[1])]
717
+ if valid_bars:
718
+ lifetimes = [bar[1] - bar[0] for bar in valid_bars]
719
+ if lifetimes:
720
+ dim_max_lifetimes[dim] = max(lifetimes)
721
+ return dim_max_lifetimes
722
+ except Exception as e:
723
+ print(f"Shuffle {i} failed: {str(e)}")
724
+ return {}
725
+
726
+
727
+ def _shuffle_spike_trains(sspikes):
728
+ """Perform random circular shift on spike trains."""
729
+ shuffled = sspikes.copy()
730
+ num_neurons = shuffled.shape[1]
731
+
732
+ # Independent shift for each neuron
733
+ for n in range(num_neurons):
734
+ shift = np.random.randint(0, int(shuffled.shape[0] * 0.1))
735
+ shuffled[:, n] = np.roll(shuffled[:, n], shift)
736
+
737
+ return shuffled
738
+
739
+
740
+ def _plot_barcode(persistence):
741
+ """
742
+ Plot barcode diagram from persistent homology result.
743
+
744
+ Parameters:
745
+ persistence (dict): Persistent homology result with 'dgms' key.
746
+ """
747
+ cs = np.repeat([[0, 0.55, 0.2]], 3).reshape(3, 3).T # RGB color for each dimension
748
+ alpha = 1
749
+ inf_delta = 0.1
750
+ colormap = cs
751
+ dgms = persistence["dgms"]
752
+ maxdim = len(dgms) - 1
753
+ dims = np.arange(maxdim + 1)
754
+ labels = ["$H_0$", "$H_1$", "$H_2$"]
755
+
756
+ # Determine axis range
757
+ min_birth, max_death = 0, 0
758
+ for dim in dims:
759
+ persistence_dim = dgms[dim][~np.isinf(dgms[dim][:, 1]), :]
760
+ if persistence_dim.size > 0:
761
+ min_birth = min(min_birth, np.min(persistence_dim))
762
+ max_death = max(max_death, np.max(persistence_dim))
763
+
764
+ delta = (max_death - min_birth) * inf_delta
765
+ infinity = max_death + delta
766
+ axis_start = min_birth - delta
767
+
768
+ # Create plot
769
+ fig = plt.figure(figsize=(10, 6))
770
+ gs = gridspec.GridSpec(len(dims), 1)
771
+
772
+ for dim in dims:
773
+ axes = plt.subplot(gs[dim])
774
+ axes.axis("on")
775
+ axes.set_yticks([])
776
+ axes.set_ylabel(labels[dim], rotation=0, labelpad=20, fontsize=12)
777
+
778
+ d = np.copy(dgms[dim])
779
+ d[np.isinf(d[:, 1]), 1] = infinity
780
+ dlife = d[:, 1] - d[:, 0]
781
+
782
+ # Select top 30 bars by lifetime
783
+ dinds = np.argsort(dlife)[-30:]
784
+ if dim > 0:
785
+ dinds = dinds[np.flip(np.argsort(d[dinds, 0]))]
786
+
787
+ axes.barh(
788
+ 0.5 + np.arange(len(dinds)),
789
+ dlife[dinds],
790
+ height=0.8,
791
+ left=d[dinds, 0],
792
+ alpha=alpha,
793
+ color=colormap[dim],
794
+ linewidth=0,
795
+ )
796
+
797
+ axes.plot([0, 0], [0, len(dinds)], c="k", linestyle="-", lw=1)
798
+ axes.plot([0, len(dinds)], [0, 0], c="k", linestyle="-", lw=1)
799
+ axes.set_xlim([axis_start, infinity])
800
+
801
+ plt.tight_layout()
802
+ return fig
803
+
804
+
805
+ def _plot_barcode_with_shuffle(persistence, shuffle_max):
806
+ """
807
+ Plot barcode with shuffle region markers.
808
+ """
809
+ # Handle case where shuffle_max is None
810
+ if shuffle_max is None:
811
+ shuffle_max = {}
812
+
813
+ cs = np.repeat([[0, 0.55, 0.2]], 3).reshape(3, 3).T
814
+ alpha = 1
815
+ inf_delta = 0.1
816
+ colormap = cs
817
+ maxdim = len(persistence["dgms"]) - 1
818
+ dims = np.arange(maxdim + 1)
819
+
820
+ min_birth, max_death = 0, 0
821
+ for dim in dims:
822
+ # Filter out infinite values
823
+ valid_bars = [bar for bar in persistence["dgms"][dim] if not np.isinf(bar[1])]
824
+ if valid_bars:
825
+ min_birth = min(min_birth, np.min(valid_bars))
826
+ max_death = max(max_death, np.max(valid_bars))
827
+
828
+ # Handle case with no valid bars
829
+ if max_death == 0 and min_birth == 0:
830
+ min_birth = 0
831
+ max_death = 1
832
+
833
+ delta = (max_death - min_birth) * inf_delta
834
+ infinity = max_death + delta
835
+
836
+ # Create figure
837
+ fig = plt.figure(figsize=(10, 8))
838
+ gs = gridspec.GridSpec(len(dims), 1)
839
+
840
+ # Get shuffle thresholds (99.9th percentile for each dimension)
841
+ thresholds = {}
842
+ for dim in dims:
843
+ if dim in shuffle_max and shuffle_max[dim]:
844
+ thresholds[dim] = np.percentile(shuffle_max[dim], 99.9)
845
+ else:
846
+ thresholds[dim] = 0
847
+
848
+ for _, dim in enumerate(dims):
849
+ axes = plt.subplot(gs[dim])
850
+ axes.axis("off")
851
+
852
+ # Add gray background to represent shuffle region
853
+ if dim in thresholds:
854
+ axes.axvspan(0, thresholds[dim], alpha=0.2, color="gray", zorder=-3)
855
+ axes.axvline(x=thresholds[dim], color="gray", linestyle="--", alpha=0.7)
856
+
857
+ # Do not pre-filter out infinite bars; copy the full diagram instead
858
+ d = np.copy(persistence["dgms"][dim])
859
+ if d.size == 0:
860
+ d = np.zeros((0, 2))
861
+
862
+ # Map infinite death values to a finite upper bound for visualization
863
+ d[np.isinf(d[:, 1]), 1] = infinity
864
+ dlife = d[:, 1] - d[:, 0]
865
+
866
+ # Select top 30 longest-lived bars
867
+ if len(dlife) > 0:
868
+ dinds = np.argsort(dlife)[-30:]
869
+ if dim > 0:
870
+ dinds = dinds[np.flip(np.argsort(d[dinds, 0]))]
871
+
872
+ # Mark significant bars
873
+ significant_bars = []
874
+ for idx in dinds:
875
+ if dlife[idx] > thresholds.get(dim, 0):
876
+ significant_bars.append(idx)
877
+
878
+ # Draw bars
879
+ for i, idx in enumerate(dinds):
880
+ color = "red" if idx in significant_bars else colormap[dim]
881
+ axes.barh(
882
+ 0.5 + i,
883
+ dlife[idx],
884
+ height=0.8,
885
+ left=d[idx, 0],
886
+ alpha=alpha,
887
+ color=color,
888
+ linewidth=0,
889
+ )
890
+
891
+ indsall = len(dinds)
892
+ else:
893
+ indsall = 0
894
+
895
+ axes.plot([0, 0], [0, indsall], c="k", linestyle="-", lw=1)
896
+ axes.plot([0, indsall], [0, 0], c="k", linestyle="-", lw=1)
897
+ axes.set_xlim([0, infinity])
898
+ axes.set_title(f"$H_{dim}$", loc="left")
899
+
900
+ plt.tight_layout()
901
+ return fig