diff-biophys 0.1.2__tar.gz → 0.1.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_biophys-0.1.3/PKG-INFO +182 -0
- diff_biophys-0.1.3/README.md +141 -0
- diff_biophys-0.1.3/diff_biophys/__init__.py +8 -0
- diff_biophys-0.1.3/diff_biophys/cd/kernels.py +87 -0
- diff_biophys-0.1.3/diff_biophys/cryo_em.py +88 -0
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/ensemble.py +30 -13
- diff_biophys-0.1.3/diff_biophys/geometry/__init__.py +3 -0
- diff_biophys-0.1.3/diff_biophys/geometry/nerf.py +84 -0
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/geometry/superposition.py +9 -8
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/geometry/torsions.py +14 -9
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/nmr/__init__.py +1 -1
- diff_biophys-0.1.3/diff_biophys/nmr/chemical_shifts.py +85 -0
- diff_biophys-0.1.3/diff_biophys/nmr/constants.py +31 -0
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/nmr/karplus.py +4 -3
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/nmr/rdc.py +57 -37
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/nmr/ring_currents.py +14 -12
- diff_biophys-0.1.3/diff_biophys/saxs/kernels.py +82 -0
- diff_biophys-0.1.3/diff_biophys/torch_interop.py +54 -0
- diff_biophys-0.1.3/diff_biophys.egg-info/PKG-INFO +182 -0
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys.egg-info/SOURCES.txt +12 -1
- diff_biophys-0.1.3/diff_biophys.egg-info/requires.txt +22 -0
- diff_biophys-0.1.3/pyproject.toml +136 -0
- diff_biophys-0.1.3/tests/test_cd_parity.py +65 -0
- diff_biophys-0.1.3/tests/test_cryo_em.py +55 -0
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_ensemble.py +19 -22
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_geometry_parity.py +11 -13
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_geometry_reconstruction.py +15 -13
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_invariance.py +21 -27
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_kabsch_parity.py +26 -33
- diff_biophys-0.1.3/tests/test_nerf_backbone_geometry.py +131 -0
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_nmr_advanced.py +23 -18
- diff_biophys-0.1.3/tests/test_random_coil_wishart.py +58 -0
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_rdc_fitting.py +13 -12
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_rdc_parity.py +11 -8
- diff_biophys-0.1.3/tests/test_ring_current_distance_decay.py +71 -0
- diff_biophys-0.1.3/tests/test_saupe_tensor_properties.py +76 -0
- diff_biophys-0.1.3/tests/test_saxs_forward_limit.py +56 -0
- diff_biophys-0.1.3/tests/test_saxs_guinier.py +39 -0
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_saxs_parity.py +24 -28
- diff_biophys-0.1.3/tests/test_science_ca_shifts.py +122 -0
- diff_biophys-0.1.3/tests/test_science_karplus.py +64 -0
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_science_rdc.py +5 -2
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_science_ring_currents.py +11 -7
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_science_saxs_advanced.py +16 -15
- diff_biophys-0.1.3/tests/test_science_saxs_sphere.py +74 -0
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_synth_parity.py +39 -27
- diff_biophys-0.1.3/tests/test_torch_interop.py +40 -0
- diff_biophys-0.1.2/PKG-INFO +0 -116
- diff_biophys-0.1.2/README.md +0 -84
- diff_biophys-0.1.2/diff_biophys/__init__.py +0 -2
- diff_biophys-0.1.2/diff_biophys/cd/kernels.py +0 -21
- diff_biophys-0.1.2/diff_biophys/geometry/__init__.py +0 -3
- diff_biophys-0.1.2/diff_biophys/geometry/nerf.py +0 -51
- diff_biophys-0.1.2/diff_biophys/nmr/chemical_shifts.py +0 -44
- diff_biophys-0.1.2/diff_biophys/nmr/constants.py +0 -30
- diff_biophys-0.1.2/diff_biophys/saxs/kernels.py +0 -62
- diff_biophys-0.1.2/diff_biophys.egg-info/PKG-INFO +0 -116
- diff_biophys-0.1.2/diff_biophys.egg-info/requires.txt +0 -11
- diff_biophys-0.1.2/pyproject.toml +0 -47
- diff_biophys-0.1.2/tests/test_cd_parity.py +0 -15
- diff_biophys-0.1.2/tests/test_science_ca_shifts.py +0 -53
- diff_biophys-0.1.2/tests/test_science_karplus.py +0 -37
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/LICENSE +0 -0
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/cd/__init__.py +0 -0
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/saxs/__init__.py +0 -0
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys.egg-info/dependency_links.txt +0 -0
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys.egg-info/top_level.txt +0 -0
- {diff_biophys-0.1.2 → diff_biophys-0.1.3}/setup.cfg +0 -0
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: diff-biophys
|
|
3
|
+
Version: 0.1.3
|
|
4
|
+
Summary: Differentiable biophysical modeling in JAX
|
|
5
|
+
Author: George Elkins
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/elkins/diff-biophys
|
|
8
|
+
Project-URL: Repository, https://github.com/elkins/diff-biophys
|
|
9
|
+
Project-URL: Documentation, https://elkins.github.io/diff-biophys/
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Intended Audience :: Science/Research
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Physics
|
|
18
|
+
Requires-Python: >=3.10
|
|
19
|
+
Description-Content-Type: text/markdown
|
|
20
|
+
License-File: LICENSE
|
|
21
|
+
Requires-Dist: jax
|
|
22
|
+
Requires-Dist: jaxlib
|
|
23
|
+
Requires-Dist: numpy
|
|
24
|
+
Requires-Dist: biotite
|
|
25
|
+
Provides-Extra: dev
|
|
26
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
|
27
|
+
Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
|
|
28
|
+
Requires-Dist: ruff>=0.6.0; extra == "dev"
|
|
29
|
+
Requires-Dist: mypy>=1.8.0; extra == "dev"
|
|
30
|
+
Requires-Dist: pre-commit>=3.0.0; extra == "dev"
|
|
31
|
+
Requires-Dist: ipython; extra == "dev"
|
|
32
|
+
Requires-Dist: synth-pdb; extra == "dev"
|
|
33
|
+
Requires-Dist: synth-nmr; extra == "dev"
|
|
34
|
+
Requires-Dist: synth-core; extra == "dev"
|
|
35
|
+
Provides-Extra: torch
|
|
36
|
+
Requires-Dist: torch>=2.0.0; extra == "torch"
|
|
37
|
+
Provides-Extra: examples
|
|
38
|
+
Requires-Dist: optax>=0.1.0; extra == "examples"
|
|
39
|
+
Requires-Dist: matplotlib>=3.5.0; extra == "examples"
|
|
40
|
+
Dynamic: license-file
|
|
41
|
+
|
|
42
|
+
# 🧬 DiffBiophys: Differentiable Biophysics for the AI Era
|
|
43
|
+
|
|
44
|
+
[](https://github.com/elkins/diff-biophys/actions/workflows/test.yml)
|
|
45
|
+
[](https://pypi.org/project/diff-biophys/)
|
|
46
|
+
[](https://pypi.org/project/diff-biophys/)
|
|
47
|
+
[](https://opensource.org/licenses/MIT)
|
|
48
|
+
[](https://codecov.io/gh/elkins/diff-biophys)
|
|
49
|
+
[](https://github.com/google/jax)
|
|
50
|
+
[](https://github.com/astral-sh/ruff)
|
|
51
|
+
[](https://mypy-lang.org/)
|
|
52
|
+
|
|
53
|
+
**DiffBiophys** is a high-performance Python library for differentiable biophysical modeling. Built on **JAX**, it re-implements core structural biology and spectroscopy observables (SAXS, NMR, CD) as hardware-accelerated, auto-differentiable kernels.
|
|
54
|
+
|
|
55
|
+
**[Documentation Website](https://elkins.github.io/diff-biophys/)** | **[Use Cases](https://elkins.github.io/diff-biophys/use_cases/)**
|
|
56
|
+
|
|
57
|
+
---
|
|
58
|
+
|
|
59
|
+
## 🎯 Vision
|
|
60
|
+
|
|
61
|
+
To bridge the gap between static structural models and experimental solution-state data by providing a "differentiable bridge." This allows researchers to:
|
|
62
|
+
1. **Optimize** protein structures directly against experimental spectra via gradient descent.
|
|
63
|
+
2. **Train** machine learning models using physics-informed loss functions.
|
|
64
|
+
3. **Accelerate** large-scale biophysical simulations on GPUs and TPUs.
|
|
65
|
+
|
|
66
|
+
---
|
|
67
|
+
|
|
68
|
+
## 🏗️ Core Components
|
|
69
|
+
|
|
70
|
+
### 1. `diff_biophys.geometry` (Differentiable Structural Engine)
|
|
71
|
+
- **NeRF (Natural Extension Reference Frame):** Differentiable conversion from internal coordinates ($\phi, \psi, \omega$, bond lengths/angles) to Cartesian XYZ.
|
|
72
|
+
- **Kabsch Alignment:** Differentiable optimal superposition using SVD.
|
|
73
|
+
- **Torsion Analysis:** Vectorized calculation of all backbone and side-chain dihedrals.
|
|
74
|
+
|
|
75
|
+
### 2. `diff_biophys.saxs` (Differentiable Scattering)
|
|
76
|
+
- **Debye Formula:** $O(N^2)$ inter-atomic interference summation.
|
|
77
|
+
- **Hydration Shell Correction:** Excluded-volume solvent subtraction (Fraser et al. 1978).
|
|
78
|
+
- **Hardware Acceleration:** GPU-optimized pairwise distance kernels via JAX `vmap`.
|
|
79
|
+
- **Use Case:** Fitting structure compactness and radius of gyration to solution-state X-ray scattering curves.
|
|
80
|
+
|
|
81
|
+
### 3. `diff_biophys.nmr` (Differentiable Spectroscopy)
|
|
82
|
+
- **Residual Dipolar Couplings (RDCs):** Differentiable Saupe tensor alignment and coupling calculation. Includes SVD-based tensor fitting.
|
|
83
|
+
- **Chemical Shifts:** Differentiable ring-current (Johnson-Bovey) shielding and softmax-weighted secondary structure Cα shift predictor.
|
|
84
|
+
- **Karplus J-coupling:** Parameterizable 3J coupling equation (Vuister & Bax 1993 defaults).
|
|
85
|
+
- **Use Case:** Refining side-chain packing and domain orientations against high-resolution NMR data.
|
|
86
|
+
|
|
87
|
+
### 4. `diff_biophys.cd` (Differentiable Dichroism)
|
|
88
|
+
- **Matrix-Method Simulation:** Differentiable simulation of peptide bond transition dipole coupling via DeVoe theory.
|
|
89
|
+
- **Status:** ✅ Implemented. Supports frequency-dependent coupled-oscillator response.
|
|
90
|
+
|
|
91
|
+
---
|
|
92
|
+
|
|
93
|
+
## ⚡ Technical Architecture
|
|
94
|
+
|
|
95
|
+
- **Backend:** JAX (XLA-compiled) — supports CPU, GPU, and TPU.
|
|
96
|
+
- **Parallelism:** Native support for `vmap` (vectorization across ensembles/trajectories) and `pmap` (multi-device execution).
|
|
97
|
+
- **Differentiability:** Forward and reverse-mode autodiff through all kernels.
|
|
98
|
+
- **Interoperability:** JAX arrays are compatible with NumPy and can be exchanged with PyTorch via `dlpack` (user-managed conversion).
|
|
99
|
+
|
|
100
|
+
---
|
|
101
|
+
|
|
102
|
+
## 🚀 Roadmap
|
|
103
|
+
|
|
104
|
+
### Phase 1: Foundations (Alpha)
|
|
105
|
+
- [x] Differentiable NeRF and Kabsch alignment.
|
|
106
|
+
- [x] GPU-accelerated Debye formula for SAXS with hydration shell correction.
|
|
107
|
+
- [x] Unit tests verifying parity with `synth-pdb` NumPy implementations.
|
|
108
|
+
|
|
109
|
+
### Phase 2: NMR & Spectroscopy (Beta)
|
|
110
|
+
- [x] Differentiable RDC and Karplus kernels.
|
|
111
|
+
- [x] Differentiable Johnson-Bovey ring current model.
|
|
112
|
+
- [x] Integration with `synth-nmr` parameter libraries (optional dependency).
|
|
113
|
+
|
|
114
|
+
### Phase 3: Integration & Optimization (v1.0)
|
|
115
|
+
- [x] Full CD matrix-method implementation (DeVoe theory).
|
|
116
|
+
- [ ] Example notebooks for structure refinement via gradient descent.
|
|
117
|
+
- [ ] Plugin for `torch`-based AI models to use biophysical loss functions.
|
|
118
|
+
- [ ] Full support for BinaryCIF streaming.
|
|
119
|
+
|
|
120
|
+
---
|
|
121
|
+
|
|
122
|
+
## 📂 Repository Structure
|
|
123
|
+
|
|
124
|
+
```text
|
|
125
|
+
diff-biophys/
|
|
126
|
+
├── diff_biophys/ # Core package
|
|
127
|
+
│ ├── geometry/ # NeRF, Kabsch, Torsions
|
|
128
|
+
│ ├── saxs/ # Debye kernels, form factors
|
|
129
|
+
│ ├── nmr/ # RDCs, Karplus, Ring Currents, Chemical Shifts
|
|
130
|
+
│ ├── cd/ # CD simulation (DeVoe Matrix Method)
|
|
131
|
+
│ └── ensemble.py # Ensemble averaging API
|
|
132
|
+
├── tests/ # Parity, gradient, and scientific validation checks
|
|
133
|
+
├── examples/ # Jupyter notebooks (Refinement Lab)
|
|
134
|
+
├── docs/ # API and Theory
|
|
135
|
+
├── pyproject.toml # Modern build config
|
|
136
|
+
└── README.md
|
|
137
|
+
```
|
|
138
|
+
|
|
139
|
+
## 🚀 Installation
|
|
140
|
+
|
|
141
|
+
```bash
|
|
142
|
+
pip install diff-biophys
|
|
143
|
+
```
|
|
144
|
+
|
|
145
|
+
For GPU support (CUDA):
|
|
146
|
+
```bash
|
|
147
|
+
pip install "jax[cuda12]" diff-biophys
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
## 🤝 Contributing
|
|
151
|
+
|
|
152
|
+
Contributions are welcome from both ML and structural biology communities! Please open an issue or pull request on [GitHub](https://github.com/elkins/diff-biophys). Run `pre-commit run --all-files` before submitting.
|
|
153
|
+
|
|
154
|
+
## 🔗 Related Projects
|
|
155
|
+
|
|
156
|
+
diff-biophys is the **differentiable engine** powering the higher-level tools in this ecosystem:
|
|
157
|
+
|
|
158
|
+
- [synth-pdb](https://github.com/elkins/synth-pdb) — Synthetic structure generation (uses NumPy implementations)
|
|
159
|
+
- [synth-nmr](https://github.com/elkins/synth-nmr) — NMR observables (optional dependency)
|
|
160
|
+
- [synth-saxs](https://github.com/elkins/synth-saxs) — SAXS profile simulator
|
|
161
|
+
- [diff-fret](https://github.com/elkins/diff-fret) — Differentiable FRET (new)
|
|
162
|
+
- [diff-hdx](https://github.com/elkins/diff-hdx) — Differentiable HDX-MS (new)
|
|
163
|
+
- [diff-epr](https://github.com/elkins/diff-epr) — Differentiable EPR/DEER (new)
|
|
164
|
+
- [diff-ensemble](https://github.com/elkins/diff-ensemble) — IDP ensemble VAE (depends on diff-biophys)
|
|
165
|
+
- [TorsionTuner](https://github.com/elkins/TorsionTuner) — GNN refinement (depends on diff-biophys)
|
|
166
|
+
- [resonance-flow](https://github.com/elkins/resonance-flow) — NMR-guided folding (depends on diff-biophys)
|
|
167
|
+
|
|
168
|
+
## ⚖️ License
|
|
169
|
+
|
|
170
|
+
MIT License — see [LICENSE](LICENSE) for details.
|
|
171
|
+
|
|
172
|
+
## 📖 Citation
|
|
173
|
+
|
|
174
|
+
```bibtex
|
|
175
|
+
@software{diff_biophys,
|
|
176
|
+
author = {Elkins, George},
|
|
177
|
+
title = {diff-biophys: Differentiable biophysics kernels for JAX},
|
|
178
|
+
year = {2024},
|
|
179
|
+
url = {https://github.com/elkins/diff-biophys},
|
|
180
|
+
version = {0.1.2}
|
|
181
|
+
}
|
|
182
|
+
```
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# 🧬 DiffBiophys: Differentiable Biophysics for the AI Era
|
|
2
|
+
|
|
3
|
+
[](https://github.com/elkins/diff-biophys/actions/workflows/test.yml)
|
|
4
|
+
[](https://pypi.org/project/diff-biophys/)
|
|
5
|
+
[](https://pypi.org/project/diff-biophys/)
|
|
6
|
+
[](https://opensource.org/licenses/MIT)
|
|
7
|
+
[](https://codecov.io/gh/elkins/diff-biophys)
|
|
8
|
+
[](https://github.com/google/jax)
|
|
9
|
+
[](https://github.com/astral-sh/ruff)
|
|
10
|
+
[](https://mypy-lang.org/)
|
|
11
|
+
|
|
12
|
+
**DiffBiophys** is a high-performance Python library for differentiable biophysical modeling. Built on **JAX**, it re-implements core structural biology and spectroscopy observables (SAXS, NMR, CD) as hardware-accelerated, auto-differentiable kernels.
|
|
13
|
+
|
|
14
|
+
**[Documentation Website](https://elkins.github.io/diff-biophys/)** | **[Use Cases](https://elkins.github.io/diff-biophys/use_cases/)**
|
|
15
|
+
|
|
16
|
+
---
|
|
17
|
+
|
|
18
|
+
## 🎯 Vision
|
|
19
|
+
|
|
20
|
+
To bridge the gap between static structural models and experimental solution-state data by providing a "differentiable bridge." This allows researchers to:
|
|
21
|
+
1. **Optimize** protein structures directly against experimental spectra via gradient descent.
|
|
22
|
+
2. **Train** machine learning models using physics-informed loss functions.
|
|
23
|
+
3. **Accelerate** large-scale biophysical simulations on GPUs and TPUs.
|
|
24
|
+
|
|
25
|
+
---
|
|
26
|
+
|
|
27
|
+
## 🏗️ Core Components
|
|
28
|
+
|
|
29
|
+
### 1. `diff_biophys.geometry` (Differentiable Structural Engine)
|
|
30
|
+
- **NeRF (Natural Extension Reference Frame):** Differentiable conversion from internal coordinates ($\phi, \psi, \omega$, bond lengths/angles) to Cartesian XYZ.
|
|
31
|
+
- **Kabsch Alignment:** Differentiable optimal superposition using SVD.
|
|
32
|
+
- **Torsion Analysis:** Vectorized calculation of all backbone and side-chain dihedrals.
|
|
33
|
+
|
|
34
|
+
### 2. `diff_biophys.saxs` (Differentiable Scattering)
|
|
35
|
+
- **Debye Formula:** $O(N^2)$ inter-atomic interference summation.
|
|
36
|
+
- **Hydration Shell Correction:** Excluded-volume solvent subtraction (Fraser et al. 1978).
|
|
37
|
+
- **Hardware Acceleration:** GPU-optimized pairwise distance kernels via JAX `vmap`.
|
|
38
|
+
- **Use Case:** Fitting structure compactness and radius of gyration to solution-state X-ray scattering curves.
|
|
39
|
+
|
|
40
|
+
### 3. `diff_biophys.nmr` (Differentiable Spectroscopy)
|
|
41
|
+
- **Residual Dipolar Couplings (RDCs):** Differentiable Saupe tensor alignment and coupling calculation. Includes SVD-based tensor fitting.
|
|
42
|
+
- **Chemical Shifts:** Differentiable ring-current (Johnson-Bovey) shielding and softmax-weighted secondary structure Cα shift predictor.
|
|
43
|
+
- **Karplus J-coupling:** Parameterizable 3J coupling equation (Vuister & Bax 1993 defaults).
|
|
44
|
+
- **Use Case:** Refining side-chain packing and domain orientations against high-resolution NMR data.
|
|
45
|
+
|
|
46
|
+
### 4. `diff_biophys.cd` (Differentiable Dichroism)
|
|
47
|
+
- **Matrix-Method Simulation:** Differentiable simulation of peptide bond transition dipole coupling via DeVoe theory.
|
|
48
|
+
- **Status:** ✅ Implemented. Supports frequency-dependent coupled-oscillator response.
|
|
49
|
+
|
|
50
|
+
---
|
|
51
|
+
|
|
52
|
+
## ⚡ Technical Architecture
|
|
53
|
+
|
|
54
|
+
- **Backend:** JAX (XLA-compiled) — supports CPU, GPU, and TPU.
|
|
55
|
+
- **Parallelism:** Native support for `vmap` (vectorization across ensembles/trajectories) and `pmap` (multi-device execution).
|
|
56
|
+
- **Differentiability:** Forward and reverse-mode autodiff through all kernels.
|
|
57
|
+
- **Interoperability:** JAX arrays are compatible with NumPy and can be exchanged with PyTorch via `dlpack` (user-managed conversion).
|
|
58
|
+
|
|
59
|
+
---
|
|
60
|
+
|
|
61
|
+
## 🚀 Roadmap
|
|
62
|
+
|
|
63
|
+
### Phase 1: Foundations (Alpha)
|
|
64
|
+
- [x] Differentiable NeRF and Kabsch alignment.
|
|
65
|
+
- [x] GPU-accelerated Debye formula for SAXS with hydration shell correction.
|
|
66
|
+
- [x] Unit tests verifying parity with `synth-pdb` NumPy implementations.
|
|
67
|
+
|
|
68
|
+
### Phase 2: NMR & Spectroscopy (Beta)
|
|
69
|
+
- [x] Differentiable RDC and Karplus kernels.
|
|
70
|
+
- [x] Differentiable Johnson-Bovey ring current model.
|
|
71
|
+
- [x] Integration with `synth-nmr` parameter libraries (optional dependency).
|
|
72
|
+
|
|
73
|
+
### Phase 3: Integration & Optimization (v1.0)
|
|
74
|
+
- [x] Full CD matrix-method implementation (DeVoe theory).
|
|
75
|
+
- [ ] Example notebooks for structure refinement via gradient descent.
|
|
76
|
+
- [ ] Plugin for `torch`-based AI models to use biophysical loss functions.
|
|
77
|
+
- [ ] Full support for BinaryCIF streaming.
|
|
78
|
+
|
|
79
|
+
---
|
|
80
|
+
|
|
81
|
+
## 📂 Repository Structure
|
|
82
|
+
|
|
83
|
+
```text
|
|
84
|
+
diff-biophys/
|
|
85
|
+
├── diff_biophys/ # Core package
|
|
86
|
+
│ ├── geometry/ # NeRF, Kabsch, Torsions
|
|
87
|
+
│ ├── saxs/ # Debye kernels, form factors
|
|
88
|
+
│ ├── nmr/ # RDCs, Karplus, Ring Currents, Chemical Shifts
|
|
89
|
+
│ ├── cd/ # CD simulation (DeVoe Matrix Method)
|
|
90
|
+
│ └── ensemble.py # Ensemble averaging API
|
|
91
|
+
├── tests/ # Parity, gradient, and scientific validation checks
|
|
92
|
+
├── examples/ # Jupyter notebooks (Refinement Lab)
|
|
93
|
+
├── docs/ # API and Theory
|
|
94
|
+
├── pyproject.toml # Modern build config
|
|
95
|
+
└── README.md
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
## 🚀 Installation
|
|
99
|
+
|
|
100
|
+
```bash
|
|
101
|
+
pip install diff-biophys
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
For GPU support (CUDA):
|
|
105
|
+
```bash
|
|
106
|
+
pip install "jax[cuda12]" diff-biophys
|
|
107
|
+
```
|
|
108
|
+
|
|
109
|
+
## 🤝 Contributing
|
|
110
|
+
|
|
111
|
+
Contributions are welcome from both ML and structural biology communities! Please open an issue or pull request on [GitHub](https://github.com/elkins/diff-biophys). Run `pre-commit run --all-files` before submitting.
|
|
112
|
+
|
|
113
|
+
## 🔗 Related Projects
|
|
114
|
+
|
|
115
|
+
diff-biophys is the **differentiable engine** powering the higher-level tools in this ecosystem:
|
|
116
|
+
|
|
117
|
+
- [synth-pdb](https://github.com/elkins/synth-pdb) — Synthetic structure generation (uses NumPy implementations)
|
|
118
|
+
- [synth-nmr](https://github.com/elkins/synth-nmr) — NMR observables (optional dependency)
|
|
119
|
+
- [synth-saxs](https://github.com/elkins/synth-saxs) — SAXS profile simulator
|
|
120
|
+
- [diff-fret](https://github.com/elkins/diff-fret) — Differentiable FRET (new)
|
|
121
|
+
- [diff-hdx](https://github.com/elkins/diff-hdx) — Differentiable HDX-MS (new)
|
|
122
|
+
- [diff-epr](https://github.com/elkins/diff-epr) — Differentiable EPR/DEER (new)
|
|
123
|
+
- [diff-ensemble](https://github.com/elkins/diff-ensemble) — IDP ensemble VAE (depends on diff-biophys)
|
|
124
|
+
- [TorsionTuner](https://github.com/elkins/TorsionTuner) — GNN refinement (depends on diff-biophys)
|
|
125
|
+
- [resonance-flow](https://github.com/elkins/resonance-flow) — NMR-guided folding (depends on diff-biophys)
|
|
126
|
+
|
|
127
|
+
## ⚖️ License
|
|
128
|
+
|
|
129
|
+
MIT License — see [LICENSE](LICENSE) for details.
|
|
130
|
+
|
|
131
|
+
## 📖 Citation
|
|
132
|
+
|
|
133
|
+
```bibtex
|
|
134
|
+
@software{diff_biophys,
|
|
135
|
+
author = {Elkins, George},
|
|
136
|
+
title = {diff-biophys: Differentiable biophysics kernels for JAX},
|
|
137
|
+
year = {2024},
|
|
138
|
+
url = {https://github.com/elkins/diff-biophys},
|
|
139
|
+
version = {0.1.2}
|
|
140
|
+
}
|
|
141
|
+
```
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def simulate_cd_matrix(
|
|
6
|
+
peptide_positions: jnp.ndarray,
|
|
7
|
+
dipole_orientations: jnp.ndarray,
|
|
8
|
+
wavelengths: jnp.ndarray,
|
|
9
|
+
f_osc: float = 0.2,
|
|
10
|
+
gamma: float = 10.0,
|
|
11
|
+
lambda_0: float = 190.0,
|
|
12
|
+
) -> jnp.ndarray:
|
|
13
|
+
"""
|
|
14
|
+
Matrix-Method CD Simulation (DeVoe Theory).
|
|
15
|
+
|
|
16
|
+
Implements the coupled-oscillator model for transition dipole coupling.
|
|
17
|
+
Calculates the interaction matrix and solves for the complex polarizability
|
|
18
|
+
response to determine molar ellipticity.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
peptide_positions: (N, 3) positions of amide chromophores in Angstroms.
|
|
22
|
+
dipole_orientations: (N, 3) unit vectors for transition dipoles.
|
|
23
|
+
wavelengths: (M,) wavelengths in nm to simulate.
|
|
24
|
+
f_osc: Oscillator strength of the transition (default 0.2 for pi->pi*).
|
|
25
|
+
gamma: Linewidth parameter in nm (default 10.0).
|
|
26
|
+
lambda_0: Resonance wavelength in nm (default 190.0).
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Molar ellipticity [θ] in deg cm^2 / dmol (M,).
|
|
30
|
+
"""
|
|
31
|
+
n_chromophores = peptide_positions.shape[0]
|
|
32
|
+
|
|
33
|
+
# 1. Compute dipole-dipole interaction matrix V_ij
|
|
34
|
+
# V_ij = (1/r^3) * [ mu_i . mu_j - 3(mu_i . r_ij)(mu_j . r_ij) ]
|
|
35
|
+
diff = peptide_positions[:, None, :] - peptide_positions[None, :, :]
|
|
36
|
+
dist_sq = jnp.sum(diff**2, axis=-1)
|
|
37
|
+
|
|
38
|
+
# Safe distance for gradients (avoid sqrt(0) and 1/0)
|
|
39
|
+
# 1e-9 is a safe epsilon for float32
|
|
40
|
+
mask = dist_sq > 0
|
|
41
|
+
safe_dist_sq = jnp.where(mask, dist_sq, 1.0)
|
|
42
|
+
r_ij = jnp.sqrt(safe_dist_sq)
|
|
43
|
+
r_ij_inv3 = jnp.where(mask, 1.0 / r_ij**3, 0.0)
|
|
44
|
+
|
|
45
|
+
# Unit vectors between chromophores
|
|
46
|
+
r_hat = diff * jnp.where(mask[:, :, None], 1.0 / r_ij[:, :, None], 0.0)
|
|
47
|
+
|
|
48
|
+
# Dot products
|
|
49
|
+
mu_i_mu_j = jnp.sum(dipole_orientations[:, None, :] * dipole_orientations[None, :, :], axis=-1)
|
|
50
|
+
mu_i_r = jnp.sum(dipole_orientations[:, None, :] * r_hat, axis=-1)
|
|
51
|
+
mu_j_r = jnp.sum(dipole_orientations[None, :, :] * r_hat, axis=-1)
|
|
52
|
+
|
|
53
|
+
# Interaction energy V (N, N)
|
|
54
|
+
V = r_ij_inv3 * (mu_i_mu_j - 3 * mu_i_r * mu_j_r)
|
|
55
|
+
|
|
56
|
+
# 2. Frequency-dependent response
|
|
57
|
+
def compute_at_wavelength(lmbda: jnp.ndarray) -> jnp.ndarray:
|
|
58
|
+
# Complex polarizability alpha(lambda)
|
|
59
|
+
# Lorentzian-like response
|
|
60
|
+
denom = (1.0 / lmbda**2 - 1.0 / lambda_0**2) + 1j * (gamma / (lmbda * lambda_0**2))
|
|
61
|
+
alpha = f_osc / denom
|
|
62
|
+
|
|
63
|
+
# Interaction matrix (I - alpha * V)
|
|
64
|
+
# alpha is scalar for all identical chromophores here
|
|
65
|
+
M = jnp.eye(n_chromophores) - alpha * V
|
|
66
|
+
|
|
67
|
+
# We'll use the matrix inverse to find the coupled response
|
|
68
|
+
# Note: jnp.linalg.inv is differentiable but can be sensitive
|
|
69
|
+
inv_M = jnp.linalg.inv(M)
|
|
70
|
+
|
|
71
|
+
# Geometric factor for CD (Scalar triple product mu_i x mu_j . r_ij)
|
|
72
|
+
# This represents the chiral arrangement.
|
|
73
|
+
cross_mu = jnp.cross(dipole_orientations[:, None, :], dipole_orientations[None, :, :])
|
|
74
|
+
R_ij = jnp.sum(cross_mu * diff, axis=-1)
|
|
75
|
+
|
|
76
|
+
# Total CD response at this wavelength
|
|
77
|
+
coupled_V = inv_M @ (alpha * V)
|
|
78
|
+
cd_val = jnp.imag(jnp.sum(coupled_V * R_ij))
|
|
79
|
+
|
|
80
|
+
return cd_val
|
|
81
|
+
|
|
82
|
+
# Vectorize over wavelengths
|
|
83
|
+
cd_spectrum = jax.vmap(compute_at_wavelength)(wavelengths)
|
|
84
|
+
|
|
85
|
+
# Scale to molar ellipticity (arbitrary units for this kernel,
|
|
86
|
+
# should be calibrated to exp data)
|
|
87
|
+
return cd_spectrum * 1e5
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@jax.jit
|
|
6
|
+
def compute_fsc(
|
|
7
|
+
data1: jax.Array, data2: jax.Array, voxel_size: tuple[float, float, float]
|
|
8
|
+
) -> tuple[jax.Array, jax.Array]:
|
|
9
|
+
"""
|
|
10
|
+
Compute the fully differentiable Fourier Shell Correlation (FSC) between two 3D maps using JAX.
|
|
11
|
+
Returns frequencies and correlation values.
|
|
12
|
+
|
|
13
|
+
This function matches the implementation in synth-core, but uses jax.numpy
|
|
14
|
+
so that gradients can flow through the FSC calculation to the input maps.
|
|
15
|
+
"""
|
|
16
|
+
# Fourier transforms in JAX
|
|
17
|
+
f1 = jnp.fft.rfftn(data1)
|
|
18
|
+
f2 = jnp.fft.rfftn(data2)
|
|
19
|
+
|
|
20
|
+
# Cross-spectral density and power spectra are computed using float arithmetic
|
|
21
|
+
# We avoid complex multiplication to parallel the numpy memory stability fix
|
|
22
|
+
cross = f1.real * f2.real + f1.imag * f2.imag
|
|
23
|
+
p1 = f1.real**2 + f1.imag**2
|
|
24
|
+
p2 = f2.real**2 + f2.imag**2
|
|
25
|
+
|
|
26
|
+
# Calculate radial bins
|
|
27
|
+
nz, ny, nx = data1.shape
|
|
28
|
+
kz = jnp.fft.fftfreq(nz, d=voxel_size[0])
|
|
29
|
+
ky = jnp.fft.fftfreq(ny, d=voxel_size[1])
|
|
30
|
+
kx = jnp.fft.rfftfreq(nx, d=voxel_size[2])
|
|
31
|
+
|
|
32
|
+
# Create 3D grid of frequencies
|
|
33
|
+
kz_grid, ky_grid, kx_grid = jnp.meshgrid(kz, ky, kx, indexing="ij")
|
|
34
|
+
|
|
35
|
+
# Calculate magnitude of frequency vector for each voxel
|
|
36
|
+
k = jnp.sqrt(kz_grid**2 + ky_grid**2 + kx_grid**2)
|
|
37
|
+
|
|
38
|
+
# Flatten everything
|
|
39
|
+
k = k.ravel()
|
|
40
|
+
cross = cross.ravel()
|
|
41
|
+
p1 = p1.ravel()
|
|
42
|
+
p2 = p2.ravel()
|
|
43
|
+
|
|
44
|
+
# Sort by frequency
|
|
45
|
+
idx = jnp.argsort(k)
|
|
46
|
+
k_sorted = k[idx]
|
|
47
|
+
cross_sorted = cross[idx]
|
|
48
|
+
p1_sorted = p1[idx]
|
|
49
|
+
p2_sorted = p2[idx]
|
|
50
|
+
|
|
51
|
+
n_bins = min(nx, ny, nz) // 2
|
|
52
|
+
k_max = k_sorted[-1]
|
|
53
|
+
k_eps = k_max / (10 * n_bins)
|
|
54
|
+
bins = jnp.linspace(k_eps, k_max, n_bins + 1)
|
|
55
|
+
|
|
56
|
+
# We use vmap to compute the bin sums to keep the function differentiable and JIT-compatible
|
|
57
|
+
# We avoid python loops with dynamic shapes.
|
|
58
|
+
|
|
59
|
+
def compute_bin(i: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
|
|
60
|
+
bin_start = bins[i]
|
|
61
|
+
bin_end = bins[i + 1]
|
|
62
|
+
mask = (k_sorted >= bin_start) & (k_sorted < bin_end)
|
|
63
|
+
|
|
64
|
+
sum_cross = jnp.sum(jnp.where(mask, cross_sorted, 0.0))
|
|
65
|
+
sum_p1 = jnp.sum(jnp.where(mask, p1_sorted, 0.0))
|
|
66
|
+
sum_p2 = jnp.sum(jnp.where(mask, p2_sorted, 0.0))
|
|
67
|
+
|
|
68
|
+
num = sum_cross
|
|
69
|
+
den = jnp.sqrt(sum_p1 * sum_p2)
|
|
70
|
+
|
|
71
|
+
# Avoid division by zero
|
|
72
|
+
val = jnp.where(den > 0, num / den, 0.0)
|
|
73
|
+
# Clamp to [-1, 1]
|
|
74
|
+
val = jnp.clip(val, -1.0, 1.0)
|
|
75
|
+
freq = (bin_start + bin_end) / 2.0
|
|
76
|
+
|
|
77
|
+
# We need to return valid mask too, because some bins might be empty
|
|
78
|
+
is_valid = jnp.any(mask)
|
|
79
|
+
return freq, val, is_valid
|
|
80
|
+
|
|
81
|
+
indices = jnp.arange(n_bins)
|
|
82
|
+
freqs, vals, is_valid = jax.vmap(compute_bin)(indices)
|
|
83
|
+
|
|
84
|
+
# Note: jnp.where with dynamic sizes breaks JIT if we don't pad.
|
|
85
|
+
# For a fully differentiable metric, we typically pad with 0s or NaNs, or return the full array.
|
|
86
|
+
# We will return the full array but mask out invalid frequencies with NaN or 0.
|
|
87
|
+
|
|
88
|
+
return freqs, vals
|
|
@@ -1,45 +1,62 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any, cast
|
|
3
|
+
|
|
1
4
|
import jax.numpy as jnp
|
|
2
|
-
from jax import
|
|
3
|
-
|
|
5
|
+
from jax import jit, vmap
|
|
6
|
+
|
|
4
7
|
|
|
5
8
|
class Ensemble:
|
|
6
9
|
"""
|
|
7
10
|
High-level API for ensemble-averaged biophysical observables.
|
|
8
11
|
"""
|
|
9
|
-
|
|
12
|
+
|
|
13
|
+
coords: jnp.ndarray
|
|
14
|
+
weights: jnp.ndarray
|
|
15
|
+
m: int
|
|
16
|
+
|
|
17
|
+
def __init__(self, coordinates: jnp.ndarray, weights: jnp.ndarray | None = None):
|
|
10
18
|
"""
|
|
11
19
|
Args:
|
|
12
20
|
coordinates: (M, N, 3) array where M is ensemble size and N is atom count.
|
|
13
21
|
weights: (M,) array of population weights. Defaults to uniform.
|
|
14
22
|
"""
|
|
15
23
|
self.coords = coordinates
|
|
16
|
-
self.m = coordinates.shape[0]
|
|
24
|
+
self.m = int(coordinates.shape[0])
|
|
17
25
|
if weights is None:
|
|
18
26
|
self.weights = jnp.full((self.m,), 1.0 / self.m)
|
|
19
27
|
else:
|
|
20
|
-
self.weights = weights / jnp.sum(weights)
|
|
28
|
+
self.weights = cast(jnp.ndarray, weights / jnp.sum(weights))
|
|
21
29
|
|
|
22
|
-
def calculate_average(
|
|
30
|
+
def calculate_average(
|
|
31
|
+
self,
|
|
32
|
+
observable_fn: Callable[..., jnp.ndarray],
|
|
33
|
+
*args: Any,
|
|
34
|
+
**kwargs: Any,
|
|
35
|
+
) -> jnp.ndarray:
|
|
23
36
|
"""
|
|
24
37
|
Calculate the population-weighted average of an observable.
|
|
25
|
-
|
|
38
|
+
|
|
26
39
|
Args:
|
|
27
40
|
observable_fn: Function that takes (N, 3) coords and returns (D,) observable.
|
|
28
41
|
*args, **kwargs: Additional arguments for the observable_fn.
|
|
29
|
-
|
|
42
|
+
|
|
30
43
|
Returns:
|
|
31
44
|
jnp.ndarray: (D,) averaged observable.
|
|
32
45
|
"""
|
|
33
46
|
# Vectorize the observable function over the ensemble dimension
|
|
34
47
|
v_fn = vmap(lambda c: observable_fn(c, *args, **kwargs))
|
|
35
|
-
ensemble_results = v_fn(self.coords)
|
|
36
|
-
|
|
48
|
+
ensemble_results = v_fn(self.coords) # (M, D)
|
|
49
|
+
|
|
37
50
|
# Weighted average
|
|
38
|
-
return jnp.sum(ensemble_results * self.weights[:, None], axis=0)
|
|
51
|
+
return cast(jnp.ndarray, jnp.sum(ensemble_results * self.weights[:, None], axis=0))
|
|
52
|
+
|
|
39
53
|
|
|
40
54
|
@jit
|
|
41
|
-
def calculate_ensemble_saxs(
|
|
55
|
+
def calculate_ensemble_saxs(
|
|
56
|
+
coords: jnp.ndarray, weights: jnp.ndarray, q_values: jnp.ndarray, form_factors: jnp.ndarray
|
|
57
|
+
) -> jnp.ndarray:
|
|
42
58
|
"""Utility for fast ensemble SAXS."""
|
|
43
59
|
from diff_biophys.saxs import debye_saxs
|
|
60
|
+
|
|
44
61
|
v_saxs = vmap(lambda c: debye_saxs(c, q_values, form_factors))
|
|
45
|
-
return jnp.sum(v_saxs(coords) * weights[:, None], axis=0)
|
|
62
|
+
return cast(jnp.ndarray, jnp.sum(v_saxs(coords) * weights[:, None], axis=0))
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from typing import Any, cast
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jax import jit, lax
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@jit
|
|
8
|
+
def position_atom_3d(
|
|
9
|
+
p1: jnp.ndarray,
|
|
10
|
+
p2: jnp.ndarray,
|
|
11
|
+
p3: jnp.ndarray,
|
|
12
|
+
bond_length: jnp.ndarray,
|
|
13
|
+
bond_angle_rad: jnp.ndarray,
|
|
14
|
+
dihedral_angle_rad: jnp.ndarray,
|
|
15
|
+
) -> jnp.ndarray:
|
|
16
|
+
"""
|
|
17
|
+
Differentiable NeRF implementation in JAX for a single atom.
|
|
18
|
+
|
|
19
|
+
Places atom p4 given three reference atoms (p1, p2, p3) and the internal
|
|
20
|
+
coordinates (bond length, bond angle, dihedral angle) that define its
|
|
21
|
+
position relative to p3.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
p1, p2, p3: (3,) reference atom coordinates.
|
|
25
|
+
bond_length: Scalar distance p3→p4 in Ångströms.
|
|
26
|
+
bond_angle_rad: Scalar bond angle ∠(p2, p3, p4) in radians.
|
|
27
|
+
dihedral_angle_rad: Scalar dihedral angle ∠(p1, p2, p3, p4) in radians.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
jnp.ndarray: (3,) Cartesian coordinates of the new atom p4.
|
|
31
|
+
"""
|
|
32
|
+
v1 = p1 - p2
|
|
33
|
+
v2 = p3 - p2
|
|
34
|
+
|
|
35
|
+
u2 = v2 / (jnp.linalg.norm(v2) + 1e-10)
|
|
36
|
+
|
|
37
|
+
n = jnp.cross(v1, u2)
|
|
38
|
+
n /= jnp.linalg.norm(n) + 1e-10
|
|
39
|
+
|
|
40
|
+
m = jnp.cross(n, u2)
|
|
41
|
+
|
|
42
|
+
p4 = p3 + bond_length * (
|
|
43
|
+
-jnp.cos(bond_angle_rad) * u2
|
|
44
|
+
- jnp.sin(bond_angle_rad) * jnp.cos(dihedral_angle_rad) * m
|
|
45
|
+
- jnp.sin(bond_angle_rad) * jnp.sin(dihedral_angle_rad) * n
|
|
46
|
+
)
|
|
47
|
+
return cast(jnp.ndarray, p4)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@jit
|
|
51
|
+
def chain_nerf(
|
|
52
|
+
init_coords: jnp.ndarray,
|
|
53
|
+
bond_lengths: jnp.ndarray,
|
|
54
|
+
bond_angles: jnp.ndarray,
|
|
55
|
+
dihedrals: jnp.ndarray,
|
|
56
|
+
) -> jnp.ndarray:
|
|
57
|
+
"""
|
|
58
|
+
Build a chain of atoms using the NeRF algorithm.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
init_coords: (3, 3) initial coordinates for the first 3 atoms
|
|
62
|
+
bond_lengths: (N,) bond lengths for atoms 4 to N+3
|
|
63
|
+
bond_angles: (N,) bond angles (in radians) for atoms 4 to N+3
|
|
64
|
+
dihedrals: (N,) dihedral angles (in radians) for atoms 4 to N+3
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
jnp.ndarray: (N+3, 3) coordinates for the entire chain
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def body_fun(
|
|
71
|
+
carry: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], i: Any
|
|
72
|
+
) -> tuple[tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]:
|
|
73
|
+
p1, p2, p3 = carry
|
|
74
|
+
|
|
75
|
+
p4 = position_atom_3d(p1, p2, p3, bond_lengths[i], bond_angles[i], dihedrals[i])
|
|
76
|
+
return (p2, p3, p4), p4
|
|
77
|
+
|
|
78
|
+
# Use .shape[0] instead of len() so this works correctly under vmap
|
|
79
|
+
# and with dynamically-shaped arrays during JAX tracing.
|
|
80
|
+
indices = jnp.arange(bond_lengths.shape[0])
|
|
81
|
+
init_carry = (init_coords[0], init_coords[1], init_coords[2])
|
|
82
|
+
_, final_coords = lax.scan(body_fun, init_carry, indices)
|
|
83
|
+
|
|
84
|
+
return cast(jnp.ndarray, jnp.concatenate([init_coords, final_coords], axis=0))
|