sclsd 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. sclsd-0.1.0/.gitignore +112 -0
  2. sclsd-0.1.0/LICENSE +21 -0
  3. sclsd-0.1.0/PKG-INFO +331 -0
  4. sclsd-0.1.0/README.md +268 -0
  5. sclsd-0.1.0/notebooks/01_quickstart.ipynb +387 -0
  6. sclsd-0.1.0/notebooks/bonemarrow/postprocessing.ipynb +516 -0
  7. sclsd-0.1.0/notebooks/bonemarrow/train.ipynb +181 -0
  8. sclsd-0.1.0/notebooks/cancer/plasticity_scores.tsv +56150 -0
  9. sclsd-0.1.0/notebooks/cancer/train.ipynb +477 -0
  10. sclsd-0.1.0/notebooks/dentategyrus/postprocessing.ipynb +508 -0
  11. sclsd-0.1.0/notebooks/dentategyrus/train.ipynb +181 -0
  12. sclsd-0.1.0/notebooks/erythroid/postprocessing.ipynb +505 -0
  13. sclsd-0.1.0/notebooks/erythroid/train.ipynb +181 -0
  14. sclsd-0.1.0/notebooks/mouse_cortex/train.ipynb +550 -0
  15. sclsd-0.1.0/notebooks/pancreas/postprocessing.ipynb +485 -0
  16. sclsd-0.1.0/notebooks/pancreas/train.ipynb +391 -0
  17. sclsd-0.1.0/notebooks/tutorial/prior_pseudotime_train.ipynb +402 -0
  18. sclsd-0.1.0/notebooks/unseen_bonemarrow/train.ipynb +416 -0
  19. sclsd-0.1.0/notebooks/zebrafish/postprocessing.ipynb +763 -0
  20. sclsd-0.1.0/notebooks/zebrafish/train.ipynb +224 -0
  21. sclsd-0.1.0/pyproject.toml +137 -0
  22. sclsd-0.1.0/src/sclsd/__init__.py +140 -0
  23. sclsd-0.1.0/src/sclsd/_version.py +4 -0
  24. sclsd-0.1.0/src/sclsd/analysis/__init__.py +15 -0
  25. sclsd-0.1.0/src/sclsd/analysis/metrics.py +319 -0
  26. sclsd-0.1.0/src/sclsd/core/__init__.py +41 -0
  27. sclsd-0.1.0/src/sclsd/core/config.py +311 -0
  28. sclsd-0.1.0/src/sclsd/core/model.py +449 -0
  29. sclsd-0.1.0/src/sclsd/core/networks.py +415 -0
  30. sclsd-0.1.0/src/sclsd/plotting/__init__.py +15 -0
  31. sclsd-0.1.0/src/sclsd/plotting/components.py +122 -0
  32. sclsd-0.1.0/src/sclsd/plotting/streamlines.py +200 -0
  33. sclsd-0.1.0/src/sclsd/plotting/walks.py +297 -0
  34. sclsd-0.1.0/src/sclsd/preprocessing/__init__.py +19 -0
  35. sclsd-0.1.0/src/sclsd/preprocessing/data.py +155 -0
  36. sclsd-0.1.0/src/sclsd/preprocessing/prior.py +695 -0
  37. sclsd-0.1.0/src/sclsd/py.typed +0 -0
  38. sclsd-0.1.0/src/sclsd/train/__init__.py +6 -0
  39. sclsd-0.1.0/src/sclsd/train/trainer.py +905 -0
  40. sclsd-0.1.0/src/sclsd/train/walks.py +150 -0
  41. sclsd-0.1.0/src/sclsd/utils/__init__.py +5 -0
  42. sclsd-0.1.0/src/sclsd/utils/io.py +80 -0
  43. sclsd-0.1.0/src/sclsd/utils/seed.py +111 -0
  44. sclsd-0.1.0/tests/__init__.py +1 -0
  45. sclsd-0.1.0/tests/conftest.py +125 -0
  46. sclsd-0.1.0/tests/fixtures/__init__.py +21 -0
  47. sclsd-0.1.0/tests/fixtures/reference_data/baseline_generation.log +6201 -0
  48. sclsd-0.1.0/tests/fixtures/reference_data/bonemarrow/cell_rep.npy +0 -0
  49. sclsd-0.1.0/tests/fixtures/reference_data/bonemarrow/config.json +17 -0
  50. sclsd-0.1.0/tests/fixtures/reference_data/bonemarrow/diff_rep.npy +0 -0
  51. sclsd-0.1.0/tests/fixtures/reference_data/bonemarrow/entropy.npy +0 -0
  52. sclsd-0.1.0/tests/fixtures/reference_data/bonemarrow/potential.npy +0 -0
  53. sclsd-0.1.0/tests/fixtures/reference_data/bonemarrow/pseudotime.npy +0 -0
  54. sclsd-0.1.0/tests/fixtures/reference_data/bonemarrow/transitions.npz +0 -0
  55. sclsd-0.1.0/tests/fixtures/reference_data/pancreas/cbdir_scores.json +1 -0
  56. sclsd-0.1.0/tests/fixtures/reference_data/pancreas/cell_rep.npy +0 -0
  57. sclsd-0.1.0/tests/fixtures/reference_data/pancreas/config.json +90 -0
  58. sclsd-0.1.0/tests/fixtures/reference_data/pancreas/diff_rep.npy +0 -0
  59. sclsd-0.1.0/tests/fixtures/reference_data/pancreas/entropy.npy +0 -0
  60. sclsd-0.1.0/tests/fixtures/reference_data/pancreas/lsdpy_cell_state.npy +0 -0
  61. sclsd-0.1.0/tests/fixtures/reference_data/pancreas/lsdpy_diff_state.npy +0 -0
  62. sclsd-0.1.0/tests/fixtures/reference_data/pancreas/lsdpy_entropy.npy +0 -0
  63. sclsd-0.1.0/tests/fixtures/reference_data/pancreas/lsdpy_potential.npy +0 -0
  64. sclsd-0.1.0/tests/fixtures/reference_data/pancreas/lsdpy_pseudotime.npy +0 -0
  65. sclsd-0.1.0/tests/fixtures/reference_data/pancreas/potential.npy +0 -0
  66. sclsd-0.1.0/tests/fixtures/reference_data/pancreas/pseudotime.npy +0 -0
  67. sclsd-0.1.0/tests/fixtures/reference_data/pancreas/transitions.npz +0 -0
  68. sclsd-0.1.0/tests/fixtures/reference_data/tutorial/cell_state.npy +0 -0
  69. sclsd-0.1.0/tests/fixtures/reference_data/tutorial/config.json +35 -0
  70. sclsd-0.1.0/tests/fixtures/reference_data/tutorial/diff_state.npy +0 -0
  71. sclsd-0.1.0/tests/fixtures/reference_data/tutorial/entropy.npy +0 -0
  72. sclsd-0.1.0/tests/fixtures/reference_data/tutorial/lsd_pseudotime.npy +0 -0
  73. sclsd-0.1.0/tests/fixtures/reference_data/tutorial/lsdpy_cell_state.npy +0 -0
  74. sclsd-0.1.0/tests/fixtures/reference_data/tutorial/lsdpy_diff_state.npy +0 -0
  75. sclsd-0.1.0/tests/fixtures/reference_data/tutorial/lsdpy_entropy.npy +0 -0
  76. sclsd-0.1.0/tests/fixtures/reference_data/tutorial/lsdpy_potential.npy +0 -0
  77. sclsd-0.1.0/tests/fixtures/reference_data/tutorial/lsdpy_pseudotime.npy +0 -0
  78. sclsd-0.1.0/tests/fixtures/reference_data/tutorial/model/loss_curves.png +0 -0
  79. sclsd-0.1.0/tests/fixtures/reference_data/tutorial/potential.npy +0 -0
  80. sclsd-0.1.0/tests/fixtures/reference_data/tutorial/transitions.npz +0 -0
  81. sclsd-0.1.0/tests/fixtures/reference_data.py +157 -0
  82. sclsd-0.1.0/tests/fixtures/synthetic_data.py +156 -0
  83. sclsd-0.1.0/tests/integration/__init__.py +1 -0
  84. sclsd-0.1.0/tests/integration/test_full_pipeline.py +268 -0
  85. sclsd-0.1.0/tests/reproducibility/__init__.py +1 -0
  86. sclsd-0.1.0/tests/reproducibility/test_parity_all_datasets.py +430 -0
  87. sclsd-0.1.0/tests/reproducibility/test_pyro_sample_order.py +235 -0
  88. sclsd-0.1.0/tests/test_model.py +177 -0
  89. sclsd-0.1.0/tests/test_reproducibility.py +217 -0
  90. sclsd-0.1.0/tests/unit/__init__.py +1 -0
  91. sclsd-0.1.0/tests/unit/test_config.py +473 -0
  92. sclsd-0.1.0/tests/unit/test_metrics.py +304 -0
  93. sclsd-0.1.0/tests/unit/test_model_extended.py +483 -0
  94. sclsd-0.1.0/tests/unit/test_networks.py +533 -0
  95. sclsd-0.1.0/tests/unit/test_trainer.py +474 -0
  96. sclsd-0.1.0/tests/unit/test_walks.py +248 -0
sclsd-0.1.0/.gitignore ADDED
@@ -0,0 +1,112 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ *.manifest
31
+ *.spec
32
+
33
+ # Installer logs
34
+ pip-log.txt
35
+ pip-delete-this-directory.txt
36
+
37
+ # Unit test / coverage reports
38
+ htmlcov/
39
+ .tox/
40
+ .nox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ *.py,cover
48
+ .hypothesis/
49
+ .pytest_cache/
50
+
51
+ # Translations
52
+ *.mo
53
+ *.pot
54
+
55
+ # Sphinx documentation
56
+ docs/_build/
57
+
58
+ # PyBuilder
59
+ .pybuilder/
60
+ target/
61
+
62
+ # Jupyter Notebook
63
+ .ipynb_checkpoints
64
+
65
+ # IPython
66
+ profile_default/
67
+ ipython_config.py
68
+
69
+ # pyenv
70
+ .python-version
71
+
72
+ # Environments
73
+ .env
74
+ .venv
75
+ env/
76
+ venv/
77
+ ENV/
78
+ env.bak/
79
+ venv.bak/
80
+
81
+ # Spyder project settings
82
+ .spyderproject
83
+ .spyproject
84
+
85
+ # Rope project settings
86
+ .ropeproject
87
+
88
+ # mypy
89
+ .mypy_cache/
90
+ .dmypy.json
91
+ dmypy.json
92
+
93
+ # Ruff
94
+ .ruff_cache/
95
+
96
+ # IDE
97
+ .idea/
98
+ .vscode/
99
+ *.swp
100
+ *.swo
101
+
102
+ # OS
103
+ .DS_Store
104
+ Thumbs.db
105
+
106
+ # Project specific
107
+ data/
108
+ Models/
109
+ checkpoints/
110
+ *.h5ad
111
+ *.pth
112
+ baseline_results/
sclsd-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 LSD Development Team
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
sclsd-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,331 @@
1
+ Metadata-Version: 2.4
2
+ Name: sclsd
3
+ Version: 0.1.0
4
+ Summary: Latent State Dynamics (LSD) for single-cell trajectory inference via neural ODE gradient flow
5
+ Project-URL: Homepage, https://github.com/csglab/sclsd
6
+ Project-URL: Documentation, https://sclsd.readthedocs.io
7
+ Project-URL: Repository, https://github.com/csglab/sclsd
8
+ Project-URL: Issues, https://github.com/csglab/sclsd/issues
9
+ Author: LSD Development Team
10
+ License-Expression: MIT
11
+ License-File: LICENSE
12
+ Keywords: Waddington-landscape,cell-differentiation,deep-learning,neural-ode,pyro,single-cell,trajectory-inference,variational-inference
13
+ Classifier: Development Status :: 4 - Beta
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: License :: OSI Approved :: MIT License
16
+ Classifier: Operating System :: OS Independent
17
+ Classifier: Programming Language :: Python :: 3
18
+ Classifier: Programming Language :: Python :: 3.9
19
+ Classifier: Programming Language :: Python :: 3.10
20
+ Classifier: Programming Language :: Python :: 3.11
21
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
+ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
23
+ Requires-Python: >=3.9
24
+ Requires-Dist: anndata>=0.9.0
25
+ Requires-Dist: cellrank>=2.0.0
26
+ Requires-Dist: leidenalg
27
+ Requires-Dist: matplotlib>=3.7.0
28
+ Requires-Dist: numpy<2.0.0,>=1.23.0
29
+ Requires-Dist: pandas>=2.0.0
30
+ Requires-Dist: pyro-ppl>=1.8.0
31
+ Requires-Dist: scanpy>=1.9.0
32
+ Requires-Dist: scikit-learn>=1.3.0
33
+ Requires-Dist: scipy>=1.10.0
34
+ Requires-Dist: seaborn>=0.12.0
35
+ Requires-Dist: torch>=2.0.0
36
+ Requires-Dist: torchdiffeq>=0.2.0
37
+ Requires-Dist: tqdm>=4.60.0
38
+ Requires-Dist: umap-learn>=0.5.0
39
+ Provides-Extra: all
40
+ Requires-Dist: black>=23.0.0; extra == 'all'
41
+ Requires-Dist: mypy>=1.0.0; extra == 'all'
42
+ Requires-Dist: myst-parser>=2.0.0; extra == 'all'
43
+ Requires-Dist: nbsphinx>=0.9.0; extra == 'all'
44
+ Requires-Dist: pre-commit>=3.0.0; extra == 'all'
45
+ Requires-Dist: pytest-cov>=4.0.0; extra == 'all'
46
+ Requires-Dist: pytest>=7.0.0; extra == 'all'
47
+ Requires-Dist: ruff>=0.1.0; extra == 'all'
48
+ Requires-Dist: sphinx-rtd-theme>=1.0.0; extra == 'all'
49
+ Requires-Dist: sphinx>=6.0.0; extra == 'all'
50
+ Provides-Extra: dev
51
+ Requires-Dist: black>=23.0.0; extra == 'dev'
52
+ Requires-Dist: mypy>=1.0.0; extra == 'dev'
53
+ Requires-Dist: pre-commit>=3.0.0; extra == 'dev'
54
+ Requires-Dist: pytest-cov>=4.0.0; extra == 'dev'
55
+ Requires-Dist: pytest>=7.0.0; extra == 'dev'
56
+ Requires-Dist: ruff>=0.1.0; extra == 'dev'
57
+ Provides-Extra: docs
58
+ Requires-Dist: myst-parser>=2.0.0; extra == 'docs'
59
+ Requires-Dist: nbsphinx>=0.9.0; extra == 'docs'
60
+ Requires-Dist: sphinx-rtd-theme>=1.0.0; extra == 'docs'
61
+ Requires-Dist: sphinx>=6.0.0; extra == 'docs'
62
+ Description-Content-Type: text/markdown
63
+
64
+ # LSDpy
65
+
66
+ **Latent State Dynamics for Single-Cell Trajectory Inference**
67
+
68
+ [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/)
69
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
70
+ [![PyPI version](https://badge.fury.io/py/sclsd.svg)](https://badge.fury.io/py/sclsd)
71
+
72
+ LSDpy is a deep learning framework for inferring cell differentiation trajectories from single-cell RNA sequencing data. It combines neural ODEs with variational inference to model the Waddington landscape of cellular differentiation.
73
+
74
+ ## Key Features
75
+
76
+ - **Neural ODE Dynamics**: Model cell state evolution as gradient flow in a learned potential landscape
77
+ - **Variational Inference**: Probabilistic modeling with Pyro for uncertainty quantification
78
+ - **Trajectory Inference**: Infer pseudotime and cell fate predictions
79
+ - **Reproducible**: Comprehensive RNG management for identical results across runs
80
+ - **GPU Accelerated**: Full CUDA support for efficient training
81
+
82
+ ## Installation
83
+
84
+ ### From PyPI (recommended)
85
+
86
+ ```bash
87
+ pip install sclsd
88
+ ```
89
+
90
+ ### From Source
91
+
92
+ ```bash
93
+ git clone https://github.com/your-repo/sclsd.git
94
+ cd sclsd
95
+ pip install -e .
96
+ ```
97
+
98
+ ### With Conda Environment
99
+
100
+ ```bash
101
+ conda env create -f environment.yml
102
+ conda activate lsd
103
+ pip install -e .
104
+ ```
105
+
106
+ ## Quick Start
107
+
108
+ ```python
109
+ import scanpy as sc
110
+ import torch
111
+ from sclsd import LSD, LSDConfig, prepare_data_dict
112
+
113
+ # Load and preprocess data
114
+ adata = sc.read("my_data.h5ad")
115
+ data_dict = prepare_data_dict(adata, n_top_genes=5000)
116
+
117
+ # Configure model
118
+ cfg = LSDConfig()
119
+ cfg.walks.path_len = 50
120
+ cfg.walks.num_walks = 10000
121
+ cfg.model.z_dim = 10
122
+ cfg.model.B_dim = 2
123
+
124
+ # Create model
125
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126
+ lsd = LSD(data_dict["adata"], cfg, device=device)
127
+
128
+ # Set prior transition matrix based on pseudotime
129
+ lsd.set_prior_transition(prior_time_key="dpt_pseudotime")
130
+
131
+ # Generate random walks
132
+ lsd.prepare_walks()
133
+
134
+ # Train model
135
+ lsd.train(num_epochs=100, random_state=42)
136
+
137
+ # Get results
138
+ result = lsd.get_adata()
139
+ print(result.obs["lsd_pseudotime"])
140
+ print(result.obsm["cell_rep"]) # Latent cell state
141
+ print(result.obsm["diff_rep"]) # Differentiation state
142
+ ```
143
+
144
+ ## Model Architecture
145
+
146
+ LSDpy models cellular differentiation using:
147
+
148
+ 1. **Cell State Encoder** (`XEncoder`): Maps gene expression to latent cell state `z`
149
+ 2. **Differentiation State Encoder** (`ZEncoder`): Maps cell state to differentiation state `B`
150
+ 3. **Potential Network**: Learns the Waddington landscape potential `V(z)`
151
+ 4. **Neural ODE**: Evolves cell states as gradient descent on the potential
152
+ 5. **Decoder** (`ZDecoder`): Reconstructs gene expression from latent state
153
+
154
+ The model is trained using stochastic variational inference with a Zero-Inflated Negative Binomial likelihood for count data.
155
+
156
+ ## Configuration
157
+
158
+ ```python
159
+ from sclsd import LSDConfig
160
+
161
+ cfg = LSDConfig()
162
+
163
+ # Model architecture
164
+ cfg.model.z_dim = 10 # Latent cell state dimension
165
+ cfg.model.B_dim = 2 # Differentiation state dimension
166
+ cfg.model.V_coeff = 0.0 # Potential regularization
167
+
168
+ # Random walks
169
+ cfg.walks.path_len = 50 # Steps per walk
170
+ cfg.walks.num_walks = 10000 # Number of training walks
171
+ cfg.walks.batch_size = 256 # Batch size
172
+
173
+ # Optimizer
174
+ cfg.optimizer.adam.lr = 1e-3
175
+ cfg.optimizer.adam.T_0 = 50 # Cosine annealing period
176
+
177
+ # KL annealing
178
+ cfg.optimizer.kl_schedule.min_af = 0.0
179
+ cfg.optimizer.kl_schedule.max_af = 1.0
180
+ cfg.optimizer.kl_schedule.max_epoch = 50
181
+ ```
182
+
183
+ ## Prior Pseudotime
184
+
185
+ LSDpy requires a prior pseudotime or transition matrix to guide training:
186
+
187
+ ```python
188
+ # Option 1: Use existing pseudotime (e.g., from diffusion pseudotime)
189
+ lsd.set_prior_transition(prior_time_key="dpt_pseudotime")
190
+
191
+ # Option 2: Infer prior pseudotime automatically
192
+ from sclsd import infer_prior_time
193
+ adata = infer_prior_time(data_dict, device, origin_cluster="Stem")
194
+
195
+ # Option 3: Use phylogeny-guided transitions
196
+ lsd.set_phylogeny(
197
+ phylogeny={"Stem": ["Prog1", "Prog2"], "Prog1": ["Mature1"]},
198
+ cluster_key="clusters"
199
+ )
200
+ lsd.set_prior_transition(prior_time_key="pseudotime")
201
+ ```
202
+
203
+ ## Cell Fate Prediction
204
+
205
+ ```python
206
+ # Predict cell fates by propagating through the potential landscape
207
+ result = lsd.get_cell_fates(
208
+ adata=result,
209
+ time_range=10.0,
210
+ dt=0.5,
211
+ cluster_key="clusters",
212
+ return_paths=True
213
+ )
214
+
215
+ print(result.obs["fate"]) # Predicted terminal state for each cell
216
+ ```
217
+
218
+ ## Evaluation Metrics
219
+
220
+ ```python
221
+ from sclsd import cross_boundary_correctness, inner_cluster_coh
222
+
223
+ # Define expected transitions
224
+ edges = [("Stem", "Prog"), ("Prog", "Mature")]
225
+
226
+ # Cross-boundary correctness
227
+ scores, mean_score = cross_boundary_correctness(
228
+ adata, "clusters", "velocity", edges
229
+ )
230
+ print(f"Cross-boundary score: {mean_score:.3f}")
231
+
232
+ # In-cluster coherence
233
+ scores, mean_score = inner_cluster_coh(adata, "clusters", "velocity")
234
+ print(f"In-cluster coherence: {mean_score:.3f}")
235
+ ```
236
+
237
+ ## Visualization
238
+
239
+ ```python
240
+ from sclsd import plot_random_walks, plot_z_components, visualize_random_walks_on_umap
241
+
242
+ # Plot random walks on UMAP
243
+ plot_random_walks(result, walks[:10], rep="X_umap")
244
+
245
+ # Visualize ODE trajectories
246
+ plot_z_components(lsd.z_sol[:, :10, :], t_max=10.0)
247
+
248
+ # Visualize walks from specific clusters
249
+ visualize_random_walks_on_umap(
250
+ result, lsd.paths,
251
+ target_clusters=["Stem"],
252
+ cluster_key="clusters"
253
+ )
254
+ ```
255
+
256
+ ## Reproducibility
257
+
258
+ LSDpy ensures reproducible results through comprehensive RNG management:
259
+
260
+ ```python
261
+ from sclsd import set_all_seeds
262
+
263
+ # Set all random seeds before training
264
+ set_all_seeds(42)
265
+
266
+ # Train model - results will be identical across runs
267
+ lsd.train(num_epochs=100, random_state=42)
268
+ ```
269
+
270
+ **Important**: The order of `pyro.sample()` calls in the model determines the random number sequence. The implementation preserves the exact sampling order to ensure reproducibility.
271
+
272
+ ## API Reference
273
+
274
+ ### Main Classes
275
+
276
+ - `LSD`: Main trainer class for model training and inference
277
+ - `LSDConfig`: Configuration dataclass with nested configs for model, optimizer, and walks
278
+ - `LSDModel`: Neural network model implementing the Pyro generative model and guide
279
+
280
+ ### Preprocessing
281
+
282
+ - `prepare_data_dict()`: Prepare AnnData for training
283
+ - `infer_prior_time()`: Automatically infer prior pseudotime
284
+ - `get_prior_transition()`: Compute prior transition matrix
285
+
286
+ ### Analysis
287
+
288
+ - `cross_boundary_correctness()`: Evaluate velocity direction correctness
289
+ - `inner_cluster_coh()`: Evaluate velocity coherence within clusters
290
+ - `evaluate()`: Run all evaluation metrics
291
+
292
+ ### Plotting
293
+
294
+ - `plot_random_walks()`: Visualize random walks on embeddings
295
+ - `plot_z_components()`: Plot latent component trajectories
296
+ - `plot_streamlines()`: Visualize velocity streamlines
297
+ - `visualize_random_walks_on_umap()`: Enhanced walk visualization
298
+
299
+ ## Requirements
300
+
301
+ - Python >= 3.9
302
+ - PyTorch >= 2.0.0
303
+ - Pyro-PPL >= 1.8.0
304
+ - torchdiffeq >= 0.2.0
305
+ - scanpy >= 1.9.0
306
+ - anndata >= 0.9.0
307
+
308
+ ## Citation
309
+
310
+ If you use LSDpy in your research, please cite:
311
+
312
+ ```bibtex
313
+ @article{lsd2024,
314
+ title={Latent State Dynamics for Single-Cell Trajectory Inference},
315
+ author={LSD Development Team},
316
+ journal={},
317
+ year={2024}
318
+ }
319
+ ```
320
+
321
+ ## License
322
+
323
+ MIT License - see [LICENSE](LICENSE) for details.
324
+
325
+ ## Contributing
326
+
327
+ Contributions are welcome! Please see our contributing guidelines for details.
328
+
329
+ ## Support
330
+
331
+ - **Issues**: https://github.com/csglab/sclsd/issues