diff-biophys 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diff_biophys-0.1.2/LICENSE +21 -0
- diff_biophys-0.1.2/PKG-INFO +116 -0
- diff_biophys-0.1.2/README.md +84 -0
- diff_biophys-0.1.2/diff_biophys/__init__.py +2 -0
- diff_biophys-0.1.2/diff_biophys/cd/__init__.py +1 -0
- diff_biophys-0.1.2/diff_biophys/cd/kernels.py +21 -0
- diff_biophys-0.1.2/diff_biophys/ensemble.py +45 -0
- diff_biophys-0.1.2/diff_biophys/geometry/__init__.py +3 -0
- diff_biophys-0.1.2/diff_biophys/geometry/nerf.py +51 -0
- diff_biophys-0.1.2/diff_biophys/geometry/superposition.py +32 -0
- diff_biophys-0.1.2/diff_biophys/geometry/torsions.py +52 -0
- diff_biophys-0.1.2/diff_biophys/nmr/__init__.py +3 -0
- diff_biophys-0.1.2/diff_biophys/nmr/chemical_shifts.py +44 -0
- diff_biophys-0.1.2/diff_biophys/nmr/constants.py +30 -0
- diff_biophys-0.1.2/diff_biophys/nmr/karplus.py +19 -0
- diff_biophys-0.1.2/diff_biophys/nmr/rdc.py +100 -0
- diff_biophys-0.1.2/diff_biophys/nmr/ring_currents.py +37 -0
- diff_biophys-0.1.2/diff_biophys/saxs/__init__.py +1 -0
- diff_biophys-0.1.2/diff_biophys/saxs/kernels.py +62 -0
- diff_biophys-0.1.2/diff_biophys.egg-info/PKG-INFO +116 -0
- diff_biophys-0.1.2/diff_biophys.egg-info/SOURCES.txt +40 -0
- diff_biophys-0.1.2/diff_biophys.egg-info/dependency_links.txt +1 -0
- diff_biophys-0.1.2/diff_biophys.egg-info/requires.txt +11 -0
- diff_biophys-0.1.2/diff_biophys.egg-info/top_level.txt +1 -0
- diff_biophys-0.1.2/pyproject.toml +47 -0
- diff_biophys-0.1.2/setup.cfg +4 -0
- diff_biophys-0.1.2/tests/test_cd_parity.py +15 -0
- diff_biophys-0.1.2/tests/test_ensemble.py +74 -0
- diff_biophys-0.1.2/tests/test_geometry_parity.py +35 -0
- diff_biophys-0.1.2/tests/test_geometry_reconstruction.py +31 -0
- diff_biophys-0.1.2/tests/test_invariance.py +69 -0
- diff_biophys-0.1.2/tests/test_kabsch_parity.py +72 -0
- diff_biophys-0.1.2/tests/test_nmr_advanced.py +77 -0
- diff_biophys-0.1.2/tests/test_rdc_fitting.py +32 -0
- diff_biophys-0.1.2/tests/test_rdc_parity.py +34 -0
- diff_biophys-0.1.2/tests/test_saxs_parity.py +59 -0
- diff_biophys-0.1.2/tests/test_science_ca_shifts.py +53 -0
- diff_biophys-0.1.2/tests/test_science_karplus.py +37 -0
- diff_biophys-0.1.2/tests/test_science_rdc.py +31 -0
- diff_biophys-0.1.2/tests/test_science_ring_currents.py +48 -0
- diff_biophys-0.1.2/tests/test_science_saxs_advanced.py +63 -0
- diff_biophys-0.1.2/tests/test_synth_parity.py +127 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 George Elkins
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: diff-biophys
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Differentiable biophysical modeling in JAX
|
|
5
|
+
Author: George Elkins
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/elkins/diff-biophys
|
|
8
|
+
Project-URL: Repository, https://github.com/elkins/diff-biophys
|
|
9
|
+
Project-URL: Documentation, https://elkins.github.io/diff-biophys/
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Intended Audience :: Science/Research
|
|
12
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Physics
|
|
19
|
+
Description-Content-Type: text/markdown
|
|
20
|
+
License-File: LICENSE
|
|
21
|
+
Requires-Dist: jax
|
|
22
|
+
Requires-Dist: jaxlib
|
|
23
|
+
Requires-Dist: numpy
|
|
24
|
+
Requires-Dist: biotite
|
|
25
|
+
Provides-Extra: dev
|
|
26
|
+
Requires-Dist: pytest; extra == "dev"
|
|
27
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
28
|
+
Requires-Dist: ipython; extra == "dev"
|
|
29
|
+
Requires-Dist: synth-pdb; extra == "dev"
|
|
30
|
+
Requires-Dist: synth-nmr; extra == "dev"
|
|
31
|
+
Dynamic: license-file
|
|
32
|
+
|
|
33
|
+
# 🧬 DiffBiophys: Differentiable Biophysics for the AI Era
|
|
34
|
+
|
|
35
|
+
**DiffBiophys** is a high-performance Python library for differentiable biophysical modeling. Built on **JAX**, it re-implements core structural biology and spectroscopy observables (SAXS, NMR, CD) as hardware-accelerated, auto-differentiable kernels.
|
|
36
|
+
|
|
37
|
+
**[Documentation Website](https://elkins.github.io/diff-biophys/)** | **[Use Cases](https://elkins.github.io/diff-biophys/use_cases/)**
|
|
38
|
+
|
|
39
|
+
---
|
|
40
|
+
|
|
41
|
+
## 🎯 Vision
|
|
42
|
+
|
|
43
|
+
To bridge the gap between static structural models and experimental solution-state data by providing a "differentiable bridge." This allows researchers to:
|
|
44
|
+
1. **Optimize** protein structures directly against experimental spectra via gradient descent.
|
|
45
|
+
2. **Train** machine learning models using physics-informed loss functions.
|
|
46
|
+
3. **Accelerate** large-scale biophysical simulations on GPUs and TPUs.
|
|
47
|
+
|
|
48
|
+
---
|
|
49
|
+
|
|
50
|
+
## 🏗️ Core Components
|
|
51
|
+
|
|
52
|
+
### 1. `diff_biophys.geometry` (Differentiable Structural Engine)
|
|
53
|
+
- **NeRF (Natural Extension Reference Frame):** Differentiable conversion from internal coordinates ($\phi, \psi, \omega$, bond lengths/angles) to Cartesian XYZ.
|
|
54
|
+
- **Kabsch Alignment:** Differentiable optimal superposition using SVD.
|
|
55
|
+
- **Torsion Analysis:** Vectorized calculation of all backbone and side-chain dihedrals.
|
|
56
|
+
|
|
57
|
+
### 2. `diff_biophys.saxs` (Differentiable Scattering)
|
|
58
|
+
- **Debye Formula:** $O(N^2)$ inter-atomic interference summation.
|
|
59
|
+
- **Hardware Acceleration:** GPU-optimized pairwise distance kernels.
|
|
60
|
+
- **Use Case:** Fitting structure "compactness" and "radius of gyration" to solution-state X-ray scattering curves.
|
|
61
|
+
|
|
62
|
+
### 3. `diff_biophys.nmr` (Differentiable Spectroscopy)
|
|
63
|
+
- **Residual Dipolar Couplings (RDCs):** Differentiable Saupe tensor alignment and coupling calculation.
|
|
64
|
+
- **Chemical Shifts:** Differentiable Ring-Current (Johnson-Bovey) shielding and Karplus J-coupling kernels.
|
|
65
|
+
- **Use Case:** Refining side-chain packing and domain orientations against high-resolution NMR data.
|
|
66
|
+
|
|
67
|
+
### 4. `diff_biophys.cd` (Differentiable Dichroism)
|
|
68
|
+
- **Matrix-Method Simulation:** Differentiable simulation of peptide bond transition dipole coupling.
|
|
69
|
+
- **Use Case:** Predicting secondary structure content and verifying fold stability.
|
|
70
|
+
|
|
71
|
+
---
|
|
72
|
+
|
|
73
|
+
## ⚡ Technical Architecture
|
|
74
|
+
|
|
75
|
+
- **Backend:** JAX (XLA-compiled).
|
|
76
|
+
- **Parallelism:** Native support for `vmap` (vectorization across ensembles/trajectories) and `pmap` (multi-device execution).
|
|
77
|
+
- **Differentiability:** Support for both Forward and Reverse-mode autodiff.
|
|
78
|
+
- **Interoperability:** Seamless integration with PyTorch/TensorFlow (via DLPack) and standard structural formats (mmCIF/BCIF).
|
|
79
|
+
|
|
80
|
+
---
|
|
81
|
+
|
|
82
|
+
## 🚀 Roadmap
|
|
83
|
+
|
|
84
|
+
### Phase 1: Foundations (Alpha)
|
|
85
|
+
- [x] Differentiable NeRF and Kabsch alignment.
|
|
86
|
+
- [x] GPU-accelerated Debye formula for SAXS.
|
|
87
|
+
- [x] Unit tests verifying parity with `synth-pdb` NumPy implementations.
|
|
88
|
+
|
|
89
|
+
### Phase 2: NMR & Spectroscopy (Beta)
|
|
90
|
+
- [x] Differentiable RDC and Karplus kernels.
|
|
91
|
+
- [x] Differentiable Johnson-Bovey ring current model.
|
|
92
|
+
- [ ] Integration with `synth-nmr` parameter libraries.
|
|
93
|
+
|
|
94
|
+
### Phase 3: Integration & Optimization (v1.0)
|
|
95
|
+
- [ ] Example notebooks for structure refinement via gradient descent.
|
|
96
|
+
- [ ] Plugin for `torch`-based AI models to use biophysical loss functions.
|
|
97
|
+
- [ ] Full support for BinaryCIF streaming.
|
|
98
|
+
|
|
99
|
+
---
|
|
100
|
+
|
|
101
|
+
## 📂 Repository Structure (Proposed)
|
|
102
|
+
|
|
103
|
+
```text
|
|
104
|
+
diff-biophys/
|
|
105
|
+
├── diff_biophys/ # Core package
|
|
106
|
+
│ ├── geometry/ # NeRF, Kabsch, Torsions
|
|
107
|
+
│ ├── saxs/ # Debye kernels, form factors
|
|
108
|
+
│ ├── nmr/ # RDCs, Karplus, Ring Currents
|
|
109
|
+
│ ├── cd/ # CD simulation
|
|
110
|
+
│ └── utils/ # Constants, JAX-NumPy shims
|
|
111
|
+
├── tests/ # Parity and gradient checks
|
|
112
|
+
├── examples/ # Jupyter notebooks (Refinement Lab)
|
|
113
|
+
├── docs/ # API and Theory
|
|
114
|
+
├── pyproject.toml # Modern build config
|
|
115
|
+
└── README.md
|
|
116
|
+
```
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# 🧬 DiffBiophys: Differentiable Biophysics for the AI Era
|
|
2
|
+
|
|
3
|
+
**DiffBiophys** is a high-performance Python library for differentiable biophysical modeling. Built on **JAX**, it re-implements core structural biology and spectroscopy observables (SAXS, NMR, CD) as hardware-accelerated, auto-differentiable kernels.
|
|
4
|
+
|
|
5
|
+
**[Documentation Website](https://elkins.github.io/diff-biophys/)** | **[Use Cases](https://elkins.github.io/diff-biophys/use_cases/)**
|
|
6
|
+
|
|
7
|
+
---
|
|
8
|
+
|
|
9
|
+
## 🎯 Vision
|
|
10
|
+
|
|
11
|
+
To bridge the gap between static structural models and experimental solution-state data by providing a "differentiable bridge." This allows researchers to:
|
|
12
|
+
1. **Optimize** protein structures directly against experimental spectra via gradient descent.
|
|
13
|
+
2. **Train** machine learning models using physics-informed loss functions.
|
|
14
|
+
3. **Accelerate** large-scale biophysical simulations on GPUs and TPUs.
|
|
15
|
+
|
|
16
|
+
---
|
|
17
|
+
|
|
18
|
+
## 🏗️ Core Components
|
|
19
|
+
|
|
20
|
+
### 1. `diff_biophys.geometry` (Differentiable Structural Engine)
|
|
21
|
+
- **NeRF (Natural Extension Reference Frame):** Differentiable conversion from internal coordinates ($\phi, \psi, \omega$, bond lengths/angles) to Cartesian XYZ.
|
|
22
|
+
- **Kabsch Alignment:** Differentiable optimal superposition using SVD.
|
|
23
|
+
- **Torsion Analysis:** Vectorized calculation of all backbone and side-chain dihedrals.
|
|
24
|
+
|
|
25
|
+
### 2. `diff_biophys.saxs` (Differentiable Scattering)
|
|
26
|
+
- **Debye Formula:** $O(N^2)$ inter-atomic interference summation.
|
|
27
|
+
- **Hardware Acceleration:** GPU-optimized pairwise distance kernels.
|
|
28
|
+
- **Use Case:** Fitting structure "compactness" and "radius of gyration" to solution-state X-ray scattering curves.
|
|
29
|
+
|
|
30
|
+
### 3. `diff_biophys.nmr` (Differentiable Spectroscopy)
|
|
31
|
+
- **Residual Dipolar Couplings (RDCs):** Differentiable Saupe tensor alignment and coupling calculation.
|
|
32
|
+
- **Chemical Shifts:** Differentiable Ring-Current (Johnson-Bovey) shielding and Karplus J-coupling kernels.
|
|
33
|
+
- **Use Case:** Refining side-chain packing and domain orientations against high-resolution NMR data.
|
|
34
|
+
|
|
35
|
+
### 4. `diff_biophys.cd` (Differentiable Dichroism)
|
|
36
|
+
- **Matrix-Method Simulation:** Differentiable simulation of peptide bond transition dipole coupling.
|
|
37
|
+
- **Use Case:** Predicting secondary structure content and verifying fold stability.
|
|
38
|
+
|
|
39
|
+
---
|
|
40
|
+
|
|
41
|
+
## ⚡ Technical Architecture
|
|
42
|
+
|
|
43
|
+
- **Backend:** JAX (XLA-compiled).
|
|
44
|
+
- **Parallelism:** Native support for `vmap` (vectorization across ensembles/trajectories) and `pmap` (multi-device execution).
|
|
45
|
+
- **Differentiability:** Support for both Forward and Reverse-mode autodiff.
|
|
46
|
+
- **Interoperability:** Seamless integration with PyTorch/TensorFlow (via DLPack) and standard structural formats (mmCIF/BCIF).
|
|
47
|
+
|
|
48
|
+
---
|
|
49
|
+
|
|
50
|
+
## 🚀 Roadmap
|
|
51
|
+
|
|
52
|
+
### Phase 1: Foundations (Alpha)
|
|
53
|
+
- [x] Differentiable NeRF and Kabsch alignment.
|
|
54
|
+
- [x] GPU-accelerated Debye formula for SAXS.
|
|
55
|
+
- [x] Unit tests verifying parity with `synth-pdb` NumPy implementations.
|
|
56
|
+
|
|
57
|
+
### Phase 2: NMR & Spectroscopy (Beta)
|
|
58
|
+
- [x] Differentiable RDC and Karplus kernels.
|
|
59
|
+
- [x] Differentiable Johnson-Bovey ring current model.
|
|
60
|
+
- [ ] Integration with `synth-nmr` parameter libraries.
|
|
61
|
+
|
|
62
|
+
### Phase 3: Integration & Optimization (v1.0)
|
|
63
|
+
- [ ] Example notebooks for structure refinement via gradient descent.
|
|
64
|
+
- [ ] Plugin for `torch`-based AI models to use biophysical loss functions.
|
|
65
|
+
- [ ] Full support for BinaryCIF streaming.
|
|
66
|
+
|
|
67
|
+
---
|
|
68
|
+
|
|
69
|
+
## 📂 Repository Structure (Proposed)
|
|
70
|
+
|
|
71
|
+
```text
|
|
72
|
+
diff-biophys/
|
|
73
|
+
├── diff_biophys/ # Core package
|
|
74
|
+
│ ├── geometry/ # NeRF, Kabsch, Torsions
|
|
75
|
+
│ ├── saxs/ # Debye kernels, form factors
|
|
76
|
+
│ ├── nmr/ # RDCs, Karplus, Ring Currents
|
|
77
|
+
│ ├── cd/ # CD simulation
|
|
78
|
+
│ └── utils/ # Constants, JAX-NumPy shims
|
|
79
|
+
├── tests/ # Parity and gradient checks
|
|
80
|
+
├── examples/ # Jupyter notebooks (Refinement Lab)
|
|
81
|
+
├── docs/ # API and Theory
|
|
82
|
+
├── pyproject.toml # Modern build config
|
|
83
|
+
└── README.md
|
|
84
|
+
```
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .kernels import simulate_cd_matrix
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jax import jit
|
|
3
|
+
|
|
4
|
+
def simulate_cd_matrix(peptide_positions, dipole_orientations, wavelengths):
|
|
5
|
+
"""
|
|
6
|
+
Placeholder for Matrix-Method CD Simulation (DeVoe Theory).
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
peptide_positions: (N, 3) positions of amide chromophores
|
|
10
|
+
dipole_orientations: (N, 3) unit vectors for transition dipoles
|
|
11
|
+
wavelengths: (M,) wavelengths in nm
|
|
12
|
+
"""
|
|
13
|
+
# 1. Interaction Matrix (N, N)
|
|
14
|
+
# V_ij = dipole-dipole coupling energy
|
|
15
|
+
|
|
16
|
+
# 2. Polarizability Tensor
|
|
17
|
+
|
|
18
|
+
# 3. Solve for complex ellipticity
|
|
19
|
+
|
|
20
|
+
# Simple mockup for now (weighted basis-like)
|
|
21
|
+
return jnp.zeros_like(wavelengths)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jax import vmap, jit
|
|
3
|
+
from typing import Callable, Any
|
|
4
|
+
|
|
5
|
+
class Ensemble:
|
|
6
|
+
"""
|
|
7
|
+
High-level API for ensemble-averaged biophysical observables.
|
|
8
|
+
"""
|
|
9
|
+
def __init__(self, coordinates: jnp.ndarray, weights: jnp.ndarray = None):
|
|
10
|
+
"""
|
|
11
|
+
Args:
|
|
12
|
+
coordinates: (M, N, 3) array where M is ensemble size and N is atom count.
|
|
13
|
+
weights: (M,) array of population weights. Defaults to uniform.
|
|
14
|
+
"""
|
|
15
|
+
self.coords = coordinates
|
|
16
|
+
self.m = coordinates.shape[0]
|
|
17
|
+
if weights is None:
|
|
18
|
+
self.weights = jnp.full((self.m,), 1.0 / self.m)
|
|
19
|
+
else:
|
|
20
|
+
self.weights = weights / jnp.sum(weights)
|
|
21
|
+
|
|
22
|
+
def calculate_average(self, observable_fn: Callable[[jnp.ndarray], jnp.ndarray], *args, **kwargs) -> jnp.ndarray:
|
|
23
|
+
"""
|
|
24
|
+
Calculate the population-weighted average of an observable.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
observable_fn: Function that takes (N, 3) coords and returns (D,) observable.
|
|
28
|
+
*args, **kwargs: Additional arguments for the observable_fn.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
jnp.ndarray: (D,) averaged observable.
|
|
32
|
+
"""
|
|
33
|
+
# Vectorize the observable function over the ensemble dimension
|
|
34
|
+
v_fn = vmap(lambda c: observable_fn(c, *args, **kwargs))
|
|
35
|
+
ensemble_results = v_fn(self.coords) # (M, D)
|
|
36
|
+
|
|
37
|
+
# Weighted average
|
|
38
|
+
return jnp.sum(ensemble_results * self.weights[:, None], axis=0)
|
|
39
|
+
|
|
40
|
+
@jit
|
|
41
|
+
def calculate_ensemble_saxs(coords: jnp.ndarray, weights: jnp.ndarray, q_values: jnp.ndarray, form_factors: jnp.ndarray):
|
|
42
|
+
"""Utility for fast ensemble SAXS."""
|
|
43
|
+
from diff_biophys.saxs import debye_saxs
|
|
44
|
+
v_saxs = vmap(lambda c: debye_saxs(c, q_values, form_factors))
|
|
45
|
+
return jnp.sum(v_saxs(coords) * weights[:, None], axis=0)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jax import jit, lax
|
|
3
|
+
|
|
4
|
+
@jit
|
|
5
|
+
def position_atom_3d(p1: jnp.ndarray, p2: jnp.ndarray, p3: jnp.ndarray,
|
|
6
|
+
bond_length: jnp.ndarray, bond_angle_rad: jnp.ndarray, dihedral_angle_rad: jnp.ndarray) -> jnp.ndarray:
|
|
7
|
+
"""
|
|
8
|
+
Differentiable NeRF implementation in JAX for a single atom.
|
|
9
|
+
"""
|
|
10
|
+
v1 = p1 - p2
|
|
11
|
+
v2 = p3 - p2
|
|
12
|
+
|
|
13
|
+
u2 = v2 / (jnp.linalg.norm(v2) + 1e-10)
|
|
14
|
+
|
|
15
|
+
n = jnp.cross(v1, u2)
|
|
16
|
+
n /= (jnp.linalg.norm(n) + 1e-10)
|
|
17
|
+
|
|
18
|
+
m = jnp.cross(n, u2)
|
|
19
|
+
|
|
20
|
+
p4 = p3 + bond_length * (
|
|
21
|
+
-jnp.cos(bond_angle_rad) * u2
|
|
22
|
+
- jnp.sin(bond_angle_rad) * jnp.cos(dihedral_angle_rad) * m
|
|
23
|
+
- jnp.sin(bond_angle_rad) * jnp.sin(dihedral_angle_rad) * n
|
|
24
|
+
)
|
|
25
|
+
return p4
|
|
26
|
+
|
|
27
|
+
@jit
|
|
28
|
+
def chain_nerf(init_coords: jnp.ndarray, bond_lengths: jnp.ndarray,
|
|
29
|
+
bond_angles: jnp.ndarray, dihedrals: jnp.ndarray) -> jnp.ndarray:
|
|
30
|
+
"""
|
|
31
|
+
Build a chain of atoms using the NeRF algorithm.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
init_coords: (3, 3) initial coordinates for the first 3 atoms
|
|
35
|
+
bond_lengths: (N,) bond lengths for atoms 4 to N+3
|
|
36
|
+
bond_angles: (N,) bond angles (in radians) for atoms 4 to N+3
|
|
37
|
+
dihedrals: (N,) dihedral angles (in radians) for atoms 4 to N+3
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
jnp.ndarray: (N+3, 3) coordinates for the entire chain
|
|
41
|
+
"""
|
|
42
|
+
def body_fun(carry, i):
|
|
43
|
+
p1, p2, p3 = carry
|
|
44
|
+
p4 = position_atom_3d(p1, p2, p3, bond_lengths[i], bond_angles[i], dihedrals[i])
|
|
45
|
+
return (p2, p3, p4), p4
|
|
46
|
+
|
|
47
|
+
indices = jnp.arange(len(bond_lengths))
|
|
48
|
+
init_carry = (init_coords[0], init_coords[1], init_coords[2])
|
|
49
|
+
_, final_coords = lax.scan(body_fun, init_carry, indices)
|
|
50
|
+
|
|
51
|
+
return jnp.concatenate([init_coords, final_coords], axis=0)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jax import jit
|
|
3
|
+
|
|
4
|
+
@jit
|
|
5
|
+
def kabsch_alignment(P: jnp.ndarray, Q: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
6
|
+
"""
|
|
7
|
+
Optimal superposition of P onto Q using Kabsch algorithm in JAX.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
P: (N, 3) mobile coordinates
|
|
11
|
+
Q: (N, 3) reference coordinates
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
tuple[jnp.ndarray, jnp.ndarray]: (3x3 rotation matrix, 3-element translation vector)
|
|
15
|
+
"""
|
|
16
|
+
p_center = jnp.mean(P, axis=0)
|
|
17
|
+
q_center = jnp.mean(Q, axis=0)
|
|
18
|
+
|
|
19
|
+
P_c = P - p_center
|
|
20
|
+
Q_c = Q - q_center
|
|
21
|
+
|
|
22
|
+
H = jnp.dot(P_c.T, Q_c)
|
|
23
|
+
|
|
24
|
+
U, S, Vt = jnp.linalg.svd(H)
|
|
25
|
+
|
|
26
|
+
d = jnp.linalg.det(jnp.dot(Vt.T, U.T))
|
|
27
|
+
step = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, jnp.where(d > 0, 1.0, -1.0)]])
|
|
28
|
+
|
|
29
|
+
R = jnp.dot(Vt.T, jnp.dot(step, U.T))
|
|
30
|
+
t = q_center - jnp.dot(R, p_center)
|
|
31
|
+
|
|
32
|
+
return R, t
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jax import jit
|
|
3
|
+
|
|
4
|
+
@jit
|
|
5
|
+
def compute_bond_lengths(coords: jnp.ndarray) -> jnp.ndarray:
|
|
6
|
+
"""
|
|
7
|
+
Compute bond lengths between adjacent atoms.
|
|
8
|
+
"""
|
|
9
|
+
vectors = coords[1:] - coords[:-1]
|
|
10
|
+
return jnp.linalg.norm(vectors, axis=-1)
|
|
11
|
+
|
|
12
|
+
@jit
|
|
13
|
+
def compute_bond_angles(coords: jnp.ndarray) -> jnp.ndarray:
|
|
14
|
+
"""
|
|
15
|
+
Compute bond angles (in radians) between three adjacent atoms.
|
|
16
|
+
"""
|
|
17
|
+
v1 = coords[:-2] - coords[1:-1]
|
|
18
|
+
v2 = coords[2:] - coords[1:-1]
|
|
19
|
+
|
|
20
|
+
v1_norm = v1 / (jnp.linalg.norm(v1, axis=-1, keepdims=True) + 1e-10)
|
|
21
|
+
v2_norm = v2 / (jnp.linalg.norm(v2, axis=-1, keepdims=True) + 1e-10)
|
|
22
|
+
|
|
23
|
+
cos_angle = jnp.sum(v1_norm * v2_norm, axis=-1)
|
|
24
|
+
return jnp.acos(jnp.clip(cos_angle, -1.0 + 1e-7, 1.0 - 1e-7))
|
|
25
|
+
|
|
26
|
+
@jit
|
|
27
|
+
def compute_dihedrals(coords: jnp.ndarray) -> jnp.ndarray:
|
|
28
|
+
"""
|
|
29
|
+
Compute dihedral angles (in radians) for four adjacent atoms.
|
|
30
|
+
Follows the IUPAC convention and matches synth-pdb.
|
|
31
|
+
Uses the robust Praxeolitic formula.
|
|
32
|
+
"""
|
|
33
|
+
# Vectors: p1-p2, p3-p2, p4-p3
|
|
34
|
+
b0 = coords[:-3] - coords[1:-2]
|
|
35
|
+
b1 = coords[2:-1] - coords[1:-2]
|
|
36
|
+
b2 = coords[3:] - coords[2:-1]
|
|
37
|
+
|
|
38
|
+
# Normalize b1
|
|
39
|
+
b1_norm = jnp.linalg.norm(b1, axis=-1, keepdims=True)
|
|
40
|
+
u1 = b1 / (b1_norm + 1e-10)
|
|
41
|
+
|
|
42
|
+
# v = orthogonal component of b0 with respect to b1
|
|
43
|
+
v = b0 - jnp.sum(b0 * u1, axis=-1, keepdims=True) * u1
|
|
44
|
+
# w = orthogonal component of b2 with respect to b1
|
|
45
|
+
w = b2 - jnp.sum(b2 * u1, axis=-1, keepdims=True) * u1
|
|
46
|
+
|
|
47
|
+
# x = dot product of v and w
|
|
48
|
+
x = jnp.sum(v * w, axis=-1)
|
|
49
|
+
# y = dot product of cross(u1, v) and w
|
|
50
|
+
y = jnp.sum(jnp.cross(u1, v) * w, axis=-1)
|
|
51
|
+
|
|
52
|
+
return jnp.atan2(y, x)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jax import jit, vmap
|
|
3
|
+
|
|
4
|
+
# Baseline Random Coil Shifts (Wishart et al. 1995)
|
|
5
|
+
# Using a simplified set for CA (Alpha Carbon) for demonstration
|
|
6
|
+
RANDOM_COIL_CA = {
|
|
7
|
+
"ALA": 52.5, "ARG": 56.0, "ASN": 53.1, "ASP": 54.2, "CYS": 58.2,
|
|
8
|
+
"GLN": 55.7, "GLU": 56.6, "GLY": 45.1, "HIS": 55.0, "ILE": 61.1,
|
|
9
|
+
"LEU": 55.1, "LYS": 56.2, "MET": 55.3, "PHE": 57.7, "PRO": 63.3,
|
|
10
|
+
"SER": 58.3, "THR": 61.8, "TRP": 57.5, "TYR": 57.9, "VAL": 62.2
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
# Statistical Secondary Structure Offsets for CA
|
|
14
|
+
# Alpha Helix: ~ +3.1 ppm, Beta Sheet: ~ -1.5 ppm
|
|
15
|
+
OFFSET_HELIX = 3.1
|
|
16
|
+
OFFSET_SHEET = -1.5
|
|
17
|
+
|
|
18
|
+
@jit
|
|
19
|
+
def predict_ca_shifts(phi: jnp.ndarray, psi: jnp.ndarray, rc_shifts: jnp.ndarray) -> jnp.ndarray:
|
|
20
|
+
"""
|
|
21
|
+
Differentiable CA Chemical Shift prediction based on Backbone Torsions.
|
|
22
|
+
|
|
23
|
+
This uses a "soft" classification of secondary structure based on Phi/Psi
|
|
24
|
+
to apply SPARTA-like offsets.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
phi, psi: (N,) backbone dihedrals in radians.
|
|
28
|
+
rc_shifts: (N,) baseline random coil shifts.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
jnp.ndarray: (N,) predicted CA shifts.
|
|
32
|
+
"""
|
|
33
|
+
# 1. Soft-classify secondary structure
|
|
34
|
+
# Alpha Helix: Phi ~ -60 deg (-1.05 rad), Psi ~ -45 deg (-0.78 rad)
|
|
35
|
+
helix_dist_sq = (phi + 1.05)**2 + (psi + 0.78)**2
|
|
36
|
+
is_helix = jnp.exp(-helix_dist_sq / 0.5) # Soft mask
|
|
37
|
+
|
|
38
|
+
# Beta Sheet: Phi ~ -120 deg (-2.09 rad), Psi ~ 135 deg (2.35 rad)
|
|
39
|
+
sheet_dist_sq = (phi + 2.09)**2 + (psi - 2.35)**2
|
|
40
|
+
is_sheet = jnp.exp(-sheet_dist_sq / 0.5) # Soft mask
|
|
41
|
+
|
|
42
|
+
# 2. Combine offsets
|
|
43
|
+
# Shift = RC + (is_helix * OFFSET_HELIX) + (is_sheet * OFFSET_SHEET)
|
|
44
|
+
return rc_shifts + (is_helix * OFFSET_HELIX) + (is_sheet * OFFSET_SHEET)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
|
|
3
|
+
# Default NMR parameters (Vuister & Bax 1993)
|
|
4
|
+
KARPLUS_A = 6.51
|
|
5
|
+
KARPLUS_B = -1.76
|
|
6
|
+
KARPLUS_C = 1.60
|
|
7
|
+
|
|
8
|
+
# Default Ring Current Intensities (Consistent with synth-nmr/SHIFTX2)
|
|
9
|
+
RING_INTENSITIES = {
|
|
10
|
+
'PHE': 1.2,
|
|
11
|
+
'TYR': 1.2,
|
|
12
|
+
'TRP': 1.3,
|
|
13
|
+
'HIS': 0.5,
|
|
14
|
+
'HID': 0.5,
|
|
15
|
+
'HIE': 0.5,
|
|
16
|
+
'HIP': 0.5
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
# Try to pull from synth-nmr if installed
|
|
20
|
+
try:
|
|
21
|
+
import synth_nmr.j_coupling as sj
|
|
22
|
+
import synth_nmr.chemical_shifts as sc
|
|
23
|
+
|
|
24
|
+
KARPLUS_A = sj.KARPLUS_PARAMS.get('A', KARPLUS_A)
|
|
25
|
+
KARPLUS_B = sj.KARPLUS_PARAMS.get('B', KARPLUS_B)
|
|
26
|
+
KARPLUS_C = sj.KARPLUS_PARAMS.get('C', KARPLUS_C)
|
|
27
|
+
|
|
28
|
+
RING_INTENSITIES.update(sc.RING_INTENSITIES)
|
|
29
|
+
except ImportError:
|
|
30
|
+
pass
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jax import jit
|
|
3
|
+
|
|
4
|
+
@jit
|
|
5
|
+
def calculate_karplus_j(theta: jnp.ndarray, a: float, b: float, c: float) -> jnp.ndarray:
|
|
6
|
+
"""
|
|
7
|
+
Calculate 3J coupling constants using the Karplus equation.
|
|
8
|
+
|
|
9
|
+
J = a * cos^2(theta) + b * cos(theta) + c
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
theta: (N,) Dihedral angles in radians.
|
|
13
|
+
a, b, c: Empirical Karplus parameters.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
jnp.ndarray: (N,) Calculated J-couplings.
|
|
17
|
+
"""
|
|
18
|
+
cos_theta = jnp.cos(theta)
|
|
19
|
+
return a * (cos_theta**2) + b * cos_theta + c
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jax import jit
|
|
3
|
+
|
|
4
|
+
@jit
|
|
5
|
+
def calculate_rdc_from_tensor(bond_vectors: jnp.ndarray, saupe_tensor: jnp.ndarray, d_max: float = 1.0) -> jnp.ndarray:
|
|
6
|
+
"""
|
|
7
|
+
Calculate RDCs from a full 3x3 Saupe alignment tensor.
|
|
8
|
+
D = d_max * sum_ij (v_i * S_ij * v_j)
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
bond_vectors: (N, 3) unit vectors
|
|
12
|
+
saupe_tensor: (3, 3) symmetric traceless Saupe tensor
|
|
13
|
+
d_max: Maximum dipolar coupling constant (Hz)
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
jnp.ndarray: Calculated RDCs (N,)
|
|
17
|
+
"""
|
|
18
|
+
# Vectorized computation of v^T S v
|
|
19
|
+
return d_max * jnp.einsum('ni,ij,nj->n', bond_vectors, saupe_tensor, bond_vectors)
|
|
20
|
+
|
|
21
|
+
@jit
|
|
22
|
+
def fit_saupe_tensor(bond_vectors: jnp.ndarray, experimental_rdcs: jnp.ndarray, d_max: float = 1.0) -> jnp.ndarray:
|
|
23
|
+
"""
|
|
24
|
+
Fit a Saupe alignment tensor to experimental RDCs using SVD (least squares).
|
|
25
|
+
|
|
26
|
+
The RDC formula can be rewritten as D = A * s
|
|
27
|
+
where s = [Sxx, Syy, Sxy, Sxz, Syz] (5 independent components)
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
bond_vectors: (N, 3) unit vectors
|
|
31
|
+
experimental_rdcs: (N,) measured RDCs in Hz
|
|
32
|
+
d_max: Maximum dipolar coupling constant (Hz)
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
jnp.ndarray: (3, 3) Fitted Saupe tensor
|
|
36
|
+
"""
|
|
37
|
+
x = bond_vectors[:, 0]
|
|
38
|
+
y = bond_vectors[:, 1]
|
|
39
|
+
z = bond_vectors[:, 2]
|
|
40
|
+
|
|
41
|
+
# Basis functions for the 5 independent components
|
|
42
|
+
# Using the identity Szz = -Sxx - Syy
|
|
43
|
+
# D = d_max * [ Sxx*x^2 + Syy*y^2 + Szz*z^2 + 2Sxy*xy + 2Sxz*xz + 2Syz*yz ]
|
|
44
|
+
# D = d_max * [ Sxx(x^2 - z^2) + Syy(y^2 - z^2) + 2Sxy*xy + 2Sxz*xz + 2Syz*yz ]
|
|
45
|
+
|
|
46
|
+
A = d_max * jnp.stack([
|
|
47
|
+
x**2 - z**2,
|
|
48
|
+
y**2 - z**2,
|
|
49
|
+
2 * x * y,
|
|
50
|
+
2 * x * z,
|
|
51
|
+
2 * y * z
|
|
52
|
+
], axis=1)
|
|
53
|
+
|
|
54
|
+
# Solve A * s = experimental_rdcs
|
|
55
|
+
s, _, _, _ = jnp.linalg.lstsq(A, experimental_rdcs)
|
|
56
|
+
|
|
57
|
+
sxx, syy, sxy, sxz, syz = s
|
|
58
|
+
szz = -(sxx + syy)
|
|
59
|
+
|
|
60
|
+
tensor = jnp.array([
|
|
61
|
+
[sxx, sxy, sxz],
|
|
62
|
+
[sxy, syy, syz],
|
|
63
|
+
[sxz, syz, szz]
|
|
64
|
+
])
|
|
65
|
+
|
|
66
|
+
return tensor
|
|
67
|
+
|
|
68
|
+
@jit
|
|
69
|
+
def calculate_q_factor(calculated_rdcs: jnp.ndarray, experimental_rdcs: jnp.ndarray) -> jnp.ndarray:
|
|
70
|
+
"""
|
|
71
|
+
Calculate the RDC Q-factor (Cornilescu et al., 1998).
|
|
72
|
+
Q = sqrt( sum((D_calc - D_exp)^2) / sum(D_exp^2) )
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
calculated_rdcs: (N,) calculated couplings.
|
|
76
|
+
experimental_rdcs: (N,) measured couplings.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
jnp.ndarray: Scalar Q-factor.
|
|
80
|
+
"""
|
|
81
|
+
diff_sq = jnp.sum((calculated_rdcs - experimental_rdcs)**2)
|
|
82
|
+
exp_sq = jnp.sum(experimental_rdcs**2)
|
|
83
|
+
return jnp.sqrt(diff_sq / (exp_sq + 1e-10))
|
|
84
|
+
|
|
85
|
+
@jit
|
|
86
|
+
def calculate_rdc(bond_vectors: jnp.ndarray, da: float, r: float) -> jnp.ndarray:
|
|
87
|
+
"""
|
|
88
|
+
Differentiable RDC calculation in the principal frame.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
bond_vectors: (N, 3) unit vectors in the tensor's principal frame
|
|
92
|
+
da: Axial component in Hz
|
|
93
|
+
r: Rhombicity (0 <= R <= 2/3)
|
|
94
|
+
"""
|
|
95
|
+
x, y, z = bond_vectors[:, 0], bond_vectors[:, 1], bond_vectors[:, 2]
|
|
96
|
+
|
|
97
|
+
axial = 3.0 * z**2 - 1.0
|
|
98
|
+
rhombic = 1.5 * r * (x**2 - y**2)
|
|
99
|
+
|
|
100
|
+
return da * (axial + rhombic)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jax import jit
|
|
3
|
+
|
|
4
|
+
@jit
|
|
5
|
+
def calculate_ring_current_shift(coords: jnp.ndarray,
|
|
6
|
+
ring_center: jnp.ndarray,
|
|
7
|
+
ring_normal: jnp.ndarray,
|
|
8
|
+
intensity: float) -> jnp.ndarray:
|
|
9
|
+
"""
|
|
10
|
+
Calculate chemical shift changes due to aromatic ring currents using
|
|
11
|
+
the Johnson-Bovey dipolar approximation.
|
|
12
|
+
|
|
13
|
+
delta = intensity * (1 - 3*cos^2(theta)) / r^3
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
coords: (N, 3) coordinates of the nuclei being shielded.
|
|
17
|
+
ring_center: (3,) coordinates of the aromatic ring center.
|
|
18
|
+
ring_normal: (3,) unit vector normal to the ring plane.
|
|
19
|
+
intensity: Scaling factor (proportional to ring area and current).
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
jnp.ndarray: (N,) shielding values in ppm.
|
|
23
|
+
"""
|
|
24
|
+
# 1. Displacement vectors from ring center
|
|
25
|
+
r_vec = coords - ring_center
|
|
26
|
+
|
|
27
|
+
# 2. Distances
|
|
28
|
+
r = jnp.linalg.norm(r_vec, axis=-1)
|
|
29
|
+
|
|
30
|
+
# 3. cos(theta) where theta is the angle between r_vec and the ring normal
|
|
31
|
+
# cos(theta) = (r_vec . normal) / (|r_vec| * |normal|)
|
|
32
|
+
# Assume ring_normal is already a unit vector
|
|
33
|
+
cos_theta = jnp.sum(r_vec * ring_normal, axis=-1) / (r + 1e-10)
|
|
34
|
+
|
|
35
|
+
# 4. Johnson-Bovey geometric term
|
|
36
|
+
# delta = intensity * (1 - 3 * cos^2(theta)) / r^3
|
|
37
|
+
return intensity * (1.0 - 3.0 * cos_theta**2) / (r**3 + 1e-10)
|