diff-biophys 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. diff_biophys-0.1.2/LICENSE +21 -0
  2. diff_biophys-0.1.2/PKG-INFO +116 -0
  3. diff_biophys-0.1.2/README.md +84 -0
  4. diff_biophys-0.1.2/diff_biophys/__init__.py +2 -0
  5. diff_biophys-0.1.2/diff_biophys/cd/__init__.py +1 -0
  6. diff_biophys-0.1.2/diff_biophys/cd/kernels.py +21 -0
  7. diff_biophys-0.1.2/diff_biophys/ensemble.py +45 -0
  8. diff_biophys-0.1.2/diff_biophys/geometry/__init__.py +3 -0
  9. diff_biophys-0.1.2/diff_biophys/geometry/nerf.py +51 -0
  10. diff_biophys-0.1.2/diff_biophys/geometry/superposition.py +32 -0
  11. diff_biophys-0.1.2/diff_biophys/geometry/torsions.py +52 -0
  12. diff_biophys-0.1.2/diff_biophys/nmr/__init__.py +3 -0
  13. diff_biophys-0.1.2/diff_biophys/nmr/chemical_shifts.py +44 -0
  14. diff_biophys-0.1.2/diff_biophys/nmr/constants.py +30 -0
  15. diff_biophys-0.1.2/diff_biophys/nmr/karplus.py +19 -0
  16. diff_biophys-0.1.2/diff_biophys/nmr/rdc.py +100 -0
  17. diff_biophys-0.1.2/diff_biophys/nmr/ring_currents.py +37 -0
  18. diff_biophys-0.1.2/diff_biophys/saxs/__init__.py +1 -0
  19. diff_biophys-0.1.2/diff_biophys/saxs/kernels.py +62 -0
  20. diff_biophys-0.1.2/diff_biophys.egg-info/PKG-INFO +116 -0
  21. diff_biophys-0.1.2/diff_biophys.egg-info/SOURCES.txt +40 -0
  22. diff_biophys-0.1.2/diff_biophys.egg-info/dependency_links.txt +1 -0
  23. diff_biophys-0.1.2/diff_biophys.egg-info/requires.txt +11 -0
  24. diff_biophys-0.1.2/diff_biophys.egg-info/top_level.txt +1 -0
  25. diff_biophys-0.1.2/pyproject.toml +47 -0
  26. diff_biophys-0.1.2/setup.cfg +4 -0
  27. diff_biophys-0.1.2/tests/test_cd_parity.py +15 -0
  28. diff_biophys-0.1.2/tests/test_ensemble.py +74 -0
  29. diff_biophys-0.1.2/tests/test_geometry_parity.py +35 -0
  30. diff_biophys-0.1.2/tests/test_geometry_reconstruction.py +31 -0
  31. diff_biophys-0.1.2/tests/test_invariance.py +69 -0
  32. diff_biophys-0.1.2/tests/test_kabsch_parity.py +72 -0
  33. diff_biophys-0.1.2/tests/test_nmr_advanced.py +77 -0
  34. diff_biophys-0.1.2/tests/test_rdc_fitting.py +32 -0
  35. diff_biophys-0.1.2/tests/test_rdc_parity.py +34 -0
  36. diff_biophys-0.1.2/tests/test_saxs_parity.py +59 -0
  37. diff_biophys-0.1.2/tests/test_science_ca_shifts.py +53 -0
  38. diff_biophys-0.1.2/tests/test_science_karplus.py +37 -0
  39. diff_biophys-0.1.2/tests/test_science_rdc.py +31 -0
  40. diff_biophys-0.1.2/tests/test_science_ring_currents.py +48 -0
  41. diff_biophys-0.1.2/tests/test_science_saxs_advanced.py +63 -0
  42. diff_biophys-0.1.2/tests/test_synth_parity.py +127 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 George Elkins
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,116 @@
1
+ Metadata-Version: 2.4
2
+ Name: diff-biophys
3
+ Version: 0.1.2
4
+ Summary: Differentiable biophysical modeling in JAX
5
+ Author: George Elkins
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/elkins/diff-biophys
8
+ Project-URL: Repository, https://github.com/elkins/diff-biophys
9
+ Project-URL: Documentation, https://elkins.github.io/diff-biophys/
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
18
+ Classifier: Topic :: Scientific/Engineering :: Physics
19
+ Description-Content-Type: text/markdown
20
+ License-File: LICENSE
21
+ Requires-Dist: jax
22
+ Requires-Dist: jaxlib
23
+ Requires-Dist: numpy
24
+ Requires-Dist: biotite
25
+ Provides-Extra: dev
26
+ Requires-Dist: pytest; extra == "dev"
27
+ Requires-Dist: pytest-cov; extra == "dev"
28
+ Requires-Dist: ipython; extra == "dev"
29
+ Requires-Dist: synth-pdb; extra == "dev"
30
+ Requires-Dist: synth-nmr; extra == "dev"
31
+ Dynamic: license-file
32
+
33
+ # 🧬 DiffBiophys: Differentiable Biophysics for the AI Era
34
+
35
+ **DiffBiophys** is a high-performance Python library for differentiable biophysical modeling. Built on **JAX**, it re-implements core structural biology and spectroscopy observables (SAXS, NMR, CD) as hardware-accelerated, auto-differentiable kernels.
36
+
37
+ **[Documentation Website](https://elkins.github.io/diff-biophys/)** | **[Use Cases](https://elkins.github.io/diff-biophys/use_cases/)**
38
+
39
+ ---
40
+
41
+ ## 🎯 Vision
42
+
43
+ To bridge the gap between static structural models and experimental solution-state data by providing a "differentiable bridge." This allows researchers to:
44
+ 1. **Optimize** protein structures directly against experimental spectra via gradient descent.
45
+ 2. **Train** machine learning models using physics-informed loss functions.
46
+ 3. **Accelerate** large-scale biophysical simulations on GPUs and TPUs.
47
+
48
+ ---
49
+
50
+ ## 🏗️ Core Components
51
+
52
+ ### 1. `diff_biophys.geometry` (Differentiable Structural Engine)
53
+ - **NeRF (Natural Extension Reference Frame):** Differentiable conversion from internal coordinates ($\phi, \psi, \omega$, bond lengths/angles) to Cartesian XYZ.
54
+ - **Kabsch Alignment:** Differentiable optimal superposition using SVD.
55
+ - **Torsion Analysis:** Vectorized calculation of all backbone and side-chain dihedrals.
56
+
57
+ ### 2. `diff_biophys.saxs` (Differentiable Scattering)
58
+ - **Debye Formula:** $O(N^2)$ inter-atomic interference summation.
59
+ - **Hardware Acceleration:** GPU-optimized pairwise distance kernels.
60
+ - **Use Case:** Fitting structure "compactness" and "radius of gyration" to solution-state X-ray scattering curves.
61
+
62
+ ### 3. `diff_biophys.nmr` (Differentiable Spectroscopy)
63
+ - **Residual Dipolar Couplings (RDCs):** Differentiable Saupe tensor alignment and coupling calculation.
64
+ - **Chemical Shifts:** Differentiable Ring-Current (Johnson-Bovey) shielding and Karplus J-coupling kernels.
65
+ - **Use Case:** Refining side-chain packing and domain orientations against high-resolution NMR data.
66
+
67
+ ### 4. `diff_biophys.cd` (Differentiable Dichroism)
68
+ - **Matrix-Method Simulation:** Differentiable simulation of peptide bond transition dipole coupling.
69
+ - **Use Case:** Predicting secondary structure content and verifying fold stability.
70
+
71
+ ---
72
+
73
+ ## ⚡ Technical Architecture
74
+
75
+ - **Backend:** JAX (XLA-compiled).
76
+ - **Parallelism:** Native support for `vmap` (vectorization across ensembles/trajectories) and `pmap` (multi-device execution).
77
+ - **Differentiability:** Support for both Forward and Reverse-mode autodiff.
78
+ - **Interoperability:** Seamless integration with PyTorch/TensorFlow (via DLPack) and standard structural formats (mmCIF/BCIF).
79
+
80
+ ---
81
+
82
+ ## 🚀 Roadmap
83
+
84
+ ### Phase 1: Foundations (Alpha)
85
+ - [x] Differentiable NeRF and Kabsch alignment.
86
+ - [x] GPU-accelerated Debye formula for SAXS.
87
+ - [x] Unit tests verifying parity with `synth-pdb` NumPy implementations.
88
+
89
+ ### Phase 2: NMR & Spectroscopy (Beta)
90
+ - [x] Differentiable RDC and Karplus kernels.
91
+ - [x] Differentiable Johnson-Bovey ring current model.
92
+ - [ ] Integration with `synth-nmr` parameter libraries.
93
+
94
+ ### Phase 3: Integration & Optimization (v1.0)
95
+ - [ ] Example notebooks for structure refinement via gradient descent.
96
+ - [ ] Plugin for `torch`-based AI models to use biophysical loss functions.
97
+ - [ ] Full support for BinaryCIF streaming.
98
+
99
+ ---
100
+
101
+ ## 📂 Repository Structure (Proposed)
102
+
103
+ ```text
104
+ diff-biophys/
105
+ ├── diff_biophys/ # Core package
106
+ │ ├── geometry/ # NeRF, Kabsch, Torsions
107
+ │ ├── saxs/ # Debye kernels, form factors
108
+ │ ├── nmr/ # RDCs, Karplus, Ring Currents
109
+ │ ├── cd/ # CD simulation
110
+ │ └── utils/ # Constants, JAX-NumPy shims
111
+ ├── tests/ # Parity and gradient checks
112
+ ├── examples/ # Jupyter notebooks (Refinement Lab)
113
+ ├── docs/ # API and Theory
114
+ ├── pyproject.toml # Modern build config
115
+ └── README.md
116
+ ```
@@ -0,0 +1,84 @@
1
+ # 🧬 DiffBiophys: Differentiable Biophysics for the AI Era
2
+
3
+ **DiffBiophys** is a high-performance Python library for differentiable biophysical modeling. Built on **JAX**, it re-implements core structural biology and spectroscopy observables (SAXS, NMR, CD) as hardware-accelerated, auto-differentiable kernels.
4
+
5
+ **[Documentation Website](https://elkins.github.io/diff-biophys/)** | **[Use Cases](https://elkins.github.io/diff-biophys/use_cases/)**
6
+
7
+ ---
8
+
9
+ ## 🎯 Vision
10
+
11
+ To bridge the gap between static structural models and experimental solution-state data by providing a "differentiable bridge." This allows researchers to:
12
+ 1. **Optimize** protein structures directly against experimental spectra via gradient descent.
13
+ 2. **Train** machine learning models using physics-informed loss functions.
14
+ 3. **Accelerate** large-scale biophysical simulations on GPUs and TPUs.
15
+
16
+ ---
17
+
18
+ ## 🏗️ Core Components
19
+
20
+ ### 1. `diff_biophys.geometry` (Differentiable Structural Engine)
21
+ - **NeRF (Natural Extension Reference Frame):** Differentiable conversion from internal coordinates ($\phi, \psi, \omega$, bond lengths/angles) to Cartesian XYZ.
22
+ - **Kabsch Alignment:** Differentiable optimal superposition using SVD.
23
+ - **Torsion Analysis:** Vectorized calculation of all backbone and side-chain dihedrals.
24
+
25
+ ### 2. `diff_biophys.saxs` (Differentiable Scattering)
26
+ - **Debye Formula:** $O(N^2)$ inter-atomic interference summation.
27
+ - **Hardware Acceleration:** GPU-optimized pairwise distance kernels.
28
+ - **Use Case:** Fitting structure "compactness" and "radius of gyration" to solution-state X-ray scattering curves.
29
+
30
+ ### 3. `diff_biophys.nmr` (Differentiable Spectroscopy)
31
+ - **Residual Dipolar Couplings (RDCs):** Differentiable Saupe tensor alignment and coupling calculation.
32
+ - **Chemical Shifts:** Differentiable Ring-Current (Johnson-Bovey) shielding and Karplus J-coupling kernels.
33
+ - **Use Case:** Refining side-chain packing and domain orientations against high-resolution NMR data.
34
+
35
+ ### 4. `diff_biophys.cd` (Differentiable Dichroism)
36
+ - **Matrix-Method Simulation:** Differentiable simulation of peptide bond transition dipole coupling.
37
+ - **Use Case:** Predicting secondary structure content and verifying fold stability.
38
+
39
+ ---
40
+
41
+ ## ⚡ Technical Architecture
42
+
43
+ - **Backend:** JAX (XLA-compiled).
44
+ - **Parallelism:** Native support for `vmap` (vectorization across ensembles/trajectories) and `pmap` (multi-device execution).
45
+ - **Differentiability:** Support for both Forward and Reverse-mode autodiff.
46
+ - **Interoperability:** Seamless integration with PyTorch/TensorFlow (via DLPack) and standard structural formats (mmCIF/BCIF).
47
+
48
+ ---
49
+
50
+ ## 🚀 Roadmap
51
+
52
+ ### Phase 1: Foundations (Alpha)
53
+ - [x] Differentiable NeRF and Kabsch alignment.
54
+ - [x] GPU-accelerated Debye formula for SAXS.
55
+ - [x] Unit tests verifying parity with `synth-pdb` NumPy implementations.
56
+
57
+ ### Phase 2: NMR & Spectroscopy (Beta)
58
+ - [x] Differentiable RDC and Karplus kernels.
59
+ - [x] Differentiable Johnson-Bovey ring current model.
60
+ - [ ] Integration with `synth-nmr` parameter libraries.
61
+
62
+ ### Phase 3: Integration & Optimization (v1.0)
63
+ - [ ] Example notebooks for structure refinement via gradient descent.
64
+ - [ ] Plugin for `torch`-based AI models to use biophysical loss functions.
65
+ - [ ] Full support for BinaryCIF streaming.
66
+
67
+ ---
68
+
69
+ ## 📂 Repository Structure (Proposed)
70
+
71
+ ```text
72
+ diff-biophys/
73
+ ├── diff_biophys/ # Core package
74
+ │ ├── geometry/ # NeRF, Kabsch, Torsions
75
+ │ ├── saxs/ # Debye kernels, form factors
76
+ │ ├── nmr/ # RDCs, Karplus, Ring Currents
77
+ │ ├── cd/ # CD simulation
78
+ │ └── utils/ # Constants, JAX-NumPy shims
79
+ ├── tests/ # Parity and gradient checks
80
+ ├── examples/ # Jupyter notebooks (Refinement Lab)
81
+ ├── docs/ # API and Theory
82
+ ├── pyproject.toml # Modern build config
83
+ └── README.md
84
+ ```
@@ -0,0 +1,2 @@
1
+ __version__ = "0.1.0"
2
+ from .ensemble import Ensemble
@@ -0,0 +1 @@
1
+ from .kernels import simulate_cd_matrix
@@ -0,0 +1,21 @@
1
+ import jax.numpy as jnp
2
+ from jax import jit
3
+
4
+ def simulate_cd_matrix(peptide_positions, dipole_orientations, wavelengths):
5
+ """
6
+ Placeholder for Matrix-Method CD Simulation (DeVoe Theory).
7
+
8
+ Args:
9
+ peptide_positions: (N, 3) positions of amide chromophores
10
+ dipole_orientations: (N, 3) unit vectors for transition dipoles
11
+ wavelengths: (M,) wavelengths in nm
12
+ """
13
+ # 1. Interaction Matrix (N, N)
14
+ # V_ij = dipole-dipole coupling energy
15
+
16
+ # 2. Polarizability Tensor
17
+
18
+ # 3. Solve for complex ellipticity
19
+
20
+ # Simple mockup for now (weighted basis-like)
21
+ return jnp.zeros_like(wavelengths)
@@ -0,0 +1,45 @@
1
+ import jax.numpy as jnp
2
+ from jax import vmap, jit
3
+ from typing import Callable, Any
4
+
5
+ class Ensemble:
6
+ """
7
+ High-level API for ensemble-averaged biophysical observables.
8
+ """
9
+ def __init__(self, coordinates: jnp.ndarray, weights: jnp.ndarray = None):
10
+ """
11
+ Args:
12
+ coordinates: (M, N, 3) array where M is ensemble size and N is atom count.
13
+ weights: (M,) array of population weights. Defaults to uniform.
14
+ """
15
+ self.coords = coordinates
16
+ self.m = coordinates.shape[0]
17
+ if weights is None:
18
+ self.weights = jnp.full((self.m,), 1.0 / self.m)
19
+ else:
20
+ self.weights = weights / jnp.sum(weights)
21
+
22
+ def calculate_average(self, observable_fn: Callable[[jnp.ndarray], jnp.ndarray], *args, **kwargs) -> jnp.ndarray:
23
+ """
24
+ Calculate the population-weighted average of an observable.
25
+
26
+ Args:
27
+ observable_fn: Function that takes (N, 3) coords and returns (D,) observable.
28
+ *args, **kwargs: Additional arguments for the observable_fn.
29
+
30
+ Returns:
31
+ jnp.ndarray: (D,) averaged observable.
32
+ """
33
+ # Vectorize the observable function over the ensemble dimension
34
+ v_fn = vmap(lambda c: observable_fn(c, *args, **kwargs))
35
+ ensemble_results = v_fn(self.coords) # (M, D)
36
+
37
+ # Weighted average
38
+ return jnp.sum(ensemble_results * self.weights[:, None], axis=0)
39
+
40
+ @jit
41
+ def calculate_ensemble_saxs(coords: jnp.ndarray, weights: jnp.ndarray, q_values: jnp.ndarray, form_factors: jnp.ndarray):
42
+ """Utility for fast ensemble SAXS."""
43
+ from diff_biophys.saxs import debye_saxs
44
+ v_saxs = vmap(lambda c: debye_saxs(c, q_values, form_factors))
45
+ return jnp.sum(v_saxs(coords) * weights[:, None], axis=0)
@@ -0,0 +1,3 @@
1
+ from .nerf import position_atom_3d, chain_nerf
2
+ from .superposition import kabsch_alignment
3
+ from .torsions import compute_bond_lengths, compute_bond_angles, compute_dihedrals
@@ -0,0 +1,51 @@
1
+ import jax.numpy as jnp
2
+ from jax import jit, lax
3
+
4
+ @jit
5
+ def position_atom_3d(p1: jnp.ndarray, p2: jnp.ndarray, p3: jnp.ndarray,
6
+ bond_length: jnp.ndarray, bond_angle_rad: jnp.ndarray, dihedral_angle_rad: jnp.ndarray) -> jnp.ndarray:
7
+ """
8
+ Differentiable NeRF implementation in JAX for a single atom.
9
+ """
10
+ v1 = p1 - p2
11
+ v2 = p3 - p2
12
+
13
+ u2 = v2 / (jnp.linalg.norm(v2) + 1e-10)
14
+
15
+ n = jnp.cross(v1, u2)
16
+ n /= (jnp.linalg.norm(n) + 1e-10)
17
+
18
+ m = jnp.cross(n, u2)
19
+
20
+ p4 = p3 + bond_length * (
21
+ -jnp.cos(bond_angle_rad) * u2
22
+ - jnp.sin(bond_angle_rad) * jnp.cos(dihedral_angle_rad) * m
23
+ - jnp.sin(bond_angle_rad) * jnp.sin(dihedral_angle_rad) * n
24
+ )
25
+ return p4
26
+
27
+ @jit
28
+ def chain_nerf(init_coords: jnp.ndarray, bond_lengths: jnp.ndarray,
29
+ bond_angles: jnp.ndarray, dihedrals: jnp.ndarray) -> jnp.ndarray:
30
+ """
31
+ Build a chain of atoms using the NeRF algorithm.
32
+
33
+ Args:
34
+ init_coords: (3, 3) initial coordinates for the first 3 atoms
35
+ bond_lengths: (N,) bond lengths for atoms 4 to N+3
36
+ bond_angles: (N,) bond angles (in radians) for atoms 4 to N+3
37
+ dihedrals: (N,) dihedral angles (in radians) for atoms 4 to N+3
38
+
39
+ Returns:
40
+ jnp.ndarray: (N+3, 3) coordinates for the entire chain
41
+ """
42
+ def body_fun(carry, i):
43
+ p1, p2, p3 = carry
44
+ p4 = position_atom_3d(p1, p2, p3, bond_lengths[i], bond_angles[i], dihedrals[i])
45
+ return (p2, p3, p4), p4
46
+
47
+ indices = jnp.arange(len(bond_lengths))
48
+ init_carry = (init_coords[0], init_coords[1], init_coords[2])
49
+ _, final_coords = lax.scan(body_fun, init_carry, indices)
50
+
51
+ return jnp.concatenate([init_coords, final_coords], axis=0)
@@ -0,0 +1,32 @@
1
+ import jax.numpy as jnp
2
+ from jax import jit
3
+
4
+ @jit
5
+ def kabsch_alignment(P: jnp.ndarray, Q: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
6
+ """
7
+ Optimal superposition of P onto Q using Kabsch algorithm in JAX.
8
+
9
+ Args:
10
+ P: (N, 3) mobile coordinates
11
+ Q: (N, 3) reference coordinates
12
+
13
+ Returns:
14
+ tuple[jnp.ndarray, jnp.ndarray]: (3x3 rotation matrix, 3-element translation vector)
15
+ """
16
+ p_center = jnp.mean(P, axis=0)
17
+ q_center = jnp.mean(Q, axis=0)
18
+
19
+ P_c = P - p_center
20
+ Q_c = Q - q_center
21
+
22
+ H = jnp.dot(P_c.T, Q_c)
23
+
24
+ U, S, Vt = jnp.linalg.svd(H)
25
+
26
+ d = jnp.linalg.det(jnp.dot(Vt.T, U.T))
27
+ step = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, jnp.where(d > 0, 1.0, -1.0)]])
28
+
29
+ R = jnp.dot(Vt.T, jnp.dot(step, U.T))
30
+ t = q_center - jnp.dot(R, p_center)
31
+
32
+ return R, t
@@ -0,0 +1,52 @@
1
+ import jax.numpy as jnp
2
+ from jax import jit
3
+
4
+ @jit
5
+ def compute_bond_lengths(coords: jnp.ndarray) -> jnp.ndarray:
6
+ """
7
+ Compute bond lengths between adjacent atoms.
8
+ """
9
+ vectors = coords[1:] - coords[:-1]
10
+ return jnp.linalg.norm(vectors, axis=-1)
11
+
12
+ @jit
13
+ def compute_bond_angles(coords: jnp.ndarray) -> jnp.ndarray:
14
+ """
15
+ Compute bond angles (in radians) between three adjacent atoms.
16
+ """
17
+ v1 = coords[:-2] - coords[1:-1]
18
+ v2 = coords[2:] - coords[1:-1]
19
+
20
+ v1_norm = v1 / (jnp.linalg.norm(v1, axis=-1, keepdims=True) + 1e-10)
21
+ v2_norm = v2 / (jnp.linalg.norm(v2, axis=-1, keepdims=True) + 1e-10)
22
+
23
+ cos_angle = jnp.sum(v1_norm * v2_norm, axis=-1)
24
+ return jnp.acos(jnp.clip(cos_angle, -1.0 + 1e-7, 1.0 - 1e-7))
25
+
26
+ @jit
27
+ def compute_dihedrals(coords: jnp.ndarray) -> jnp.ndarray:
28
+ """
29
+ Compute dihedral angles (in radians) for four adjacent atoms.
30
+ Follows the IUPAC convention and matches synth-pdb.
31
+ Uses the robust Praxeolitic formula.
32
+ """
33
+ # Vectors: p1-p2, p3-p2, p4-p3
34
+ b0 = coords[:-3] - coords[1:-2]
35
+ b1 = coords[2:-1] - coords[1:-2]
36
+ b2 = coords[3:] - coords[2:-1]
37
+
38
+ # Normalize b1
39
+ b1_norm = jnp.linalg.norm(b1, axis=-1, keepdims=True)
40
+ u1 = b1 / (b1_norm + 1e-10)
41
+
42
+ # v = orthogonal component of b0 with respect to b1
43
+ v = b0 - jnp.sum(b0 * u1, axis=-1, keepdims=True) * u1
44
+ # w = orthogonal component of b2 with respect to b1
45
+ w = b2 - jnp.sum(b2 * u1, axis=-1, keepdims=True) * u1
46
+
47
+ # x = dot product of v and w
48
+ x = jnp.sum(v * w, axis=-1)
49
+ # y = dot product of cross(u1, v) and w
50
+ y = jnp.sum(jnp.cross(u1, v) * w, axis=-1)
51
+
52
+ return jnp.atan2(y, x)
@@ -0,0 +1,3 @@
1
+ from .rdc import calculate_rdc, calculate_rdc_from_tensor, fit_saupe_tensor, calculate_q_factor
2
+ from .karplus import calculate_karplus_j
3
+ from .ring_currents import calculate_ring_current_shift
@@ -0,0 +1,44 @@
1
+ import jax.numpy as jnp
2
+ from jax import jit, vmap
3
+
4
+ # Baseline Random Coil Shifts (Wishart et al. 1995)
5
+ # Using a simplified set for CA (Alpha Carbon) for demonstration
6
+ RANDOM_COIL_CA = {
7
+ "ALA": 52.5, "ARG": 56.0, "ASN": 53.1, "ASP": 54.2, "CYS": 58.2,
8
+ "GLN": 55.7, "GLU": 56.6, "GLY": 45.1, "HIS": 55.0, "ILE": 61.1,
9
+ "LEU": 55.1, "LYS": 56.2, "MET": 55.3, "PHE": 57.7, "PRO": 63.3,
10
+ "SER": 58.3, "THR": 61.8, "TRP": 57.5, "TYR": 57.9, "VAL": 62.2
11
+ }
12
+
13
+ # Statistical Secondary Structure Offsets for CA
14
+ # Alpha Helix: ~ +3.1 ppm, Beta Sheet: ~ -1.5 ppm
15
+ OFFSET_HELIX = 3.1
16
+ OFFSET_SHEET = -1.5
17
+
18
+ @jit
19
+ def predict_ca_shifts(phi: jnp.ndarray, psi: jnp.ndarray, rc_shifts: jnp.ndarray) -> jnp.ndarray:
20
+ """
21
+ Differentiable CA Chemical Shift prediction based on Backbone Torsions.
22
+
23
+ This uses a "soft" classification of secondary structure based on Phi/Psi
24
+ to apply SPARTA-like offsets.
25
+
26
+ Args:
27
+ phi, psi: (N,) backbone dihedrals in radians.
28
+ rc_shifts: (N,) baseline random coil shifts.
29
+
30
+ Returns:
31
+ jnp.ndarray: (N,) predicted CA shifts.
32
+ """
33
+ # 1. Soft-classify secondary structure
34
+ # Alpha Helix: Phi ~ -60 deg (-1.05 rad), Psi ~ -45 deg (-0.78 rad)
35
+ helix_dist_sq = (phi + 1.05)**2 + (psi + 0.78)**2
36
+ is_helix = jnp.exp(-helix_dist_sq / 0.5) # Soft mask
37
+
38
+ # Beta Sheet: Phi ~ -120 deg (-2.09 rad), Psi ~ 135 deg (2.35 rad)
39
+ sheet_dist_sq = (phi + 2.09)**2 + (psi - 2.35)**2
40
+ is_sheet = jnp.exp(-sheet_dist_sq / 0.5) # Soft mask
41
+
42
+ # 2. Combine offsets
43
+ # Shift = RC + (is_helix * OFFSET_HELIX) + (is_sheet * OFFSET_SHEET)
44
+ return rc_shifts + (is_helix * OFFSET_HELIX) + (is_sheet * OFFSET_SHEET)
@@ -0,0 +1,30 @@
1
+ import jax.numpy as jnp
2
+
3
+ # Default NMR parameters (Vuister & Bax 1993)
4
+ KARPLUS_A = 6.51
5
+ KARPLUS_B = -1.76
6
+ KARPLUS_C = 1.60
7
+
8
+ # Default Ring Current Intensities (Consistent with synth-nmr/SHIFTX2)
9
+ RING_INTENSITIES = {
10
+ 'PHE': 1.2,
11
+ 'TYR': 1.2,
12
+ 'TRP': 1.3,
13
+ 'HIS': 0.5,
14
+ 'HID': 0.5,
15
+ 'HIE': 0.5,
16
+ 'HIP': 0.5
17
+ }
18
+
19
+ # Try to pull from synth-nmr if installed
20
+ try:
21
+ import synth_nmr.j_coupling as sj
22
+ import synth_nmr.chemical_shifts as sc
23
+
24
+ KARPLUS_A = sj.KARPLUS_PARAMS.get('A', KARPLUS_A)
25
+ KARPLUS_B = sj.KARPLUS_PARAMS.get('B', KARPLUS_B)
26
+ KARPLUS_C = sj.KARPLUS_PARAMS.get('C', KARPLUS_C)
27
+
28
+ RING_INTENSITIES.update(sc.RING_INTENSITIES)
29
+ except ImportError:
30
+ pass
@@ -0,0 +1,19 @@
1
+ import jax.numpy as jnp
2
+ from jax import jit
3
+
4
+ @jit
5
+ def calculate_karplus_j(theta: jnp.ndarray, a: float, b: float, c: float) -> jnp.ndarray:
6
+ """
7
+ Calculate 3J coupling constants using the Karplus equation.
8
+
9
+ J = a * cos^2(theta) + b * cos(theta) + c
10
+
11
+ Args:
12
+ theta: (N,) Dihedral angles in radians.
13
+ a, b, c: Empirical Karplus parameters.
14
+
15
+ Returns:
16
+ jnp.ndarray: (N,) Calculated J-couplings.
17
+ """
18
+ cos_theta = jnp.cos(theta)
19
+ return a * (cos_theta**2) + b * cos_theta + c
@@ -0,0 +1,100 @@
1
+ import jax.numpy as jnp
2
+ from jax import jit
3
+
4
+ @jit
5
+ def calculate_rdc_from_tensor(bond_vectors: jnp.ndarray, saupe_tensor: jnp.ndarray, d_max: float = 1.0) -> jnp.ndarray:
6
+ """
7
+ Calculate RDCs from a full 3x3 Saupe alignment tensor.
8
+ D = d_max * sum_ij (v_i * S_ij * v_j)
9
+
10
+ Args:
11
+ bond_vectors: (N, 3) unit vectors
12
+ saupe_tensor: (3, 3) symmetric traceless Saupe tensor
13
+ d_max: Maximum dipolar coupling constant (Hz)
14
+
15
+ Returns:
16
+ jnp.ndarray: Calculated RDCs (N,)
17
+ """
18
+ # Vectorized computation of v^T S v
19
+ return d_max * jnp.einsum('ni,ij,nj->n', bond_vectors, saupe_tensor, bond_vectors)
20
+
21
+ @jit
22
+ def fit_saupe_tensor(bond_vectors: jnp.ndarray, experimental_rdcs: jnp.ndarray, d_max: float = 1.0) -> jnp.ndarray:
23
+ """
24
+ Fit a Saupe alignment tensor to experimental RDCs using SVD (least squares).
25
+
26
+ The RDC formula can be rewritten as D = A * s
27
+ where s = [Sxx, Syy, Sxy, Sxz, Syz] (5 independent components)
28
+
29
+ Args:
30
+ bond_vectors: (N, 3) unit vectors
31
+ experimental_rdcs: (N,) measured RDCs in Hz
32
+ d_max: Maximum dipolar coupling constant (Hz)
33
+
34
+ Returns:
35
+ jnp.ndarray: (3, 3) Fitted Saupe tensor
36
+ """
37
+ x = bond_vectors[:, 0]
38
+ y = bond_vectors[:, 1]
39
+ z = bond_vectors[:, 2]
40
+
41
+ # Basis functions for the 5 independent components
42
+ # Using the identity Szz = -Sxx - Syy
43
+ # D = d_max * [ Sxx*x^2 + Syy*y^2 + Szz*z^2 + 2Sxy*xy + 2Sxz*xz + 2Syz*yz ]
44
+ # D = d_max * [ Sxx(x^2 - z^2) + Syy(y^2 - z^2) + 2Sxy*xy + 2Sxz*xz + 2Syz*yz ]
45
+
46
+ A = d_max * jnp.stack([
47
+ x**2 - z**2,
48
+ y**2 - z**2,
49
+ 2 * x * y,
50
+ 2 * x * z,
51
+ 2 * y * z
52
+ ], axis=1)
53
+
54
+ # Solve A * s = experimental_rdcs
55
+ s, _, _, _ = jnp.linalg.lstsq(A, experimental_rdcs)
56
+
57
+ sxx, syy, sxy, sxz, syz = s
58
+ szz = -(sxx + syy)
59
+
60
+ tensor = jnp.array([
61
+ [sxx, sxy, sxz],
62
+ [sxy, syy, syz],
63
+ [sxz, syz, szz]
64
+ ])
65
+
66
+ return tensor
67
+
68
+ @jit
69
+ def calculate_q_factor(calculated_rdcs: jnp.ndarray, experimental_rdcs: jnp.ndarray) -> jnp.ndarray:
70
+ """
71
+ Calculate the RDC Q-factor (Cornilescu et al., 1998).
72
+ Q = sqrt( sum((D_calc - D_exp)^2) / sum(D_exp^2) )
73
+
74
+ Args:
75
+ calculated_rdcs: (N,) calculated couplings.
76
+ experimental_rdcs: (N,) measured couplings.
77
+
78
+ Returns:
79
+ jnp.ndarray: Scalar Q-factor.
80
+ """
81
+ diff_sq = jnp.sum((calculated_rdcs - experimental_rdcs)**2)
82
+ exp_sq = jnp.sum(experimental_rdcs**2)
83
+ return jnp.sqrt(diff_sq / (exp_sq + 1e-10))
84
+
85
+ @jit
86
+ def calculate_rdc(bond_vectors: jnp.ndarray, da: float, r: float) -> jnp.ndarray:
87
+ """
88
+ Differentiable RDC calculation in the principal frame.
89
+
90
+ Args:
91
+ bond_vectors: (N, 3) unit vectors in the tensor's principal frame
92
+ da: Axial component in Hz
93
+ r: Rhombicity (0 <= R <= 2/3)
94
+ """
95
+ x, y, z = bond_vectors[:, 0], bond_vectors[:, 1], bond_vectors[:, 2]
96
+
97
+ axial = 3.0 * z**2 - 1.0
98
+ rhombic = 1.5 * r * (x**2 - y**2)
99
+
100
+ return da * (axial + rhombic)
@@ -0,0 +1,37 @@
1
+ import jax.numpy as jnp
2
+ from jax import jit
3
+
4
+ @jit
5
+ def calculate_ring_current_shift(coords: jnp.ndarray,
6
+ ring_center: jnp.ndarray,
7
+ ring_normal: jnp.ndarray,
8
+ intensity: float) -> jnp.ndarray:
9
+ """
10
+ Calculate chemical shift changes due to aromatic ring currents using
11
+ the Johnson-Bovey dipolar approximation.
12
+
13
+ delta = intensity * (1 - 3*cos^2(theta)) / r^3
14
+
15
+ Args:
16
+ coords: (N, 3) coordinates of the nuclei being shielded.
17
+ ring_center: (3,) coordinates of the aromatic ring center.
18
+ ring_normal: (3,) unit vector normal to the ring plane.
19
+ intensity: Scaling factor (proportional to ring area and current).
20
+
21
+ Returns:
22
+ jnp.ndarray: (N,) shielding values in ppm.
23
+ """
24
+ # 1. Displacement vectors from ring center
25
+ r_vec = coords - ring_center
26
+
27
+ # 2. Distances
28
+ r = jnp.linalg.norm(r_vec, axis=-1)
29
+
30
+ # 3. cos(theta) where theta is the angle between r_vec and the ring normal
31
+ # cos(theta) = (r_vec . normal) / (|r_vec| * |normal|)
32
+ # Assume ring_normal is already a unit vector
33
+ cos_theta = jnp.sum(r_vec * ring_normal, axis=-1) / (r + 1e-10)
34
+
35
+ # 4. Johnson-Bovey geometric term
36
+ # delta = intensity * (1 - 3 * cos^2(theta)) / r^3
37
+ return intensity * (1.0 - 3.0 * cos_theta**2) / (r**3 + 1e-10)