shesha-geometry 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_crispr.py ADDED
@@ -0,0 +1,191 @@
1
+ """
2
+ Test shesha.bio on CRISPR perturbation data.
3
+
4
+ Uses pertpy's Norman et al 2019 dataset (CRISPRa screen).
5
+ Install dependencies: pip install pertpy scanpy
6
+
7
+ Expected behavior:
8
+ - Strong perturbations should have higher stability (cells respond consistently)
9
+ - Weak/noisy perturbations should have lower stability
10
+ """
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+
15
+ # Check for required dependencies
16
+ try:
17
+ import pertpy as pt
18
+ import scanpy as sc
19
+ except ImportError:
20
+ print("This test requires pertpy and scanpy.")
21
+ print("Install with: pip install pertpy scanpy")
22
+ exit(1)
23
+
24
+ from shesha.bio import perturbation_stability, perturbation_effect_size
25
+
26
+ # Configuration
27
+ SEED = 320
28
+ np.random.seed(SEED)
29
+ N_PCA_DIMS = 50 # Use PCA-reduced space like in the paper
30
+
31
+
32
+ def load_norman_2019():
33
+ """Load Norman 2019 CRISPRa dataset."""
34
+ print("Loading Norman 2019 dataset...")
35
+ adata = pt.dt.norman_2019()
36
+ print(f" Shape: {adata.shape}")
37
+
38
+ # Basic preprocessing if not already done
39
+ if 'X_pca' not in adata.obsm:
40
+ print(" Running PCA...")
41
+ sc.pp.normalize_total(adata, target_sum=1e4)
42
+ sc.pp.log1p(adata)
43
+ sc.pp.highly_variable_genes(adata, n_top_genes=2000)
44
+ sc.pp.pca(adata, n_comps=N_PCA_DIMS)
45
+
46
+ return adata
47
+
48
+
49
+ def get_perturbation_groups(adata):
50
+ """Extract control and perturbation groups."""
51
+ # Norman 2019 uses 'guide_ids' or 'perturbation' column
52
+ if 'perturbation' in adata.obs.columns:
53
+ pert_col = 'perturbation'
54
+ elif 'guide_ids' in adata.obs.columns:
55
+ pert_col = 'guide_ids'
56
+ else:
57
+ # Try to find a suitable column
58
+ candidates = [c for c in adata.obs.columns if 'pert' in c.lower() or 'guide' in c.lower()]
59
+ if candidates:
60
+ pert_col = candidates[0]
61
+ else:
62
+ raise ValueError(f"Could not find perturbation column. Available: {list(adata.obs.columns)}")
63
+
64
+ print(f" Using perturbation column: {pert_col}")
65
+
66
+ # Identify control cells
67
+ all_perts = adata.obs[pert_col].unique()
68
+ control_keywords = ['control', 'ctrl', 'neg', 'nt', 'non-targeting', 'unperturbed', 'nan']
69
+
70
+ control_perts = []
71
+ for p in all_perts:
72
+ p_lower = str(p).lower()
73
+ if any(kw in p_lower for kw in control_keywords) or p_lower == 'nan' or pd.isna(p):
74
+ control_perts.append(p)
75
+
76
+ if not control_perts:
77
+ # Fallback: look for most common perturbation (often control)
78
+ counts = adata.obs[pert_col].value_counts()
79
+ control_perts = [counts.index[0]]
80
+ print(f" Warning: No obvious control found, using most common: {control_perts}")
81
+
82
+ print(f" Control perturbations: {control_perts}")
83
+
84
+ # Get control cells
85
+ control_mask = adata.obs[pert_col].isin(control_perts)
86
+
87
+ # Get non-control perturbations
88
+ other_perts = [p for p in all_perts if p not in control_perts]
89
+
90
+ return pert_col, control_mask, other_perts
91
+
92
+
93
+ def test_on_real_data():
94
+ """Main test function."""
95
+
96
+ # Load data
97
+ adata = load_norman_2019()
98
+
99
+ # Get perturbation groups
100
+ pert_col, control_mask, perturbations = get_perturbation_groups(adata)
101
+
102
+ # Get embeddings
103
+ X_pca = adata.obsm['X_pca'][:, :N_PCA_DIMS]
104
+ X_control = X_pca[control_mask]
105
+
106
+ print(f"\nControl cells: {X_control.shape[0]}")
107
+ print(f"Perturbations to test: {len(perturbations)}")
108
+
109
+ # Test on a subset of perturbations
110
+ results = []
111
+ n_test = min(20, len(perturbations))
112
+
113
+ print(f"\nTesting {n_test} perturbations...")
114
+ print("-" * 60)
115
+
116
+ for pert in perturbations[:n_test]:
117
+ pert_mask = adata.obs[pert_col] == pert
118
+ n_cells = pert_mask.sum()
119
+
120
+ if n_cells < 10:
121
+ continue
122
+
123
+ X_pert = X_pca[pert_mask]
124
+
125
+ # Compute metrics using shesha.bio
126
+ stability = perturbation_stability(X_control, X_pert, seed=SEED)
127
+ effect = perturbation_effect_size(X_control, X_pert)
128
+
129
+ results.append({
130
+ 'perturbation': str(pert)[:30], # Truncate long names
131
+ 'n_cells': n_cells,
132
+ 'stability': stability,
133
+ 'effect_size': effect
134
+ })
135
+
136
+ print(f"{str(pert)[:30]:30s} n={n_cells:4d} stability={stability:.3f} effect={effect:.2f}")
137
+
138
+ # Summary statistics
139
+ df = pd.DataFrame(results)
140
+
141
+ print("\n" + "=" * 60)
142
+ print("SUMMARY")
143
+ print("=" * 60)
144
+ print(f"Perturbations tested: {len(df)}")
145
+ print(f"Stability - mean: {df['stability'].mean():.3f}, std: {df['stability'].std():.3f}")
146
+ print(f"Stability - min: {df['stability'].min():.3f}, max: {df['stability'].max():.3f}")
147
+ print(f"Effect size - mean: {df['effect_size'].mean():.2f}, std: {df['effect_size'].std():.2f}")
148
+
149
+ # Check correlation between stability and effect size
150
+ from scipy.stats import spearmanr
151
+ rho, p = spearmanr(df['stability'], df['effect_size'])
152
+ print(f"\nCorrelation (stability vs effect): rho={rho:.3f}, p={p:.4f}")
153
+
154
+ # Sanity checks
155
+ print("\n" + "=" * 60)
156
+ print("SANITY CHECKS")
157
+ print("=" * 60)
158
+
159
+ # Check that stability values are in expected range
160
+ assert df['stability'].min() >= -1, "Stability below -1"
161
+ assert df['stability'].max() <= 1, "Stability above 1"
162
+ print("✓ Stability values in [-1, 1]")
163
+
164
+ # Check that effect sizes are non-negative
165
+ assert df['effect_size'].min() >= 0, "Negative effect size"
166
+ print("✓ Effect sizes non-negative")
167
+
168
+ # Check that we get reasonable variation (not all same value)
169
+ assert df['stability'].std() > 0.01, "No variation in stability"
170
+ print("✓ Reasonable variation in stability")
171
+
172
+ # Most perturbations should have positive stability (coherent effect)
173
+ frac_positive = (df['stability'] > 0).mean()
174
+ print(f"✓ {frac_positive*100:.0f}% of perturbations have positive stability")
175
+
176
+ print("\n✓ All sanity checks passed!")
177
+
178
+ return df
179
+
180
+
181
+ if __name__ == "__main__":
182
+ print("=" * 60)
183
+ print("SHESHA.BIO TEST ON REAL CRISPR DATA")
184
+ print("=" * 60)
185
+ print()
186
+
187
+ df = test_on_real_data()
188
+
189
+ print("\n" + "=" * 60)
190
+ print("TEST COMPLETE")
191
+ print("=" * 60)