shesha-geometry 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/tutorial.py +336 -0
- shesha/__init__.py +59 -0
- shesha/bio.py +315 -0
- shesha/core.py +648 -0
- shesha_geometry-0.1.0.dist-info/LICENSE +21 -0
- shesha_geometry-0.1.0.dist-info/METADATA +396 -0
- shesha_geometry-0.1.0.dist-info/RECORD +13 -0
- shesha_geometry-0.1.0.dist-info/WHEEL +5 -0
- shesha_geometry-0.1.0.dist-info/top_level.txt +3 -0
- tests/__init__.py +1 -0
- tests/test_bio.py +49 -0
- tests/test_core.py +164 -0
- tests/test_crispr.py +191 -0
examples/tutorial.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
# %% [markdown]
|
|
2
|
+
# # Shesha Tutorial
|
|
3
|
+
#
|
|
4
|
+
# This tutorial demonstrates how to use SHESHA (Self-consistency metrics for representational stability analysis) to measure geometric stability of high-dimensional representations.
|
|
5
|
+
#
|
|
6
|
+
# **What you'll learn:**
|
|
7
|
+
# 1. Basic usage of unsupervised variants (feature_split, sample_split, anchor_stability)
|
|
8
|
+
# 2. Supervised variants with labels (variance_ratio, supervised_alignment)
|
|
9
|
+
# 3. Practical applications: comparing models, detecting drift, analyzing embeddings
|
|
10
|
+
|
|
11
|
+
# %% [markdown]
|
|
12
|
+
# ## Installation
|
|
13
|
+
#
|
|
14
|
+
# ```bash
|
|
15
|
+
# pip install shesha-geometry
|
|
16
|
+
# ```
|
|
17
|
+
|
|
18
|
+
# %%
|
|
19
|
+
import numpy as np
|
|
20
|
+
import matplotlib.pyplot as plt
|
|
21
|
+
import shesha
|
|
22
|
+
|
|
23
|
+
print(f"shesha version: {shesha.__version__}")
|
|
24
|
+
|
|
25
|
+
# Set random seed for reproducibility
|
|
26
|
+
np.random.seed(320)
|
|
27
|
+
|
|
28
|
+
# %% [markdown]
|
|
29
|
+
# ## 1. Understanding Shesha: The Core Idea
|
|
30
|
+
#
|
|
31
|
+
# Shesha measures **geometric stability** - whether a representation's distance structure
|
|
32
|
+
# is internally consistent. Unlike similarity metrics (CKA, Procrustes) that compare
|
|
33
|
+
# *between* representations, Shesha measures consistency *within* a single representation.
|
|
34
|
+
#
|
|
35
|
+
# **Key insight:** A stable representation should produce similar pairwise distance patterns
|
|
36
|
+
# (RDMs) when computed from different "views" of the same data.
|
|
37
|
+
|
|
38
|
+
# %% [markdown]
|
|
39
|
+
# ## 2. Feature-Split Shesha (Unsupervised)
|
|
40
|
+
#
|
|
41
|
+
# The most common variant. Splits feature dimensions into random halves and checks if
|
|
42
|
+
# both halves encode consistent distance relationships.
|
|
43
|
+
#
|
|
44
|
+
# **High stability** = geometric structure is distributed across features (redundant encoding)
|
|
45
|
+
# **Low stability** = structure concentrated in few features or noisy
|
|
46
|
+
|
|
47
|
+
# %%
|
|
48
|
+
# Example 1: Structured data (high stability expected)
|
|
49
|
+
# Create data with low-rank structure - geometry should be consistent across feature subsets
|
|
50
|
+
n_samples, latent_dim, feature_dim = 500, 20, 768
|
|
51
|
+
|
|
52
|
+
latent = np.random.randn(n_samples, latent_dim)
|
|
53
|
+
projection = np.random.randn(latent_dim, feature_dim)
|
|
54
|
+
X_structured = latent @ projection
|
|
55
|
+
|
|
56
|
+
stability_structured = shesha.feature_split(X_structured, n_splits=30, seed=320)
|
|
57
|
+
print(f"Structured data stability: {stability_structured:.3f}")
|
|
58
|
+
|
|
59
|
+
# %%
|
|
60
|
+
# Example 2: Random noise (lower stability expected)
|
|
61
|
+
# Each feature is independent - no consistent structure across subsets
|
|
62
|
+
X_random = np.random.randn(n_samples, feature_dim)
|
|
63
|
+
|
|
64
|
+
stability_random = shesha.feature_split(X_random, n_splits=30, seed=320)
|
|
65
|
+
print(f"Random noise stability: {stability_random:.3f}")
|
|
66
|
+
|
|
67
|
+
# %%
|
|
68
|
+
# Visualize the difference
|
|
69
|
+
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
|
|
70
|
+
|
|
71
|
+
# Run multiple seeds to show distribution
|
|
72
|
+
stabilities_structured = [shesha.feature_split(X_structured, n_splits=30, seed=s) for s in range(10)]
|
|
73
|
+
stabilities_random = [shesha.feature_split(X_random, n_splits=30, seed=s) for s in range(10)]
|
|
74
|
+
|
|
75
|
+
axes[0].bar(['Structured', 'Random'],
|
|
76
|
+
[np.mean(stabilities_structured), np.mean(stabilities_random)],
|
|
77
|
+
yerr=[np.std(stabilities_structured), np.std(stabilities_random)],
|
|
78
|
+
capsize=5, color=['steelblue', 'coral'])
|
|
79
|
+
axes[0].set_ylabel('Feature-Split Stability')
|
|
80
|
+
axes[0].set_title('Stability Comparison')
|
|
81
|
+
axes[0].set_ylim(0, 1)
|
|
82
|
+
|
|
83
|
+
# Show how stability varies with latent dimension
|
|
84
|
+
latent_dims = [5, 10, 20, 50, 100, 200]
|
|
85
|
+
stabilities_by_dim = []
|
|
86
|
+
|
|
87
|
+
for ld in latent_dims:
|
|
88
|
+
latent = np.random.randn(n_samples, ld)
|
|
89
|
+
proj = np.random.randn(ld, feature_dim)
|
|
90
|
+
X = latent @ proj
|
|
91
|
+
stabilities_by_dim.append(shesha.feature_split(X, n_splits=30, seed=320))
|
|
92
|
+
|
|
93
|
+
axes[1].plot(latent_dims, stabilities_by_dim, 'o-', color='steelblue', linewidth=2)
|
|
94
|
+
axes[1].set_xlabel('Latent Dimensionality')
|
|
95
|
+
axes[1].set_ylabel('Feature-Split Stability')
|
|
96
|
+
axes[1].set_title('Stability vs. Latent Rank')
|
|
97
|
+
axes[1].set_ylim(0, 1)
|
|
98
|
+
|
|
99
|
+
plt.tight_layout()
|
|
100
|
+
plt.savefig('tutorial_feature_split.png', dpi=150)
|
|
101
|
+
plt.show()
|
|
102
|
+
|
|
103
|
+
# %% [markdown]
|
|
104
|
+
# ## 3. Sample-Split Shesha (Bootstrap)
|
|
105
|
+
#
|
|
106
|
+
# Measures robustness to sampling variation by computing RDMs on different
|
|
107
|
+
# random subsets of data points.
|
|
108
|
+
|
|
109
|
+
# %%
|
|
110
|
+
# Compare sample-split stability
|
|
111
|
+
sample_stability_structured = shesha.sample_split(X_structured, n_splits=30, seed=320)
|
|
112
|
+
sample_stability_random = shesha.sample_split(X_random, n_splits=30, seed=320)
|
|
113
|
+
|
|
114
|
+
print(f"Sample-split stability (structured): {sample_stability_structured:.3f}")
|
|
115
|
+
print(f"Sample-split stability (random): {sample_stability_random:.3f}")
|
|
116
|
+
|
|
117
|
+
# %% [markdown]
|
|
118
|
+
# ## 4. Anchor Stability
|
|
119
|
+
#
|
|
120
|
+
# Uses fixed anchor points to measure distance profile consistency.
|
|
121
|
+
# More robust for large datasets.
|
|
122
|
+
|
|
123
|
+
# %%
|
|
124
|
+
anchor_stability_structured = shesha.anchor_stability(X_structured, n_splits=30, seed=320)
|
|
125
|
+
anchor_stability_random = shesha.anchor_stability(X_random, n_splits=30, seed=320)
|
|
126
|
+
|
|
127
|
+
print(f"Anchor stability (structured): {anchor_stability_structured:.3f}")
|
|
128
|
+
print(f"Anchor stability (random): {anchor_stability_random:.3f}")
|
|
129
|
+
|
|
130
|
+
# %% [markdown]
|
|
131
|
+
# ## 5. Supervised Variants
|
|
132
|
+
#
|
|
133
|
+
# When you have class labels, supervised variants measure task-relevant stability.
|
|
134
|
+
|
|
135
|
+
# %%
|
|
136
|
+
# Create labeled data with clear class structure
|
|
137
|
+
n_per_class = 100
|
|
138
|
+
n_classes = 5
|
|
139
|
+
|
|
140
|
+
# Well-separated clusters
|
|
141
|
+
X_separated = np.vstack([
|
|
142
|
+
np.random.randn(n_per_class, 768) + np.random.randn(768) * 3
|
|
143
|
+
for _ in range(n_classes)
|
|
144
|
+
])
|
|
145
|
+
y_separated = np.repeat(np.arange(n_classes), n_per_class)
|
|
146
|
+
|
|
147
|
+
# Overlapping clusters (harder to separate)
|
|
148
|
+
X_overlapping = np.vstack([
|
|
149
|
+
np.random.randn(n_per_class, 768) + np.random.randn(768) * 0.5
|
|
150
|
+
for _ in range(n_classes)
|
|
151
|
+
])
|
|
152
|
+
y_overlapping = np.repeat(np.arange(n_classes), n_per_class)
|
|
153
|
+
|
|
154
|
+
# %%
|
|
155
|
+
# Variance ratio: between-class / total variance
|
|
156
|
+
vr_separated = shesha.variance_ratio(X_separated, y_separated)
|
|
157
|
+
vr_overlapping = shesha.variance_ratio(X_overlapping, y_overlapping)
|
|
158
|
+
|
|
159
|
+
print(f"Variance ratio (well-separated): {vr_separated:.3f}")
|
|
160
|
+
print(f"Variance ratio (overlapping): {vr_overlapping:.3f}")
|
|
161
|
+
|
|
162
|
+
# %%
|
|
163
|
+
# Supervised alignment: correlation with ideal label-based RDM
|
|
164
|
+
align_separated = shesha.supervised_alignment(X_separated, y_separated, seed=320)
|
|
165
|
+
align_overlapping = shesha.supervised_alignment(X_overlapping, y_overlapping, seed=320)
|
|
166
|
+
|
|
167
|
+
print(f"Supervised alignment (well-separated): {align_separated:.3f}")
|
|
168
|
+
print(f"Supervised alignment (overlapping): {align_overlapping:.3f}")
|
|
169
|
+
|
|
170
|
+
# %% [markdown]
|
|
171
|
+
# ## 6. Practical Application: Comparing Embedding Models
|
|
172
|
+
#
|
|
173
|
+
# A common use case: which embedding model has more stable representations?
|
|
174
|
+
|
|
175
|
+
# %%
|
|
176
|
+
# Simulate embeddings from different "models" with varying structure
|
|
177
|
+
def simulate_embeddings(n_samples, n_features, latent_dim, noise_level):
|
|
178
|
+
"""Simulate embeddings with controllable structure and noise."""
|
|
179
|
+
latent = np.random.randn(n_samples, latent_dim)
|
|
180
|
+
projection = np.random.randn(latent_dim, n_features)
|
|
181
|
+
signal = latent @ projection
|
|
182
|
+
signal = signal / np.std(signal)
|
|
183
|
+
noise = np.random.randn(n_samples, n_features) * noise_level
|
|
184
|
+
return signal + noise
|
|
185
|
+
|
|
186
|
+
# Simulate 4 "models" with different properties
|
|
187
|
+
models = {
|
|
188
|
+
'Model A (high rank, low noise)': simulate_embeddings(500, 768, 100, 0.1),
|
|
189
|
+
'Model B (high rank, high noise)': simulate_embeddings(500, 768, 100, 1.0),
|
|
190
|
+
'Model C (low rank, low noise)': simulate_embeddings(500, 768, 20, 0.1),
|
|
191
|
+
'Model D (low rank, high noise)': simulate_embeddings(500, 768, 20, 1.0),
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
# Compare stability across models
|
|
195
|
+
print("Model Comparison (Feature-Split Stability):")
|
|
196
|
+
print("-" * 50)
|
|
197
|
+
|
|
198
|
+
results = {}
|
|
199
|
+
for name, X in models.items():
|
|
200
|
+
stability = shesha.feature_split(X, n_splits=30, seed=320)
|
|
201
|
+
results[name] = stability
|
|
202
|
+
print(f"{name}: {stability:.3f}")
|
|
203
|
+
|
|
204
|
+
# %%
|
|
205
|
+
# Visualize model comparison
|
|
206
|
+
fig, ax = plt.subplots(figsize=(10, 5))
|
|
207
|
+
|
|
208
|
+
names = list(results.keys())
|
|
209
|
+
values = list(results.values())
|
|
210
|
+
colors = ['#2ecc71', '#e74c3c', '#3498db', '#9b59b6']
|
|
211
|
+
|
|
212
|
+
bars = ax.barh(names, values, color=colors)
|
|
213
|
+
ax.set_xlabel('Feature-Split Stability')
|
|
214
|
+
ax.set_title('Embedding Model Stability Comparison')
|
|
215
|
+
ax.set_xlim(0, 1)
|
|
216
|
+
|
|
217
|
+
for bar, val in zip(bars, values):
|
|
218
|
+
ax.text(val + 0.02, bar.get_y() + bar.get_height()/2,
|
|
219
|
+
f'{val:.3f}', va='center', fontsize=10)
|
|
220
|
+
|
|
221
|
+
plt.tight_layout()
|
|
222
|
+
plt.savefig('tutorial_model_comparison.png', dpi=150)
|
|
223
|
+
plt.show()
|
|
224
|
+
|
|
225
|
+
# %% [markdown]
|
|
226
|
+
# ## 7. Practical Application: Monitoring Training Drift
|
|
227
|
+
#
|
|
228
|
+
# Track how representation stability changes during fine-tuning.
|
|
229
|
+
|
|
230
|
+
# %%
|
|
231
|
+
# Simulate embeddings at different "epochs" of training
|
|
232
|
+
# Early: random initialization, Late: structured representations
|
|
233
|
+
|
|
234
|
+
def simulate_training_trajectory(n_epochs=10):
|
|
235
|
+
"""Simulate how embeddings evolve during training."""
|
|
236
|
+
n_samples, n_features = 500, 768
|
|
237
|
+
|
|
238
|
+
# Start with noise, gradually add structure
|
|
239
|
+
embeddings = []
|
|
240
|
+
for epoch in range(n_epochs):
|
|
241
|
+
structure_weight = epoch / (n_epochs - 1) # 0 to 1
|
|
242
|
+
|
|
243
|
+
# Structured component
|
|
244
|
+
latent = np.random.randn(n_samples, 50)
|
|
245
|
+
projection = np.random.randn(50, n_features)
|
|
246
|
+
structured = latent @ projection
|
|
247
|
+
structured = structured / np.std(structured)
|
|
248
|
+
|
|
249
|
+
# Random component
|
|
250
|
+
random = np.random.randn(n_samples, n_features)
|
|
251
|
+
|
|
252
|
+
# Mix based on epoch
|
|
253
|
+
X = structure_weight * structured + (1 - structure_weight) * random
|
|
254
|
+
embeddings.append(X)
|
|
255
|
+
|
|
256
|
+
return embeddings
|
|
257
|
+
|
|
258
|
+
# Generate trajectory
|
|
259
|
+
embeddings_over_time = simulate_training_trajectory(n_epochs=10)
|
|
260
|
+
|
|
261
|
+
# Track stability
|
|
262
|
+
stabilities = []
|
|
263
|
+
for epoch, X in enumerate(embeddings_over_time):
|
|
264
|
+
stability = shesha.feature_split(X, n_splits=30, seed=320)
|
|
265
|
+
stabilities.append(stability)
|
|
266
|
+
print(f"Epoch {epoch}: stability = {stability:.3f}")
|
|
267
|
+
|
|
268
|
+
# %%
|
|
269
|
+
# Plot training trajectory
|
|
270
|
+
fig, ax = plt.subplots(figsize=(8, 5))
|
|
271
|
+
|
|
272
|
+
ax.plot(range(len(stabilities)), stabilities, 'o-', linewidth=2, markersize=8, color='steelblue')
|
|
273
|
+
ax.fill_between(range(len(stabilities)), stabilities, alpha=0.3, color='steelblue')
|
|
274
|
+
|
|
275
|
+
ax.set_xlabel('Epoch')
|
|
276
|
+
ax.set_ylabel('Feature-Split Stability')
|
|
277
|
+
ax.set_title('Representation Stability During Training')
|
|
278
|
+
ax.set_ylim(0, 1)
|
|
279
|
+
ax.grid(True, alpha=0.3)
|
|
280
|
+
|
|
281
|
+
plt.tight_layout()
|
|
282
|
+
plt.savefig('tutorial_training_drift.png', dpi=150)
|
|
283
|
+
plt.show()
|
|
284
|
+
|
|
285
|
+
# %% [markdown]
|
|
286
|
+
# ## 8. Unified Interface
|
|
287
|
+
#
|
|
288
|
+
# Use `shesha.shesha()` for a single entry point to all variants.
|
|
289
|
+
|
|
290
|
+
# %%
|
|
291
|
+
X = np.random.randn(500, 768)
|
|
292
|
+
y = np.random.randint(0, 5, 500)
|
|
293
|
+
|
|
294
|
+
# All variants through unified interface
|
|
295
|
+
print("Unified interface examples:")
|
|
296
|
+
print(f" feature_split: {shesha.shesha(X, variant='feature_split', seed=320):.3f}")
|
|
297
|
+
print(f" sample_split: {shesha.shesha(X, variant='sample_split', seed=320):.3f}")
|
|
298
|
+
print(f" anchor: {shesha.shesha(X, variant='anchor', seed=320):.3f}")
|
|
299
|
+
print(f" variance: {shesha.shesha(X, y, variant='variance'):.3f}")
|
|
300
|
+
print(f" supervised: {shesha.shesha(X, y, variant='supervised', seed=320):.3f}")
|
|
301
|
+
|
|
302
|
+
# %% [markdown]
|
|
303
|
+
# ## 9. Tips and Best Practices
|
|
304
|
+
#
|
|
305
|
+
# ### Choosing a variant:
|
|
306
|
+
# - **feature_split**: Default choice for unsupervised analysis. Good for drift detection, intrinsic quality.
|
|
307
|
+
# - **sample_split**: When you care about robustness to sampling. Good for small datasets.
|
|
308
|
+
# - **anchor_stability**: For very large datasets where feature_split is slow.
|
|
309
|
+
# - **variance_ratio**: Quick supervised check. Computationally cheap.
|
|
310
|
+
# - **supervised_alignment**: When you want RDM-based task alignment (more nuanced than variance_ratio).
|
|
311
|
+
#
|
|
312
|
+
# ### Parameter recommendations:
|
|
313
|
+
# - `n_splits=30` is usually sufficient; increase to 50+ for publication-quality results
|
|
314
|
+
# - `seed` should always be set for reproducibility
|
|
315
|
+
# - `max_samples` prevents memory issues with large datasets
|
|
316
|
+
#
|
|
317
|
+
# ### Interpretation:
|
|
318
|
+
# - **feature_split > 0.7**: Strong internal consistency, distributed structure
|
|
319
|
+
# - **feature_split 0.3-0.7**: Moderate consistency
|
|
320
|
+
# - **feature_split < 0.3**: Weak consistency, possibly noisy or sparse structure
|
|
321
|
+
# - **variance_ratio**: Directly interpretable as "fraction of variance explained by classes"
|
|
322
|
+
|
|
323
|
+
# %% [markdown]
|
|
324
|
+
# ## 10. Summary
|
|
325
|
+
#
|
|
326
|
+
# | Variant | Supervised | Best For |
|
|
327
|
+
# |---------|------------|----------|
|
|
328
|
+
# | `feature_split` | No | General stability, drift detection |
|
|
329
|
+
# | `sample_split` | No | Sampling robustness |
|
|
330
|
+
# | `anchor_stability` | No | Large-scale analysis |
|
|
331
|
+
# | `variance_ratio` | Yes | Quick separability check |
|
|
332
|
+
# | `supervised_alignment` | Yes | Task alignment |
|
|
333
|
+
|
|
334
|
+
# %%
|
|
335
|
+
print("Tutorial complete!")
|
|
336
|
+
print("\nFor more information, see: https://github.com/prashantcraju/shesha")
|
shesha/__init__.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shesha: Self-consistency Metrics for Representational Stability
|
|
3
|
+
|
|
4
|
+
A framework for measuring geometric stability via self-consistency of
|
|
5
|
+
Representational Dissimilarity Matrices (RDMs).
|
|
6
|
+
|
|
7
|
+
Basic usage:
|
|
8
|
+
>>> import shesha
|
|
9
|
+
>>> stability = shesha.feature_split(X, n_splits=30, seed=320)
|
|
10
|
+
|
|
11
|
+
>>> # Or with labels
|
|
12
|
+
>>> alignment = shesha.supervised_alignment(X, y)
|
|
13
|
+
|
|
14
|
+
>>> # Unified interface
|
|
15
|
+
>>> score = shesha.shesha(X, variant='feature_split')
|
|
16
|
+
|
|
17
|
+
>>> # Measure drift between representations
|
|
18
|
+
>>> similarity = shesha.rdm_similarity(X_before, X_after)
|
|
19
|
+
>>> drift = shesha.rdm_drift(X_before, X_after)
|
|
20
|
+
|
|
21
|
+
>>> # Biological perturbation analysis
|
|
22
|
+
>>> from shesha.bio import perturbation_stability
|
|
23
|
+
>>> stability = perturbation_stability(X_control, X_perturbed)
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from .core import (
|
|
27
|
+
# Main function
|
|
28
|
+
shesha,
|
|
29
|
+
# Unsupervised variants
|
|
30
|
+
feature_split,
|
|
31
|
+
sample_split,
|
|
32
|
+
anchor_stability,
|
|
33
|
+
# Supervised variants
|
|
34
|
+
variance_ratio,
|
|
35
|
+
supervised_alignment,
|
|
36
|
+
# Drift metrics
|
|
37
|
+
rdm_similarity,
|
|
38
|
+
rdm_drift,
|
|
39
|
+
# Utilities
|
|
40
|
+
compute_rdm,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
from . import bio
|
|
44
|
+
|
|
45
|
+
__version__ = "0.1.0"
|
|
46
|
+
__author__ = "Prashant Raju"
|
|
47
|
+
|
|
48
|
+
__all__ = [
|
|
49
|
+
"shesha",
|
|
50
|
+
"feature_split",
|
|
51
|
+
"sample_split",
|
|
52
|
+
"anchor_stability",
|
|
53
|
+
"variance_ratio",
|
|
54
|
+
"supervised_alignment",
|
|
55
|
+
"rdm_similarity",
|
|
56
|
+
"rdm_drift",
|
|
57
|
+
"compute_rdm",
|
|
58
|
+
"bio",
|
|
59
|
+
]
|
shesha/bio.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shesha Bio: Stability metrics for biological perturbation experiments.
|
|
3
|
+
|
|
4
|
+
This module provides Shesha variants for single-cell and perturbation biology,
|
|
5
|
+
measuring the consistency of perturbation effects across individual cells.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from typing import Optional, Literal
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"perturbation_stability",
|
|
13
|
+
"perturbation_effect_size",
|
|
14
|
+
"compute_stability",
|
|
15
|
+
"compute_magnitude"
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
EPS = 1e-12
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def perturbation_stability(
|
|
22
|
+
X_control: np.ndarray,
|
|
23
|
+
X_perturbed: np.ndarray,
|
|
24
|
+
metric: Literal["cosine", "euclidean"] = "cosine",
|
|
25
|
+
seed: Optional[int] = None,
|
|
26
|
+
max_samples: Optional[int] = 1000,
|
|
27
|
+
) -> float:
|
|
28
|
+
"""
|
|
29
|
+
Perturbation stability: consistency of perturbation effects across samples.
|
|
30
|
+
|
|
31
|
+
Measures whether individual perturbed samples shift in a consistent direction
|
|
32
|
+
relative to the control population. High values indicate that the perturbation
|
|
33
|
+
has a coherent, reproducible effect; low values suggest heterogeneous or noisy
|
|
34
|
+
responses.
|
|
35
|
+
|
|
36
|
+
The metric computes the mean cosine similarity between each perturbed sample's
|
|
37
|
+
shift vector (relative to the control centroid) and the mean shift direction.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
X_control : np.ndarray
|
|
42
|
+
Control population embeddings, shape (n_control, n_features).
|
|
43
|
+
X_perturbed : np.ndarray
|
|
44
|
+
Perturbed population embeddings, shape (n_perturbed, n_features).
|
|
45
|
+
metric : str
|
|
46
|
+
How to measure directional consistency:
|
|
47
|
+
- 'cosine': Cosine similarity of shift vectors to mean direction (default)
|
|
48
|
+
- 'euclidean': Normalized euclidean consistency
|
|
49
|
+
seed : int, optional
|
|
50
|
+
Random seed for subsampling reproducibility.
|
|
51
|
+
max_samples : int, optional
|
|
52
|
+
Subsample perturbed population if exceeded.
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
float
|
|
57
|
+
Stability score in [-1, 1] for cosine metric. Higher = more consistent
|
|
58
|
+
perturbation effect. Values near 1 indicate all samples shift in the
|
|
59
|
+
same direction; values near 0 indicate random/inconsistent shifts.
|
|
60
|
+
|
|
61
|
+
Examples
|
|
62
|
+
--------
|
|
63
|
+
>>> # Control and perturbed cell populations
|
|
64
|
+
>>> X_ctrl = np.random.randn(500, 50) # 500 control cells, 50 genes
|
|
65
|
+
>>>
|
|
66
|
+
>>> # Coherent perturbation: all cells shift similarly
|
|
67
|
+
>>> shift = np.random.randn(50) # consistent direction
|
|
68
|
+
>>> X_pert_coherent = X_ctrl + shift + np.random.randn(500, 50) * 0.1
|
|
69
|
+
>>> stability = perturbation_stability(X_ctrl, X_pert_coherent)
|
|
70
|
+
>>> print(f"Coherent perturbation: {stability:.3f}") # High value
|
|
71
|
+
>>>
|
|
72
|
+
>>> # Incoherent perturbation: cells shift randomly
|
|
73
|
+
>>> X_pert_random = X_ctrl + np.random.randn(500, 50)
|
|
74
|
+
>>> stability = perturbation_stability(X_ctrl, X_pert_random)
|
|
75
|
+
>>> print(f"Random perturbation: {stability:.3f}") # Low value
|
|
76
|
+
|
|
77
|
+
Notes
|
|
78
|
+
-----
|
|
79
|
+
This metric is designed for single-cell perturbation experiments (e.g.,
|
|
80
|
+
Perturb-seq, CRISPR screens) where you want to assess whether a genetic
|
|
81
|
+
perturbation produces a consistent phenotypic shift across cells.
|
|
82
|
+
|
|
83
|
+
The control centroid is used as the reference point. Each perturbed cell's
|
|
84
|
+
shift vector is computed as (x_perturbed - centroid_control), and these
|
|
85
|
+
are compared to the mean shift direction.
|
|
86
|
+
"""
|
|
87
|
+
X_control = np.asarray(X_control, dtype=np.float64)
|
|
88
|
+
X_perturbed = np.asarray(X_perturbed, dtype=np.float64)
|
|
89
|
+
|
|
90
|
+
if X_control.shape[1] != X_perturbed.shape[1]:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"Feature dimensions must match: control has {X_control.shape[1]}, "
|
|
93
|
+
f"perturbed has {X_perturbed.shape[1]}"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
if len(X_control) < 5:
|
|
97
|
+
return np.nan
|
|
98
|
+
if len(X_perturbed) < 5:
|
|
99
|
+
return np.nan
|
|
100
|
+
|
|
101
|
+
rng = np.random.default_rng(seed)
|
|
102
|
+
|
|
103
|
+
# Subsample perturbed if needed
|
|
104
|
+
if max_samples is not None and len(X_perturbed) > max_samples:
|
|
105
|
+
idx = rng.choice(len(X_perturbed), max_samples, replace=False)
|
|
106
|
+
X_perturbed = X_perturbed[idx]
|
|
107
|
+
|
|
108
|
+
# Compute control centroid
|
|
109
|
+
control_centroid = np.mean(X_control, axis=0)
|
|
110
|
+
|
|
111
|
+
# Compute shift vectors for each perturbed sample
|
|
112
|
+
shift_vectors = X_perturbed - control_centroid
|
|
113
|
+
|
|
114
|
+
# Compute mean shift direction
|
|
115
|
+
mean_shift = np.mean(shift_vectors, axis=0)
|
|
116
|
+
mean_shift_norm = np.linalg.norm(mean_shift)
|
|
117
|
+
|
|
118
|
+
if mean_shift_norm < EPS:
|
|
119
|
+
# No net shift - perturbation has no coherent effect
|
|
120
|
+
return 0.0
|
|
121
|
+
|
|
122
|
+
if metric == "cosine":
|
|
123
|
+
# Normalize mean shift
|
|
124
|
+
mean_shift_unit = mean_shift / mean_shift_norm
|
|
125
|
+
|
|
126
|
+
# Compute cosine similarity of each shift to mean direction
|
|
127
|
+
shift_norms = np.linalg.norm(shift_vectors, axis=1, keepdims=True)
|
|
128
|
+
shift_norms = np.maximum(shift_norms, EPS)
|
|
129
|
+
shift_unit = shift_vectors / shift_norms
|
|
130
|
+
|
|
131
|
+
# Cosine similarities
|
|
132
|
+
cosines = shift_unit @ mean_shift_unit
|
|
133
|
+
|
|
134
|
+
return float(np.mean(cosines))
|
|
135
|
+
|
|
136
|
+
elif metric == "euclidean":
|
|
137
|
+
# Euclidean-based consistency: how tight are shifts around mean?
|
|
138
|
+
# Normalized by expected variance under random shifts
|
|
139
|
+
deviations = shift_vectors - mean_shift
|
|
140
|
+
deviation_var = np.mean(np.sum(deviations ** 2, axis=1))
|
|
141
|
+
total_var = np.mean(np.sum(shift_vectors ** 2, axis=1))
|
|
142
|
+
|
|
143
|
+
if total_var < EPS:
|
|
144
|
+
return np.nan
|
|
145
|
+
|
|
146
|
+
# 1 - (deviation / total) gives consistency score
|
|
147
|
+
consistency = 1.0 - (deviation_var / total_var)
|
|
148
|
+
return float(np.clip(consistency, -1, 1))
|
|
149
|
+
|
|
150
|
+
else:
|
|
151
|
+
raise ValueError(f"Unknown metric: {metric}. Use 'cosine' or 'euclidean'.")
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
from anndata import AnnData
|
|
155
|
+
except ImportError:
|
|
156
|
+
AnnData = None
|
|
157
|
+
|
|
158
|
+
def compute_stability(
|
|
159
|
+
adata: "AnnData",
|
|
160
|
+
perturbation_key: str,
|
|
161
|
+
control_label: str = "control",
|
|
162
|
+
layer: Optional[str] = None,
|
|
163
|
+
**kwargs
|
|
164
|
+
) -> dict:
|
|
165
|
+
"""
|
|
166
|
+
Scanpy-compatible wrapper for perturbation stability.
|
|
167
|
+
|
|
168
|
+
Computes stability for all perturbations in an AnnData object.
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
adata : AnnData
|
|
173
|
+
Annotated data matrix.
|
|
174
|
+
perturbation_key : str
|
|
175
|
+
Column in adata.obs containing perturbation labels (e.g. 'guide_id').
|
|
176
|
+
control_label : str
|
|
177
|
+
The label in perturbation_key representing control cells (e.g. 'NT').
|
|
178
|
+
|
|
179
|
+
Returns
|
|
180
|
+
-------
|
|
181
|
+
dict
|
|
182
|
+
Dictionary mapping perturbation names to stability scores.
|
|
183
|
+
"""
|
|
184
|
+
if AnnData is None or not isinstance(adata, AnnData):
|
|
185
|
+
raise ImportError("anndata is required for this function.")
|
|
186
|
+
|
|
187
|
+
# Get control data
|
|
188
|
+
ctrl_mask = adata.obs[perturbation_key] == control_label
|
|
189
|
+
if layer:
|
|
190
|
+
X_ctrl = adata[ctrl_mask].layers[layer]
|
|
191
|
+
else:
|
|
192
|
+
X_ctrl = adata[ctrl_mask].X
|
|
193
|
+
|
|
194
|
+
# Handle sparse matrices
|
|
195
|
+
if hasattr(X_ctrl, "toarray"):
|
|
196
|
+
X_ctrl = X_ctrl.toarray()
|
|
197
|
+
|
|
198
|
+
results = {}
|
|
199
|
+
perturbations = adata.obs[perturbation_key].unique()
|
|
200
|
+
|
|
201
|
+
for pert in perturbations:
|
|
202
|
+
if pert == control_label:
|
|
203
|
+
continue
|
|
204
|
+
|
|
205
|
+
pert_mask = adata.obs[perturbation_key] == pert
|
|
206
|
+
if layer:
|
|
207
|
+
X_pert = adata[pert_mask].layers[layer]
|
|
208
|
+
else:
|
|
209
|
+
X_pert = adata[pert_mask].X
|
|
210
|
+
|
|
211
|
+
if hasattr(X_pert, "toarray"):
|
|
212
|
+
X_pert = X_pert.toarray()
|
|
213
|
+
|
|
214
|
+
score = perturbation_stability(X_ctrl, X_pert, **kwargs)
|
|
215
|
+
results[pert] = score
|
|
216
|
+
|
|
217
|
+
return results
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def perturbation_effect_size(
|
|
221
|
+
X_control: np.ndarray,
|
|
222
|
+
X_perturbed: np.ndarray,
|
|
223
|
+
metric: Literal["euclidean", "cohen"] = "euclidean"
|
|
224
|
+
) -> float:
|
|
225
|
+
"""
|
|
226
|
+
Compute the magnitude of the perturbation effect.
|
|
227
|
+
|
|
228
|
+
Parameters
|
|
229
|
+
----------
|
|
230
|
+
X_control : np.ndarray
|
|
231
|
+
Control population embeddings.
|
|
232
|
+
X_perturbed : np.ndarray
|
|
233
|
+
Perturbed population embeddings.
|
|
234
|
+
metric : str, default="euclidean"
|
|
235
|
+
- 'euclidean': Raw L2 distance between centroids (Magnitude).
|
|
236
|
+
Use this for geometric plots (Stability vs Magnitude).
|
|
237
|
+
- 'cohen': Standardized effect size (Magnitude / Pooled SD).
|
|
238
|
+
Use this for statistical power analysis.
|
|
239
|
+
|
|
240
|
+
Returns
|
|
241
|
+
-------
|
|
242
|
+
float
|
|
243
|
+
The calculated magnitude/effect size.
|
|
244
|
+
"""
|
|
245
|
+
X_control = np.asarray(X_control, dtype=np.float64)
|
|
246
|
+
X_perturbed = np.asarray(X_perturbed, dtype=np.float64)
|
|
247
|
+
|
|
248
|
+
control_centroid = np.mean(X_control, axis=0)
|
|
249
|
+
perturbed_centroid = np.mean(X_perturbed, axis=0)
|
|
250
|
+
|
|
251
|
+
# 1. Raw Magnitude (Euclidean Distance)
|
|
252
|
+
shift_magnitude = np.linalg.norm(perturbed_centroid - control_centroid)
|
|
253
|
+
|
|
254
|
+
if metric == "euclidean":
|
|
255
|
+
return float(shift_magnitude)
|
|
256
|
+
|
|
257
|
+
elif metric == "cohen":
|
|
258
|
+
# 2. Standardized Effect Size (Cohen's d-like)
|
|
259
|
+
# Pooled standard deviation (averaged across features)
|
|
260
|
+
control_var = np.var(X_control, axis=0, ddof=1)
|
|
261
|
+
perturbed_var = np.var(X_perturbed, axis=0, ddof=1)
|
|
262
|
+
|
|
263
|
+
# Average variance across features to get a scalar scale
|
|
264
|
+
pooled_var = np.mean((control_var + perturbed_var) / 2)
|
|
265
|
+
pooled_std = np.sqrt(pooled_var) + EPS
|
|
266
|
+
|
|
267
|
+
return float(shift_magnitude / pooled_std)
|
|
268
|
+
|
|
269
|
+
else:
|
|
270
|
+
raise ValueError(f"Unknown metric: {metric}")
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def compute_magnitude(
|
|
274
|
+
adata: "AnnData",
|
|
275
|
+
perturbation_key: str,
|
|
276
|
+
control_label: str = "control",
|
|
277
|
+
metric: str = "euclidean",
|
|
278
|
+
layer: Optional[str] = None,
|
|
279
|
+
) -> dict:
|
|
280
|
+
"""
|
|
281
|
+
Scanpy-compatible wrapper for perturbation magnitude.
|
|
282
|
+
"""
|
|
283
|
+
if AnnData is None or not isinstance(adata, AnnData):
|
|
284
|
+
raise ImportError("anndata is required for this function.")
|
|
285
|
+
|
|
286
|
+
# Get control data
|
|
287
|
+
ctrl_mask = adata.obs[perturbation_key] == control_label
|
|
288
|
+
if layer:
|
|
289
|
+
X_ctrl = adata[ctrl_mask].layers[layer]
|
|
290
|
+
else:
|
|
291
|
+
X_ctrl = adata[ctrl_mask].X
|
|
292
|
+
|
|
293
|
+
if hasattr(X_ctrl, "toarray"):
|
|
294
|
+
X_ctrl = X_ctrl.toarray()
|
|
295
|
+
|
|
296
|
+
results = {}
|
|
297
|
+
perturbations = adata.obs[perturbation_key].unique()
|
|
298
|
+
|
|
299
|
+
for pert in perturbations:
|
|
300
|
+
if pert == control_label:
|
|
301
|
+
continue
|
|
302
|
+
|
|
303
|
+
pert_mask = adata.obs[perturbation_key] == pert
|
|
304
|
+
if layer:
|
|
305
|
+
X_pert = adata[pert_mask].layers[layer]
|
|
306
|
+
else:
|
|
307
|
+
X_pert = adata[pert_mask].X
|
|
308
|
+
|
|
309
|
+
if hasattr(X_pert, "toarray"):
|
|
310
|
+
X_pert = X_pert.toarray()
|
|
311
|
+
|
|
312
|
+
score = perturbation_effect_size(X_ctrl, X_pert, metric=metric)
|
|
313
|
+
results[pert] = score
|
|
314
|
+
|
|
315
|
+
return results
|