shesha-geometry 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
examples/tutorial.py ADDED
@@ -0,0 +1,336 @@
1
+ # %% [markdown]
2
+ # # Shesha Tutorial
3
+ #
4
+ # This tutorial demonstrates how to use SHESHA (Self-consistency metrics for representational stability analysis) to measure geometric stability of high-dimensional representations.
5
+ #
6
+ # **What you'll learn:**
7
+ # 1. Basic usage of unsupervised variants (feature_split, sample_split, anchor_stability)
8
+ # 2. Supervised variants with labels (variance_ratio, supervised_alignment)
9
+ # 3. Practical applications: comparing models, detecting drift, analyzing embeddings
10
+
11
+ # %% [markdown]
12
+ # ## Installation
13
+ #
14
+ # ```bash
15
+ # pip install shesha-geometry
16
+ # ```
17
+
18
+ # %%
19
+ import numpy as np
20
+ import matplotlib.pyplot as plt
21
+ import shesha
22
+
23
+ print(f"shesha version: {shesha.__version__}")
24
+
25
+ # Set random seed for reproducibility
26
+ np.random.seed(320)
27
+
28
+ # %% [markdown]
29
+ # ## 1. Understanding Shesha: The Core Idea
30
+ #
31
+ # Shesha measures **geometric stability** - whether a representation's distance structure
32
+ # is internally consistent. Unlike similarity metrics (CKA, Procrustes) that compare
33
+ # *between* representations, Shesha measures consistency *within* a single representation.
34
+ #
35
+ # **Key insight:** A stable representation should produce similar pairwise distance patterns
36
+ # (RDMs) when computed from different "views" of the same data.
37
+
38
+ # %% [markdown]
39
+ # ## 2. Feature-Split Shesha (Unsupervised)
40
+ #
41
+ # The most common variant. Splits feature dimensions into random halves and checks if
42
+ # both halves encode consistent distance relationships.
43
+ #
44
+ # **High stability** = geometric structure is distributed across features (redundant encoding)
45
+ # **Low stability** = structure concentrated in few features or noisy
46
+
47
+ # %%
48
+ # Example 1: Structured data (high stability expected)
49
+ # Create data with low-rank structure - geometry should be consistent across feature subsets
50
+ n_samples, latent_dim, feature_dim = 500, 20, 768
51
+
52
+ latent = np.random.randn(n_samples, latent_dim)
53
+ projection = np.random.randn(latent_dim, feature_dim)
54
+ X_structured = latent @ projection
55
+
56
+ stability_structured = shesha.feature_split(X_structured, n_splits=30, seed=320)
57
+ print(f"Structured data stability: {stability_structured:.3f}")
58
+
59
+ # %%
60
+ # Example 2: Random noise (lower stability expected)
61
+ # Each feature is independent - no consistent structure across subsets
62
+ X_random = np.random.randn(n_samples, feature_dim)
63
+
64
+ stability_random = shesha.feature_split(X_random, n_splits=30, seed=320)
65
+ print(f"Random noise stability: {stability_random:.3f}")
66
+
67
+ # %%
68
+ # Visualize the difference
69
+ fig, axes = plt.subplots(1, 2, figsize=(10, 4))
70
+
71
+ # Run multiple seeds to show distribution
72
+ stabilities_structured = [shesha.feature_split(X_structured, n_splits=30, seed=s) for s in range(10)]
73
+ stabilities_random = [shesha.feature_split(X_random, n_splits=30, seed=s) for s in range(10)]
74
+
75
+ axes[0].bar(['Structured', 'Random'],
76
+ [np.mean(stabilities_structured), np.mean(stabilities_random)],
77
+ yerr=[np.std(stabilities_structured), np.std(stabilities_random)],
78
+ capsize=5, color=['steelblue', 'coral'])
79
+ axes[0].set_ylabel('Feature-Split Stability')
80
+ axes[0].set_title('Stability Comparison')
81
+ axes[0].set_ylim(0, 1)
82
+
83
+ # Show how stability varies with latent dimension
84
+ latent_dims = [5, 10, 20, 50, 100, 200]
85
+ stabilities_by_dim = []
86
+
87
+ for ld in latent_dims:
88
+ latent = np.random.randn(n_samples, ld)
89
+ proj = np.random.randn(ld, feature_dim)
90
+ X = latent @ proj
91
+ stabilities_by_dim.append(shesha.feature_split(X, n_splits=30, seed=320))
92
+
93
+ axes[1].plot(latent_dims, stabilities_by_dim, 'o-', color='steelblue', linewidth=2)
94
+ axes[1].set_xlabel('Latent Dimensionality')
95
+ axes[1].set_ylabel('Feature-Split Stability')
96
+ axes[1].set_title('Stability vs. Latent Rank')
97
+ axes[1].set_ylim(0, 1)
98
+
99
+ plt.tight_layout()
100
+ plt.savefig('tutorial_feature_split.png', dpi=150)
101
+ plt.show()
102
+
103
+ # %% [markdown]
104
+ # ## 3. Sample-Split Shesha (Bootstrap)
105
+ #
106
+ # Measures robustness to sampling variation by computing RDMs on different
107
+ # random subsets of data points.
108
+
109
+ # %%
110
+ # Compare sample-split stability
111
+ sample_stability_structured = shesha.sample_split(X_structured, n_splits=30, seed=320)
112
+ sample_stability_random = shesha.sample_split(X_random, n_splits=30, seed=320)
113
+
114
+ print(f"Sample-split stability (structured): {sample_stability_structured:.3f}")
115
+ print(f"Sample-split stability (random): {sample_stability_random:.3f}")
116
+
117
+ # %% [markdown]
118
+ # ## 4. Anchor Stability
119
+ #
120
+ # Uses fixed anchor points to measure distance profile consistency.
121
+ # More robust for large datasets.
122
+
123
+ # %%
124
+ anchor_stability_structured = shesha.anchor_stability(X_structured, n_splits=30, seed=320)
125
+ anchor_stability_random = shesha.anchor_stability(X_random, n_splits=30, seed=320)
126
+
127
+ print(f"Anchor stability (structured): {anchor_stability_structured:.3f}")
128
+ print(f"Anchor stability (random): {anchor_stability_random:.3f}")
129
+
130
+ # %% [markdown]
131
+ # ## 5. Supervised Variants
132
+ #
133
+ # When you have class labels, supervised variants measure task-relevant stability.
134
+
135
+ # %%
136
+ # Create labeled data with clear class structure
137
+ n_per_class = 100
138
+ n_classes = 5
139
+
140
+ # Well-separated clusters
141
+ X_separated = np.vstack([
142
+ np.random.randn(n_per_class, 768) + np.random.randn(768) * 3
143
+ for _ in range(n_classes)
144
+ ])
145
+ y_separated = np.repeat(np.arange(n_classes), n_per_class)
146
+
147
+ # Overlapping clusters (harder to separate)
148
+ X_overlapping = np.vstack([
149
+ np.random.randn(n_per_class, 768) + np.random.randn(768) * 0.5
150
+ for _ in range(n_classes)
151
+ ])
152
+ y_overlapping = np.repeat(np.arange(n_classes), n_per_class)
153
+
154
+ # %%
155
+ # Variance ratio: between-class / total variance
156
+ vr_separated = shesha.variance_ratio(X_separated, y_separated)
157
+ vr_overlapping = shesha.variance_ratio(X_overlapping, y_overlapping)
158
+
159
+ print(f"Variance ratio (well-separated): {vr_separated:.3f}")
160
+ print(f"Variance ratio (overlapping): {vr_overlapping:.3f}")
161
+
162
+ # %%
163
+ # Supervised alignment: correlation with ideal label-based RDM
164
+ align_separated = shesha.supervised_alignment(X_separated, y_separated, seed=320)
165
+ align_overlapping = shesha.supervised_alignment(X_overlapping, y_overlapping, seed=320)
166
+
167
+ print(f"Supervised alignment (well-separated): {align_separated:.3f}")
168
+ print(f"Supervised alignment (overlapping): {align_overlapping:.3f}")
169
+
170
+ # %% [markdown]
171
+ # ## 6. Practical Application: Comparing Embedding Models
172
+ #
173
+ # A common use case: which embedding model has more stable representations?
174
+
175
+ # %%
176
+ # Simulate embeddings from different "models" with varying structure
177
+ def simulate_embeddings(n_samples, n_features, latent_dim, noise_level):
178
+ """Simulate embeddings with controllable structure and noise."""
179
+ latent = np.random.randn(n_samples, latent_dim)
180
+ projection = np.random.randn(latent_dim, n_features)
181
+ signal = latent @ projection
182
+ signal = signal / np.std(signal)
183
+ noise = np.random.randn(n_samples, n_features) * noise_level
184
+ return signal + noise
185
+
186
+ # Simulate 4 "models" with different properties
187
+ models = {
188
+ 'Model A (high rank, low noise)': simulate_embeddings(500, 768, 100, 0.1),
189
+ 'Model B (high rank, high noise)': simulate_embeddings(500, 768, 100, 1.0),
190
+ 'Model C (low rank, low noise)': simulate_embeddings(500, 768, 20, 0.1),
191
+ 'Model D (low rank, high noise)': simulate_embeddings(500, 768, 20, 1.0),
192
+ }
193
+
194
+ # Compare stability across models
195
+ print("Model Comparison (Feature-Split Stability):")
196
+ print("-" * 50)
197
+
198
+ results = {}
199
+ for name, X in models.items():
200
+ stability = shesha.feature_split(X, n_splits=30, seed=320)
201
+ results[name] = stability
202
+ print(f"{name}: {stability:.3f}")
203
+
204
+ # %%
205
+ # Visualize model comparison
206
+ fig, ax = plt.subplots(figsize=(10, 5))
207
+
208
+ names = list(results.keys())
209
+ values = list(results.values())
210
+ colors = ['#2ecc71', '#e74c3c', '#3498db', '#9b59b6']
211
+
212
+ bars = ax.barh(names, values, color=colors)
213
+ ax.set_xlabel('Feature-Split Stability')
214
+ ax.set_title('Embedding Model Stability Comparison')
215
+ ax.set_xlim(0, 1)
216
+
217
+ for bar, val in zip(bars, values):
218
+ ax.text(val + 0.02, bar.get_y() + bar.get_height()/2,
219
+ f'{val:.3f}', va='center', fontsize=10)
220
+
221
+ plt.tight_layout()
222
+ plt.savefig('tutorial_model_comparison.png', dpi=150)
223
+ plt.show()
224
+
225
+ # %% [markdown]
226
+ # ## 7. Practical Application: Monitoring Training Drift
227
+ #
228
+ # Track how representation stability changes during fine-tuning.
229
+
230
+ # %%
231
+ # Simulate embeddings at different "epochs" of training
232
+ # Early: random initialization, Late: structured representations
233
+
234
+ def simulate_training_trajectory(n_epochs=10):
235
+ """Simulate how embeddings evolve during training."""
236
+ n_samples, n_features = 500, 768
237
+
238
+ # Start with noise, gradually add structure
239
+ embeddings = []
240
+ for epoch in range(n_epochs):
241
+ structure_weight = epoch / (n_epochs - 1) # 0 to 1
242
+
243
+ # Structured component
244
+ latent = np.random.randn(n_samples, 50)
245
+ projection = np.random.randn(50, n_features)
246
+ structured = latent @ projection
247
+ structured = structured / np.std(structured)
248
+
249
+ # Random component
250
+ random = np.random.randn(n_samples, n_features)
251
+
252
+ # Mix based on epoch
253
+ X = structure_weight * structured + (1 - structure_weight) * random
254
+ embeddings.append(X)
255
+
256
+ return embeddings
257
+
258
+ # Generate trajectory
259
+ embeddings_over_time = simulate_training_trajectory(n_epochs=10)
260
+
261
+ # Track stability
262
+ stabilities = []
263
+ for epoch, X in enumerate(embeddings_over_time):
264
+ stability = shesha.feature_split(X, n_splits=30, seed=320)
265
+ stabilities.append(stability)
266
+ print(f"Epoch {epoch}: stability = {stability:.3f}")
267
+
268
+ # %%
269
+ # Plot training trajectory
270
+ fig, ax = plt.subplots(figsize=(8, 5))
271
+
272
+ ax.plot(range(len(stabilities)), stabilities, 'o-', linewidth=2, markersize=8, color='steelblue')
273
+ ax.fill_between(range(len(stabilities)), stabilities, alpha=0.3, color='steelblue')
274
+
275
+ ax.set_xlabel('Epoch')
276
+ ax.set_ylabel('Feature-Split Stability')
277
+ ax.set_title('Representation Stability During Training')
278
+ ax.set_ylim(0, 1)
279
+ ax.grid(True, alpha=0.3)
280
+
281
+ plt.tight_layout()
282
+ plt.savefig('tutorial_training_drift.png', dpi=150)
283
+ plt.show()
284
+
285
+ # %% [markdown]
286
+ # ## 8. Unified Interface
287
+ #
288
+ # Use `shesha.shesha()` for a single entry point to all variants.
289
+
290
+ # %%
291
+ X = np.random.randn(500, 768)
292
+ y = np.random.randint(0, 5, 500)
293
+
294
+ # All variants through unified interface
295
+ print("Unified interface examples:")
296
+ print(f" feature_split: {shesha.shesha(X, variant='feature_split', seed=320):.3f}")
297
+ print(f" sample_split: {shesha.shesha(X, variant='sample_split', seed=320):.3f}")
298
+ print(f" anchor: {shesha.shesha(X, variant='anchor', seed=320):.3f}")
299
+ print(f" variance: {shesha.shesha(X, y, variant='variance'):.3f}")
300
+ print(f" supervised: {shesha.shesha(X, y, variant='supervised', seed=320):.3f}")
301
+
302
+ # %% [markdown]
303
+ # ## 9. Tips and Best Practices
304
+ #
305
+ # ### Choosing a variant:
306
+ # - **feature_split**: Default choice for unsupervised analysis. Good for drift detection, intrinsic quality.
307
+ # - **sample_split**: When you care about robustness to sampling. Good for small datasets.
308
+ # - **anchor_stability**: For very large datasets where feature_split is slow.
309
+ # - **variance_ratio**: Quick supervised check. Computationally cheap.
310
+ # - **supervised_alignment**: When you want RDM-based task alignment (more nuanced than variance_ratio).
311
+ #
312
+ # ### Parameter recommendations:
313
+ # - `n_splits=30` is usually sufficient; increase to 50+ for publication-quality results
314
+ # - `seed` should always be set for reproducibility
315
+ # - `max_samples` prevents memory issues with large datasets
316
+ #
317
+ # ### Interpretation:
318
+ # - **feature_split > 0.7**: Strong internal consistency, distributed structure
319
+ # - **feature_split 0.3-0.7**: Moderate consistency
320
+ # - **feature_split < 0.3**: Weak consistency, possibly noisy or sparse structure
321
+ # - **variance_ratio**: Directly interpretable as "fraction of variance explained by classes"
322
+
323
+ # %% [markdown]
324
+ # ## 10. Summary
325
+ #
326
+ # | Variant | Supervised | Best For |
327
+ # |---------|------------|----------|
328
+ # | `feature_split` | No | General stability, drift detection |
329
+ # | `sample_split` | No | Sampling robustness |
330
+ # | `anchor_stability` | No | Large-scale analysis |
331
+ # | `variance_ratio` | Yes | Quick separability check |
332
+ # | `supervised_alignment` | Yes | Task alignment |
333
+
334
+ # %%
335
+ print("Tutorial complete!")
336
+ print("\nFor more information, see: https://github.com/prashantcraju/shesha")
shesha/__init__.py ADDED
@@ -0,0 +1,59 @@
1
+ """
2
+ Shesha: Self-consistency Metrics for Representational Stability
3
+
4
+ A framework for measuring geometric stability via self-consistency of
5
+ Representational Dissimilarity Matrices (RDMs).
6
+
7
+ Basic usage:
8
+ >>> import shesha
9
+ >>> stability = shesha.feature_split(X, n_splits=30, seed=320)
10
+
11
+ >>> # Or with labels
12
+ >>> alignment = shesha.supervised_alignment(X, y)
13
+
14
+ >>> # Unified interface
15
+ >>> score = shesha.shesha(X, variant='feature_split')
16
+
17
+ >>> # Measure drift between representations
18
+ >>> similarity = shesha.rdm_similarity(X_before, X_after)
19
+ >>> drift = shesha.rdm_drift(X_before, X_after)
20
+
21
+ >>> # Biological perturbation analysis
22
+ >>> from shesha.bio import perturbation_stability
23
+ >>> stability = perturbation_stability(X_control, X_perturbed)
24
+ """
25
+
26
+ from .core import (
27
+ # Main function
28
+ shesha,
29
+ # Unsupervised variants
30
+ feature_split,
31
+ sample_split,
32
+ anchor_stability,
33
+ # Supervised variants
34
+ variance_ratio,
35
+ supervised_alignment,
36
+ # Drift metrics
37
+ rdm_similarity,
38
+ rdm_drift,
39
+ # Utilities
40
+ compute_rdm,
41
+ )
42
+
43
+ from . import bio
44
+
45
+ __version__ = "0.1.0"
46
+ __author__ = "Prashant Raju"
47
+
48
+ __all__ = [
49
+ "shesha",
50
+ "feature_split",
51
+ "sample_split",
52
+ "anchor_stability",
53
+ "variance_ratio",
54
+ "supervised_alignment",
55
+ "rdm_similarity",
56
+ "rdm_drift",
57
+ "compute_rdm",
58
+ "bio",
59
+ ]
shesha/bio.py ADDED
@@ -0,0 +1,315 @@
1
+ """
2
+ Shesha Bio: Stability metrics for biological perturbation experiments.
3
+
4
+ This module provides Shesha variants for single-cell and perturbation biology,
5
+ measuring the consistency of perturbation effects across individual cells.
6
+ """
7
+
8
+ import numpy as np
9
+ from typing import Optional, Literal
10
+
11
+ __all__ = [
12
+ "perturbation_stability",
13
+ "perturbation_effect_size",
14
+ "compute_stability",
15
+ "compute_magnitude"
16
+ ]
17
+
18
+ EPS = 1e-12
19
+
20
+
21
+ def perturbation_stability(
22
+ X_control: np.ndarray,
23
+ X_perturbed: np.ndarray,
24
+ metric: Literal["cosine", "euclidean"] = "cosine",
25
+ seed: Optional[int] = None,
26
+ max_samples: Optional[int] = 1000,
27
+ ) -> float:
28
+ """
29
+ Perturbation stability: consistency of perturbation effects across samples.
30
+
31
+ Measures whether individual perturbed samples shift in a consistent direction
32
+ relative to the control population. High values indicate that the perturbation
33
+ has a coherent, reproducible effect; low values suggest heterogeneous or noisy
34
+ responses.
35
+
36
+ The metric computes the mean cosine similarity between each perturbed sample's
37
+ shift vector (relative to the control centroid) and the mean shift direction.
38
+
39
+ Parameters
40
+ ----------
41
+ X_control : np.ndarray
42
+ Control population embeddings, shape (n_control, n_features).
43
+ X_perturbed : np.ndarray
44
+ Perturbed population embeddings, shape (n_perturbed, n_features).
45
+ metric : str
46
+ How to measure directional consistency:
47
+ - 'cosine': Cosine similarity of shift vectors to mean direction (default)
48
+ - 'euclidean': Normalized euclidean consistency
49
+ seed : int, optional
50
+ Random seed for subsampling reproducibility.
51
+ max_samples : int, optional
52
+ Subsample perturbed population if exceeded.
53
+
54
+ Returns
55
+ -------
56
+ float
57
+ Stability score in [-1, 1] for cosine metric. Higher = more consistent
58
+ perturbation effect. Values near 1 indicate all samples shift in the
59
+ same direction; values near 0 indicate random/inconsistent shifts.
60
+
61
+ Examples
62
+ --------
63
+ >>> # Control and perturbed cell populations
64
+ >>> X_ctrl = np.random.randn(500, 50) # 500 control cells, 50 genes
65
+ >>>
66
+ >>> # Coherent perturbation: all cells shift similarly
67
+ >>> shift = np.random.randn(50) # consistent direction
68
+ >>> X_pert_coherent = X_ctrl + shift + np.random.randn(500, 50) * 0.1
69
+ >>> stability = perturbation_stability(X_ctrl, X_pert_coherent)
70
+ >>> print(f"Coherent perturbation: {stability:.3f}") # High value
71
+ >>>
72
+ >>> # Incoherent perturbation: cells shift randomly
73
+ >>> X_pert_random = X_ctrl + np.random.randn(500, 50)
74
+ >>> stability = perturbation_stability(X_ctrl, X_pert_random)
75
+ >>> print(f"Random perturbation: {stability:.3f}") # Low value
76
+
77
+ Notes
78
+ -----
79
+ This metric is designed for single-cell perturbation experiments (e.g.,
80
+ Perturb-seq, CRISPR screens) where you want to assess whether a genetic
81
+ perturbation produces a consistent phenotypic shift across cells.
82
+
83
+ The control centroid is used as the reference point. Each perturbed cell's
84
+ shift vector is computed as (x_perturbed - centroid_control), and these
85
+ are compared to the mean shift direction.
86
+ """
87
+ X_control = np.asarray(X_control, dtype=np.float64)
88
+ X_perturbed = np.asarray(X_perturbed, dtype=np.float64)
89
+
90
+ if X_control.shape[1] != X_perturbed.shape[1]:
91
+ raise ValueError(
92
+ f"Feature dimensions must match: control has {X_control.shape[1]}, "
93
+ f"perturbed has {X_perturbed.shape[1]}"
94
+ )
95
+
96
+ if len(X_control) < 5:
97
+ return np.nan
98
+ if len(X_perturbed) < 5:
99
+ return np.nan
100
+
101
+ rng = np.random.default_rng(seed)
102
+
103
+ # Subsample perturbed if needed
104
+ if max_samples is not None and len(X_perturbed) > max_samples:
105
+ idx = rng.choice(len(X_perturbed), max_samples, replace=False)
106
+ X_perturbed = X_perturbed[idx]
107
+
108
+ # Compute control centroid
109
+ control_centroid = np.mean(X_control, axis=0)
110
+
111
+ # Compute shift vectors for each perturbed sample
112
+ shift_vectors = X_perturbed - control_centroid
113
+
114
+ # Compute mean shift direction
115
+ mean_shift = np.mean(shift_vectors, axis=0)
116
+ mean_shift_norm = np.linalg.norm(mean_shift)
117
+
118
+ if mean_shift_norm < EPS:
119
+ # No net shift - perturbation has no coherent effect
120
+ return 0.0
121
+
122
+ if metric == "cosine":
123
+ # Normalize mean shift
124
+ mean_shift_unit = mean_shift / mean_shift_norm
125
+
126
+ # Compute cosine similarity of each shift to mean direction
127
+ shift_norms = np.linalg.norm(shift_vectors, axis=1, keepdims=True)
128
+ shift_norms = np.maximum(shift_norms, EPS)
129
+ shift_unit = shift_vectors / shift_norms
130
+
131
+ # Cosine similarities
132
+ cosines = shift_unit @ mean_shift_unit
133
+
134
+ return float(np.mean(cosines))
135
+
136
+ elif metric == "euclidean":
137
+ # Euclidean-based consistency: how tight are shifts around mean?
138
+ # Normalized by expected variance under random shifts
139
+ deviations = shift_vectors - mean_shift
140
+ deviation_var = np.mean(np.sum(deviations ** 2, axis=1))
141
+ total_var = np.mean(np.sum(shift_vectors ** 2, axis=1))
142
+
143
+ if total_var < EPS:
144
+ return np.nan
145
+
146
+ # 1 - (deviation / total) gives consistency score
147
+ consistency = 1.0 - (deviation_var / total_var)
148
+ return float(np.clip(consistency, -1, 1))
149
+
150
+ else:
151
+ raise ValueError(f"Unknown metric: {metric}. Use 'cosine' or 'euclidean'.")
152
+
153
+ try:
154
+ from anndata import AnnData
155
+ except ImportError:
156
+ AnnData = None
157
+
158
+ def compute_stability(
159
+ adata: "AnnData",
160
+ perturbation_key: str,
161
+ control_label: str = "control",
162
+ layer: Optional[str] = None,
163
+ **kwargs
164
+ ) -> dict:
165
+ """
166
+ Scanpy-compatible wrapper for perturbation stability.
167
+
168
+ Computes stability for all perturbations in an AnnData object.
169
+
170
+ Parameters
171
+ ----------
172
+ adata : AnnData
173
+ Annotated data matrix.
174
+ perturbation_key : str
175
+ Column in adata.obs containing perturbation labels (e.g. 'guide_id').
176
+ control_label : str
177
+ The label in perturbation_key representing control cells (e.g. 'NT').
178
+
179
+ Returns
180
+ -------
181
+ dict
182
+ Dictionary mapping perturbation names to stability scores.
183
+ """
184
+ if AnnData is None or not isinstance(adata, AnnData):
185
+ raise ImportError("anndata is required for this function.")
186
+
187
+ # Get control data
188
+ ctrl_mask = adata.obs[perturbation_key] == control_label
189
+ if layer:
190
+ X_ctrl = adata[ctrl_mask].layers[layer]
191
+ else:
192
+ X_ctrl = adata[ctrl_mask].X
193
+
194
+ # Handle sparse matrices
195
+ if hasattr(X_ctrl, "toarray"):
196
+ X_ctrl = X_ctrl.toarray()
197
+
198
+ results = {}
199
+ perturbations = adata.obs[perturbation_key].unique()
200
+
201
+ for pert in perturbations:
202
+ if pert == control_label:
203
+ continue
204
+
205
+ pert_mask = adata.obs[perturbation_key] == pert
206
+ if layer:
207
+ X_pert = adata[pert_mask].layers[layer]
208
+ else:
209
+ X_pert = adata[pert_mask].X
210
+
211
+ if hasattr(X_pert, "toarray"):
212
+ X_pert = X_pert.toarray()
213
+
214
+ score = perturbation_stability(X_ctrl, X_pert, **kwargs)
215
+ results[pert] = score
216
+
217
+ return results
218
+
219
+
220
+ def perturbation_effect_size(
221
+ X_control: np.ndarray,
222
+ X_perturbed: np.ndarray,
223
+ metric: Literal["euclidean", "cohen"] = "euclidean"
224
+ ) -> float:
225
+ """
226
+ Compute the magnitude of the perturbation effect.
227
+
228
+ Parameters
229
+ ----------
230
+ X_control : np.ndarray
231
+ Control population embeddings.
232
+ X_perturbed : np.ndarray
233
+ Perturbed population embeddings.
234
+ metric : str, default="euclidean"
235
+ - 'euclidean': Raw L2 distance between centroids (Magnitude).
236
+ Use this for geometric plots (Stability vs Magnitude).
237
+ - 'cohen': Standardized effect size (Magnitude / Pooled SD).
238
+ Use this for statistical power analysis.
239
+
240
+ Returns
241
+ -------
242
+ float
243
+ The calculated magnitude/effect size.
244
+ """
245
+ X_control = np.asarray(X_control, dtype=np.float64)
246
+ X_perturbed = np.asarray(X_perturbed, dtype=np.float64)
247
+
248
+ control_centroid = np.mean(X_control, axis=0)
249
+ perturbed_centroid = np.mean(X_perturbed, axis=0)
250
+
251
+ # 1. Raw Magnitude (Euclidean Distance)
252
+ shift_magnitude = np.linalg.norm(perturbed_centroid - control_centroid)
253
+
254
+ if metric == "euclidean":
255
+ return float(shift_magnitude)
256
+
257
+ elif metric == "cohen":
258
+ # 2. Standardized Effect Size (Cohen's d-like)
259
+ # Pooled standard deviation (averaged across features)
260
+ control_var = np.var(X_control, axis=0, ddof=1)
261
+ perturbed_var = np.var(X_perturbed, axis=0, ddof=1)
262
+
263
+ # Average variance across features to get a scalar scale
264
+ pooled_var = np.mean((control_var + perturbed_var) / 2)
265
+ pooled_std = np.sqrt(pooled_var) + EPS
266
+
267
+ return float(shift_magnitude / pooled_std)
268
+
269
+ else:
270
+ raise ValueError(f"Unknown metric: {metric}")
271
+
272
+
273
+ def compute_magnitude(
274
+ adata: "AnnData",
275
+ perturbation_key: str,
276
+ control_label: str = "control",
277
+ metric: str = "euclidean",
278
+ layer: Optional[str] = None,
279
+ ) -> dict:
280
+ """
281
+ Scanpy-compatible wrapper for perturbation magnitude.
282
+ """
283
+ if AnnData is None or not isinstance(adata, AnnData):
284
+ raise ImportError("anndata is required for this function.")
285
+
286
+ # Get control data
287
+ ctrl_mask = adata.obs[perturbation_key] == control_label
288
+ if layer:
289
+ X_ctrl = adata[ctrl_mask].layers[layer]
290
+ else:
291
+ X_ctrl = adata[ctrl_mask].X
292
+
293
+ if hasattr(X_ctrl, "toarray"):
294
+ X_ctrl = X_ctrl.toarray()
295
+
296
+ results = {}
297
+ perturbations = adata.obs[perturbation_key].unique()
298
+
299
+ for pert in perturbations:
300
+ if pert == control_label:
301
+ continue
302
+
303
+ pert_mask = adata.obs[perturbation_key] == pert
304
+ if layer:
305
+ X_pert = adata[pert_mask].layers[layer]
306
+ else:
307
+ X_pert = adata[pert_mask].X
308
+
309
+ if hasattr(X_pert, "toarray"):
310
+ X_pert = X_pert.toarray()
311
+
312
+ score = perturbation_effect_size(X_ctrl, X_pert, metric=metric)
313
+ results[pert] = score
314
+
315
+ return results