canns 0.12.7__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/__init__.py +3 -11
- canns/analyzer/data/asa/__init__.py +74 -0
- canns/analyzer/data/asa/cohospace.py +905 -0
- canns/analyzer/data/asa/config.py +246 -0
- canns/analyzer/data/asa/decode.py +448 -0
- canns/analyzer/data/asa/embedding.py +269 -0
- canns/analyzer/data/asa/filters.py +208 -0
- canns/analyzer/data/asa/fr.py +439 -0
- canns/analyzer/data/asa/path.py +389 -0
- canns/analyzer/data/asa/plotting.py +1276 -0
- canns/analyzer/data/asa/tda.py +901 -0
- canns/analyzer/data/legacy/__init__.py +6 -0
- canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
- canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
- canns/analyzer/visualization/core/backend.py +1 -1
- canns/analyzer/visualization/core/config.py +77 -0
- canns/analyzer/visualization/core/rendering.py +10 -6
- canns/analyzer/visualization/energy_plots.py +22 -8
- canns/analyzer/visualization/spatial_plots.py +31 -11
- canns/analyzer/visualization/theta_sweep_plots.py +15 -6
- canns/pipeline/__init__.py +4 -8
- canns/pipeline/asa/__init__.py +21 -0
- canns/pipeline/asa/__main__.py +11 -0
- canns/pipeline/asa/app.py +1000 -0
- canns/pipeline/asa/runner.py +1095 -0
- canns/pipeline/asa/screens.py +215 -0
- canns/pipeline/asa/state.py +248 -0
- canns/pipeline/asa/styles.tcss +221 -0
- canns/pipeline/asa/widgets.py +233 -0
- canns/pipeline/gallery/__init__.py +7 -0
- canns/task/open_loop_navigation.py +3 -1
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/RECORD +36 -17
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
- canns/pipeline/theta_sweep.py +0 -573
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,901 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import multiprocessing as mp
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import numpy as np
|
|
9
|
+
from canns_lib.ripser import ripser
|
|
10
|
+
from matplotlib import gridspec
|
|
11
|
+
from scipy.sparse import coo_matrix
|
|
12
|
+
from scipy.spatial.distance import pdist, squareform
|
|
13
|
+
from sklearn import preprocessing
|
|
14
|
+
|
|
15
|
+
from .config import Constants, ProcessingError, TDAConfig
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
from numba import njit
|
|
19
|
+
|
|
20
|
+
HAS_NUMBA = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
HAS_NUMBA = False
|
|
23
|
+
|
|
24
|
+
def njit(*args, **kwargs):
|
|
25
|
+
def decorator(func):
|
|
26
|
+
return func
|
|
27
|
+
|
|
28
|
+
return decorator
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def tda_vis(embed_data: np.ndarray, config: TDAConfig | None = None, **kwargs) -> dict[str, Any]:
|
|
32
|
+
"""
|
|
33
|
+
Topological Data Analysis visualization with optional shuffle testing.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
embed_data : np.ndarray
|
|
38
|
+
Embedded spike train data of shape (T, N).
|
|
39
|
+
config : TDAConfig, optional
|
|
40
|
+
Configuration object with all TDA parameters. If None, legacy kwargs are used.
|
|
41
|
+
**kwargs : Any
|
|
42
|
+
Legacy keyword parameters (``dim``, ``num_times``, ``active_times``, ``k``,
|
|
43
|
+
``n_points``, ``metric``, ``nbs``, ``maxdim``, ``coeff``, ``show``,
|
|
44
|
+
``do_shuffle``, ``num_shuffles``, ``progress_bar``).
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
dict
|
|
49
|
+
Dictionary containing:
|
|
50
|
+
- ``persistence``: persistence diagrams from real data.
|
|
51
|
+
- ``indstemp``: indices of sampled points.
|
|
52
|
+
- ``movetimes``: selected time points.
|
|
53
|
+
- ``n_points``: number of sampled points.
|
|
54
|
+
- ``shuffle_max``: shuffle analysis results (if ``do_shuffle=True``), else ``None``.
|
|
55
|
+
|
|
56
|
+
Examples
|
|
57
|
+
--------
|
|
58
|
+
>>> from canns.analyzer.data import TDAConfig, tda_vis
|
|
59
|
+
>>> cfg = TDAConfig(maxdim=1, do_shuffle=False, show=False)
|
|
60
|
+
>>> result = tda_vis(embed_data, config=cfg) # doctest: +SKIP
|
|
61
|
+
>>> sorted(result.keys())
|
|
62
|
+
['indstemp', 'movetimes', 'n_points', 'persistence', 'shuffle_max']
|
|
63
|
+
"""
|
|
64
|
+
# Handle backward compatibility and configuration
|
|
65
|
+
if config is None:
|
|
66
|
+
config = TDAConfig(
|
|
67
|
+
dim=kwargs.get("dim", 6),
|
|
68
|
+
num_times=kwargs.get("num_times", 5),
|
|
69
|
+
active_times=kwargs.get("active_times", 15000),
|
|
70
|
+
k=kwargs.get("k", 1000),
|
|
71
|
+
n_points=kwargs.get("n_points", 1200),
|
|
72
|
+
metric=kwargs.get("metric", "cosine"),
|
|
73
|
+
nbs=kwargs.get("nbs", 800),
|
|
74
|
+
maxdim=kwargs.get("maxdim", 1),
|
|
75
|
+
coeff=kwargs.get("coeff", 47),
|
|
76
|
+
show=kwargs.get("show", True),
|
|
77
|
+
do_shuffle=kwargs.get("do_shuffle", False),
|
|
78
|
+
num_shuffles=kwargs.get("num_shuffles", 1000),
|
|
79
|
+
progress_bar=kwargs.get("progress_bar", True),
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
# Compute persistent homology for real data
|
|
84
|
+
print("Computing persistent homology for real data...")
|
|
85
|
+
real_persistence = _compute_real_persistence(embed_data, config)
|
|
86
|
+
|
|
87
|
+
# Perform shuffle analysis if requested
|
|
88
|
+
shuffle_max = None
|
|
89
|
+
if config.do_shuffle:
|
|
90
|
+
shuffle_max = _perform_shuffle_analysis(embed_data, config)
|
|
91
|
+
|
|
92
|
+
# Visualization
|
|
93
|
+
_handle_visualization(real_persistence["persistence"], shuffle_max, config)
|
|
94
|
+
|
|
95
|
+
# Return results as dictionary
|
|
96
|
+
return {
|
|
97
|
+
"persistence": real_persistence["persistence"],
|
|
98
|
+
"indstemp": real_persistence["indstemp"],
|
|
99
|
+
"movetimes": real_persistence["movetimes"],
|
|
100
|
+
"n_points": real_persistence["n_points"],
|
|
101
|
+
"shuffle_max": shuffle_max,
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
except Exception as e:
|
|
105
|
+
raise ProcessingError(f"TDA analysis failed: {e}") from e
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _compute_real_persistence(embed_data: np.ndarray, config: TDAConfig) -> dict[str, Any]:
|
|
109
|
+
"""Compute persistent homology for real data with progress tracking."""
|
|
110
|
+
|
|
111
|
+
logging.info("Processing real data - Starting TDA analysis (5 steps)")
|
|
112
|
+
|
|
113
|
+
# Step 1: Time point downsampling
|
|
114
|
+
logging.info("Step 1/5: Time point downsampling")
|
|
115
|
+
times_cube = _downsample_timepoints(embed_data, config.num_times)
|
|
116
|
+
|
|
117
|
+
# Step 2: Select most active time points
|
|
118
|
+
logging.info("Step 2/5: Selecting active time points")
|
|
119
|
+
movetimes = _select_active_timepoints(embed_data, times_cube, config.active_times)
|
|
120
|
+
|
|
121
|
+
# Step 3: PCA dimensionality reduction
|
|
122
|
+
logging.info("Step 3/5: PCA dimensionality reduction")
|
|
123
|
+
dimred = _apply_pca_reduction(embed_data, movetimes, config.dim)
|
|
124
|
+
|
|
125
|
+
# Step 4: Point cloud sampling (denoising)
|
|
126
|
+
logging.info("Step 4/5: Point cloud denoising")
|
|
127
|
+
indstemp = _apply_denoising(dimred, config)
|
|
128
|
+
|
|
129
|
+
# Step 5: Compute persistent homology
|
|
130
|
+
logging.info("Step 5/5: Computing persistent homology")
|
|
131
|
+
persistence = _compute_persistence_homology(dimred, indstemp, config)
|
|
132
|
+
|
|
133
|
+
logging.info("TDA analysis completed successfully")
|
|
134
|
+
|
|
135
|
+
# Return all necessary data in dictionary format
|
|
136
|
+
return {
|
|
137
|
+
"persistence": persistence,
|
|
138
|
+
"indstemp": indstemp,
|
|
139
|
+
"movetimes": movetimes,
|
|
140
|
+
"n_points": config.n_points,
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _downsample_timepoints(embed_data: np.ndarray, num_times: int) -> np.ndarray:
|
|
145
|
+
"""Downsample timepoints for computational efficiency."""
|
|
146
|
+
return np.arange(0, embed_data.shape[0], num_times)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _select_active_timepoints(
|
|
150
|
+
embed_data: np.ndarray, times_cube: np.ndarray, active_times: int
|
|
151
|
+
) -> np.ndarray:
|
|
152
|
+
"""Select most active timepoints based on total activity."""
|
|
153
|
+
activity_scores = np.sum(embed_data[times_cube, :], 1)
|
|
154
|
+
# Match external TDAvis: sort indices first, then map to times_cube
|
|
155
|
+
movetimes = np.sort(np.argsort(activity_scores)[-active_times:])
|
|
156
|
+
return times_cube[movetimes]
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _apply_pca_reduction(embed_data: np.ndarray, movetimes: np.ndarray, dim: int) -> np.ndarray:
|
|
160
|
+
"""Apply PCA dimensionality reduction."""
|
|
161
|
+
scaled_data = preprocessing.scale(embed_data[movetimes, :])
|
|
162
|
+
dimred, *_ = _pca(scaled_data, dim=dim)
|
|
163
|
+
return dimred
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _apply_denoising(dimred: np.ndarray, config: TDAConfig) -> np.ndarray:
|
|
167
|
+
"""Apply point cloud denoising."""
|
|
168
|
+
indstemp, *_ = _sample_denoising(
|
|
169
|
+
dimred,
|
|
170
|
+
k=config.k,
|
|
171
|
+
num_sample=config.n_points,
|
|
172
|
+
omega=1, # Match external TDAvis: uses 1, not default 0.2
|
|
173
|
+
metric=config.metric,
|
|
174
|
+
)
|
|
175
|
+
return indstemp
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _compute_persistence_homology(
|
|
179
|
+
dimred: np.ndarray, indstemp: np.ndarray, config: TDAConfig
|
|
180
|
+
) -> dict[str, Any]:
|
|
181
|
+
"""Compute persistent homology using ripser."""
|
|
182
|
+
d = _second_build(dimred, indstemp, metric=config.metric, nbs=config.nbs)
|
|
183
|
+
np.fill_diagonal(d, 0)
|
|
184
|
+
|
|
185
|
+
return ripser(
|
|
186
|
+
d,
|
|
187
|
+
maxdim=config.maxdim,
|
|
188
|
+
coeff=config.coeff,
|
|
189
|
+
do_cocycles=True,
|
|
190
|
+
distance_matrix=True,
|
|
191
|
+
progress_bar=config.progress_bar,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _perform_shuffle_analysis(embed_data: np.ndarray, config: TDAConfig) -> dict[int, Any]:
|
|
196
|
+
"""Perform shuffle analysis with progress tracking."""
|
|
197
|
+
print(f"\nStarting shuffle analysis with {config.num_shuffles} iterations...")
|
|
198
|
+
|
|
199
|
+
# Create parameters dict for shuffle analysis
|
|
200
|
+
shuffle_params = {
|
|
201
|
+
"dim": config.dim,
|
|
202
|
+
"num_times": config.num_times,
|
|
203
|
+
"active_times": config.active_times,
|
|
204
|
+
"k": config.k,
|
|
205
|
+
"n_points": config.n_points,
|
|
206
|
+
"metric": config.metric,
|
|
207
|
+
"nbs": config.nbs,
|
|
208
|
+
"maxdim": config.maxdim,
|
|
209
|
+
"coeff": config.coeff,
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
shuffle_max = _run_shuffle_analysis(
|
|
213
|
+
embed_data,
|
|
214
|
+
num_shuffles=config.num_shuffles,
|
|
215
|
+
num_cores=Constants.MULTIPROCESSING_CORES,
|
|
216
|
+
progress_bar=config.progress_bar,
|
|
217
|
+
**shuffle_params,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Print shuffle analysis summary
|
|
221
|
+
_print_shuffle_summary(shuffle_max)
|
|
222
|
+
|
|
223
|
+
return shuffle_max
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _print_shuffle_summary(shuffle_max: dict[int, Any]) -> None:
|
|
227
|
+
"""Print summary of shuffle analysis results."""
|
|
228
|
+
print("\nSummary of shuffle-based analysis:")
|
|
229
|
+
for dim_idx in [0, 1, 2]:
|
|
230
|
+
if shuffle_max and dim_idx in shuffle_max and shuffle_max[dim_idx]:
|
|
231
|
+
values = shuffle_max[dim_idx]
|
|
232
|
+
print(
|
|
233
|
+
f"H{dim_idx}: {len(values)} valid iterations | "
|
|
234
|
+
f"Mean maximum persistence: {np.mean(values):.4f} | "
|
|
235
|
+
f"99.9th percentile: {np.percentile(values, 99.9):.4f}"
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _handle_visualization(
|
|
240
|
+
real_persistence: dict[str, Any], shuffle_max: dict[int, Any] | None, config: TDAConfig
|
|
241
|
+
) -> None:
|
|
242
|
+
"""Handle visualization based on configuration."""
|
|
243
|
+
if config.show:
|
|
244
|
+
if config.do_shuffle and shuffle_max is not None:
|
|
245
|
+
_plot_barcode_with_shuffle(real_persistence, shuffle_max)
|
|
246
|
+
else:
|
|
247
|
+
_plot_barcode(real_persistence)
|
|
248
|
+
plt.show()
|
|
249
|
+
else:
|
|
250
|
+
plt.close()
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _compute_persistence(
|
|
254
|
+
sspikes,
|
|
255
|
+
dim=6,
|
|
256
|
+
num_times=5,
|
|
257
|
+
active_times=15000,
|
|
258
|
+
k=1000,
|
|
259
|
+
n_points=1200,
|
|
260
|
+
metric="cosine",
|
|
261
|
+
nbs=800,
|
|
262
|
+
maxdim=1,
|
|
263
|
+
coeff=47,
|
|
264
|
+
progress_bar=True,
|
|
265
|
+
):
|
|
266
|
+
# Time point downsampling
|
|
267
|
+
times_cube = np.arange(0, sspikes.shape[0], num_times)
|
|
268
|
+
|
|
269
|
+
# Select most active time points
|
|
270
|
+
movetimes = np.sort(np.argsort(np.sum(sspikes[times_cube, :], 1))[-active_times:])
|
|
271
|
+
movetimes = times_cube[movetimes]
|
|
272
|
+
|
|
273
|
+
# PCA dimensionality reduction
|
|
274
|
+
scaled_data = preprocessing.scale(sspikes[movetimes, :])
|
|
275
|
+
dimred, *_ = _pca(scaled_data, dim=dim)
|
|
276
|
+
|
|
277
|
+
# Point cloud sampling (denoising)
|
|
278
|
+
indstemp, *_ = _sample_denoising(dimred, k, n_points, 1, metric)
|
|
279
|
+
|
|
280
|
+
# Build distance matrix
|
|
281
|
+
d = _second_build(dimred, indstemp, metric=metric, nbs=nbs)
|
|
282
|
+
np.fill_diagonal(d, 0)
|
|
283
|
+
|
|
284
|
+
# Compute persistent homology
|
|
285
|
+
persistence = ripser(
|
|
286
|
+
d,
|
|
287
|
+
maxdim=maxdim,
|
|
288
|
+
coeff=coeff,
|
|
289
|
+
do_cocycles=True,
|
|
290
|
+
distance_matrix=True,
|
|
291
|
+
progress_bar=progress_bar,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
return persistence
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _pca(data, dim=2):
|
|
298
|
+
"""
|
|
299
|
+
Perform PCA (Principal Component Analysis) for dimensionality reduction.
|
|
300
|
+
|
|
301
|
+
Parameters:
|
|
302
|
+
data (ndarray): Input data matrix of shape (N_samples, N_features).
|
|
303
|
+
dim (int): Target dimension for PCA projection.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
components (ndarray): Projected data of shape (N_samples, dim).
|
|
307
|
+
var_exp (list): Variance explained by each principal component.
|
|
308
|
+
evals (ndarray): Eigenvalues corresponding to the selected components.
|
|
309
|
+
"""
|
|
310
|
+
if dim < 2:
|
|
311
|
+
return data, [0], np.array([])
|
|
312
|
+
_ = data.shape
|
|
313
|
+
# mean center the data
|
|
314
|
+
# data -= data.mean(axis=0)
|
|
315
|
+
# calculate the covariance matrix
|
|
316
|
+
R = np.cov(data, rowvar=False)
|
|
317
|
+
# calculate eigenvectors & eigenvalues of the covariance matrix
|
|
318
|
+
# use 'eigh' rather than 'eig' since R is symmetric,
|
|
319
|
+
# the performance gain is substantial
|
|
320
|
+
evals, evecs = np.linalg.eig(R)
|
|
321
|
+
# sort eigenvalue in decreasing order
|
|
322
|
+
idx = np.argsort(evals)[::-1]
|
|
323
|
+
evecs = evecs[:, idx]
|
|
324
|
+
# sort eigenvectors according to same index
|
|
325
|
+
evals = evals[idx]
|
|
326
|
+
# select the first n eigenvectors (n is desired dimension
|
|
327
|
+
# of rescaled data array, or dims_rescaled_data)
|
|
328
|
+
evecs = evecs[:, :dim]
|
|
329
|
+
# carry out the transformation on the data using eigenvectors
|
|
330
|
+
# and return the re-scaled data, eigenvalues, and eigenvectors
|
|
331
|
+
|
|
332
|
+
tot = np.sum(evals)
|
|
333
|
+
var_exp = [(i / tot) * 100 for i in sorted(evals[:dim], reverse=True)]
|
|
334
|
+
components = np.dot(evecs.T, data.T).T
|
|
335
|
+
return components, var_exp, evals[:dim]
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _sample_denoising(data, k=10, num_sample=500, omega=0.2, metric="euclidean"):
|
|
339
|
+
"""
|
|
340
|
+
Perform denoising and greedy sampling based on mutual k-NN graph.
|
|
341
|
+
|
|
342
|
+
Parameters:
|
|
343
|
+
data (ndarray): High-dimensional point cloud data.
|
|
344
|
+
k (int): Number of neighbors for local density estimation.
|
|
345
|
+
num_sample (int): Number of samples to retain.
|
|
346
|
+
omega (float): Suppression factor during greedy sampling.
|
|
347
|
+
metric (str): Distance metric used for kNN ('euclidean', 'cosine', etc).
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
inds (ndarray): Indices of sampled points.
|
|
351
|
+
d (ndarray): Pairwise similarity matrix of sampled points.
|
|
352
|
+
Fs (ndarray): Sampling scores at each step.
|
|
353
|
+
"""
|
|
354
|
+
if HAS_NUMBA:
|
|
355
|
+
return _sample_denoising_numba(data, k, num_sample, omega, metric)
|
|
356
|
+
else:
|
|
357
|
+
return _sample_denoising_numpy(data, k, num_sample, omega, metric)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _sample_denoising_numpy(data, k=10, num_sample=500, omega=0.2, metric="euclidean"):
|
|
361
|
+
"""Original numpy implementation for fallback."""
|
|
362
|
+
n = data.shape[0]
|
|
363
|
+
X = squareform(pdist(data, metric))
|
|
364
|
+
knn_indices = np.argsort(X)[:, :k]
|
|
365
|
+
knn_dists = X[np.arange(X.shape[0])[:, None], knn_indices].copy()
|
|
366
|
+
|
|
367
|
+
sigmas, rhos = _smooth_knn_dist(knn_dists, k, local_connectivity=0)
|
|
368
|
+
rows, cols, vals = _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos)
|
|
369
|
+
result = coo_matrix((vals, (rows, cols)), shape=(n, n))
|
|
370
|
+
result.eliminate_zeros()
|
|
371
|
+
transpose = result.transpose()
|
|
372
|
+
prod_matrix = result.multiply(transpose)
|
|
373
|
+
result = result + transpose - prod_matrix
|
|
374
|
+
result.eliminate_zeros()
|
|
375
|
+
X = result.toarray()
|
|
376
|
+
F = np.sum(X, 1)
|
|
377
|
+
Fs = np.zeros(num_sample)
|
|
378
|
+
Fs[0] = np.max(F)
|
|
379
|
+
i = np.argmax(F)
|
|
380
|
+
inds_all = np.arange(n)
|
|
381
|
+
inds_left = inds_all > -1
|
|
382
|
+
inds_left[i] = False
|
|
383
|
+
inds = np.zeros(num_sample, dtype=int)
|
|
384
|
+
inds[0] = i
|
|
385
|
+
for j in np.arange(1, num_sample):
|
|
386
|
+
F -= omega * X[i, :]
|
|
387
|
+
Fmax = np.argmax(F[inds_left])
|
|
388
|
+
# Exactly match external TDAvis implementation (including the indexing logic)
|
|
389
|
+
Fs[j] = F[Fmax]
|
|
390
|
+
i = inds_all[inds_left][Fmax]
|
|
391
|
+
|
|
392
|
+
inds_left[i] = False
|
|
393
|
+
inds[j] = i
|
|
394
|
+
d = np.zeros((num_sample, num_sample))
|
|
395
|
+
|
|
396
|
+
for j, i in enumerate(inds):
|
|
397
|
+
d[j, :] = X[i, inds]
|
|
398
|
+
return inds, d, Fs
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def _sample_denoising_numba(data, k=10, num_sample=500, omega=0.2, metric="euclidean"):
|
|
402
|
+
"""Optimized numba implementation."""
|
|
403
|
+
n = data.shape[0]
|
|
404
|
+
X = squareform(pdist(data, metric))
|
|
405
|
+
knn_indices = np.argsort(X)[:, :k]
|
|
406
|
+
knn_dists = X[np.arange(X.shape[0])[:, None], knn_indices].copy()
|
|
407
|
+
|
|
408
|
+
sigmas, rhos = _smooth_knn_dist(knn_dists, k, local_connectivity=0)
|
|
409
|
+
rows, cols, vals = _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos)
|
|
410
|
+
|
|
411
|
+
# Build symmetric adjacency matrix using optimized function
|
|
412
|
+
X_adj = _build_adjacency_matrix_numba(rows, cols, vals, n)
|
|
413
|
+
|
|
414
|
+
# Greedy sampling using optimized function
|
|
415
|
+
inds, Fs = _greedy_sampling_numba(X_adj, num_sample, omega)
|
|
416
|
+
|
|
417
|
+
# Build final distance matrix
|
|
418
|
+
d = _build_distance_matrix_numba(X_adj, inds)
|
|
419
|
+
|
|
420
|
+
return inds, d, Fs
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
@njit(fastmath=True)
|
|
424
|
+
def _build_adjacency_matrix_numba(rows, cols, vals, n):
|
|
425
|
+
"""Build symmetric adjacency matrix efficiently with numba.
|
|
426
|
+
|
|
427
|
+
This matches the scipy sparse matrix operations:
|
|
428
|
+
result = result + transpose - prod_matrix
|
|
429
|
+
where prod_matrix = result.multiply(transpose)
|
|
430
|
+
"""
|
|
431
|
+
# Initialize matrices
|
|
432
|
+
X = np.zeros((n, n), dtype=np.float64)
|
|
433
|
+
X_T = np.zeros((n, n), dtype=np.float64)
|
|
434
|
+
|
|
435
|
+
# Build adjacency matrix and its transpose simultaneously
|
|
436
|
+
for i in range(len(rows)):
|
|
437
|
+
X[rows[i], cols[i]] = vals[i]
|
|
438
|
+
X_T[cols[i], rows[i]] = vals[i] # Transpose
|
|
439
|
+
|
|
440
|
+
# Apply the symmetrization formula: A = A + A^T - A ⊙ A^T (vectorized)
|
|
441
|
+
# This matches scipy's: result + transpose - prod_matrix
|
|
442
|
+
X[:, :] = X + X_T - X * X_T
|
|
443
|
+
|
|
444
|
+
return X
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
@njit(fastmath=True)
|
|
448
|
+
def _greedy_sampling_numba(X, num_sample, omega):
|
|
449
|
+
"""Optimized greedy sampling with numba."""
|
|
450
|
+
n = X.shape[0]
|
|
451
|
+
F = np.sum(X, axis=1)
|
|
452
|
+
Fs = np.zeros(num_sample)
|
|
453
|
+
inds = np.zeros(num_sample, dtype=np.int64)
|
|
454
|
+
inds_left = np.ones(n, dtype=np.bool_)
|
|
455
|
+
|
|
456
|
+
# Initialize with maximum F
|
|
457
|
+
i = np.argmax(F)
|
|
458
|
+
Fs[0] = F[i]
|
|
459
|
+
inds[0] = i
|
|
460
|
+
inds_left[i] = False
|
|
461
|
+
|
|
462
|
+
# Greedy sampling loop
|
|
463
|
+
for j in range(1, num_sample):
|
|
464
|
+
# Update F values
|
|
465
|
+
for k in range(n):
|
|
466
|
+
F[k] -= omega * X[i, k]
|
|
467
|
+
|
|
468
|
+
# Find maximum among remaining points (matching numpy logic exactly)
|
|
469
|
+
max_val = -np.inf
|
|
470
|
+
max_idx = -1
|
|
471
|
+
for k in range(n):
|
|
472
|
+
if inds_left[k] and F[k] > max_val:
|
|
473
|
+
max_val = F[k]
|
|
474
|
+
max_idx = k
|
|
475
|
+
|
|
476
|
+
# Record the F value using the selected index (matching external TDAvis)
|
|
477
|
+
i = max_idx
|
|
478
|
+
Fs[j] = F[i]
|
|
479
|
+
inds[j] = i
|
|
480
|
+
inds_left[i] = False
|
|
481
|
+
|
|
482
|
+
return inds, Fs
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
@njit(fastmath=True)
|
|
486
|
+
def _build_distance_matrix_numba(X, inds):
|
|
487
|
+
"""Build final distance matrix efficiently with numba."""
|
|
488
|
+
num_sample = len(inds)
|
|
489
|
+
d = np.zeros((num_sample, num_sample))
|
|
490
|
+
|
|
491
|
+
for j in range(num_sample):
|
|
492
|
+
for k in range(num_sample):
|
|
493
|
+
d[j, k] = X[inds[j], inds[k]]
|
|
494
|
+
|
|
495
|
+
return d
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
@njit(fastmath=True)
|
|
499
|
+
def _smooth_knn_dist(distances, k, n_iter=64, local_connectivity=0.0, bandwidth=1.0):
|
|
500
|
+
"""
|
|
501
|
+
Compute smoothed local distances for kNN graph with entropy balancing.
|
|
502
|
+
|
|
503
|
+
Parameters:
|
|
504
|
+
distances (ndarray): kNN distance matrix.
|
|
505
|
+
k (int): Number of neighbors.
|
|
506
|
+
n_iter (int): Number of binary search iterations.
|
|
507
|
+
local_connectivity (float): Minimum local connectivity.
|
|
508
|
+
bandwidth (float): Bandwidth parameter.
|
|
509
|
+
|
|
510
|
+
Returns:
|
|
511
|
+
sigmas (ndarray): Smoothed sigma values for each point.
|
|
512
|
+
rhos (ndarray): Minimum distances (connectivity cutoff) for each point.
|
|
513
|
+
"""
|
|
514
|
+
target = np.log2(k) * bandwidth
|
|
515
|
+
# target = np.log(k) * bandwidth
|
|
516
|
+
# target = k
|
|
517
|
+
|
|
518
|
+
rho = np.zeros(distances.shape[0])
|
|
519
|
+
result = np.zeros(distances.shape[0])
|
|
520
|
+
|
|
521
|
+
mean_distances = np.mean(distances)
|
|
522
|
+
|
|
523
|
+
for i in range(distances.shape[0]):
|
|
524
|
+
lo = 0.0
|
|
525
|
+
hi = np.inf
|
|
526
|
+
mid = 1.0
|
|
527
|
+
|
|
528
|
+
# Vectorized computation of non-zero distances
|
|
529
|
+
ith_distances = distances[i]
|
|
530
|
+
non_zero_dists = ith_distances[ith_distances > 0.0]
|
|
531
|
+
if non_zero_dists.shape[0] >= local_connectivity:
|
|
532
|
+
index = int(np.floor(local_connectivity))
|
|
533
|
+
interpolation = local_connectivity - index
|
|
534
|
+
if index > 0:
|
|
535
|
+
rho[i] = non_zero_dists[index - 1]
|
|
536
|
+
if interpolation > 1e-5:
|
|
537
|
+
rho[i] += interpolation * (non_zero_dists[index] - non_zero_dists[index - 1])
|
|
538
|
+
else:
|
|
539
|
+
rho[i] = interpolation * non_zero_dists[0]
|
|
540
|
+
elif non_zero_dists.shape[0] > 0:
|
|
541
|
+
rho[i] = np.max(non_zero_dists)
|
|
542
|
+
|
|
543
|
+
# Vectorized binary search loop - compute all at once instead of loop
|
|
544
|
+
for _ in range(n_iter):
|
|
545
|
+
# Vectorized computation: compute all distances at once
|
|
546
|
+
d_array = distances[i, 1:] - rho[i]
|
|
547
|
+
# Vectorized conditional: use np.where for conditional computation
|
|
548
|
+
psum = np.sum(np.where(d_array > 0, np.exp(-(d_array / mid)), 1.0))
|
|
549
|
+
|
|
550
|
+
if np.fabs(psum - target) < 1e-5:
|
|
551
|
+
break
|
|
552
|
+
|
|
553
|
+
if psum > target:
|
|
554
|
+
hi = mid
|
|
555
|
+
mid = (lo + hi) / 2.0
|
|
556
|
+
else:
|
|
557
|
+
lo = mid
|
|
558
|
+
if hi == np.inf:
|
|
559
|
+
mid *= 2
|
|
560
|
+
else:
|
|
561
|
+
mid = (lo + hi) / 2.0
|
|
562
|
+
result[i] = mid
|
|
563
|
+
# Optimized mean computation - reuse ith_distances
|
|
564
|
+
if rho[i] > 0.0:
|
|
565
|
+
mean_ith_distances = np.mean(ith_distances)
|
|
566
|
+
if result[i] < 1e-3 * mean_ith_distances:
|
|
567
|
+
result[i] = 1e-3 * mean_ith_distances
|
|
568
|
+
else:
|
|
569
|
+
if result[i] < 1e-3 * mean_distances:
|
|
570
|
+
result[i] = 1e-3 * mean_distances
|
|
571
|
+
|
|
572
|
+
return result, rho
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
@njit(parallel=True, fastmath=True)
|
|
576
|
+
def _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos):
|
|
577
|
+
"""
|
|
578
|
+
Compute membership strength matrix from smoothed kNN graph.
|
|
579
|
+
|
|
580
|
+
Parameters:
|
|
581
|
+
knn_indices (ndarray): Indices of k-nearest neighbors.
|
|
582
|
+
knn_dists (ndarray): Corresponding distances.
|
|
583
|
+
sigmas (ndarray): Local bandwidths.
|
|
584
|
+
rhos (ndarray): Minimum distance thresholds.
|
|
585
|
+
|
|
586
|
+
Returns:
|
|
587
|
+
rows (ndarray): Row indices for sparse matrix.
|
|
588
|
+
cols (ndarray): Column indices for sparse matrix.
|
|
589
|
+
vals (ndarray): Weight values for sparse matrix.
|
|
590
|
+
"""
|
|
591
|
+
n_samples = knn_indices.shape[0]
|
|
592
|
+
n_neighbors = knn_indices.shape[1]
|
|
593
|
+
rows = np.zeros((n_samples * n_neighbors), dtype=np.int64)
|
|
594
|
+
cols = np.zeros((n_samples * n_neighbors), dtype=np.int64)
|
|
595
|
+
vals = np.zeros((n_samples * n_neighbors), dtype=np.float64)
|
|
596
|
+
for i in range(n_samples):
|
|
597
|
+
for j in range(n_neighbors):
|
|
598
|
+
if knn_indices[i, j] == -1:
|
|
599
|
+
continue # We didn't get the full knn for i
|
|
600
|
+
if knn_indices[i, j] == i:
|
|
601
|
+
val = 0.0
|
|
602
|
+
elif knn_dists[i, j] - rhos[i] <= 0.0:
|
|
603
|
+
val = 1.0
|
|
604
|
+
else:
|
|
605
|
+
val = np.exp(-((knn_dists[i, j] - rhos[i]) / (sigmas[i])))
|
|
606
|
+
# val = ((knn_dists[i, j] - rhos[i]) / (sigmas[i]))
|
|
607
|
+
|
|
608
|
+
rows[i * n_neighbors + j] = i
|
|
609
|
+
cols[i * n_neighbors + j] = knn_indices[i, j]
|
|
610
|
+
vals[i * n_neighbors + j] = val
|
|
611
|
+
|
|
612
|
+
return rows, cols, vals
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
def _second_build(data, indstemp, nbs=800, metric="cosine"):
|
|
616
|
+
"""
|
|
617
|
+
Reconstruct distance matrix after denoising for persistent homology.
|
|
618
|
+
|
|
619
|
+
Parameters:
|
|
620
|
+
data (ndarray): PCA-reduced data matrix.
|
|
621
|
+
indstemp (ndarray): Indices of sampled points.
|
|
622
|
+
nbs (int): Number of neighbors in reconstructed graph.
|
|
623
|
+
metric (str): Distance metric ('cosine', 'euclidean', etc).
|
|
624
|
+
|
|
625
|
+
Returns:
|
|
626
|
+
d (ndarray): Symmetric distance matrix used for persistent homology.
|
|
627
|
+
"""
|
|
628
|
+
# Filter the data using the sampled point indices
|
|
629
|
+
data = data[indstemp, :]
|
|
630
|
+
|
|
631
|
+
# Compute the pairwise distance matrix
|
|
632
|
+
X = squareform(pdist(data, metric))
|
|
633
|
+
knn_indices = np.argsort(X)[:, :nbs]
|
|
634
|
+
knn_dists = X[np.arange(X.shape[0])[:, None], knn_indices].copy()
|
|
635
|
+
|
|
636
|
+
# Compute smoothed kernel widths
|
|
637
|
+
sigmas, rhos = _smooth_knn_dist(knn_dists, nbs, local_connectivity=0)
|
|
638
|
+
rows, cols, vals = _compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos)
|
|
639
|
+
|
|
640
|
+
# Construct a sparse graph
|
|
641
|
+
result = coo_matrix((vals, (rows, cols)), shape=(X.shape[0], X.shape[0]))
|
|
642
|
+
result.eliminate_zeros()
|
|
643
|
+
transpose = result.transpose()
|
|
644
|
+
prod_matrix = result.multiply(transpose)
|
|
645
|
+
result = result + transpose - prod_matrix
|
|
646
|
+
result.eliminate_zeros()
|
|
647
|
+
|
|
648
|
+
# Build the final distance matrix
|
|
649
|
+
d = result.toarray()
|
|
650
|
+
# Match external TDAvis: direct negative log without epsilon handling
|
|
651
|
+
# Temporarily suppress divide by zero warning to match external behavior
|
|
652
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
653
|
+
d = -np.log(d)
|
|
654
|
+
np.fill_diagonal(d, 0)
|
|
655
|
+
|
|
656
|
+
return d
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
def _fast_pca_transform(data, components):
|
|
660
|
+
"""Fast PCA transformation using numba."""
|
|
661
|
+
return np.dot(data, components.T)
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
def _run_shuffle_analysis(sspikes, num_shuffles=1000, num_cores=4, progress_bar=True, **kwargs):
|
|
665
|
+
"""Perform shuffle analysis with optimized computation."""
|
|
666
|
+
return _run_shuffle_analysis_multiprocessing(
|
|
667
|
+
sspikes, num_shuffles, num_cores, progress_bar, **kwargs
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
def _run_shuffle_analysis_multiprocessing(
|
|
672
|
+
sspikes, num_shuffles=1000, num_cores=4, progress_bar=True, **kwargs
|
|
673
|
+
):
|
|
674
|
+
"""Original multiprocessing implementation for fallback."""
|
|
675
|
+
# Use numpy arrays with NaN for failed results (more efficient than None filtering)
|
|
676
|
+
max_lifetimes = {
|
|
677
|
+
0: np.full(num_shuffles, np.nan),
|
|
678
|
+
1: np.full(num_shuffles, np.nan),
|
|
679
|
+
2: np.full(num_shuffles, np.nan),
|
|
680
|
+
}
|
|
681
|
+
|
|
682
|
+
# Prepare task list
|
|
683
|
+
tasks = [(i, sspikes, kwargs) for i in range(num_shuffles)]
|
|
684
|
+
logging.info(
|
|
685
|
+
f"Starting shuffle analysis with {num_shuffles} iterations using {num_cores} cores..."
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
# Use multiprocessing pool for parallel processing
|
|
689
|
+
with mp.Pool(processes=num_cores) as pool:
|
|
690
|
+
results = list(pool.imap(_process_single_shuffle, tasks))
|
|
691
|
+
logging.info("Shuffle analysis completed")
|
|
692
|
+
|
|
693
|
+
# Collect results - use indexing instead of append for better performance
|
|
694
|
+
for idx, res in enumerate(results):
|
|
695
|
+
for dim, lifetime in res.items():
|
|
696
|
+
max_lifetimes[dim][idx] = lifetime
|
|
697
|
+
|
|
698
|
+
# Filter out NaN values (failed results) - convert to list for consistency
|
|
699
|
+
for dim in max_lifetimes:
|
|
700
|
+
max_lifetimes[dim] = max_lifetimes[dim][~np.isnan(max_lifetimes[dim])].tolist()
|
|
701
|
+
|
|
702
|
+
return max_lifetimes
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def _process_single_shuffle(args):
|
|
706
|
+
"""Process a single shuffle task."""
|
|
707
|
+
i, sspikes, kwargs = args
|
|
708
|
+
try:
|
|
709
|
+
shuffled_data = _shuffle_spike_trains(sspikes)
|
|
710
|
+
persistence = _compute_persistence(shuffled_data, **kwargs)
|
|
711
|
+
|
|
712
|
+
dim_max_lifetimes = {}
|
|
713
|
+
for dim in [0, 1, 2]:
|
|
714
|
+
if dim < len(persistence["dgms"]):
|
|
715
|
+
# Filter out infinite values
|
|
716
|
+
valid_bars = [bar for bar in persistence["dgms"][dim] if not np.isinf(bar[1])]
|
|
717
|
+
if valid_bars:
|
|
718
|
+
lifetimes = [bar[1] - bar[0] for bar in valid_bars]
|
|
719
|
+
if lifetimes:
|
|
720
|
+
dim_max_lifetimes[dim] = max(lifetimes)
|
|
721
|
+
return dim_max_lifetimes
|
|
722
|
+
except Exception as e:
|
|
723
|
+
print(f"Shuffle {i} failed: {str(e)}")
|
|
724
|
+
return {}
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
def _shuffle_spike_trains(sspikes):
|
|
728
|
+
"""Perform random circular shift on spike trains."""
|
|
729
|
+
shuffled = sspikes.copy()
|
|
730
|
+
num_neurons = shuffled.shape[1]
|
|
731
|
+
|
|
732
|
+
# Independent shift for each neuron
|
|
733
|
+
for n in range(num_neurons):
|
|
734
|
+
shift = np.random.randint(0, int(shuffled.shape[0] * 0.1))
|
|
735
|
+
shuffled[:, n] = np.roll(shuffled[:, n], shift)
|
|
736
|
+
|
|
737
|
+
return shuffled
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
def _plot_barcode(persistence):
|
|
741
|
+
"""
|
|
742
|
+
Plot barcode diagram from persistent homology result.
|
|
743
|
+
|
|
744
|
+
Parameters:
|
|
745
|
+
persistence (dict): Persistent homology result with 'dgms' key.
|
|
746
|
+
"""
|
|
747
|
+
cs = np.repeat([[0, 0.55, 0.2]], 3).reshape(3, 3).T # RGB color for each dimension
|
|
748
|
+
alpha = 1
|
|
749
|
+
inf_delta = 0.1
|
|
750
|
+
colormap = cs
|
|
751
|
+
dgms = persistence["dgms"]
|
|
752
|
+
maxdim = len(dgms) - 1
|
|
753
|
+
dims = np.arange(maxdim + 1)
|
|
754
|
+
labels = ["$H_0$", "$H_1$", "$H_2$"]
|
|
755
|
+
|
|
756
|
+
# Determine axis range
|
|
757
|
+
min_birth, max_death = 0, 0
|
|
758
|
+
for dim in dims:
|
|
759
|
+
persistence_dim = dgms[dim][~np.isinf(dgms[dim][:, 1]), :]
|
|
760
|
+
if persistence_dim.size > 0:
|
|
761
|
+
min_birth = min(min_birth, np.min(persistence_dim))
|
|
762
|
+
max_death = max(max_death, np.max(persistence_dim))
|
|
763
|
+
|
|
764
|
+
delta = (max_death - min_birth) * inf_delta
|
|
765
|
+
infinity = max_death + delta
|
|
766
|
+
axis_start = min_birth - delta
|
|
767
|
+
|
|
768
|
+
# Create plot
|
|
769
|
+
fig = plt.figure(figsize=(10, 6))
|
|
770
|
+
gs = gridspec.GridSpec(len(dims), 1)
|
|
771
|
+
|
|
772
|
+
for dim in dims:
|
|
773
|
+
axes = plt.subplot(gs[dim])
|
|
774
|
+
axes.axis("on")
|
|
775
|
+
axes.set_yticks([])
|
|
776
|
+
axes.set_ylabel(labels[dim], rotation=0, labelpad=20, fontsize=12)
|
|
777
|
+
|
|
778
|
+
d = np.copy(dgms[dim])
|
|
779
|
+
d[np.isinf(d[:, 1]), 1] = infinity
|
|
780
|
+
dlife = d[:, 1] - d[:, 0]
|
|
781
|
+
|
|
782
|
+
# Select top 30 bars by lifetime
|
|
783
|
+
dinds = np.argsort(dlife)[-30:]
|
|
784
|
+
if dim > 0:
|
|
785
|
+
dinds = dinds[np.flip(np.argsort(d[dinds, 0]))]
|
|
786
|
+
|
|
787
|
+
axes.barh(
|
|
788
|
+
0.5 + np.arange(len(dinds)),
|
|
789
|
+
dlife[dinds],
|
|
790
|
+
height=0.8,
|
|
791
|
+
left=d[dinds, 0],
|
|
792
|
+
alpha=alpha,
|
|
793
|
+
color=colormap[dim],
|
|
794
|
+
linewidth=0,
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
axes.plot([0, 0], [0, len(dinds)], c="k", linestyle="-", lw=1)
|
|
798
|
+
axes.plot([0, len(dinds)], [0, 0], c="k", linestyle="-", lw=1)
|
|
799
|
+
axes.set_xlim([axis_start, infinity])
|
|
800
|
+
|
|
801
|
+
plt.tight_layout()
|
|
802
|
+
return fig
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
def _plot_barcode_with_shuffle(persistence, shuffle_max):
|
|
806
|
+
"""
|
|
807
|
+
Plot barcode with shuffle region markers.
|
|
808
|
+
"""
|
|
809
|
+
# Handle case where shuffle_max is None
|
|
810
|
+
if shuffle_max is None:
|
|
811
|
+
shuffle_max = {}
|
|
812
|
+
|
|
813
|
+
cs = np.repeat([[0, 0.55, 0.2]], 3).reshape(3, 3).T
|
|
814
|
+
alpha = 1
|
|
815
|
+
inf_delta = 0.1
|
|
816
|
+
colormap = cs
|
|
817
|
+
maxdim = len(persistence["dgms"]) - 1
|
|
818
|
+
dims = np.arange(maxdim + 1)
|
|
819
|
+
|
|
820
|
+
min_birth, max_death = 0, 0
|
|
821
|
+
for dim in dims:
|
|
822
|
+
# Filter out infinite values
|
|
823
|
+
valid_bars = [bar for bar in persistence["dgms"][dim] if not np.isinf(bar[1])]
|
|
824
|
+
if valid_bars:
|
|
825
|
+
min_birth = min(min_birth, np.min(valid_bars))
|
|
826
|
+
max_death = max(max_death, np.max(valid_bars))
|
|
827
|
+
|
|
828
|
+
# Handle case with no valid bars
|
|
829
|
+
if max_death == 0 and min_birth == 0:
|
|
830
|
+
min_birth = 0
|
|
831
|
+
max_death = 1
|
|
832
|
+
|
|
833
|
+
delta = (max_death - min_birth) * inf_delta
|
|
834
|
+
infinity = max_death + delta
|
|
835
|
+
|
|
836
|
+
# Create figure
|
|
837
|
+
fig = plt.figure(figsize=(10, 8))
|
|
838
|
+
gs = gridspec.GridSpec(len(dims), 1)
|
|
839
|
+
|
|
840
|
+
# Get shuffle thresholds (99.9th percentile for each dimension)
|
|
841
|
+
thresholds = {}
|
|
842
|
+
for dim in dims:
|
|
843
|
+
if dim in shuffle_max and shuffle_max[dim]:
|
|
844
|
+
thresholds[dim] = np.percentile(shuffle_max[dim], 99.9)
|
|
845
|
+
else:
|
|
846
|
+
thresholds[dim] = 0
|
|
847
|
+
|
|
848
|
+
for _, dim in enumerate(dims):
|
|
849
|
+
axes = plt.subplot(gs[dim])
|
|
850
|
+
axes.axis("off")
|
|
851
|
+
|
|
852
|
+
# Add gray background to represent shuffle region
|
|
853
|
+
if dim in thresholds:
|
|
854
|
+
axes.axvspan(0, thresholds[dim], alpha=0.2, color="gray", zorder=-3)
|
|
855
|
+
axes.axvline(x=thresholds[dim], color="gray", linestyle="--", alpha=0.7)
|
|
856
|
+
|
|
857
|
+
# Do not pre-filter out infinite bars; copy the full diagram instead
|
|
858
|
+
d = np.copy(persistence["dgms"][dim])
|
|
859
|
+
if d.size == 0:
|
|
860
|
+
d = np.zeros((0, 2))
|
|
861
|
+
|
|
862
|
+
# Map infinite death values to a finite upper bound for visualization
|
|
863
|
+
d[np.isinf(d[:, 1]), 1] = infinity
|
|
864
|
+
dlife = d[:, 1] - d[:, 0]
|
|
865
|
+
|
|
866
|
+
# Select top 30 longest-lived bars
|
|
867
|
+
if len(dlife) > 0:
|
|
868
|
+
dinds = np.argsort(dlife)[-30:]
|
|
869
|
+
if dim > 0:
|
|
870
|
+
dinds = dinds[np.flip(np.argsort(d[dinds, 0]))]
|
|
871
|
+
|
|
872
|
+
# Mark significant bars
|
|
873
|
+
significant_bars = []
|
|
874
|
+
for idx in dinds:
|
|
875
|
+
if dlife[idx] > thresholds.get(dim, 0):
|
|
876
|
+
significant_bars.append(idx)
|
|
877
|
+
|
|
878
|
+
# Draw bars
|
|
879
|
+
for i, idx in enumerate(dinds):
|
|
880
|
+
color = "red" if idx in significant_bars else colormap[dim]
|
|
881
|
+
axes.barh(
|
|
882
|
+
0.5 + i,
|
|
883
|
+
dlife[idx],
|
|
884
|
+
height=0.8,
|
|
885
|
+
left=d[idx, 0],
|
|
886
|
+
alpha=alpha,
|
|
887
|
+
color=color,
|
|
888
|
+
linewidth=0,
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
indsall = len(dinds)
|
|
892
|
+
else:
|
|
893
|
+
indsall = 0
|
|
894
|
+
|
|
895
|
+
axes.plot([0, 0], [0, indsall], c="k", linestyle="-", lw=1)
|
|
896
|
+
axes.plot([0, indsall], [0, 0], c="k", linestyle="-", lw=1)
|
|
897
|
+
axes.set_xlim([0, infinity])
|
|
898
|
+
axes.set_title(f"$H_{dim}$", loc="left")
|
|
899
|
+
|
|
900
|
+
plt.tight_layout()
|
|
901
|
+
return fig
|