shesha-geometry 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,396 @@
1
+ Metadata-Version: 2.1
2
+ Name: shesha-geometry
3
+ Version: 0.1.0
4
+ Summary: Self-consistency metrics for representational stability analysis
5
+ Author-email: Prashant Raju <rajuprashant@gmail.com>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/prashantcraju/shesha
8
+ Project-URL: Documentation, https://github.com/prashantcraju/shesha#readme
9
+ Project-URL: Repository, https://github.com/prashantcraju/shesha
10
+ Project-URL: Issues, https://github.com/prashantcraju/shesha/issues
11
+ Keywords: representation learning,neural networks,geometric stability,geometric analysis,manifold analysis,latent space,single-cell,crispr,perturb-seq,anndata,functional genomics,phenotypic stability,computational biology,scanpy,llm alignment,concept drift,model steering,ai safety,constitutional ai,perturbation analysis
12
+ Classifier: Development Status :: 4 - Beta
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Classifier: Programming Language :: Python :: 3
16
+ Classifier: Programming Language :: Python :: 3.8
17
+ Classifier: Programming Language :: Python :: 3.9
18
+ Classifier: Programming Language :: Python :: 3.10
19
+ Classifier: Programming Language :: Python :: 3.11
20
+ Classifier: Programming Language :: Python :: 3.12
21
+ Classifier: Topic :: Scientific/Engineering
22
+ Requires-Python: >=3.8
23
+ Description-Content-Type: text/markdown
24
+ License-File: LICENSE
25
+ Requires-Dist: numpy>=1.20
26
+ Requires-Dist: scipy>=1.7
27
+ Requires-Dist: anndata>=0.8
28
+ Provides-Extra: dev
29
+ Requires-Dist: pytest>=7.0; extra == "dev"
30
+ Requires-Dist: pytest-cov; extra == "dev"
31
+ Requires-Dist: black; extra == "dev"
32
+ Requires-Dist: ruff; extra == "dev"
33
+
34
+ [![DOI](https://zenodo.org/badge/1133185691.svg)](https://doi.org/10.5281/zenodo.18227453)
35
+ <p align="center">
36
+ <img src="https://i.imgur.com/oJ5YhBo.jpg" alt="Shesha Logo" width="300">
37
+ </p>
38
+
39
+ # Shesha
40
+
41
+ Self-consistency metrics for representational stability analysis.
42
+
43
+ Shesha measures the geometric stability of high-dimensional representations by quantifying the self-consistency of their pairwise distance structure (RDMs) under controlled internal perturbations.
44
+
45
+ ## Installation
46
+
47
+ ```bash
48
+ pip install shesha-geometry
49
+ ```
50
+
51
+ ## Quick Start
52
+
53
+ ```python
54
+ import numpy as np
55
+ import shesha
56
+
57
+ # Your embeddings: (n_samples, n_features)
58
+ X = np.random.randn(500, 768)
59
+
60
+ # Feature-split stability (unsupervised)
61
+ stability = shesha.feature_split(X, n_splits=30, seed=320)
62
+ print(f"Feature-split stability: {stability:.3f}")
63
+ ```
64
+
65
+ With labels:
66
+
67
+ ```python
68
+ y = np.random.randint(0, 10, 500)
69
+ alignment = shesha.supervised_alignment(X, y)
70
+ print(f"Supervised alignment: {alignment:.3f}")
71
+ ```
72
+
73
+ Measuring drift between representations:
74
+
75
+ ```python
76
+ X_before = np.random.randn(100, 256)
77
+ X_after = X_before + np.random.randn(100, 256) * 0.3 # Add noise
78
+
79
+ # Compare before/after fine-tuning
80
+ similarity = shesha.rdm_similarity(X_before, X_after)
81
+ drift = shesha.rdm_drift(X_before, X_after)
82
+ print(f"RDM similarity: {similarity:.3f}, drift: {drift:.3f}")
83
+ ```
84
+
85
+
86
+ ## Variants
87
+
88
+ ### Unsupervised (no labels required)
89
+
90
+ **`feature_split(X, n_splits=30, metric='cosine', seed=None)`**
91
+
92
+ Correlates RDMs from random feature partitions. Use for internal consistency and drift detection.
93
+
94
+ **`sample_split(X, n_splits=30, subsample_fraction=0.4, seed=None)`**
95
+
96
+ Correlates RDMs from bootstrap samples. Use for robustness to sampling.
97
+
98
+ **`anchor_stability(X, n_splits=30, n_anchors=100, seed=None)`**
99
+
100
+ Distance profile consistency from fixed anchors. Use for large-scale stability.
101
+
102
+ ### Supervised (labels required)
103
+
104
+ **`variance_ratio(X, y)`**
105
+
106
+ Between-class / total variance. Use for quick separability check.
107
+
108
+ **`supervised_alignment(X, y, metric='correlation', seed=None)`**
109
+
110
+ Correlation with ideal label RDM. Use for task alignment.
111
+
112
+ ### Drift Metrics (comparing two representations)
113
+
114
+ **`rdm_similarity(X, Y, method='spearman', metric='cosine')`**
115
+
116
+ RDM correlation between two representations. Use for comparing models or tracking changes.
117
+
118
+ **`rdm_drift(X, Y, method='spearman', metric='cosine')`**
119
+
120
+ Representational drift (1 - similarity). Use for quantifying how much geometry has changed.
121
+
122
+ ## Examples
123
+
124
+ ### Comparing model stability
125
+
126
+ ```python
127
+ import numpy as np
128
+ import shesha
129
+
130
+ # Example embeddings from two different models
131
+ embeddings_a = np.random.randn(500, 768) # Model A embeddings
132
+ embeddings_b = np.random.randn(500, 768) # Model B embeddings
133
+
134
+ models = {'model_a': embeddings_a, 'model_b': embeddings_b}
135
+
136
+ for name, X in models.items():
137
+ fs = shesha.feature_split(X, seed=320)
138
+ print(f"{name}: {fs:.3f}")
139
+ ```
140
+
141
+ ### Monitoring fine-tuning drift
142
+
143
+ ```python
144
+ import shesha
145
+
146
+ X_initial = model.encode(data)
147
+
148
+ for epoch in range(10):
149
+ train_one_epoch(model)
150
+ X_current = model.encode(data)
151
+
152
+ # Internal stability
153
+ stability = shesha.feature_split(X_current, seed=320)
154
+
155
+ # Drift from initial
156
+ drift = shesha.rdm_drift(X_initial, X_current)
157
+
158
+ print(f"Epoch {epoch}: stability={stability:.3f}, drift={drift:.3f}")
159
+ ```
160
+
161
+ ### Comparing two models
162
+
163
+ ```python
164
+ import shesha
165
+
166
+ X_model1 = model1.encode(data)
167
+ X_model2 = model2.encode(data)
168
+
169
+ # How similar are their geometric structures?
170
+ similarity = shesha.rdm_similarity(X_model1, X_model2)
171
+ print(f"Model similarity: {similarity:.3f}")
172
+ ```
173
+
174
+ ### Analyzing single-cell perturbations
175
+
176
+ Measure the geometric consistency of CRISPR/drug screens directly from `AnnData` objects:
177
+
178
+ ```python
179
+ import numpy as np
180
+ from shesha.bio import compute_stability, compute_magnitude
181
+ from anndata import AnnData
182
+
183
+
184
+ # 1. Setup mock single-cell data (1000 cells, 50 PCA features)
185
+ n_cells = 1000
186
+ n_genes = 2000 # Original feature space (genes)
187
+ n_pcs = 50
188
+
189
+ # Create a mock AnnData object
190
+ # Note: Shesha works best on PCA coordinates (latent space), not raw counts
191
+ adata = AnnData(X=np.random.randn(n_cells, n_genes)) # Raw counts (unused)
192
+ adata.obsm['X_pca'] = np.random.randn(n_cells, n_pcs) # PCA embeddings
193
+ adata.obs['guide_id'] = ['NT'] * 800 + ['KLF1'] * 200 # Metadata
194
+
195
+
196
+ # Create a proxy for PCA coordinates (Recommended for robust geometry)
197
+ adata_pca = AnnData(X=adata.obsm['X_pca'], obs=adata.obs)
198
+
199
+
200
+ # Compute Stability (Consistency of the phenotype)
201
+ stability = compute_stability(
202
+ adata_pca,
203
+ perturbation_key='guide_id',
204
+ control_label='NT',
205
+ metric='cosine'
206
+ )
207
+
208
+ # Compute Magnitude (Strength of the phenotype)
209
+ magnitude = compute_magnitude(
210
+ adata_pca,
211
+ perturbation_key='guide_id',
212
+ control_label='NT',
213
+ metric='euclidean'
214
+ )
215
+
216
+ print(f"KLF1 Stability: {stability['KLF1']:.3f}") # e.g., 0.85 (High = Consistent)
217
+ print(f"KLF1 Magnitude: {magnitude['KLF1']:.3f}") # e.g., 2.40 (High = Strong)
218
+ ```
219
+
220
+
221
+ ## API Reference
222
+
223
+ ### `shesha.feature_split(X, n_splits=30, metric='cosine', seed=None, max_samples=1600)`
224
+
225
+ Measures internal geometric consistency by correlating RDMs computed from random, disjoint subsets of feature dimensions.
226
+
227
+ **Parameters:**
228
+ - `X` - array of shape (n_samples, n_features)
229
+ - `n_splits` - number of random partitions to average
230
+ - `metric` - 'cosine' or 'correlation'
231
+ - `seed` - random seed for reproducibility
232
+ - `max_samples` - subsample if exceeded
233
+
234
+ **Returns:** float in [-1, 1], higher = more stable
235
+
236
+ ### `shesha.sample_split(X, n_splits=30, subsample_fraction=0.4, metric='cosine', seed=None, max_samples=1500)`
237
+
238
+ Measures robustness to input variation via bootstrap resampling.
239
+
240
+ **Parameters:**
241
+ - `X` - array of shape (n_samples, n_features)
242
+ - `n_splits` - number of bootstrap iterations
243
+ - `subsample_fraction` - fraction of samples per bootstrap
244
+ - `metric` - 'cosine' or 'correlation'
245
+ - `seed` - random seed for reproducibility
246
+ - `max_samples` - subsample if exceeded
247
+
248
+ **Returns:** float in [-1, 1], higher = more stable
249
+
250
+ ### `shesha.anchor_stability(X, n_splits=30, n_anchors=100, n_per_split=200, metric='cosine', rank_normalize=True, seed=None, max_samples=1500)`
251
+
252
+ Measures stability of distance profiles from fixed anchor points.
253
+
254
+ **Parameters:**
255
+ - `X` - array of shape (n_samples, n_features)
256
+ - `n_splits` - number of random splits
257
+ - `n_anchors` - number of fixed anchor points
258
+ - `n_per_split` - samples per split
259
+ - `metric` - 'cosine' or 'euclidean'
260
+ - `rank_normalize` - rank-normalize distances within each anchor
261
+ - `seed` - random seed for reproducibility
262
+ - `max_samples` - subsample if exceeded
263
+
264
+ **Returns:** float in [-1, 1], higher = more stable
265
+
266
+ ### `shesha.variance_ratio(X, y)`
267
+
268
+ Ratio of between-class to total variance.
269
+
270
+ **Parameters:**
271
+ - `X` - array of shape (n_samples, n_features)
272
+ - `y` - array of shape (n_samples,) with class labels
273
+
274
+ **Returns:** float in [0, 1], higher = better class separation
275
+
276
+ ### `shesha.supervised_alignment(X, y, metric='correlation', seed=None, max_samples=300)`
277
+
278
+ Spearman correlation between model RDM and ideal label-based RDM.
279
+
280
+ **Parameters:**
281
+ - `X` - array of shape (n_samples, n_features)
282
+ - `y` - array of shape (n_samples,) with class labels
283
+ - `metric` - 'cosine' or 'correlation'
284
+ - `seed` - random seed for reproducibility
285
+ - `max_samples` - subsample if exceeded (RDM is O(n^2))
286
+
287
+ **Returns:** float in [-1, 1], higher = better task alignment
288
+
289
+ ### `shesha.rdm_similarity(X, Y, method='spearman', metric='cosine')`
290
+
291
+ Computes RDM correlation between two representations. Useful for comparing models, tracking drift during training, or measuring the effect of interventions.
292
+
293
+ **Parameters:**
294
+ - `X` - array of shape (n_samples, n_features_x), first representation
295
+ - `Y` - array of shape (n_samples, n_features_y), second representation (same n_samples)
296
+ - `method` - 'spearman' (rank-based, default) or 'pearson' (linear)
297
+ - `metric` - 'cosine', 'correlation', or 'euclidean'
298
+
299
+ **Returns:** float in [-1, 1], higher = more similar geometric structure
300
+
301
+ ### `shesha.rdm_drift(X, Y, method='spearman', metric='cosine')`
302
+
303
+ Computes representational drift as 1 - rdm_similarity. Useful for quantifying how much a representation has changed.
304
+
305
+ **Parameters:**
306
+ - `X` - array of shape (n_samples, n_features_x), baseline representation
307
+ - `Y` - array of shape (n_samples, n_features_y), comparison representation
308
+ - `method` - 'spearman' (rank-based, default) or 'pearson' (linear)
309
+ - `metric` - 'cosine', 'correlation', or 'euclidean'
310
+
311
+ **Returns:** float in [0, 2], where 0 = identical, 1 = uncorrelated, 2 = inverted
312
+
313
+ ## Biological Perturbation Analysis
314
+
315
+ The `shesha.bio` module provides metrics for single-cell perturbation experiments (e.g., Perturb-seq, CRISPR screens).
316
+
317
+ ### `shesha.bio.perturbation_stability(X_control, X_perturbed, metric='cosine', seed=None, max_samples=1000)`
318
+
319
+ Measures consistency of perturbation effects across samples. High values indicate coherent, reproducible perturbation effects.
320
+
321
+ **Parameters:**
322
+ - `X_control` - array of shape (n_control, n_features), control population
323
+ - `X_perturbed` - array of shape (n_perturbed, n_features), perturbed population
324
+ - `metric` - 'cosine' (default) or 'euclidean'
325
+ - `seed` - random seed for reproducibility
326
+ - `max_samples` - subsample perturbed population if exceeded
327
+
328
+ **Returns:** float in [-1, 1], higher = more consistent perturbation
329
+
330
+ ### `shesha.bio.perturbation_effect_size(X_control, X_perturbed)`
331
+
332
+ Cohen's d-like effect size measuring magnitude of perturbation shift.
333
+
334
+ **Parameters:**
335
+ - `X_control` - array of shape (n_control, n_features)
336
+ - `X_perturbed` - array of shape (n_perturbed, n_features)
337
+
338
+ **Returns:** float >= 0, higher = larger perturbation effect
339
+
340
+ ### Scanpy / AnnData Integration
341
+
342
+ For single-cell analysis, Shesha provides high-level wrappers that work directly with `AnnData` objects.
343
+
344
+ ### `shesha.bio.compute_stability(adata, perturbation_key, control_label, layer=None, metric='cosine')`
345
+
346
+ Computes the geometric stability for every perturbation in the dataset.
347
+
348
+ **Parameters:**
349
+ - `adata` - AnnData object.
350
+ - `perturbation_key` - Column in `adata.obs` identifying the perturbation (e.g., `'guide_id'`).
351
+ - `control_label` - The label in that column representing control cells (e.g., `'NT'`).
352
+ - `layer` - (Optional) Layer to use (e.g., `'pca'`). If None, uses `.X`.
353
+ - `metric` - `'cosine'` (default) or `'euclidean'`.
354
+
355
+ **Returns:** Dictionary `{perturbation_name: stability_score}`.
356
+
357
+ ### `shesha.bio.compute_magnitude(adata, perturbation_key, control_label, layer=None, metric='euclidean')`
358
+
359
+ Computes the magnitude (effect size) for every perturbation.
360
+
361
+ **Parameters:**
362
+ - `adata` - AnnData object.
363
+ - `metric` - `'euclidean'` (default, raw distance) or `'cohen'` (standardized effect size).
364
+
365
+ **Returns:** Dictionary `{perturbation_name: magnitude_score}`.
366
+
367
+
368
+ ## Citation
369
+
370
+ If you use `shesha-geometry`, please cite:
371
+ ```bibtex
372
+ @software{shesha2026,
373
+ title = {Shesha: Self-consistency Metrics for Representational Stability},
374
+ author = {Prashant C. Raju},
375
+ year = {2026},
376
+ url = {https://github.com/prashantcraju/shesha},
377
+ publisher = {Zenodo},
378
+ doi = {10.5281/zenodo.18227454},
379
+ note = {Python package version 0.1.0}
380
+ }
381
+
382
+ @article{raju2026geometric,
383
+ title={Geometric Stability: The Missing Axis of Representations},
384
+ author={Prashant C. Raju},
385
+ journal={arXiv},
386
+ year={2026}
387
+ }
388
+ ```
389
+
390
+ ## License
391
+
392
+ MIT
393
+
394
+ ---
395
+
396
+ <sub>Logo generated by [Nano Banana Pro](https://nanobananapro.com)</sub>
@@ -0,0 +1,13 @@
1
+ examples/tutorial.py,sha256=HoDNK_NgCEW7_PkfOkKswNH2mitCHPJ56yuldiH8skI,12167
2
+ shesha/__init__.py,sha256=KaDsrk_j30OeFn7M70nHmCXp5MJ-cumS-eiT7PZyJQY,1390
3
+ shesha/bio.py,sha256=4fZHfLfkMIbAYEJ80MJreP8Hg9ApoOQUtk0uDuX1ejo,10465
4
+ shesha/core.py,sha256=_M--xVtbOBRujpCeSNu_Ala_gTjDhfJUbv8BvaaUHGM,19986
5
+ tests/__init__.py,sha256=YYq097nXxIO0mWUP8sqw6YS-CSjFRWVEa1lPN7ATA0U,27
6
+ tests/test_bio.py,sha256=71poE9FbfjWnAQFYFHAiPsGbdyGurRmxX6qRxHuEths,2142
7
+ tests/test_core.py,sha256=UmVV56VkmM71ZPySsOEKEG4_sXt9D0dvXX5cfB-EGrI,5545
8
+ tests/test_crispr.py,sha256=tUSOgXLyvd9PkEuXZ7NeqYWMWcu0aUMcXOdP7ErVLMg,6202
9
+ shesha_geometry-0.1.0.dist-info/LICENSE,sha256=HWzfciMdWWnz_MNV5xEni15hw7Jnks3FWBmhRbb-bFk,1072
10
+ shesha_geometry-0.1.0.dist-info/METADATA,sha256=J_LirSGGwV0amvXS5i3F7kP9tgHzFZhanT1eHRvvrE8,13036
11
+ shesha_geometry-0.1.0.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
12
+ shesha_geometry-0.1.0.dist-info/top_level.txt,sha256=30aOCV4WW_bVgWOawbrRZPahW9ilJ-S9jd-oj4-Rucw,22
13
+ shesha_geometry-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (75.3.2)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,3 @@
1
+ examples
2
+ shesha
3
+ tests
tests/__init__.py ADDED
@@ -0,0 +1 @@
1
+ # Tests for shesha package
tests/test_bio.py ADDED
@@ -0,0 +1,49 @@
1
+ import numpy as np
2
+ from shesha.bio import perturbation_stability, perturbation_effect_size
3
+
4
+ # Test 1: Basic functionality
5
+ print("=== Test 1: Basic functionality ===")
6
+ X_ctrl = np.random.randn(200, 50)
7
+ X_pert = np.random.randn(200, 50)
8
+ stability = perturbation_stability(X_ctrl, X_pert, seed=320)
9
+ effect = perturbation_effect_size(X_ctrl, X_pert)
10
+ print(f"Random data - Stability: {stability:.3f}, Effect size: {effect:.3f}")
11
+
12
+ # Test 2: Coherent perturbation (should have HIGH stability)
13
+ print("\n=== Test 2: Coherent perturbation ===")
14
+ X_ctrl = np.random.randn(200, 50)
15
+ shift = np.random.randn(50) * 3 # Same direction for all cells
16
+ X_pert = X_ctrl + shift + np.random.randn(200, 50) * 0.1 # Small noise
17
+ stability = perturbation_stability(X_ctrl, X_pert, seed=320)
18
+ effect = perturbation_effect_size(X_ctrl, X_pert)
19
+ print(f"Coherent shift - Stability: {stability:.3f} (should be >0.8), Effect: {effect:.3f}")
20
+
21
+ # Test 3: Incoherent perturbation (should have LOW stability)
22
+ print("\n=== Test 3: Incoherent perturbation ===")
23
+ X_ctrl = np.random.randn(200, 50)
24
+ X_pert = X_ctrl + np.random.randn(200, 50) # Each cell shifts randomly
25
+ stability = perturbation_stability(X_ctrl, X_pert, seed=320)
26
+ effect = perturbation_effect_size(X_ctrl, X_pert)
27
+ print(f"Random shifts - Stability: {stability:.3f} (should be <0.5), Effect: {effect:.3f}")
28
+
29
+ # Test 4: Large vs small effect
30
+ print("\n=== Test 4: Effect size comparison ===")
31
+ X_ctrl = np.random.randn(200, 50)
32
+ X_pert_small = X_ctrl + 0.1 # Small shift
33
+ X_pert_large = X_ctrl + 5.0 # Large shift
34
+ effect_small = perturbation_effect_size(X_ctrl, X_pert_small)
35
+ effect_large = perturbation_effect_size(X_ctrl, X_pert_large)
36
+ print(f"Small shift effect: {effect_small:.3f}")
37
+ print(f"Large shift effect: {effect_large:.3f} (should be >> small)")
38
+
39
+ # Test 5: Determinism
40
+ print("\n=== Test 5: Determinism ===")
41
+ X_ctrl = np.random.randn(200, 50)
42
+ X_pert = np.random.randn(200, 50)
43
+ r1 = perturbation_stability(X_ctrl, X_pert, seed=320)
44
+ r2 = perturbation_stability(X_ctrl, X_pert, seed=320)
45
+ print(f"Run 1: {r1:.6f}")
46
+ print(f"Run 2: {r2:.6f}")
47
+ print(f"Deterministic: {r1 == r2}")
48
+
49
+ print("\n=== All tests complete ===")
tests/test_core.py ADDED
@@ -0,0 +1,164 @@
1
+ """Basic tests for Shesha metrics."""
2
+
3
+ import numpy as np
4
+ import pytest
5
+ import shesha
6
+
7
+
8
+ class TestFeatureSplit:
9
+ """Tests for feature_split variant."""
10
+
11
+ def test_basic_usage(self):
12
+ """Test basic functionality."""
13
+ X = np.random.randn(100, 64)
14
+ result = shesha.feature_split(X, n_splits=10, seed=320)
15
+ assert isinstance(result, float)
16
+ assert -1 <= result <= 1
17
+
18
+ def test_determinism(self):
19
+ """Test reproducibility with seed."""
20
+ X = np.random.randn(100, 64)
21
+ r1 = shesha.feature_split(X, seed=320)
22
+ r2 = shesha.feature_split(X, seed=320)
23
+ assert r1 == r2
24
+
25
+ def test_structured_data_high_stability(self):
26
+ """Structured data should have high stability."""
27
+ # Create data with redundant structure
28
+ latent = np.random.randn(100, 10)
29
+ projection = np.random.randn(10, 128)
30
+ X = latent @ projection
31
+
32
+ result = shesha.feature_split(X, n_splits=30, seed=320)
33
+ assert result > 0.5, "Structured data should have high stability"
34
+
35
+ def test_random_data_low_stability(self):
36
+ """Random iid features should have lower stability."""
37
+ X = np.random.randn(100, 128)
38
+ result = shesha.feature_split(X, n_splits=30, seed=320)
39
+ # Random data can still have moderate stability, just check it runs
40
+ assert -1 <= result <= 1
41
+
42
+ def test_too_few_features(self):
43
+ """Should return NaN with too few features."""
44
+ X = np.random.randn(100, 3)
45
+ result = shesha.feature_split(X)
46
+ assert np.isnan(result)
47
+
48
+ def test_too_few_samples(self):
49
+ """Should return NaN with too few samples."""
50
+ X = np.random.randn(3, 64)
51
+ result = shesha.feature_split(X)
52
+ assert np.isnan(result)
53
+
54
+
55
+ class TestSampleSplit:
56
+ """Tests for sample_split variant."""
57
+
58
+ def test_basic_usage(self):
59
+ X = np.random.randn(200, 64)
60
+ result = shesha.sample_split(X, n_splits=10, seed=320)
61
+ assert isinstance(result, float)
62
+ assert -1 <= result <= 1
63
+
64
+ def test_determinism(self):
65
+ X = np.random.randn(200, 64)
66
+ r1 = shesha.sample_split(X, seed=320)
67
+ r2 = shesha.sample_split(X, seed=320)
68
+ assert r1 == r2
69
+
70
+
71
+ class TestAnchorStability:
72
+ """Tests for anchor_stability variant."""
73
+
74
+ def test_basic_usage(self):
75
+ X = np.random.randn(500, 64)
76
+ result = shesha.anchor_stability(X, n_splits=10, seed=320)
77
+ assert isinstance(result, float)
78
+ assert -1 <= result <= 1
79
+
80
+ def test_small_data_handling(self):
81
+ """Should handle small datasets gracefully."""
82
+ X = np.random.randn(50, 64)
83
+ result = shesha.anchor_stability(X, seed=320)
84
+ # Should either return valid result or NaN, not crash
85
+ assert np.isnan(result) or (-1 <= result <= 1)
86
+
87
+
88
+ class TestVarianceRatio:
89
+ """Tests for variance_ratio variant."""
90
+
91
+ def test_basic_usage(self):
92
+ X = np.random.randn(100, 64)
93
+ y = np.random.randint(0, 5, 100)
94
+ result = shesha.variance_ratio(X, y)
95
+ assert isinstance(result, float)
96
+ assert 0 <= result <= 1
97
+
98
+ def test_perfect_separation(self):
99
+ """Perfectly separated classes should have high ratio."""
100
+ # Create well-separated clusters
101
+ X = np.vstack([
102
+ np.random.randn(50, 64) + np.array([10] * 64),
103
+ np.random.randn(50, 64) + np.array([-10] * 64),
104
+ ])
105
+ y = np.array([0] * 50 + [1] * 50)
106
+
107
+ result = shesha.variance_ratio(X, y)
108
+ assert result > 0.8, "Well-separated classes should have high variance ratio"
109
+
110
+ def test_single_class(self):
111
+ """Single class should return NaN."""
112
+ X = np.random.randn(100, 64)
113
+ y = np.zeros(100)
114
+ result = shesha.variance_ratio(X, y)
115
+ assert np.isnan(result)
116
+
117
+
118
+ class TestSupervisedAlignment:
119
+ """Tests for supervised_alignment variant."""
120
+
121
+ def test_basic_usage(self):
122
+ X = np.random.randn(100, 64)
123
+ y = np.random.randint(0, 5, 100)
124
+ result = shesha.supervised_alignment(X, y, seed=320)
125
+ assert isinstance(result, float)
126
+ assert -1 <= result <= 1
127
+
128
+ def test_determinism(self):
129
+ X = np.random.randn(100, 64)
130
+ y = np.random.randint(0, 5, 100)
131
+ r1 = shesha.supervised_alignment(X, y, seed=320)
132
+ r2 = shesha.supervised_alignment(X, y, seed=320)
133
+ assert r1 == r2
134
+
135
+
136
+ class TestUnifiedInterface:
137
+ """Tests for the unified shesha() function."""
138
+
139
+ def test_feature_split_variant(self):
140
+ X = np.random.randn(100, 64)
141
+ result = shesha.shesha(X, variant='feature_split', seed=320)
142
+ expected = shesha.feature_split(X, seed=320)
143
+ assert result == expected
144
+
145
+ def test_supervised_variant(self):
146
+ X = np.random.randn(100, 64)
147
+ y = np.random.randint(0, 5, 100)
148
+ result = shesha.shesha(X, y, variant='supervised', seed=320)
149
+ expected = shesha.supervised_alignment(X, y, seed=320)
150
+ assert result == expected
151
+
152
+ def test_missing_labels_error(self):
153
+ X = np.random.randn(100, 64)
154
+ with pytest.raises(ValueError, match="Labels required"):
155
+ shesha.shesha(X, variant='supervised')
156
+
157
+ def test_unknown_variant_error(self):
158
+ X = np.random.randn(100, 64)
159
+ with pytest.raises(ValueError, match="Unknown variant"):
160
+ shesha.shesha(X, variant='nonexistent')
161
+
162
+
163
+ if __name__ == "__main__":
164
+ pytest.main([__file__, "-v"])