diff-biophys 0.1.2__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. diff_biophys-0.1.3/PKG-INFO +182 -0
  2. diff_biophys-0.1.3/README.md +141 -0
  3. diff_biophys-0.1.3/diff_biophys/__init__.py +8 -0
  4. diff_biophys-0.1.3/diff_biophys/cd/kernels.py +87 -0
  5. diff_biophys-0.1.3/diff_biophys/cryo_em.py +88 -0
  6. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/ensemble.py +30 -13
  7. diff_biophys-0.1.3/diff_biophys/geometry/__init__.py +3 -0
  8. diff_biophys-0.1.3/diff_biophys/geometry/nerf.py +84 -0
  9. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/geometry/superposition.py +9 -8
  10. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/geometry/torsions.py +14 -9
  11. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/nmr/__init__.py +1 -1
  12. diff_biophys-0.1.3/diff_biophys/nmr/chemical_shifts.py +85 -0
  13. diff_biophys-0.1.3/diff_biophys/nmr/constants.py +31 -0
  14. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/nmr/karplus.py +4 -3
  15. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/nmr/rdc.py +57 -37
  16. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/nmr/ring_currents.py +14 -12
  17. diff_biophys-0.1.3/diff_biophys/saxs/kernels.py +82 -0
  18. diff_biophys-0.1.3/diff_biophys/torch_interop.py +54 -0
  19. diff_biophys-0.1.3/diff_biophys.egg-info/PKG-INFO +182 -0
  20. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys.egg-info/SOURCES.txt +12 -1
  21. diff_biophys-0.1.3/diff_biophys.egg-info/requires.txt +22 -0
  22. diff_biophys-0.1.3/pyproject.toml +136 -0
  23. diff_biophys-0.1.3/tests/test_cd_parity.py +65 -0
  24. diff_biophys-0.1.3/tests/test_cryo_em.py +55 -0
  25. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_ensemble.py +19 -22
  26. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_geometry_parity.py +11 -13
  27. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_geometry_reconstruction.py +15 -13
  28. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_invariance.py +21 -27
  29. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_kabsch_parity.py +26 -33
  30. diff_biophys-0.1.3/tests/test_nerf_backbone_geometry.py +131 -0
  31. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_nmr_advanced.py +23 -18
  32. diff_biophys-0.1.3/tests/test_random_coil_wishart.py +58 -0
  33. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_rdc_fitting.py +13 -12
  34. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_rdc_parity.py +11 -8
  35. diff_biophys-0.1.3/tests/test_ring_current_distance_decay.py +71 -0
  36. diff_biophys-0.1.3/tests/test_saupe_tensor_properties.py +76 -0
  37. diff_biophys-0.1.3/tests/test_saxs_forward_limit.py +56 -0
  38. diff_biophys-0.1.3/tests/test_saxs_guinier.py +39 -0
  39. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_saxs_parity.py +24 -28
  40. diff_biophys-0.1.3/tests/test_science_ca_shifts.py +122 -0
  41. diff_biophys-0.1.3/tests/test_science_karplus.py +64 -0
  42. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_science_rdc.py +5 -2
  43. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_science_ring_currents.py +11 -7
  44. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_science_saxs_advanced.py +16 -15
  45. diff_biophys-0.1.3/tests/test_science_saxs_sphere.py +74 -0
  46. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/tests/test_synth_parity.py +39 -27
  47. diff_biophys-0.1.3/tests/test_torch_interop.py +40 -0
  48. diff_biophys-0.1.2/PKG-INFO +0 -116
  49. diff_biophys-0.1.2/README.md +0 -84
  50. diff_biophys-0.1.2/diff_biophys/__init__.py +0 -2
  51. diff_biophys-0.1.2/diff_biophys/cd/kernels.py +0 -21
  52. diff_biophys-0.1.2/diff_biophys/geometry/__init__.py +0 -3
  53. diff_biophys-0.1.2/diff_biophys/geometry/nerf.py +0 -51
  54. diff_biophys-0.1.2/diff_biophys/nmr/chemical_shifts.py +0 -44
  55. diff_biophys-0.1.2/diff_biophys/nmr/constants.py +0 -30
  56. diff_biophys-0.1.2/diff_biophys/saxs/kernels.py +0 -62
  57. diff_biophys-0.1.2/diff_biophys.egg-info/PKG-INFO +0 -116
  58. diff_biophys-0.1.2/diff_biophys.egg-info/requires.txt +0 -11
  59. diff_biophys-0.1.2/pyproject.toml +0 -47
  60. diff_biophys-0.1.2/tests/test_cd_parity.py +0 -15
  61. diff_biophys-0.1.2/tests/test_science_ca_shifts.py +0 -53
  62. diff_biophys-0.1.2/tests/test_science_karplus.py +0 -37
  63. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/LICENSE +0 -0
  64. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/cd/__init__.py +0 -0
  65. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys/saxs/__init__.py +0 -0
  66. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys.egg-info/dependency_links.txt +0 -0
  67. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/diff_biophys.egg-info/top_level.txt +0 -0
  68. {diff_biophys-0.1.2 → diff_biophys-0.1.3}/setup.cfg +0 -0
@@ -0,0 +1,182 @@
1
+ Metadata-Version: 2.4
2
+ Name: diff-biophys
3
+ Version: 0.1.3
4
+ Summary: Differentiable biophysical modeling in JAX
5
+ Author: George Elkins
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/elkins/diff-biophys
8
+ Project-URL: Repository, https://github.com/elkins/diff-biophys
9
+ Project-URL: Documentation, https://elkins.github.io/diff-biophys/
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
17
+ Classifier: Topic :: Scientific/Engineering :: Physics
18
+ Requires-Python: >=3.10
19
+ Description-Content-Type: text/markdown
20
+ License-File: LICENSE
21
+ Requires-Dist: jax
22
+ Requires-Dist: jaxlib
23
+ Requires-Dist: numpy
24
+ Requires-Dist: biotite
25
+ Provides-Extra: dev
26
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
27
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
28
+ Requires-Dist: ruff>=0.6.0; extra == "dev"
29
+ Requires-Dist: mypy>=1.8.0; extra == "dev"
30
+ Requires-Dist: pre-commit>=3.0.0; extra == "dev"
31
+ Requires-Dist: ipython; extra == "dev"
32
+ Requires-Dist: synth-pdb; extra == "dev"
33
+ Requires-Dist: synth-nmr; extra == "dev"
34
+ Requires-Dist: synth-core; extra == "dev"
35
+ Provides-Extra: torch
36
+ Requires-Dist: torch>=2.0.0; extra == "torch"
37
+ Provides-Extra: examples
38
+ Requires-Dist: optax>=0.1.0; extra == "examples"
39
+ Requires-Dist: matplotlib>=3.5.0; extra == "examples"
40
+ Dynamic: license-file
41
+
42
+ # 🧬 DiffBiophys: Differentiable Biophysics for the AI Era
43
+
44
+ [![Tests](https://github.com/elkins/diff-biophys/actions/workflows/test.yml/badge.svg)](https://github.com/elkins/diff-biophys/actions/workflows/test.yml)
45
+ [![PyPI version](https://img.shields.io/pypi/v/diff-biophys.svg)](https://pypi.org/project/diff-biophys/)
46
+ [![Python 3.10+](https://img.shields.io/pypi/pyversions/diff-biophys.svg)](https://pypi.org/project/diff-biophys/)
47
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
48
+ [![codecov](https://codecov.io/gh/elkins/diff-biophys/branch/main/graph/badge.svg)](https://codecov.io/gh/elkins/diff-biophys)
49
+ [![JAX](https://img.shields.io/badge/backend-JAX-9cf.svg)](https://github.com/google/jax)
50
+ [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
51
+ [![Checked with mypy](https://img.shields.io/badge/type%20checked-mypy-blue)](https://mypy-lang.org/)
52
+
53
+ **DiffBiophys** is a high-performance Python library for differentiable biophysical modeling. Built on **JAX**, it re-implements core structural biology and spectroscopy observables (SAXS, NMR, CD) as hardware-accelerated, auto-differentiable kernels.
54
+
55
+ **[Documentation Website](https://elkins.github.io/diff-biophys/)** | **[Use Cases](https://elkins.github.io/diff-biophys/use_cases/)**
56
+
57
+ ---
58
+
59
+ ## 🎯 Vision
60
+
61
+ To bridge the gap between static structural models and experimental solution-state data by providing a "differentiable bridge." This allows researchers to:
62
+ 1. **Optimize** protein structures directly against experimental spectra via gradient descent.
63
+ 2. **Train** machine learning models using physics-informed loss functions.
64
+ 3. **Accelerate** large-scale biophysical simulations on GPUs and TPUs.
65
+
66
+ ---
67
+
68
+ ## 🏗️ Core Components
69
+
70
+ ### 1. `diff_biophys.geometry` (Differentiable Structural Engine)
71
+ - **NeRF (Natural Extension Reference Frame):** Differentiable conversion from internal coordinates ($\phi, \psi, \omega$, bond lengths/angles) to Cartesian XYZ.
72
+ - **Kabsch Alignment:** Differentiable optimal superposition using SVD.
73
+ - **Torsion Analysis:** Vectorized calculation of all backbone and side-chain dihedrals.
74
+
75
+ ### 2. `diff_biophys.saxs` (Differentiable Scattering)
76
+ - **Debye Formula:** $O(N^2)$ inter-atomic interference summation.
77
+ - **Hydration Shell Correction:** Excluded-volume solvent subtraction (Fraser et al. 1978).
78
+ - **Hardware Acceleration:** GPU-optimized pairwise distance kernels via JAX `vmap`.
79
+ - **Use Case:** Fitting structure compactness and radius of gyration to solution-state X-ray scattering curves.
80
+
81
+ ### 3. `diff_biophys.nmr` (Differentiable Spectroscopy)
82
+ - **Residual Dipolar Couplings (RDCs):** Differentiable Saupe tensor alignment and coupling calculation. Includes SVD-based tensor fitting.
83
+ - **Chemical Shifts:** Differentiable ring-current (Johnson-Bovey) shielding and softmax-weighted secondary structure Cα shift predictor.
84
+ - **Karplus J-coupling:** Parameterizable 3J coupling equation (Vuister & Bax 1993 defaults).
85
+ - **Use Case:** Refining side-chain packing and domain orientations against high-resolution NMR data.
86
+
87
+ ### 4. `diff_biophys.cd` (Differentiable Dichroism)
88
+ - **Matrix-Method Simulation:** Differentiable simulation of peptide bond transition dipole coupling via DeVoe theory.
89
+ - **Status:** ✅ Implemented. Supports frequency-dependent coupled-oscillator response.
90
+
91
+ ---
92
+
93
+ ## ⚡ Technical Architecture
94
+
95
+ - **Backend:** JAX (XLA-compiled) — supports CPU, GPU, and TPU.
96
+ - **Parallelism:** Native support for `vmap` (vectorization across ensembles/trajectories) and `pmap` (multi-device execution).
97
+ - **Differentiability:** Forward and reverse-mode autodiff through all kernels.
98
+ - **Interoperability:** JAX arrays are compatible with NumPy and can be exchanged with PyTorch via `dlpack` (user-managed conversion).
99
+
100
+ ---
101
+
102
+ ## 🚀 Roadmap
103
+
104
+ ### Phase 1: Foundations (Alpha)
105
+ - [x] Differentiable NeRF and Kabsch alignment.
106
+ - [x] GPU-accelerated Debye formula for SAXS with hydration shell correction.
107
+ - [x] Unit tests verifying parity with `synth-pdb` NumPy implementations.
108
+
109
+ ### Phase 2: NMR & Spectroscopy (Beta)
110
+ - [x] Differentiable RDC and Karplus kernels.
111
+ - [x] Differentiable Johnson-Bovey ring current model.
112
+ - [x] Integration with `synth-nmr` parameter libraries (optional dependency).
113
+
114
+ ### Phase 3: Integration & Optimization (v1.0)
115
+ - [x] Full CD matrix-method implementation (DeVoe theory).
116
+ - [ ] Example notebooks for structure refinement via gradient descent.
117
+ - [ ] Plugin for `torch`-based AI models to use biophysical loss functions.
118
+ - [ ] Full support for BinaryCIF streaming.
119
+
120
+ ---
121
+
122
+ ## 📂 Repository Structure
123
+
124
+ ```text
125
+ diff-biophys/
126
+ ├── diff_biophys/ # Core package
127
+ │ ├── geometry/ # NeRF, Kabsch, Torsions
128
+ │ ├── saxs/ # Debye kernels, form factors
129
+ │ ├── nmr/ # RDCs, Karplus, Ring Currents, Chemical Shifts
130
+ │ ├── cd/ # CD simulation (DeVoe Matrix Method)
131
+ │ └── ensemble.py # Ensemble averaging API
132
+ ├── tests/ # Parity, gradient, and scientific validation checks
133
+ ├── examples/ # Jupyter notebooks (Refinement Lab)
134
+ ├── docs/ # API and Theory
135
+ ├── pyproject.toml # Modern build config
136
+ └── README.md
137
+ ```
138
+
139
+ ## 🚀 Installation
140
+
141
+ ```bash
142
+ pip install diff-biophys
143
+ ```
144
+
145
+ For GPU support (CUDA):
146
+ ```bash
147
+ pip install "jax[cuda12]" diff-biophys
148
+ ```
149
+
150
+ ## 🤝 Contributing
151
+
152
+ Contributions are welcome from both ML and structural biology communities! Please open an issue or pull request on [GitHub](https://github.com/elkins/diff-biophys). Run `pre-commit run --all-files` before submitting.
153
+
154
+ ## 🔗 Related Projects
155
+
156
+ diff-biophys is the **differentiable engine** powering the higher-level tools in this ecosystem:
157
+
158
+ - [synth-pdb](https://github.com/elkins/synth-pdb) — Synthetic structure generation (uses NumPy implementations)
159
+ - [synth-nmr](https://github.com/elkins/synth-nmr) — NMR observables (optional dependency)
160
+ - [synth-saxs](https://github.com/elkins/synth-saxs) — SAXS profile simulator
161
+ - [diff-fret](https://github.com/elkins/diff-fret) — Differentiable FRET (new)
162
+ - [diff-hdx](https://github.com/elkins/diff-hdx) — Differentiable HDX-MS (new)
163
+ - [diff-epr](https://github.com/elkins/diff-epr) — Differentiable EPR/DEER (new)
164
+ - [diff-ensemble](https://github.com/elkins/diff-ensemble) — IDP ensemble VAE (depends on diff-biophys)
165
+ - [TorsionTuner](https://github.com/elkins/TorsionTuner) — GNN refinement (depends on diff-biophys)
166
+ - [resonance-flow](https://github.com/elkins/resonance-flow) — NMR-guided folding (depends on diff-biophys)
167
+
168
+ ## ⚖️ License
169
+
170
+ MIT License — see [LICENSE](LICENSE) for details.
171
+
172
+ ## 📖 Citation
173
+
174
+ ```bibtex
175
+ @software{diff_biophys,
176
+ author = {Elkins, George},
177
+ title = {diff-biophys: Differentiable biophysics kernels for JAX},
178
+ year = {2024},
179
+ url = {https://github.com/elkins/diff-biophys},
180
+ version = {0.1.2}
181
+ }
182
+ ```
@@ -0,0 +1,141 @@
1
+ # 🧬 DiffBiophys: Differentiable Biophysics for the AI Era
2
+
3
+ [![Tests](https://github.com/elkins/diff-biophys/actions/workflows/test.yml/badge.svg)](https://github.com/elkins/diff-biophys/actions/workflows/test.yml)
4
+ [![PyPI version](https://img.shields.io/pypi/v/diff-biophys.svg)](https://pypi.org/project/diff-biophys/)
5
+ [![Python 3.10+](https://img.shields.io/pypi/pyversions/diff-biophys.svg)](https://pypi.org/project/diff-biophys/)
6
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
7
+ [![codecov](https://codecov.io/gh/elkins/diff-biophys/branch/main/graph/badge.svg)](https://codecov.io/gh/elkins/diff-biophys)
8
+ [![JAX](https://img.shields.io/badge/backend-JAX-9cf.svg)](https://github.com/google/jax)
9
+ [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
10
+ [![Checked with mypy](https://img.shields.io/badge/type%20checked-mypy-blue)](https://mypy-lang.org/)
11
+
12
+ **DiffBiophys** is a high-performance Python library for differentiable biophysical modeling. Built on **JAX**, it re-implements core structural biology and spectroscopy observables (SAXS, NMR, CD) as hardware-accelerated, auto-differentiable kernels.
13
+
14
+ **[Documentation Website](https://elkins.github.io/diff-biophys/)** | **[Use Cases](https://elkins.github.io/diff-biophys/use_cases/)**
15
+
16
+ ---
17
+
18
+ ## 🎯 Vision
19
+
20
+ To bridge the gap between static structural models and experimental solution-state data by providing a "differentiable bridge." This allows researchers to:
21
+ 1. **Optimize** protein structures directly against experimental spectra via gradient descent.
22
+ 2. **Train** machine learning models using physics-informed loss functions.
23
+ 3. **Accelerate** large-scale biophysical simulations on GPUs and TPUs.
24
+
25
+ ---
26
+
27
+ ## 🏗️ Core Components
28
+
29
+ ### 1. `diff_biophys.geometry` (Differentiable Structural Engine)
30
+ - **NeRF (Natural Extension Reference Frame):** Differentiable conversion from internal coordinates ($\phi, \psi, \omega$, bond lengths/angles) to Cartesian XYZ.
31
+ - **Kabsch Alignment:** Differentiable optimal superposition using SVD.
32
+ - **Torsion Analysis:** Vectorized calculation of all backbone and side-chain dihedrals.
33
+
34
+ ### 2. `diff_biophys.saxs` (Differentiable Scattering)
35
+ - **Debye Formula:** $O(N^2)$ inter-atomic interference summation.
36
+ - **Hydration Shell Correction:** Excluded-volume solvent subtraction (Fraser et al. 1978).
37
+ - **Hardware Acceleration:** GPU-optimized pairwise distance kernels via JAX `vmap`.
38
+ - **Use Case:** Fitting structure compactness and radius of gyration to solution-state X-ray scattering curves.
39
+
40
+ ### 3. `diff_biophys.nmr` (Differentiable Spectroscopy)
41
+ - **Residual Dipolar Couplings (RDCs):** Differentiable Saupe tensor alignment and coupling calculation. Includes SVD-based tensor fitting.
42
+ - **Chemical Shifts:** Differentiable ring-current (Johnson-Bovey) shielding and softmax-weighted secondary structure Cα shift predictor.
43
+ - **Karplus J-coupling:** Parameterizable 3J coupling equation (Vuister & Bax 1993 defaults).
44
+ - **Use Case:** Refining side-chain packing and domain orientations against high-resolution NMR data.
45
+
46
+ ### 4. `diff_biophys.cd` (Differentiable Dichroism)
47
+ - **Matrix-Method Simulation:** Differentiable simulation of peptide bond transition dipole coupling via DeVoe theory.
48
+ - **Status:** ✅ Implemented. Supports frequency-dependent coupled-oscillator response.
49
+
50
+ ---
51
+
52
+ ## ⚡ Technical Architecture
53
+
54
+ - **Backend:** JAX (XLA-compiled) — supports CPU, GPU, and TPU.
55
+ - **Parallelism:** Native support for `vmap` (vectorization across ensembles/trajectories) and `pmap` (multi-device execution).
56
+ - **Differentiability:** Forward and reverse-mode autodiff through all kernels.
57
+ - **Interoperability:** JAX arrays are compatible with NumPy and can be exchanged with PyTorch via `dlpack` (user-managed conversion).
58
+
59
+ ---
60
+
61
+ ## 🚀 Roadmap
62
+
63
+ ### Phase 1: Foundations (Alpha)
64
+ - [x] Differentiable NeRF and Kabsch alignment.
65
+ - [x] GPU-accelerated Debye formula for SAXS with hydration shell correction.
66
+ - [x] Unit tests verifying parity with `synth-pdb` NumPy implementations.
67
+
68
+ ### Phase 2: NMR & Spectroscopy (Beta)
69
+ - [x] Differentiable RDC and Karplus kernels.
70
+ - [x] Differentiable Johnson-Bovey ring current model.
71
+ - [x] Integration with `synth-nmr` parameter libraries (optional dependency).
72
+
73
+ ### Phase 3: Integration & Optimization (v1.0)
74
+ - [x] Full CD matrix-method implementation (DeVoe theory).
75
+ - [ ] Example notebooks for structure refinement via gradient descent.
76
+ - [ ] Plugin for `torch`-based AI models to use biophysical loss functions.
77
+ - [ ] Full support for BinaryCIF streaming.
78
+
79
+ ---
80
+
81
+ ## 📂 Repository Structure
82
+
83
+ ```text
84
+ diff-biophys/
85
+ ├── diff_biophys/ # Core package
86
+ │ ├── geometry/ # NeRF, Kabsch, Torsions
87
+ │ ├── saxs/ # Debye kernels, form factors
88
+ │ ├── nmr/ # RDCs, Karplus, Ring Currents, Chemical Shifts
89
+ │ ├── cd/ # CD simulation (DeVoe Matrix Method)
90
+ │ └── ensemble.py # Ensemble averaging API
91
+ ├── tests/ # Parity, gradient, and scientific validation checks
92
+ ├── examples/ # Jupyter notebooks (Refinement Lab)
93
+ ├── docs/ # API and Theory
94
+ ├── pyproject.toml # Modern build config
95
+ └── README.md
96
+ ```
97
+
98
+ ## 🚀 Installation
99
+
100
+ ```bash
101
+ pip install diff-biophys
102
+ ```
103
+
104
+ For GPU support (CUDA):
105
+ ```bash
106
+ pip install "jax[cuda12]" diff-biophys
107
+ ```
108
+
109
+ ## 🤝 Contributing
110
+
111
+ Contributions are welcome from both ML and structural biology communities! Please open an issue or pull request on [GitHub](https://github.com/elkins/diff-biophys). Run `pre-commit run --all-files` before submitting.
112
+
113
+ ## 🔗 Related Projects
114
+
115
+ diff-biophys is the **differentiable engine** powering the higher-level tools in this ecosystem:
116
+
117
+ - [synth-pdb](https://github.com/elkins/synth-pdb) — Synthetic structure generation (uses NumPy implementations)
118
+ - [synth-nmr](https://github.com/elkins/synth-nmr) — NMR observables (optional dependency)
119
+ - [synth-saxs](https://github.com/elkins/synth-saxs) — SAXS profile simulator
120
+ - [diff-fret](https://github.com/elkins/diff-fret) — Differentiable FRET (new)
121
+ - [diff-hdx](https://github.com/elkins/diff-hdx) — Differentiable HDX-MS (new)
122
+ - [diff-epr](https://github.com/elkins/diff-epr) — Differentiable EPR/DEER (new)
123
+ - [diff-ensemble](https://github.com/elkins/diff-ensemble) — IDP ensemble VAE (depends on diff-biophys)
124
+ - [TorsionTuner](https://github.com/elkins/TorsionTuner) — GNN refinement (depends on diff-biophys)
125
+ - [resonance-flow](https://github.com/elkins/resonance-flow) — NMR-guided folding (depends on diff-biophys)
126
+
127
+ ## ⚖️ License
128
+
129
+ MIT License — see [LICENSE](LICENSE) for details.
130
+
131
+ ## 📖 Citation
132
+
133
+ ```bibtex
134
+ @software{diff_biophys,
135
+ author = {Elkins, George},
136
+ title = {diff-biophys: Differentiable biophysics kernels for JAX},
137
+ year = {2024},
138
+ url = {https://github.com/elkins/diff-biophys},
139
+ version = {0.1.2}
140
+ }
141
+ ```
@@ -0,0 +1,8 @@
1
+ __version__ = "0.1.3"
2
+ from . import cryo_em
3
+ from .ensemble import Ensemble
4
+
5
+ try:
6
+ from . import torch_interop
7
+ except ImportError:
8
+ pass
@@ -0,0 +1,87 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+
5
+ def simulate_cd_matrix(
6
+ peptide_positions: jnp.ndarray,
7
+ dipole_orientations: jnp.ndarray,
8
+ wavelengths: jnp.ndarray,
9
+ f_osc: float = 0.2,
10
+ gamma: float = 10.0,
11
+ lambda_0: float = 190.0,
12
+ ) -> jnp.ndarray:
13
+ """
14
+ Matrix-Method CD Simulation (DeVoe Theory).
15
+
16
+ Implements the coupled-oscillator model for transition dipole coupling.
17
+ Calculates the interaction matrix and solves for the complex polarizability
18
+ response to determine molar ellipticity.
19
+
20
+ Args:
21
+ peptide_positions: (N, 3) positions of amide chromophores in Angstroms.
22
+ dipole_orientations: (N, 3) unit vectors for transition dipoles.
23
+ wavelengths: (M,) wavelengths in nm to simulate.
24
+ f_osc: Oscillator strength of the transition (default 0.2 for pi->pi*).
25
+ gamma: Linewidth parameter in nm (default 10.0).
26
+ lambda_0: Resonance wavelength in nm (default 190.0).
27
+
28
+ Returns:
29
+ Molar ellipticity [θ] in deg cm^2 / dmol (M,).
30
+ """
31
+ n_chromophores = peptide_positions.shape[0]
32
+
33
+ # 1. Compute dipole-dipole interaction matrix V_ij
34
+ # V_ij = (1/r^3) * [ mu_i . mu_j - 3(mu_i . r_ij)(mu_j . r_ij) ]
35
+ diff = peptide_positions[:, None, :] - peptide_positions[None, :, :]
36
+ dist_sq = jnp.sum(diff**2, axis=-1)
37
+
38
+ # Safe distance for gradients (avoid sqrt(0) and 1/0)
39
+ # 1e-9 is a safe epsilon for float32
40
+ mask = dist_sq > 0
41
+ safe_dist_sq = jnp.where(mask, dist_sq, 1.0)
42
+ r_ij = jnp.sqrt(safe_dist_sq)
43
+ r_ij_inv3 = jnp.where(mask, 1.0 / r_ij**3, 0.0)
44
+
45
+ # Unit vectors between chromophores
46
+ r_hat = diff * jnp.where(mask[:, :, None], 1.0 / r_ij[:, :, None], 0.0)
47
+
48
+ # Dot products
49
+ mu_i_mu_j = jnp.sum(dipole_orientations[:, None, :] * dipole_orientations[None, :, :], axis=-1)
50
+ mu_i_r = jnp.sum(dipole_orientations[:, None, :] * r_hat, axis=-1)
51
+ mu_j_r = jnp.sum(dipole_orientations[None, :, :] * r_hat, axis=-1)
52
+
53
+ # Interaction energy V (N, N)
54
+ V = r_ij_inv3 * (mu_i_mu_j - 3 * mu_i_r * mu_j_r)
55
+
56
+ # 2. Frequency-dependent response
57
+ def compute_at_wavelength(lmbda: jnp.ndarray) -> jnp.ndarray:
58
+ # Complex polarizability alpha(lambda)
59
+ # Lorentzian-like response
60
+ denom = (1.0 / lmbda**2 - 1.0 / lambda_0**2) + 1j * (gamma / (lmbda * lambda_0**2))
61
+ alpha = f_osc / denom
62
+
63
+ # Interaction matrix (I - alpha * V)
64
+ # alpha is scalar for all identical chromophores here
65
+ M = jnp.eye(n_chromophores) - alpha * V
66
+
67
+ # We'll use the matrix inverse to find the coupled response
68
+ # Note: jnp.linalg.inv is differentiable but can be sensitive
69
+ inv_M = jnp.linalg.inv(M)
70
+
71
+ # Geometric factor for CD (Scalar triple product mu_i x mu_j . r_ij)
72
+ # This represents the chiral arrangement.
73
+ cross_mu = jnp.cross(dipole_orientations[:, None, :], dipole_orientations[None, :, :])
74
+ R_ij = jnp.sum(cross_mu * diff, axis=-1)
75
+
76
+ # Total CD response at this wavelength
77
+ coupled_V = inv_M @ (alpha * V)
78
+ cd_val = jnp.imag(jnp.sum(coupled_V * R_ij))
79
+
80
+ return cd_val
81
+
82
+ # Vectorize over wavelengths
83
+ cd_spectrum = jax.vmap(compute_at_wavelength)(wavelengths)
84
+
85
+ # Scale to molar ellipticity (arbitrary units for this kernel,
86
+ # should be calibrated to exp data)
87
+ return cd_spectrum * 1e5
@@ -0,0 +1,88 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+
4
+
5
+ @jax.jit
6
+ def compute_fsc(
7
+ data1: jax.Array, data2: jax.Array, voxel_size: tuple[float, float, float]
8
+ ) -> tuple[jax.Array, jax.Array]:
9
+ """
10
+ Compute the fully differentiable Fourier Shell Correlation (FSC) between two 3D maps using JAX.
11
+ Returns frequencies and correlation values.
12
+
13
+ This function matches the implementation in synth-core, but uses jax.numpy
14
+ so that gradients can flow through the FSC calculation to the input maps.
15
+ """
16
+ # Fourier transforms in JAX
17
+ f1 = jnp.fft.rfftn(data1)
18
+ f2 = jnp.fft.rfftn(data2)
19
+
20
+ # Cross-spectral density and power spectra are computed using float arithmetic
21
+ # We avoid complex multiplication to parallel the numpy memory stability fix
22
+ cross = f1.real * f2.real + f1.imag * f2.imag
23
+ p1 = f1.real**2 + f1.imag**2
24
+ p2 = f2.real**2 + f2.imag**2
25
+
26
+ # Calculate radial bins
27
+ nz, ny, nx = data1.shape
28
+ kz = jnp.fft.fftfreq(nz, d=voxel_size[0])
29
+ ky = jnp.fft.fftfreq(ny, d=voxel_size[1])
30
+ kx = jnp.fft.rfftfreq(nx, d=voxel_size[2])
31
+
32
+ # Create 3D grid of frequencies
33
+ kz_grid, ky_grid, kx_grid = jnp.meshgrid(kz, ky, kx, indexing="ij")
34
+
35
+ # Calculate magnitude of frequency vector for each voxel
36
+ k = jnp.sqrt(kz_grid**2 + ky_grid**2 + kx_grid**2)
37
+
38
+ # Flatten everything
39
+ k = k.ravel()
40
+ cross = cross.ravel()
41
+ p1 = p1.ravel()
42
+ p2 = p2.ravel()
43
+
44
+ # Sort by frequency
45
+ idx = jnp.argsort(k)
46
+ k_sorted = k[idx]
47
+ cross_sorted = cross[idx]
48
+ p1_sorted = p1[idx]
49
+ p2_sorted = p2[idx]
50
+
51
+ n_bins = min(nx, ny, nz) // 2
52
+ k_max = k_sorted[-1]
53
+ k_eps = k_max / (10 * n_bins)
54
+ bins = jnp.linspace(k_eps, k_max, n_bins + 1)
55
+
56
+ # We use vmap to compute the bin sums to keep the function differentiable and JIT-compatible
57
+ # We avoid python loops with dynamic shapes.
58
+
59
+ def compute_bin(i: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
60
+ bin_start = bins[i]
61
+ bin_end = bins[i + 1]
62
+ mask = (k_sorted >= bin_start) & (k_sorted < bin_end)
63
+
64
+ sum_cross = jnp.sum(jnp.where(mask, cross_sorted, 0.0))
65
+ sum_p1 = jnp.sum(jnp.where(mask, p1_sorted, 0.0))
66
+ sum_p2 = jnp.sum(jnp.where(mask, p2_sorted, 0.0))
67
+
68
+ num = sum_cross
69
+ den = jnp.sqrt(sum_p1 * sum_p2)
70
+
71
+ # Avoid division by zero
72
+ val = jnp.where(den > 0, num / den, 0.0)
73
+ # Clamp to [-1, 1]
74
+ val = jnp.clip(val, -1.0, 1.0)
75
+ freq = (bin_start + bin_end) / 2.0
76
+
77
+ # We need to return valid mask too, because some bins might be empty
78
+ is_valid = jnp.any(mask)
79
+ return freq, val, is_valid
80
+
81
+ indices = jnp.arange(n_bins)
82
+ freqs, vals, is_valid = jax.vmap(compute_bin)(indices)
83
+
84
+ # Note: jnp.where with dynamic sizes breaks JIT if we don't pad.
85
+ # For a fully differentiable metric, we typically pad with 0s or NaNs, or return the full array.
86
+ # We will return the full array but mask out invalid frequencies with NaN or 0.
87
+
88
+ return freqs, vals
@@ -1,45 +1,62 @@
1
+ from collections.abc import Callable
2
+ from typing import Any, cast
3
+
1
4
  import jax.numpy as jnp
2
- from jax import vmap, jit
3
- from typing import Callable, Any
5
+ from jax import jit, vmap
6
+
4
7
 
5
8
  class Ensemble:
6
9
  """
7
10
  High-level API for ensemble-averaged biophysical observables.
8
11
  """
9
- def __init__(self, coordinates: jnp.ndarray, weights: jnp.ndarray = None):
12
+
13
+ coords: jnp.ndarray
14
+ weights: jnp.ndarray
15
+ m: int
16
+
17
+ def __init__(self, coordinates: jnp.ndarray, weights: jnp.ndarray | None = None):
10
18
  """
11
19
  Args:
12
20
  coordinates: (M, N, 3) array where M is ensemble size and N is atom count.
13
21
  weights: (M,) array of population weights. Defaults to uniform.
14
22
  """
15
23
  self.coords = coordinates
16
- self.m = coordinates.shape[0]
24
+ self.m = int(coordinates.shape[0])
17
25
  if weights is None:
18
26
  self.weights = jnp.full((self.m,), 1.0 / self.m)
19
27
  else:
20
- self.weights = weights / jnp.sum(weights)
28
+ self.weights = cast(jnp.ndarray, weights / jnp.sum(weights))
21
29
 
22
- def calculate_average(self, observable_fn: Callable[[jnp.ndarray], jnp.ndarray], *args, **kwargs) -> jnp.ndarray:
30
+ def calculate_average(
31
+ self,
32
+ observable_fn: Callable[..., jnp.ndarray],
33
+ *args: Any,
34
+ **kwargs: Any,
35
+ ) -> jnp.ndarray:
23
36
  """
24
37
  Calculate the population-weighted average of an observable.
25
-
38
+
26
39
  Args:
27
40
  observable_fn: Function that takes (N, 3) coords and returns (D,) observable.
28
41
  *args, **kwargs: Additional arguments for the observable_fn.
29
-
42
+
30
43
  Returns:
31
44
  jnp.ndarray: (D,) averaged observable.
32
45
  """
33
46
  # Vectorize the observable function over the ensemble dimension
34
47
  v_fn = vmap(lambda c: observable_fn(c, *args, **kwargs))
35
- ensemble_results = v_fn(self.coords) # (M, D)
36
-
48
+ ensemble_results = v_fn(self.coords) # (M, D)
49
+
37
50
  # Weighted average
38
- return jnp.sum(ensemble_results * self.weights[:, None], axis=0)
51
+ return cast(jnp.ndarray, jnp.sum(ensemble_results * self.weights[:, None], axis=0))
52
+
39
53
 
40
54
  @jit
41
- def calculate_ensemble_saxs(coords: jnp.ndarray, weights: jnp.ndarray, q_values: jnp.ndarray, form_factors: jnp.ndarray):
55
+ def calculate_ensemble_saxs(
56
+ coords: jnp.ndarray, weights: jnp.ndarray, q_values: jnp.ndarray, form_factors: jnp.ndarray
57
+ ) -> jnp.ndarray:
42
58
  """Utility for fast ensemble SAXS."""
43
59
  from diff_biophys.saxs import debye_saxs
60
+
44
61
  v_saxs = vmap(lambda c: debye_saxs(c, q_values, form_factors))
45
- return jnp.sum(v_saxs(coords) * weights[:, None], axis=0)
62
+ return cast(jnp.ndarray, jnp.sum(v_saxs(coords) * weights[:, None], axis=0))
@@ -0,0 +1,3 @@
1
+ from .nerf import chain_nerf, position_atom_3d
2
+ from .superposition import kabsch_alignment
3
+ from .torsions import compute_bond_angles, compute_bond_lengths, compute_dihedrals
@@ -0,0 +1,84 @@
1
+ from typing import Any, cast
2
+
3
+ import jax.numpy as jnp
4
+ from jax import jit, lax
5
+
6
+
7
+ @jit
8
+ def position_atom_3d(
9
+ p1: jnp.ndarray,
10
+ p2: jnp.ndarray,
11
+ p3: jnp.ndarray,
12
+ bond_length: jnp.ndarray,
13
+ bond_angle_rad: jnp.ndarray,
14
+ dihedral_angle_rad: jnp.ndarray,
15
+ ) -> jnp.ndarray:
16
+ """
17
+ Differentiable NeRF implementation in JAX for a single atom.
18
+
19
+ Places atom p4 given three reference atoms (p1, p2, p3) and the internal
20
+ coordinates (bond length, bond angle, dihedral angle) that define its
21
+ position relative to p3.
22
+
23
+ Args:
24
+ p1, p2, p3: (3,) reference atom coordinates.
25
+ bond_length: Scalar distance p3→p4 in Ångströms.
26
+ bond_angle_rad: Scalar bond angle ∠(p2, p3, p4) in radians.
27
+ dihedral_angle_rad: Scalar dihedral angle ∠(p1, p2, p3, p4) in radians.
28
+
29
+ Returns:
30
+ jnp.ndarray: (3,) Cartesian coordinates of the new atom p4.
31
+ """
32
+ v1 = p1 - p2
33
+ v2 = p3 - p2
34
+
35
+ u2 = v2 / (jnp.linalg.norm(v2) + 1e-10)
36
+
37
+ n = jnp.cross(v1, u2)
38
+ n /= jnp.linalg.norm(n) + 1e-10
39
+
40
+ m = jnp.cross(n, u2)
41
+
42
+ p4 = p3 + bond_length * (
43
+ -jnp.cos(bond_angle_rad) * u2
44
+ - jnp.sin(bond_angle_rad) * jnp.cos(dihedral_angle_rad) * m
45
+ - jnp.sin(bond_angle_rad) * jnp.sin(dihedral_angle_rad) * n
46
+ )
47
+ return cast(jnp.ndarray, p4)
48
+
49
+
50
+ @jit
51
+ def chain_nerf(
52
+ init_coords: jnp.ndarray,
53
+ bond_lengths: jnp.ndarray,
54
+ bond_angles: jnp.ndarray,
55
+ dihedrals: jnp.ndarray,
56
+ ) -> jnp.ndarray:
57
+ """
58
+ Build a chain of atoms using the NeRF algorithm.
59
+
60
+ Args:
61
+ init_coords: (3, 3) initial coordinates for the first 3 atoms
62
+ bond_lengths: (N,) bond lengths for atoms 4 to N+3
63
+ bond_angles: (N,) bond angles (in radians) for atoms 4 to N+3
64
+ dihedrals: (N,) dihedral angles (in radians) for atoms 4 to N+3
65
+
66
+ Returns:
67
+ jnp.ndarray: (N+3, 3) coordinates for the entire chain
68
+ """
69
+
70
+ def body_fun(
71
+ carry: tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], i: Any
72
+ ) -> tuple[tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], Any]:
73
+ p1, p2, p3 = carry
74
+
75
+ p4 = position_atom_3d(p1, p2, p3, bond_lengths[i], bond_angles[i], dihedrals[i])
76
+ return (p2, p3, p4), p4
77
+
78
+ # Use .shape[0] instead of len() so this works correctly under vmap
79
+ # and with dynamically-shaped arrays during JAX tracing.
80
+ indices = jnp.arange(bond_lengths.shape[0])
81
+ init_carry = (init_coords[0], init_coords[1], init_coords[2])
82
+ _, final_coords = lax.scan(body_fun, init_carry, indices)
83
+
84
+ return cast(jnp.ndarray, jnp.concatenate([init_coords, final_coords], axis=0))