shesha-geometry 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/tutorial.py +336 -0
- shesha/__init__.py +59 -0
- shesha/bio.py +315 -0
- shesha/core.py +648 -0
- shesha_geometry-0.1.0.dist-info/LICENSE +21 -0
- shesha_geometry-0.1.0.dist-info/METADATA +396 -0
- shesha_geometry-0.1.0.dist-info/RECORD +13 -0
- shesha_geometry-0.1.0.dist-info/WHEEL +5 -0
- shesha_geometry-0.1.0.dist-info/top_level.txt +3 -0
- tests/__init__.py +1 -0
- tests/test_bio.py +49 -0
- tests/test_core.py +164 -0
- tests/test_crispr.py +191 -0
shesha/core.py
ADDED
|
@@ -0,0 +1,648 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shesha: Self-consistency Metrics for Representational Stability
|
|
3
|
+
|
|
4
|
+
Core implementations of Shesha variants for measuring geometric stability
|
|
5
|
+
of high-dimensional representations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from scipy.stats import spearmanr, pearsonr
|
|
10
|
+
from scipy.spatial.distance import pdist, cdist
|
|
11
|
+
from typing import Optional, Literal, Union
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
# Unsupervised variants
|
|
15
|
+
"feature_split",
|
|
16
|
+
"sample_split",
|
|
17
|
+
"anchor_stability",
|
|
18
|
+
# Supervised variants
|
|
19
|
+
"variance_ratio",
|
|
20
|
+
"supervised_alignment",
|
|
21
|
+
# Drift metrics
|
|
22
|
+
"rdm_similarity",
|
|
23
|
+
"rdm_drift",
|
|
24
|
+
# Utilities
|
|
25
|
+
"compute_rdm",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
EPS = 1e-12
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# =============================================================================
|
|
32
|
+
# RDM Utilities
|
|
33
|
+
# =============================================================================
|
|
34
|
+
|
|
35
|
+
def compute_rdm(
|
|
36
|
+
X: np.ndarray,
|
|
37
|
+
metric: Literal["cosine", "correlation", "euclidean"] = "cosine",
|
|
38
|
+
normalize: bool = True,
|
|
39
|
+
) -> np.ndarray:
|
|
40
|
+
"""
|
|
41
|
+
Compute Representational Dissimilarity Matrix (RDM).
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
X : np.ndarray
|
|
46
|
+
Data matrix of shape (n_samples, n_features).
|
|
47
|
+
metric : str
|
|
48
|
+
Distance metric: 'cosine', 'correlation', or 'euclidean'.
|
|
49
|
+
normalize : bool
|
|
50
|
+
If True and metric='cosine', L2-normalize rows before computing distances.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
np.ndarray
|
|
55
|
+
Condensed distance vector (upper triangle of RDM).
|
|
56
|
+
"""
|
|
57
|
+
X = np.asarray(X, dtype=np.float64)
|
|
58
|
+
|
|
59
|
+
if normalize and metric == "cosine":
|
|
60
|
+
norms = np.linalg.norm(X, axis=1, keepdims=True)
|
|
61
|
+
X = X / np.maximum(norms, EPS)
|
|
62
|
+
|
|
63
|
+
return pdist(X, metric=metric)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# =============================================================================
|
|
67
|
+
# Unsupervised Variants
|
|
68
|
+
# =============================================================================
|
|
69
|
+
|
|
70
|
+
def feature_split(
|
|
71
|
+
X: np.ndarray,
|
|
72
|
+
n_splits: int = 30,
|
|
73
|
+
metric: Literal["cosine", "correlation"] = "cosine",
|
|
74
|
+
seed: Optional[int] = None,
|
|
75
|
+
max_samples: Optional[int] = 1600,
|
|
76
|
+
) -> float:
|
|
77
|
+
"""
|
|
78
|
+
Feature-Split Shesha: measures internal geometric consistency.
|
|
79
|
+
|
|
80
|
+
Partitions feature dimensions into random disjoint halves, computes RDMs
|
|
81
|
+
on each half, and measures their rank correlation. High values indicate
|
|
82
|
+
that geometric structure is distributed across features (redundant encoding).
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
X : np.ndarray
|
|
87
|
+
Data matrix of shape (n_samples, n_features).
|
|
88
|
+
n_splits : int
|
|
89
|
+
Number of random feature partitions to average over.
|
|
90
|
+
metric : str
|
|
91
|
+
Distance metric for RDM computation.
|
|
92
|
+
seed : int, optional
|
|
93
|
+
Random seed for reproducibility.
|
|
94
|
+
max_samples : int, optional
|
|
95
|
+
Subsample to this many samples if exceeded (for efficiency).
|
|
96
|
+
|
|
97
|
+
Returns
|
|
98
|
+
-------
|
|
99
|
+
float
|
|
100
|
+
Mean Spearman correlation between split-half RDMs. Range: [-1, 1].
|
|
101
|
+
|
|
102
|
+
Examples
|
|
103
|
+
--------
|
|
104
|
+
>>> X = np.random.randn(500, 768) # 500 samples, 768-dim embeddings
|
|
105
|
+
>>> stability = feature_split(X, n_splits=30, seed=320)
|
|
106
|
+
>>> print(f"Feature-split stability: {stability:.3f}")
|
|
107
|
+
"""
|
|
108
|
+
X = np.asarray(X, dtype=np.float64)
|
|
109
|
+
n_samples, n_features = X.shape
|
|
110
|
+
|
|
111
|
+
if n_features < 4:
|
|
112
|
+
return np.nan
|
|
113
|
+
if n_samples < 4:
|
|
114
|
+
return np.nan
|
|
115
|
+
|
|
116
|
+
rng = np.random.default_rng(seed)
|
|
117
|
+
|
|
118
|
+
# Subsample if needed
|
|
119
|
+
if max_samples is not None and n_samples > max_samples:
|
|
120
|
+
idx = rng.choice(n_samples, max_samples, replace=False)
|
|
121
|
+
X = X[idx]
|
|
122
|
+
n_samples = max_samples
|
|
123
|
+
|
|
124
|
+
# L2 normalize for cosine metric
|
|
125
|
+
if metric == "cosine":
|
|
126
|
+
norms = np.linalg.norm(X, axis=1, keepdims=True)
|
|
127
|
+
X = X / np.maximum(norms, EPS)
|
|
128
|
+
|
|
129
|
+
correlations = []
|
|
130
|
+
|
|
131
|
+
for i in range(n_splits):
|
|
132
|
+
# Random partition of features
|
|
133
|
+
perm = rng.permutation(n_features)
|
|
134
|
+
mid = n_features // 2
|
|
135
|
+
feat1, feat2 = perm[:mid], perm[mid:2*mid]
|
|
136
|
+
|
|
137
|
+
X1, X2 = X[:, feat1], X[:, feat2]
|
|
138
|
+
|
|
139
|
+
# Compute RDMs
|
|
140
|
+
rdm1 = pdist(X1, metric=metric)
|
|
141
|
+
rdm2 = pdist(X2, metric=metric)
|
|
142
|
+
|
|
143
|
+
# Handle NaN distances (can occur with zero vectors)
|
|
144
|
+
rdm1 = np.nan_to_num(rdm1, nan=1.0)
|
|
145
|
+
rdm2 = np.nan_to_num(rdm2, nan=1.0)
|
|
146
|
+
|
|
147
|
+
# Check for constant RDMs
|
|
148
|
+
if np.std(rdm1) < EPS or np.std(rdm2) < EPS:
|
|
149
|
+
continue
|
|
150
|
+
|
|
151
|
+
rho, _ = spearmanr(rdm1, rdm2)
|
|
152
|
+
if np.isfinite(rho):
|
|
153
|
+
correlations.append(rho)
|
|
154
|
+
|
|
155
|
+
return float(np.mean(correlations)) if correlations else np.nan
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def sample_split(
|
|
159
|
+
X: np.ndarray,
|
|
160
|
+
n_splits: int = 30,
|
|
161
|
+
subsample_fraction: float = 0.4,
|
|
162
|
+
metric: Literal["cosine", "correlation"] = "cosine",
|
|
163
|
+
seed: Optional[int] = None,
|
|
164
|
+
max_samples: Optional[int] = 1500,
|
|
165
|
+
) -> float:
|
|
166
|
+
"""
|
|
167
|
+
Sample-Split Shesha (Bootstrap RDM): measures robustness to input variation.
|
|
168
|
+
|
|
169
|
+
Creates random subsamples of data points, computes RDMs on each, and
|
|
170
|
+
measures their correlation. Assesses whether distance structure generalizes
|
|
171
|
+
across different subsets of the data.
|
|
172
|
+
|
|
173
|
+
Parameters
|
|
174
|
+
----------
|
|
175
|
+
X : np.ndarray
|
|
176
|
+
Data matrix of shape (n_samples, n_features).
|
|
177
|
+
n_splits : int
|
|
178
|
+
Number of bootstrap iterations.
|
|
179
|
+
subsample_fraction : float
|
|
180
|
+
Fraction of samples to use in each subsample.
|
|
181
|
+
metric : str
|
|
182
|
+
Distance metric for RDM computation.
|
|
183
|
+
seed : int, optional
|
|
184
|
+
Random seed for reproducibility.
|
|
185
|
+
max_samples : int, optional
|
|
186
|
+
Subsample to this many samples if exceeded.
|
|
187
|
+
|
|
188
|
+
Returns
|
|
189
|
+
-------
|
|
190
|
+
float
|
|
191
|
+
Mean Spearman correlation between bootstrap RDMs. Range: [-1, 1].
|
|
192
|
+
|
|
193
|
+
Examples
|
|
194
|
+
--------
|
|
195
|
+
>>> X = np.random.randn(1000, 384)
|
|
196
|
+
>>> stability = sample_split(X, n_splits=50, seed=320)
|
|
197
|
+
"""
|
|
198
|
+
X = np.asarray(X, dtype=np.float64)
|
|
199
|
+
n_samples = X.shape[0]
|
|
200
|
+
|
|
201
|
+
if n_samples < 10:
|
|
202
|
+
return np.nan
|
|
203
|
+
|
|
204
|
+
rng = np.random.default_rng(seed)
|
|
205
|
+
|
|
206
|
+
# Subsample if needed
|
|
207
|
+
if max_samples is not None and n_samples > max_samples:
|
|
208
|
+
idx = rng.choice(n_samples, max_samples, replace=False)
|
|
209
|
+
X = X[idx]
|
|
210
|
+
n_samples = max_samples
|
|
211
|
+
|
|
212
|
+
m = int(n_samples * subsample_fraction)
|
|
213
|
+
if m < 5:
|
|
214
|
+
return np.nan
|
|
215
|
+
|
|
216
|
+
correlations = []
|
|
217
|
+
|
|
218
|
+
for _ in range(n_splits):
|
|
219
|
+
# Two independent subsamples
|
|
220
|
+
idx1 = rng.choice(n_samples, m, replace=False)
|
|
221
|
+
idx2 = rng.choice(n_samples, m, replace=False)
|
|
222
|
+
|
|
223
|
+
rdm1 = pdist(X[idx1], metric=metric)
|
|
224
|
+
rdm2 = pdist(X[idx2], metric=metric)
|
|
225
|
+
|
|
226
|
+
if np.std(rdm1) < EPS or np.std(rdm2) < EPS:
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
rho, _ = spearmanr(rdm1, rdm2)
|
|
230
|
+
if np.isfinite(rho):
|
|
231
|
+
correlations.append(rho)
|
|
232
|
+
|
|
233
|
+
return float(np.mean(correlations)) if correlations else np.nan
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def anchor_stability(
|
|
237
|
+
X: np.ndarray,
|
|
238
|
+
n_splits: int = 30,
|
|
239
|
+
n_anchors: int = 100,
|
|
240
|
+
n_per_split: int = 200,
|
|
241
|
+
metric: Literal["cosine", "euclidean"] = "cosine",
|
|
242
|
+
rank_normalize: bool = True,
|
|
243
|
+
seed: Optional[int] = None,
|
|
244
|
+
max_samples: Optional[int] = 1500,
|
|
245
|
+
) -> float:
|
|
246
|
+
"""
|
|
247
|
+
Anchor-based Shesha: measures stability of distance profiles from fixed anchors.
|
|
248
|
+
|
|
249
|
+
Selects fixed anchor points, then measures consistency of distance profiles
|
|
250
|
+
from anchors to random data splits. More robust to sampling variation than
|
|
251
|
+
pure bootstrap approaches.
|
|
252
|
+
|
|
253
|
+
Parameters
|
|
254
|
+
----------
|
|
255
|
+
X : np.ndarray
|
|
256
|
+
Data matrix of shape (n_samples, n_features).
|
|
257
|
+
n_splits : int
|
|
258
|
+
Number of random splits.
|
|
259
|
+
n_anchors : int
|
|
260
|
+
Number of fixed anchor points.
|
|
261
|
+
n_per_split : int
|
|
262
|
+
Number of samples per split.
|
|
263
|
+
metric : str
|
|
264
|
+
Distance metric.
|
|
265
|
+
rank_normalize : bool
|
|
266
|
+
If True, rank-normalize distances within each anchor before correlating.
|
|
267
|
+
seed : int, optional
|
|
268
|
+
Random seed.
|
|
269
|
+
max_samples : int, optional
|
|
270
|
+
Subsample to this many samples if exceeded.
|
|
271
|
+
|
|
272
|
+
Returns
|
|
273
|
+
-------
|
|
274
|
+
float
|
|
275
|
+
Mean correlation of anchor distance profiles across splits.
|
|
276
|
+
"""
|
|
277
|
+
X = np.asarray(X, dtype=np.float64)
|
|
278
|
+
n_samples = X.shape[0]
|
|
279
|
+
|
|
280
|
+
rng = np.random.default_rng(seed)
|
|
281
|
+
|
|
282
|
+
# Subsample if needed
|
|
283
|
+
if max_samples is not None and n_samples > max_samples:
|
|
284
|
+
idx = rng.choice(n_samples, max_samples, replace=False)
|
|
285
|
+
X = X[idx]
|
|
286
|
+
n_samples = max_samples
|
|
287
|
+
|
|
288
|
+
# Need enough samples for anchors + two splits
|
|
289
|
+
min_required = n_anchors + 2 * n_per_split
|
|
290
|
+
if n_samples < min_required:
|
|
291
|
+
# Reduce sizes proportionally
|
|
292
|
+
scale = n_samples / min_required * 0.9
|
|
293
|
+
n_anchors = max(10, int(n_anchors * scale))
|
|
294
|
+
n_per_split = max(20, int(n_per_split * scale))
|
|
295
|
+
|
|
296
|
+
if n_samples < n_anchors + 2 * n_per_split:
|
|
297
|
+
return np.nan
|
|
298
|
+
|
|
299
|
+
# Select fixed anchors
|
|
300
|
+
anchor_idx = rng.choice(n_samples, n_anchors, replace=False)
|
|
301
|
+
anchors = X[anchor_idx]
|
|
302
|
+
remaining_idx = np.setdiff1d(np.arange(n_samples), anchor_idx)
|
|
303
|
+
|
|
304
|
+
if len(remaining_idx) < 2 * n_per_split:
|
|
305
|
+
return np.nan
|
|
306
|
+
|
|
307
|
+
correlations = []
|
|
308
|
+
|
|
309
|
+
for _ in range(n_splits):
|
|
310
|
+
# Two disjoint splits from remaining samples
|
|
311
|
+
perm = rng.permutation(remaining_idx)
|
|
312
|
+
split1_idx = perm[:n_per_split]
|
|
313
|
+
split2_idx = perm[n_per_split:2*n_per_split]
|
|
314
|
+
|
|
315
|
+
# Distance matrices: anchors x split_samples
|
|
316
|
+
D1 = cdist(anchors, X[split1_idx], metric=metric)
|
|
317
|
+
D2 = cdist(anchors, X[split2_idx], metric=metric)
|
|
318
|
+
|
|
319
|
+
if rank_normalize:
|
|
320
|
+
# Rank within each anchor's distances
|
|
321
|
+
from scipy.stats import rankdata
|
|
322
|
+
D1 = np.apply_along_axis(rankdata, 1, D1)
|
|
323
|
+
D2 = np.apply_along_axis(rankdata, 1, D2)
|
|
324
|
+
|
|
325
|
+
# Flatten and correlate
|
|
326
|
+
rho, _ = spearmanr(D1.ravel(), D2.ravel())
|
|
327
|
+
if np.isfinite(rho):
|
|
328
|
+
correlations.append(rho)
|
|
329
|
+
|
|
330
|
+
return float(np.mean(correlations)) if correlations else np.nan
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
# =============================================================================
|
|
334
|
+
# Supervised Variants
|
|
335
|
+
# =============================================================================
|
|
336
|
+
|
|
337
|
+
def variance_ratio(
|
|
338
|
+
X: np.ndarray,
|
|
339
|
+
y: np.ndarray,
|
|
340
|
+
) -> float:
|
|
341
|
+
"""
|
|
342
|
+
Variance Ratio Shesha: ratio of between-class to total variance.
|
|
343
|
+
|
|
344
|
+
A simple, efficient measure of how much geometric structure is explained
|
|
345
|
+
by class labels. Equivalent to the R-squared of predicting coordinates
|
|
346
|
+
from class membership.
|
|
347
|
+
|
|
348
|
+
Parameters
|
|
349
|
+
----------
|
|
350
|
+
X : np.ndarray
|
|
351
|
+
Data matrix of shape (n_samples, n_features).
|
|
352
|
+
y : np.ndarray
|
|
353
|
+
Class labels of shape (n_samples,).
|
|
354
|
+
|
|
355
|
+
Returns
|
|
356
|
+
-------
|
|
357
|
+
float
|
|
358
|
+
Between-class variance / total variance. Range: [0, 1].
|
|
359
|
+
|
|
360
|
+
Examples
|
|
361
|
+
--------
|
|
362
|
+
>>> X = np.random.randn(500, 768)
|
|
363
|
+
>>> y = np.random.randint(0, 10, 500)
|
|
364
|
+
>>> vr = variance_ratio(X, y)
|
|
365
|
+
"""
|
|
366
|
+
X = np.asarray(X, dtype=np.float64)
|
|
367
|
+
y = np.asarray(y)
|
|
368
|
+
|
|
369
|
+
classes = np.unique(y)
|
|
370
|
+
if len(classes) < 2:
|
|
371
|
+
return np.nan
|
|
372
|
+
|
|
373
|
+
global_mean = np.mean(X, axis=0)
|
|
374
|
+
X_centered = X - global_mean
|
|
375
|
+
ss_total = np.sum(X_centered ** 2) + EPS
|
|
376
|
+
|
|
377
|
+
ss_between = 0.0
|
|
378
|
+
for c in classes:
|
|
379
|
+
mask = (y == c)
|
|
380
|
+
n_c = np.sum(mask)
|
|
381
|
+
if n_c == 0:
|
|
382
|
+
continue
|
|
383
|
+
class_mean = np.mean(X[mask], axis=0)
|
|
384
|
+
ss_between += n_c * np.sum((class_mean - global_mean) ** 2)
|
|
385
|
+
|
|
386
|
+
return float(ss_between / ss_total)
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def supervised_alignment(
|
|
390
|
+
X: np.ndarray,
|
|
391
|
+
y: np.ndarray,
|
|
392
|
+
metric: Literal["cosine", "correlation"] = "correlation",
|
|
393
|
+
seed: Optional[int] = None,
|
|
394
|
+
max_samples: int = 300,
|
|
395
|
+
) -> float:
|
|
396
|
+
"""
|
|
397
|
+
Supervised RDM Alignment: correlation between model RDM and ideal label RDM.
|
|
398
|
+
|
|
399
|
+
Measures how well the representation's distance structure aligns with
|
|
400
|
+
task-defined similarity (same class = similar, different class = dissimilar).
|
|
401
|
+
|
|
402
|
+
Parameters
|
|
403
|
+
----------
|
|
404
|
+
X : np.ndarray
|
|
405
|
+
Data matrix of shape (n_samples, n_features).
|
|
406
|
+
y : np.ndarray
|
|
407
|
+
Class labels of shape (n_samples,).
|
|
408
|
+
metric : str
|
|
409
|
+
Distance metric for model RDM.
|
|
410
|
+
seed : int, optional
|
|
411
|
+
Random seed for subsampling.
|
|
412
|
+
max_samples : int
|
|
413
|
+
Subsample to this many samples (RDM computation is O(n^2)).
|
|
414
|
+
|
|
415
|
+
Returns
|
|
416
|
+
-------
|
|
417
|
+
float
|
|
418
|
+
Spearman correlation between model and ideal RDMs. Range: [-1, 1].
|
|
419
|
+
"""
|
|
420
|
+
X = np.asarray(X, dtype=np.float64)
|
|
421
|
+
y = np.asarray(y)
|
|
422
|
+
|
|
423
|
+
rng = np.random.default_rng(seed)
|
|
424
|
+
|
|
425
|
+
if len(X) > max_samples:
|
|
426
|
+
idx = rng.choice(len(X), max_samples, replace=False)
|
|
427
|
+
X, y = X[idx], y[idx]
|
|
428
|
+
|
|
429
|
+
# Center for correlation distance
|
|
430
|
+
X = X - np.mean(X, axis=0)
|
|
431
|
+
|
|
432
|
+
# Model RDM
|
|
433
|
+
model_rdm = pdist(X, metric=metric)
|
|
434
|
+
|
|
435
|
+
# Ideal RDM from labels (Hamming distance on labels)
|
|
436
|
+
ideal_rdm = pdist(y.reshape(-1, 1), metric="hamming")
|
|
437
|
+
|
|
438
|
+
rho, _ = spearmanr(model_rdm, ideal_rdm)
|
|
439
|
+
return float(rho) if np.isfinite(rho) else np.nan
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
# =============================================================================
|
|
443
|
+
# Drift Metrics
|
|
444
|
+
# =============================================================================
|
|
445
|
+
|
|
446
|
+
def rdm_similarity(
|
|
447
|
+
X: np.ndarray,
|
|
448
|
+
Y: np.ndarray,
|
|
449
|
+
method: Literal["spearman", "pearson"] = "spearman",
|
|
450
|
+
metric: Literal["cosine", "correlation", "euclidean"] = "cosine",
|
|
451
|
+
) -> float:
|
|
452
|
+
"""
|
|
453
|
+
Compute RDM similarity between two representations.
|
|
454
|
+
|
|
455
|
+
Measures how similar the pairwise distance structures are between two
|
|
456
|
+
representations. Useful for measuring representational drift, comparing
|
|
457
|
+
models, or tracking changes during training.
|
|
458
|
+
|
|
459
|
+
Parameters
|
|
460
|
+
----------
|
|
461
|
+
X : np.ndarray
|
|
462
|
+
First representation matrix of shape (n_samples, n_features_x).
|
|
463
|
+
Y : np.ndarray
|
|
464
|
+
Second representation matrix of shape (n_samples, n_features_y).
|
|
465
|
+
Must have the same number of samples as X.
|
|
466
|
+
method : str
|
|
467
|
+
Correlation method: 'spearman' (rank-based, default) or 'pearson'.
|
|
468
|
+
metric : str
|
|
469
|
+
Distance metric for RDM computation: 'cosine', 'correlation', or 'euclidean'.
|
|
470
|
+
|
|
471
|
+
Returns
|
|
472
|
+
-------
|
|
473
|
+
float
|
|
474
|
+
Correlation between RDMs. Range: [-1, 1].
|
|
475
|
+
Higher values indicate more similar geometric structure.
|
|
476
|
+
|
|
477
|
+
Examples
|
|
478
|
+
--------
|
|
479
|
+
>>> # Compare representations before and after training
|
|
480
|
+
>>> X_before = model_before.encode(data)
|
|
481
|
+
>>> X_after = model_after.encode(data)
|
|
482
|
+
>>> similarity = rdm_similarity(X_before, X_after)
|
|
483
|
+
>>> print(f"RDM similarity: {similarity:.3f}")
|
|
484
|
+
|
|
485
|
+
>>> # Compare two different models
|
|
486
|
+
>>> X_model1 = model1.encode(data)
|
|
487
|
+
>>> X_model2 = model2.encode(data)
|
|
488
|
+
>>> similarity = rdm_similarity(X_model1, X_model2, method='pearson')
|
|
489
|
+
|
|
490
|
+
Notes
|
|
491
|
+
-----
|
|
492
|
+
- Spearman (default) is more robust to outliers and non-linear relationships
|
|
493
|
+
- Pearson captures linear relationships in distance magnitudes
|
|
494
|
+
- The representations can have different feature dimensions (only sample
|
|
495
|
+
count must match)
|
|
496
|
+
"""
|
|
497
|
+
X = np.asarray(X, dtype=np.float64)
|
|
498
|
+
Y = np.asarray(Y, dtype=np.float64)
|
|
499
|
+
|
|
500
|
+
if X.shape[0] != Y.shape[0]:
|
|
501
|
+
raise ValueError(f"Sample counts must match: X has {X.shape[0]}, Y has {Y.shape[0]}")
|
|
502
|
+
|
|
503
|
+
if X.shape[0] < 3:
|
|
504
|
+
return np.nan
|
|
505
|
+
|
|
506
|
+
# Compute RDMs
|
|
507
|
+
rdm_x = pdist(X, metric=metric)
|
|
508
|
+
rdm_y = pdist(Y, metric=metric)
|
|
509
|
+
|
|
510
|
+
# Handle NaN values
|
|
511
|
+
rdm_x = np.nan_to_num(rdm_x, nan=1.0)
|
|
512
|
+
rdm_y = np.nan_to_num(rdm_y, nan=1.0)
|
|
513
|
+
|
|
514
|
+
# Check for constant RDMs
|
|
515
|
+
if np.std(rdm_x) < EPS or np.std(rdm_y) < EPS:
|
|
516
|
+
return 0.0
|
|
517
|
+
|
|
518
|
+
# Compute correlation
|
|
519
|
+
if method == "spearman":
|
|
520
|
+
rho = spearmanr(rdm_x, rdm_y).correlation
|
|
521
|
+
elif method == "pearson":
|
|
522
|
+
rho, _ = pearsonr(rdm_x, rdm_y)
|
|
523
|
+
else:
|
|
524
|
+
raise ValueError(f"Unknown method: {method}. Use 'spearman' or 'pearson'")
|
|
525
|
+
|
|
526
|
+
return float(rho) if np.isfinite(rho) else 0.0
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def rdm_drift(
|
|
530
|
+
X: np.ndarray,
|
|
531
|
+
Y: np.ndarray,
|
|
532
|
+
method: Literal["spearman", "pearson"] = "spearman",
|
|
533
|
+
metric: Literal["cosine", "correlation", "euclidean"] = "cosine",
|
|
534
|
+
) -> float:
|
|
535
|
+
"""
|
|
536
|
+
Compute representational drift between two representations.
|
|
537
|
+
|
|
538
|
+
Drift is defined as 1 - rdm_similarity, so higher values indicate
|
|
539
|
+
more change in geometric structure. This is useful for tracking
|
|
540
|
+
how much a representation has changed over time or due to some
|
|
541
|
+
intervention (fine-tuning, perturbation, etc.).
|
|
542
|
+
|
|
543
|
+
Parameters
|
|
544
|
+
----------
|
|
545
|
+
X : np.ndarray
|
|
546
|
+
First (baseline/before) representation of shape (n_samples, n_features_x).
|
|
547
|
+
Y : np.ndarray
|
|
548
|
+
Second (comparison/after) representation of shape (n_samples, n_features_y).
|
|
549
|
+
Must have the same number of samples as X.
|
|
550
|
+
method : str
|
|
551
|
+
Correlation method: 'spearman' (rank-based, default) or 'pearson'.
|
|
552
|
+
metric : str
|
|
553
|
+
Distance metric for RDM computation.
|
|
554
|
+
|
|
555
|
+
Returns
|
|
556
|
+
-------
|
|
557
|
+
float
|
|
558
|
+
Drift score: 1 - RDM_correlation. Range: [0, 2].
|
|
559
|
+
- 0: Identical geometric structure
|
|
560
|
+
- 1: Uncorrelated (random relationship)
|
|
561
|
+
- 2: Perfectly anti-correlated (inverted structure)
|
|
562
|
+
|
|
563
|
+
Examples
|
|
564
|
+
--------
|
|
565
|
+
>>> # Track drift during training
|
|
566
|
+
>>> X_epoch0 = model.encode(data)
|
|
567
|
+
>>> for epoch in range(10):
|
|
568
|
+
... train_one_epoch(model)
|
|
569
|
+
... X_current = model.encode(data)
|
|
570
|
+
... drift = rdm_drift(X_epoch0, X_current)
|
|
571
|
+
... print(f"Epoch {epoch+1}: drift = {drift:.3f}")
|
|
572
|
+
|
|
573
|
+
>>> # Measure drift due to noise perturbation
|
|
574
|
+
>>> X_clean = model.encode(clean_data)
|
|
575
|
+
>>> X_noisy = model.encode(noisy_data)
|
|
576
|
+
>>> drift = rdm_drift(X_clean, X_noisy)
|
|
577
|
+
>>> print(f"Noise-induced drift: {drift:.3f}")
|
|
578
|
+
|
|
579
|
+
See Also
|
|
580
|
+
--------
|
|
581
|
+
rdm_similarity : The inverse metric (similarity instead of drift)
|
|
582
|
+
"""
|
|
583
|
+
similarity = rdm_similarity(X, Y, method=method, metric=metric)
|
|
584
|
+
|
|
585
|
+
if np.isnan(similarity):
|
|
586
|
+
return np.nan
|
|
587
|
+
|
|
588
|
+
return 1.0 - similarity
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
# =============================================================================
|
|
592
|
+
# Convenience function
|
|
593
|
+
# =============================================================================
|
|
594
|
+
|
|
595
|
+
def shesha(
|
|
596
|
+
X: np.ndarray,
|
|
597
|
+
y: Optional[np.ndarray] = None,
|
|
598
|
+
variant: Literal["feature_split", "sample_split", "anchor", "variance", "supervised"] = "feature_split",
|
|
599
|
+
**kwargs,
|
|
600
|
+
) -> float:
|
|
601
|
+
"""
|
|
602
|
+
Unified interface for computing Shesha stability metrics.
|
|
603
|
+
|
|
604
|
+
Parameters
|
|
605
|
+
----------
|
|
606
|
+
X : np.ndarray
|
|
607
|
+
Data matrix of shape (n_samples, n_features).
|
|
608
|
+
y : np.ndarray, optional
|
|
609
|
+
Class labels (required for supervised variants).
|
|
610
|
+
variant : str
|
|
611
|
+
Which Shesha variant to compute:
|
|
612
|
+
- 'feature_split': Unsupervised, partitions features
|
|
613
|
+
- 'sample_split': Unsupervised, bootstrap resampling
|
|
614
|
+
- 'anchor': Unsupervised, anchor-based stability
|
|
615
|
+
- 'variance': Supervised, variance ratio
|
|
616
|
+
- 'supervised': Supervised, RDM alignment
|
|
617
|
+
**kwargs
|
|
618
|
+
Additional arguments passed to the specific variant function.
|
|
619
|
+
|
|
620
|
+
Returns
|
|
621
|
+
-------
|
|
622
|
+
float
|
|
623
|
+
Shesha stability score.
|
|
624
|
+
|
|
625
|
+
Examples
|
|
626
|
+
--------
|
|
627
|
+
>>> # Unsupervised
|
|
628
|
+
>>> stability = shesha(X, variant='feature_split', n_splits=30, seed=320)
|
|
629
|
+
|
|
630
|
+
>>> # Supervised
|
|
631
|
+
>>> alignment = shesha(X, y, variant='supervised')
|
|
632
|
+
"""
|
|
633
|
+
if variant == "feature_split":
|
|
634
|
+
return feature_split(X, **kwargs)
|
|
635
|
+
elif variant == "sample_split":
|
|
636
|
+
return sample_split(X, **kwargs)
|
|
637
|
+
elif variant == "anchor":
|
|
638
|
+
return anchor_stability(X, **kwargs)
|
|
639
|
+
elif variant == "variance":
|
|
640
|
+
if y is None:
|
|
641
|
+
raise ValueError("Labels required for variance_ratio")
|
|
642
|
+
return variance_ratio(X, y)
|
|
643
|
+
elif variant == "supervised":
|
|
644
|
+
if y is None:
|
|
645
|
+
raise ValueError("Labels required for supervised_alignment")
|
|
646
|
+
return supervised_alignment(X, y, **kwargs)
|
|
647
|
+
else:
|
|
648
|
+
raise ValueError(f"Unknown variant: {variant}")
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Prashant C. Raju
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|