shesha-geometry 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/tutorial.py +336 -0
- shesha/__init__.py +59 -0
- shesha/bio.py +315 -0
- shesha/core.py +648 -0
- shesha_geometry-0.1.0.dist-info/LICENSE +21 -0
- shesha_geometry-0.1.0.dist-info/METADATA +396 -0
- shesha_geometry-0.1.0.dist-info/RECORD +13 -0
- shesha_geometry-0.1.0.dist-info/WHEEL +5 -0
- shesha_geometry-0.1.0.dist-info/top_level.txt +3 -0
- tests/__init__.py +1 -0
- tests/test_bio.py +49 -0
- tests/test_core.py +164 -0
- tests/test_crispr.py +191 -0
tests/test_crispr.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Test shesha.bio on CRISPR perturbation data.
|
|
3
|
+
|
|
4
|
+
Uses pertpy's Norman et al 2019 dataset (CRISPRa screen).
|
|
5
|
+
Install dependencies: pip install pertpy scanpy
|
|
6
|
+
|
|
7
|
+
Expected behavior:
|
|
8
|
+
- Strong perturbations should have higher stability (cells respond consistently)
|
|
9
|
+
- Weak/noisy perturbations should have lower stability
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
# Check for required dependencies
|
|
16
|
+
try:
|
|
17
|
+
import pertpy as pt
|
|
18
|
+
import scanpy as sc
|
|
19
|
+
except ImportError:
|
|
20
|
+
print("This test requires pertpy and scanpy.")
|
|
21
|
+
print("Install with: pip install pertpy scanpy")
|
|
22
|
+
exit(1)
|
|
23
|
+
|
|
24
|
+
from shesha.bio import perturbation_stability, perturbation_effect_size
|
|
25
|
+
|
|
26
|
+
# Configuration
|
|
27
|
+
SEED = 320
|
|
28
|
+
np.random.seed(SEED)
|
|
29
|
+
N_PCA_DIMS = 50 # Use PCA-reduced space like in the paper
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def load_norman_2019():
|
|
33
|
+
"""Load Norman 2019 CRISPRa dataset."""
|
|
34
|
+
print("Loading Norman 2019 dataset...")
|
|
35
|
+
adata = pt.dt.norman_2019()
|
|
36
|
+
print(f" Shape: {adata.shape}")
|
|
37
|
+
|
|
38
|
+
# Basic preprocessing if not already done
|
|
39
|
+
if 'X_pca' not in adata.obsm:
|
|
40
|
+
print(" Running PCA...")
|
|
41
|
+
sc.pp.normalize_total(adata, target_sum=1e4)
|
|
42
|
+
sc.pp.log1p(adata)
|
|
43
|
+
sc.pp.highly_variable_genes(adata, n_top_genes=2000)
|
|
44
|
+
sc.pp.pca(adata, n_comps=N_PCA_DIMS)
|
|
45
|
+
|
|
46
|
+
return adata
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_perturbation_groups(adata):
|
|
50
|
+
"""Extract control and perturbation groups."""
|
|
51
|
+
# Norman 2019 uses 'guide_ids' or 'perturbation' column
|
|
52
|
+
if 'perturbation' in adata.obs.columns:
|
|
53
|
+
pert_col = 'perturbation'
|
|
54
|
+
elif 'guide_ids' in adata.obs.columns:
|
|
55
|
+
pert_col = 'guide_ids'
|
|
56
|
+
else:
|
|
57
|
+
# Try to find a suitable column
|
|
58
|
+
candidates = [c for c in adata.obs.columns if 'pert' in c.lower() or 'guide' in c.lower()]
|
|
59
|
+
if candidates:
|
|
60
|
+
pert_col = candidates[0]
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(f"Could not find perturbation column. Available: {list(adata.obs.columns)}")
|
|
63
|
+
|
|
64
|
+
print(f" Using perturbation column: {pert_col}")
|
|
65
|
+
|
|
66
|
+
# Identify control cells
|
|
67
|
+
all_perts = adata.obs[pert_col].unique()
|
|
68
|
+
control_keywords = ['control', 'ctrl', 'neg', 'nt', 'non-targeting', 'unperturbed', 'nan']
|
|
69
|
+
|
|
70
|
+
control_perts = []
|
|
71
|
+
for p in all_perts:
|
|
72
|
+
p_lower = str(p).lower()
|
|
73
|
+
if any(kw in p_lower for kw in control_keywords) or p_lower == 'nan' or pd.isna(p):
|
|
74
|
+
control_perts.append(p)
|
|
75
|
+
|
|
76
|
+
if not control_perts:
|
|
77
|
+
# Fallback: look for most common perturbation (often control)
|
|
78
|
+
counts = adata.obs[pert_col].value_counts()
|
|
79
|
+
control_perts = [counts.index[0]]
|
|
80
|
+
print(f" Warning: No obvious control found, using most common: {control_perts}")
|
|
81
|
+
|
|
82
|
+
print(f" Control perturbations: {control_perts}")
|
|
83
|
+
|
|
84
|
+
# Get control cells
|
|
85
|
+
control_mask = adata.obs[pert_col].isin(control_perts)
|
|
86
|
+
|
|
87
|
+
# Get non-control perturbations
|
|
88
|
+
other_perts = [p for p in all_perts if p not in control_perts]
|
|
89
|
+
|
|
90
|
+
return pert_col, control_mask, other_perts
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def test_on_real_data():
|
|
94
|
+
"""Main test function."""
|
|
95
|
+
|
|
96
|
+
# Load data
|
|
97
|
+
adata = load_norman_2019()
|
|
98
|
+
|
|
99
|
+
# Get perturbation groups
|
|
100
|
+
pert_col, control_mask, perturbations = get_perturbation_groups(adata)
|
|
101
|
+
|
|
102
|
+
# Get embeddings
|
|
103
|
+
X_pca = adata.obsm['X_pca'][:, :N_PCA_DIMS]
|
|
104
|
+
X_control = X_pca[control_mask]
|
|
105
|
+
|
|
106
|
+
print(f"\nControl cells: {X_control.shape[0]}")
|
|
107
|
+
print(f"Perturbations to test: {len(perturbations)}")
|
|
108
|
+
|
|
109
|
+
# Test on a subset of perturbations
|
|
110
|
+
results = []
|
|
111
|
+
n_test = min(20, len(perturbations))
|
|
112
|
+
|
|
113
|
+
print(f"\nTesting {n_test} perturbations...")
|
|
114
|
+
print("-" * 60)
|
|
115
|
+
|
|
116
|
+
for pert in perturbations[:n_test]:
|
|
117
|
+
pert_mask = adata.obs[pert_col] == pert
|
|
118
|
+
n_cells = pert_mask.sum()
|
|
119
|
+
|
|
120
|
+
if n_cells < 10:
|
|
121
|
+
continue
|
|
122
|
+
|
|
123
|
+
X_pert = X_pca[pert_mask]
|
|
124
|
+
|
|
125
|
+
# Compute metrics using shesha.bio
|
|
126
|
+
stability = perturbation_stability(X_control, X_pert, seed=SEED)
|
|
127
|
+
effect = perturbation_effect_size(X_control, X_pert)
|
|
128
|
+
|
|
129
|
+
results.append({
|
|
130
|
+
'perturbation': str(pert)[:30], # Truncate long names
|
|
131
|
+
'n_cells': n_cells,
|
|
132
|
+
'stability': stability,
|
|
133
|
+
'effect_size': effect
|
|
134
|
+
})
|
|
135
|
+
|
|
136
|
+
print(f"{str(pert)[:30]:30s} n={n_cells:4d} stability={stability:.3f} effect={effect:.2f}")
|
|
137
|
+
|
|
138
|
+
# Summary statistics
|
|
139
|
+
df = pd.DataFrame(results)
|
|
140
|
+
|
|
141
|
+
print("\n" + "=" * 60)
|
|
142
|
+
print("SUMMARY")
|
|
143
|
+
print("=" * 60)
|
|
144
|
+
print(f"Perturbations tested: {len(df)}")
|
|
145
|
+
print(f"Stability - mean: {df['stability'].mean():.3f}, std: {df['stability'].std():.3f}")
|
|
146
|
+
print(f"Stability - min: {df['stability'].min():.3f}, max: {df['stability'].max():.3f}")
|
|
147
|
+
print(f"Effect size - mean: {df['effect_size'].mean():.2f}, std: {df['effect_size'].std():.2f}")
|
|
148
|
+
|
|
149
|
+
# Check correlation between stability and effect size
|
|
150
|
+
from scipy.stats import spearmanr
|
|
151
|
+
rho, p = spearmanr(df['stability'], df['effect_size'])
|
|
152
|
+
print(f"\nCorrelation (stability vs effect): rho={rho:.3f}, p={p:.4f}")
|
|
153
|
+
|
|
154
|
+
# Sanity checks
|
|
155
|
+
print("\n" + "=" * 60)
|
|
156
|
+
print("SANITY CHECKS")
|
|
157
|
+
print("=" * 60)
|
|
158
|
+
|
|
159
|
+
# Check that stability values are in expected range
|
|
160
|
+
assert df['stability'].min() >= -1, "Stability below -1"
|
|
161
|
+
assert df['stability'].max() <= 1, "Stability above 1"
|
|
162
|
+
print("✓ Stability values in [-1, 1]")
|
|
163
|
+
|
|
164
|
+
# Check that effect sizes are non-negative
|
|
165
|
+
assert df['effect_size'].min() >= 0, "Negative effect size"
|
|
166
|
+
print("✓ Effect sizes non-negative")
|
|
167
|
+
|
|
168
|
+
# Check that we get reasonable variation (not all same value)
|
|
169
|
+
assert df['stability'].std() > 0.01, "No variation in stability"
|
|
170
|
+
print("✓ Reasonable variation in stability")
|
|
171
|
+
|
|
172
|
+
# Most perturbations should have positive stability (coherent effect)
|
|
173
|
+
frac_positive = (df['stability'] > 0).mean()
|
|
174
|
+
print(f"✓ {frac_positive*100:.0f}% of perturbations have positive stability")
|
|
175
|
+
|
|
176
|
+
print("\n✓ All sanity checks passed!")
|
|
177
|
+
|
|
178
|
+
return df
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
if __name__ == "__main__":
|
|
182
|
+
print("=" * 60)
|
|
183
|
+
print("SHESHA.BIO TEST ON REAL CRISPR DATA")
|
|
184
|
+
print("=" * 60)
|
|
185
|
+
print()
|
|
186
|
+
|
|
187
|
+
df = test_on_real_data()
|
|
188
|
+
|
|
189
|
+
print("\n" + "=" * 60)
|
|
190
|
+
print("TEST COMPLETE")
|
|
191
|
+
print("=" * 60)
|