mlmm-toolkit 0.2.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hessian_ff/__init__.py +50 -0
- hessian_ff/analytical_hessian.py +609 -0
- hessian_ff/constants.py +46 -0
- hessian_ff/forcefield.py +339 -0
- hessian_ff/loaders.py +608 -0
- hessian_ff/native/Makefile +8 -0
- hessian_ff/native/__init__.py +28 -0
- hessian_ff/native/analytical_hessian.py +88 -0
- hessian_ff/native/analytical_hessian_ext.cpp +258 -0
- hessian_ff/native/bonded.py +82 -0
- hessian_ff/native/bonded_ext.cpp +640 -0
- hessian_ff/native/loader.py +349 -0
- hessian_ff/native/nonbonded.py +118 -0
- hessian_ff/native/nonbonded_ext.cpp +1150 -0
- hessian_ff/prmtop_parmed.py +23 -0
- hessian_ff/system.py +107 -0
- hessian_ff/terms/__init__.py +14 -0
- hessian_ff/terms/angle.py +73 -0
- hessian_ff/terms/bond.py +44 -0
- hessian_ff/terms/cmap.py +406 -0
- hessian_ff/terms/dihedral.py +141 -0
- hessian_ff/terms/nonbonded.py +209 -0
- hessian_ff/tests/__init__.py +0 -0
- hessian_ff/tests/conftest.py +75 -0
- hessian_ff/tests/data/small/complex.parm7 +1346 -0
- hessian_ff/tests/data/small/complex.pdb +125 -0
- hessian_ff/tests/data/small/complex.rst7 +63 -0
- hessian_ff/tests/test_coords_input.py +44 -0
- hessian_ff/tests/test_energy_force.py +49 -0
- hessian_ff/tests/test_hessian.py +137 -0
- hessian_ff/tests/test_smoke.py +18 -0
- hessian_ff/tests/test_validation.py +40 -0
- hessian_ff/workflows.py +889 -0
- mlmm/__init__.py +36 -0
- mlmm/__main__.py +7 -0
- mlmm/_version.py +34 -0
- mlmm/add_elem_info.py +374 -0
- mlmm/advanced_help.py +91 -0
- mlmm/align_freeze_atoms.py +601 -0
- mlmm/all.py +3535 -0
- mlmm/bond_changes.py +231 -0
- mlmm/bool_compat.py +223 -0
- mlmm/cli.py +574 -0
- mlmm/cli_utils.py +166 -0
- mlmm/default_group.py +337 -0
- mlmm/defaults.py +467 -0
- mlmm/define_layer.py +526 -0
- mlmm/dft.py +1041 -0
- mlmm/energy_diagram.py +253 -0
- mlmm/extract.py +2213 -0
- mlmm/fix_altloc.py +464 -0
- mlmm/freq.py +1406 -0
- mlmm/harmonic_constraints.py +140 -0
- mlmm/hessian_cache.py +44 -0
- mlmm/hessian_calc.py +174 -0
- mlmm/irc.py +638 -0
- mlmm/mlmm_calc.py +2262 -0
- mlmm/mm_parm.py +945 -0
- mlmm/oniom_export.py +1983 -0
- mlmm/oniom_import.py +457 -0
- mlmm/opt.py +1742 -0
- mlmm/path_opt.py +1353 -0
- mlmm/path_search.py +2299 -0
- mlmm/preflight.py +88 -0
- mlmm/py.typed +1 -0
- mlmm/pysis_runner.py +45 -0
- mlmm/scan.py +1047 -0
- mlmm/scan2d.py +1226 -0
- mlmm/scan3d.py +1265 -0
- mlmm/scan_common.py +184 -0
- mlmm/summary_log.py +736 -0
- mlmm/trj2fig.py +448 -0
- mlmm/tsopt.py +2871 -0
- mlmm/utils.py +2309 -0
- mlmm/xtb_embedcharge_correction.py +475 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
- pysisyphus/Geometry.py +1667 -0
- pysisyphus/LICENSE +674 -0
- pysisyphus/TableFormatter.py +63 -0
- pysisyphus/TablePrinter.py +74 -0
- pysisyphus/__init__.py +12 -0
- pysisyphus/calculators/AFIR.py +452 -0
- pysisyphus/calculators/AnaPot.py +20 -0
- pysisyphus/calculators/AnaPot2.py +48 -0
- pysisyphus/calculators/AnaPot3.py +12 -0
- pysisyphus/calculators/AnaPot4.py +20 -0
- pysisyphus/calculators/AnaPotBase.py +337 -0
- pysisyphus/calculators/AnaPotCBM.py +25 -0
- pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
- pysisyphus/calculators/CFOUR.py +250 -0
- pysisyphus/calculators/Calculator.py +844 -0
- pysisyphus/calculators/CerjanMiller.py +24 -0
- pysisyphus/calculators/Composite.py +123 -0
- pysisyphus/calculators/ConicalIntersection.py +171 -0
- pysisyphus/calculators/DFTBp.py +430 -0
- pysisyphus/calculators/DFTD3.py +66 -0
- pysisyphus/calculators/DFTD4.py +84 -0
- pysisyphus/calculators/Dalton.py +61 -0
- pysisyphus/calculators/Dimer.py +681 -0
- pysisyphus/calculators/Dummy.py +20 -0
- pysisyphus/calculators/EGO.py +76 -0
- pysisyphus/calculators/EnergyMin.py +224 -0
- pysisyphus/calculators/ExternalPotential.py +264 -0
- pysisyphus/calculators/FakeASE.py +35 -0
- pysisyphus/calculators/FourWellAnaPot.py +28 -0
- pysisyphus/calculators/FreeEndNEBPot.py +39 -0
- pysisyphus/calculators/Gaussian09.py +18 -0
- pysisyphus/calculators/Gaussian16.py +726 -0
- pysisyphus/calculators/HardSphere.py +159 -0
- pysisyphus/calculators/IDPPCalculator.py +49 -0
- pysisyphus/calculators/IPIClient.py +133 -0
- pysisyphus/calculators/IPIServer.py +234 -0
- pysisyphus/calculators/LEPSBase.py +24 -0
- pysisyphus/calculators/LEPSExpr.py +139 -0
- pysisyphus/calculators/LennardJones.py +80 -0
- pysisyphus/calculators/MOPAC.py +219 -0
- pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
- pysisyphus/calculators/MultiCalc.py +85 -0
- pysisyphus/calculators/NFK.py +45 -0
- pysisyphus/calculators/OBabel.py +87 -0
- pysisyphus/calculators/ONIOMv2.py +1129 -0
- pysisyphus/calculators/ORCA.py +893 -0
- pysisyphus/calculators/ORCA5.py +6 -0
- pysisyphus/calculators/OpenMM.py +88 -0
- pysisyphus/calculators/OpenMolcas.py +281 -0
- pysisyphus/calculators/OverlapCalculator.py +908 -0
- pysisyphus/calculators/Psi4.py +218 -0
- pysisyphus/calculators/PyPsi4.py +37 -0
- pysisyphus/calculators/PySCF.py +341 -0
- pysisyphus/calculators/PyXTB.py +73 -0
- pysisyphus/calculators/QCEngine.py +106 -0
- pysisyphus/calculators/Rastrigin.py +22 -0
- pysisyphus/calculators/Remote.py +76 -0
- pysisyphus/calculators/Rosenbrock.py +15 -0
- pysisyphus/calculators/SocketCalc.py +97 -0
- pysisyphus/calculators/TIP3P.py +111 -0
- pysisyphus/calculators/TransTorque.py +161 -0
- pysisyphus/calculators/Turbomole.py +965 -0
- pysisyphus/calculators/VRIPot.py +37 -0
- pysisyphus/calculators/WFOWrapper.py +333 -0
- pysisyphus/calculators/WFOWrapper2.py +341 -0
- pysisyphus/calculators/XTB.py +418 -0
- pysisyphus/calculators/__init__.py +81 -0
- pysisyphus/calculators/cosmo_data.py +139 -0
- pysisyphus/calculators/parser.py +150 -0
- pysisyphus/color.py +19 -0
- pysisyphus/config.py +133 -0
- pysisyphus/constants.py +65 -0
- pysisyphus/cos/AdaptiveNEB.py +230 -0
- pysisyphus/cos/ChainOfStates.py +725 -0
- pysisyphus/cos/FreeEndNEB.py +25 -0
- pysisyphus/cos/FreezingString.py +103 -0
- pysisyphus/cos/GrowingChainOfStates.py +71 -0
- pysisyphus/cos/GrowingNT.py +309 -0
- pysisyphus/cos/GrowingString.py +508 -0
- pysisyphus/cos/NEB.py +189 -0
- pysisyphus/cos/SimpleZTS.py +64 -0
- pysisyphus/cos/__init__.py +22 -0
- pysisyphus/cos/stiffness.py +199 -0
- pysisyphus/drivers/__init__.py +17 -0
- pysisyphus/drivers/afir.py +855 -0
- pysisyphus/drivers/barriers.py +271 -0
- pysisyphus/drivers/birkholz.py +138 -0
- pysisyphus/drivers/cluster.py +318 -0
- pysisyphus/drivers/diabatization.py +133 -0
- pysisyphus/drivers/merge.py +368 -0
- pysisyphus/drivers/merge_mol2.py +322 -0
- pysisyphus/drivers/opt.py +375 -0
- pysisyphus/drivers/perf.py +91 -0
- pysisyphus/drivers/pka.py +52 -0
- pysisyphus/drivers/precon_pos_rot.py +669 -0
- pysisyphus/drivers/rates.py +480 -0
- pysisyphus/drivers/replace.py +219 -0
- pysisyphus/drivers/scan.py +212 -0
- pysisyphus/drivers/spectrum.py +166 -0
- pysisyphus/drivers/thermo.py +31 -0
- pysisyphus/dynamics/Gaussian.py +103 -0
- pysisyphus/dynamics/__init__.py +20 -0
- pysisyphus/dynamics/colvars.py +136 -0
- pysisyphus/dynamics/driver.py +297 -0
- pysisyphus/dynamics/helpers.py +256 -0
- pysisyphus/dynamics/lincs.py +105 -0
- pysisyphus/dynamics/mdp.py +364 -0
- pysisyphus/dynamics/rattle.py +121 -0
- pysisyphus/dynamics/thermostats.py +128 -0
- pysisyphus/dynamics/wigner.py +266 -0
- pysisyphus/elem_data.py +3473 -0
- pysisyphus/exceptions.py +2 -0
- pysisyphus/filtertrj.py +69 -0
- pysisyphus/helpers.py +623 -0
- pysisyphus/helpers_pure.py +649 -0
- pysisyphus/init_logging.py +50 -0
- pysisyphus/intcoords/Bend.py +69 -0
- pysisyphus/intcoords/Bend2.py +25 -0
- pysisyphus/intcoords/BondedFragment.py +32 -0
- pysisyphus/intcoords/Cartesian.py +41 -0
- pysisyphus/intcoords/CartesianCoords.py +140 -0
- pysisyphus/intcoords/Coords.py +56 -0
- pysisyphus/intcoords/DLC.py +197 -0
- pysisyphus/intcoords/DistanceFunction.py +34 -0
- pysisyphus/intcoords/DummyImproper.py +70 -0
- pysisyphus/intcoords/DummyTorsion.py +72 -0
- pysisyphus/intcoords/LinearBend.py +105 -0
- pysisyphus/intcoords/LinearDisplacement.py +80 -0
- pysisyphus/intcoords/OutOfPlane.py +59 -0
- pysisyphus/intcoords/PrimTypes.py +286 -0
- pysisyphus/intcoords/Primitive.py +137 -0
- pysisyphus/intcoords/RedundantCoords.py +659 -0
- pysisyphus/intcoords/RobustTorsion.py +59 -0
- pysisyphus/intcoords/Rotation.py +147 -0
- pysisyphus/intcoords/Stretch.py +31 -0
- pysisyphus/intcoords/Torsion.py +101 -0
- pysisyphus/intcoords/Torsion2.py +25 -0
- pysisyphus/intcoords/Translation.py +45 -0
- pysisyphus/intcoords/__init__.py +61 -0
- pysisyphus/intcoords/augment_bonds.py +126 -0
- pysisyphus/intcoords/derivatives.py +10512 -0
- pysisyphus/intcoords/eval.py +80 -0
- pysisyphus/intcoords/exceptions.py +37 -0
- pysisyphus/intcoords/findiffs.py +48 -0
- pysisyphus/intcoords/generate_derivatives.py +414 -0
- pysisyphus/intcoords/helpers.py +235 -0
- pysisyphus/intcoords/logging_conf.py +10 -0
- pysisyphus/intcoords/mp_derivatives.py +10836 -0
- pysisyphus/intcoords/setup.py +962 -0
- pysisyphus/intcoords/setup_fast.py +176 -0
- pysisyphus/intcoords/update.py +272 -0
- pysisyphus/intcoords/valid.py +89 -0
- pysisyphus/interpolate/Geodesic.py +93 -0
- pysisyphus/interpolate/IDPP.py +55 -0
- pysisyphus/interpolate/Interpolator.py +116 -0
- pysisyphus/interpolate/LST.py +70 -0
- pysisyphus/interpolate/Redund.py +152 -0
- pysisyphus/interpolate/__init__.py +9 -0
- pysisyphus/interpolate/helpers.py +34 -0
- pysisyphus/io/__init__.py +22 -0
- pysisyphus/io/aomix.py +178 -0
- pysisyphus/io/cjson.py +24 -0
- pysisyphus/io/crd.py +101 -0
- pysisyphus/io/cube.py +220 -0
- pysisyphus/io/fchk.py +184 -0
- pysisyphus/io/hdf5.py +49 -0
- pysisyphus/io/hessian.py +72 -0
- pysisyphus/io/mol2.py +146 -0
- pysisyphus/io/molden.py +293 -0
- pysisyphus/io/orca.py +189 -0
- pysisyphus/io/pdb.py +269 -0
- pysisyphus/io/psf.py +79 -0
- pysisyphus/io/pubchem.py +31 -0
- pysisyphus/io/qcschema.py +34 -0
- pysisyphus/io/sdf.py +29 -0
- pysisyphus/io/xyz.py +61 -0
- pysisyphus/io/zmat.py +175 -0
- pysisyphus/irc/DWI.py +108 -0
- pysisyphus/irc/DampedVelocityVerlet.py +134 -0
- pysisyphus/irc/Euler.py +22 -0
- pysisyphus/irc/EulerPC.py +345 -0
- pysisyphus/irc/GonzalezSchlegel.py +187 -0
- pysisyphus/irc/IMKMod.py +164 -0
- pysisyphus/irc/IRC.py +878 -0
- pysisyphus/irc/IRCDummy.py +10 -0
- pysisyphus/irc/Instanton.py +307 -0
- pysisyphus/irc/LQA.py +53 -0
- pysisyphus/irc/ModeKill.py +136 -0
- pysisyphus/irc/ParamPlot.py +53 -0
- pysisyphus/irc/RK4.py +36 -0
- pysisyphus/irc/__init__.py +31 -0
- pysisyphus/irc/initial_displ.py +219 -0
- pysisyphus/linalg.py +411 -0
- pysisyphus/line_searches/Backtracking.py +88 -0
- pysisyphus/line_searches/HagerZhang.py +184 -0
- pysisyphus/line_searches/LineSearch.py +232 -0
- pysisyphus/line_searches/StrongWolfe.py +108 -0
- pysisyphus/line_searches/__init__.py +9 -0
- pysisyphus/line_searches/interpol.py +15 -0
- pysisyphus/modefollow/NormalMode.py +40 -0
- pysisyphus/modefollow/__init__.py +10 -0
- pysisyphus/modefollow/davidson.py +199 -0
- pysisyphus/modefollow/lanczos.py +95 -0
- pysisyphus/optimizers/BFGS.py +99 -0
- pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
- pysisyphus/optimizers/ConjugateGradient.py +98 -0
- pysisyphus/optimizers/CubicNewton.py +75 -0
- pysisyphus/optimizers/FIRE.py +113 -0
- pysisyphus/optimizers/HessianOptimizer.py +1176 -0
- pysisyphus/optimizers/LBFGS.py +228 -0
- pysisyphus/optimizers/LayerOpt.py +411 -0
- pysisyphus/optimizers/MicroOptimizer.py +169 -0
- pysisyphus/optimizers/NCOptimizer.py +90 -0
- pysisyphus/optimizers/Optimizer.py +1084 -0
- pysisyphus/optimizers/PreconLBFGS.py +260 -0
- pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
- pysisyphus/optimizers/QuickMin.py +74 -0
- pysisyphus/optimizers/RFOptimizer.py +181 -0
- pysisyphus/optimizers/RSA.py +99 -0
- pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
- pysisyphus/optimizers/SteepestDescent.py +23 -0
- pysisyphus/optimizers/StringOptimizer.py +173 -0
- pysisyphus/optimizers/__init__.py +41 -0
- pysisyphus/optimizers/closures.py +301 -0
- pysisyphus/optimizers/cls_map.py +58 -0
- pysisyphus/optimizers/exceptions.py +6 -0
- pysisyphus/optimizers/gdiis.py +280 -0
- pysisyphus/optimizers/guess_hessians.py +311 -0
- pysisyphus/optimizers/hessian_updates.py +355 -0
- pysisyphus/optimizers/poly_fit.py +285 -0
- pysisyphus/optimizers/precon.py +153 -0
- pysisyphus/optimizers/restrict_step.py +24 -0
- pysisyphus/pack.py +172 -0
- pysisyphus/peakdetect.py +948 -0
- pysisyphus/plot.py +1031 -0
- pysisyphus/run.py +2106 -0
- pysisyphus/socket_helper.py +74 -0
- pysisyphus/stocastic/FragmentKick.py +132 -0
- pysisyphus/stocastic/Kick.py +81 -0
- pysisyphus/stocastic/Pipeline.py +303 -0
- pysisyphus/stocastic/__init__.py +21 -0
- pysisyphus/stocastic/align.py +127 -0
- pysisyphus/testing.py +96 -0
- pysisyphus/thermo.py +156 -0
- pysisyphus/trj.py +824 -0
- pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
- pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
- pysisyphus/tsoptimizers/TRIM.py +59 -0
- pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
- pysisyphus/tsoptimizers/__init__.py +23 -0
- pysisyphus/wavefunction/Basis.py +239 -0
- pysisyphus/wavefunction/DIIS.py +76 -0
- pysisyphus/wavefunction/__init__.py +25 -0
- pysisyphus/wavefunction/build_ext.py +42 -0
- pysisyphus/wavefunction/cart2sph.py +190 -0
- pysisyphus/wavefunction/diabatization.py +304 -0
- pysisyphus/wavefunction/excited_states.py +435 -0
- pysisyphus/wavefunction/gen_ints.py +1811 -0
- pysisyphus/wavefunction/helpers.py +104 -0
- pysisyphus/wavefunction/ints/__init__.py +0 -0
- pysisyphus/wavefunction/ints/boys.py +193 -0
- pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
- pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
- pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
- pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
- pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
- pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
- pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
- pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
- pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
- pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
- pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
- pysisyphus/wavefunction/localization.py +458 -0
- pysisyphus/wavefunction/multipole.py +159 -0
- pysisyphus/wavefunction/normalization.py +36 -0
- pysisyphus/wavefunction/pop_analysis.py +134 -0
- pysisyphus/wavefunction/shells.py +1171 -0
- pysisyphus/wavefunction/wavefunction.py +504 -0
- pysisyphus/wrapper/__init__.py +11 -0
- pysisyphus/wrapper/exceptions.py +2 -0
- pysisyphus/wrapper/jmol.py +120 -0
- pysisyphus/wrapper/mwfn.py +169 -0
- pysisyphus/wrapper/packmol.py +71 -0
- pysisyphus/xyzloader.py +168 -0
- pysisyphus/yaml_mods.py +45 -0
- thermoanalysis/LICENSE +674 -0
- thermoanalysis/QCData.py +244 -0
- thermoanalysis/__init__.py +0 -0
- thermoanalysis/config.py +3 -0
- thermoanalysis/constants.py +20 -0
- thermoanalysis/thermo.py +1011 -0
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Dict, List, Union
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def read_prmtop_with_parmed(prmtop_path: Union[str, Path]) -> Dict[str, List[Any]]:
|
|
8
|
+
"""Read a prmtop using ParmEd.
|
|
9
|
+
|
|
10
|
+
Returns a dict mapping prmtop %FLAG names to the underlying raw arrays.
|
|
11
|
+
This package requires ParmEd and always uses this route.
|
|
12
|
+
"""
|
|
13
|
+
try:
|
|
14
|
+
from parmed.amber import AmberParm # type: ignore
|
|
15
|
+
except Exception as e: # pragma: no cover
|
|
16
|
+
raise ImportError(
|
|
17
|
+
"ParmEd is required but not installed. Install it (pip install parmed)."
|
|
18
|
+
) from e
|
|
19
|
+
|
|
20
|
+
prmtop_path = Path(prmtop_path)
|
|
21
|
+
parm = AmberParm(str(prmtop_path))
|
|
22
|
+
# parm.parm_data is a dict: flag -> numpy array/list
|
|
23
|
+
return {k: list(v) for k, v in parm.parm_data.items()}
|
hessian_ff/system.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class AmberSystem:
|
|
12
|
+
"""Tensorized representation of an AMBER prmtop system.
|
|
13
|
+
|
|
14
|
+
Floating tensors are loaded as torch.float64 by default and can be cast to
|
|
15
|
+
torch.float32 when needed.
|
|
16
|
+
Coordinates are expected in Angstrom.
|
|
17
|
+
|
|
18
|
+
Notes on charge units:
|
|
19
|
+
- This package reads prmtop via ParmEd, which returns atomic charges in units
|
|
20
|
+
of elementary charge (i.e., already de-scaled from raw %FLAG CHARGE).
|
|
21
|
+
- Coulomb energy in kcal/mol is computed as:
|
|
22
|
+
E = 332.0637132991921 * sum_{i<j} q_i q_j / r_ij
|
|
23
|
+
where r_ij is in Angstrom.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
# ---- atom-level ----
|
|
27
|
+
natom: int
|
|
28
|
+
charge: torch.Tensor # [N] atomic charges in e
|
|
29
|
+
atom_type: torch.Tensor # [N] Lennard-Jones type index (0-based int64)
|
|
30
|
+
|
|
31
|
+
# ---- bonded terms ----
|
|
32
|
+
bond_i: torch.Tensor # [Nb]
|
|
33
|
+
bond_j: torch.Tensor # [Nb]
|
|
34
|
+
bond_k: torch.Tensor # [Nb] force constant
|
|
35
|
+
bond_r0: torch.Tensor # [Nb] equilibrium distance
|
|
36
|
+
|
|
37
|
+
angle_i: torch.Tensor # [Na]
|
|
38
|
+
angle_j: torch.Tensor # [Na]
|
|
39
|
+
angle_k: torch.Tensor # [Na]
|
|
40
|
+
angle_k0: torch.Tensor # [Na] force constant
|
|
41
|
+
angle_t0: torch.Tensor # [Na] equilibrium angle (radians)
|
|
42
|
+
|
|
43
|
+
dihed_i: torch.Tensor # [Nd]
|
|
44
|
+
dihed_j: torch.Tensor # [Nd]
|
|
45
|
+
dihed_k: torch.Tensor # [Nd]
|
|
46
|
+
dihed_l: torch.Tensor # [Nd]
|
|
47
|
+
dihed_force: torch.Tensor # [Nd]
|
|
48
|
+
dihed_period: torch.Tensor # [Nd]
|
|
49
|
+
dihed_phase: torch.Tensor # [Nd] (radians)
|
|
50
|
+
|
|
51
|
+
# ---- nonbonded parameter tables ----
|
|
52
|
+
lj_acoef: torch.Tensor # [Nlj] A in A/r^12
|
|
53
|
+
lj_bcoef: torch.Tensor # [Nlj] B in B/r^6
|
|
54
|
+
nb_index: torch.Tensor # [ntypes, ntypes] raw NONBONDED_PARM_INDEX (Fortran-style signed int)
|
|
55
|
+
hb_acoef: torch.Tensor # [Nhb] HBOND A in A/r^12 (used when nb_index<0)
|
|
56
|
+
hb_bcoef: torch.Tensor # [Nhb] HBOND B in B/r^10 (used when nb_index<0)
|
|
57
|
+
|
|
58
|
+
# ---- pair lists (precomputed for no-PBC O(N^2) evaluation) ----
|
|
59
|
+
pair_i: torch.Tensor # [Np] nonbonded general pairs (i<j), excludes exclusions and 1-4
|
|
60
|
+
pair_j: torch.Tensor # [Np]
|
|
61
|
+
|
|
62
|
+
pair14_i: torch.Tensor # [N14] 1-4 pairs (i<j)
|
|
63
|
+
pair14_j: torch.Tensor # [N14]
|
|
64
|
+
pair14_inv_scee: torch.Tensor # [N14] multiply Coulomb by this
|
|
65
|
+
pair14_inv_scnb: torch.Tensor # [N14] multiply LJ by this
|
|
66
|
+
|
|
67
|
+
# ---- CMAP term (optional; CHARMM-style correction map) ----
|
|
68
|
+
cmap_type: torch.Tensor # [Ncmap] map type index (0-based)
|
|
69
|
+
cmap_i: torch.Tensor # [Ncmap] first torsion atom i
|
|
70
|
+
cmap_j: torch.Tensor # [Ncmap] first torsion atom j
|
|
71
|
+
cmap_k: torch.Tensor # [Ncmap] first torsion atom k
|
|
72
|
+
cmap_l: torch.Tensor # [Ncmap] first torsion atom l
|
|
73
|
+
cmap_m: torch.Tensor # [Ncmap] second torsion terminal atom m
|
|
74
|
+
cmap_resolution: Tuple[int, ...] # one resolution per map type
|
|
75
|
+
cmap_maps: Tuple[torch.Tensor, ...] # flattened map values (kcal/mol) in OpenMM ordering
|
|
76
|
+
cmap_size: torch.Tensor # [Nmap] grid size per map type
|
|
77
|
+
cmap_delta: torch.Tensor # [Nmap] angular grid width (radian)
|
|
78
|
+
cmap_offset: torch.Tensor # [Nmap] flattened patch row offset
|
|
79
|
+
cmap_coeff: torch.Tensor # [sum(size*size),16] bicubic coefficients
|
|
80
|
+
|
|
81
|
+
def to(
|
|
82
|
+
self,
|
|
83
|
+
device: torch.device | str | None = None,
|
|
84
|
+
dtype: Optional[torch.dtype] = None,
|
|
85
|
+
) -> "AmberSystem":
|
|
86
|
+
"""Return a copy of this system moved to a device and/or floating dtype."""
|
|
87
|
+
|
|
88
|
+
target_device = torch.device(device) if device is not None else None
|
|
89
|
+
|
|
90
|
+
def mv(x: torch.Tensor) -> torch.Tensor:
|
|
91
|
+
kw = {}
|
|
92
|
+
if target_device is not None:
|
|
93
|
+
kw["device"] = target_device
|
|
94
|
+
if dtype is not None and x.is_floating_point():
|
|
95
|
+
kw["dtype"] = dtype
|
|
96
|
+
return x.to(**kw) if kw else x
|
|
97
|
+
|
|
98
|
+
kwargs = {}
|
|
99
|
+
for f in dataclasses.fields(self):
|
|
100
|
+
val = getattr(self, f.name)
|
|
101
|
+
if isinstance(val, torch.Tensor):
|
|
102
|
+
kwargs[f.name] = mv(val)
|
|
103
|
+
elif isinstance(val, tuple) and val and isinstance(val[0], torch.Tensor):
|
|
104
|
+
kwargs[f.name] = tuple(mv(x) for x in val)
|
|
105
|
+
else:
|
|
106
|
+
kwargs[f.name] = val
|
|
107
|
+
return AmberSystem(**kwargs)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from .bond import BondTerm
|
|
2
|
+
from .angle import AngleTerm
|
|
3
|
+
from .dihedral import DihedralTerm
|
|
4
|
+
from .cmap import CMapTerm
|
|
5
|
+
from .nonbonded import NonbondedTerm, NonbondedEnergies
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"BondTerm",
|
|
9
|
+
"AngleTerm",
|
|
10
|
+
"DihedralTerm",
|
|
11
|
+
"CMapTerm",
|
|
12
|
+
"NonbondedTerm",
|
|
13
|
+
"NonbondedEnergies",
|
|
14
|
+
]
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AngleTerm(nn.Module):
|
|
8
|
+
"""AMBER angle term: E = sum k (theta - theta0)^2
|
|
9
|
+
|
|
10
|
+
theta is computed via atan2(|u x v|, u·v) for numerical stability.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
i: torch.Tensor,
|
|
16
|
+
j: torch.Tensor,
|
|
17
|
+
k: torch.Tensor,
|
|
18
|
+
k_theta: torch.Tensor,
|
|
19
|
+
theta0: torch.Tensor,
|
|
20
|
+
):
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.register_buffer("i", i.long())
|
|
23
|
+
self.register_buffer("j", j.long())
|
|
24
|
+
self.register_buffer("k", k.long())
|
|
25
|
+
self.register_buffer("k_theta", k_theta)
|
|
26
|
+
self.register_buffer("theta0", theta0)
|
|
27
|
+
|
|
28
|
+
def forward(self, coords: torch.Tensor) -> torch.Tensor:
|
|
29
|
+
if self.i.numel() == 0:
|
|
30
|
+
return coords.new_zeros(())
|
|
31
|
+
u = coords[self.i] - coords[self.j]
|
|
32
|
+
v = coords[self.k] - coords[self.j]
|
|
33
|
+
dot = torch.sum(u * v, dim=-1)
|
|
34
|
+
cross = torch.linalg.norm(torch.cross(u, v, dim=-1), dim=-1)
|
|
35
|
+
theta = torch.atan2(cross, dot)
|
|
36
|
+
return torch.sum(self.k_theta * (theta - self.theta0) ** 2)
|
|
37
|
+
|
|
38
|
+
def energy_force(self, coords: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
39
|
+
"""Return angle energy and analytical force."""
|
|
40
|
+
force = torch.zeros_like(coords)
|
|
41
|
+
if self.i.numel() == 0:
|
|
42
|
+
return coords.new_zeros(()), force
|
|
43
|
+
|
|
44
|
+
# OpenMM-style geometry for numerical stability.
|
|
45
|
+
d0 = coords[self.j] - coords[self.i]
|
|
46
|
+
d1 = coords[self.j] - coords[self.k]
|
|
47
|
+
p = torch.cross(d0, d1, dim=-1)
|
|
48
|
+
|
|
49
|
+
r20 = torch.sum(d0 * d0, dim=-1).clamp_min(1.0e-24)
|
|
50
|
+
r21 = torch.sum(d1 * d1, dim=-1).clamp_min(1.0e-24)
|
|
51
|
+
rp = torch.linalg.norm(p, dim=-1).clamp_min(1.0e-12)
|
|
52
|
+
dot = torch.sum(d0 * d1, dim=-1)
|
|
53
|
+
cos_theta = dot / torch.sqrt(r20 * r21)
|
|
54
|
+
cos_theta = torch.clamp(cos_theta, -1.0, 1.0)
|
|
55
|
+
theta = torch.acos(cos_theta)
|
|
56
|
+
|
|
57
|
+
dtheta = theta - self.theta0
|
|
58
|
+
e = torch.sum(self.k_theta * dtheta * dtheta)
|
|
59
|
+
|
|
60
|
+
# E = k (theta-theta0)^2, dE/dtheta = 2k(theta-theta0)
|
|
61
|
+
dE_dtheta = 2.0 * self.k_theta * dtheta
|
|
62
|
+
|
|
63
|
+
term_i = dE_dtheta / (r20 * rp)
|
|
64
|
+
term_k = -dE_dtheta / (r21 * rp)
|
|
65
|
+
|
|
66
|
+
fi = torch.cross(d0, p, dim=-1) * term_i.unsqueeze(-1)
|
|
67
|
+
fk = torch.cross(d1, p, dim=-1) * term_k.unsqueeze(-1)
|
|
68
|
+
fj = -(fi + fk)
|
|
69
|
+
|
|
70
|
+
force.index_add_(0, self.i, fi)
|
|
71
|
+
force.index_add_(0, self.j, fj)
|
|
72
|
+
force.index_add_(0, self.k, fk)
|
|
73
|
+
return e, force
|
hessian_ff/terms/bond.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BondTerm(nn.Module):
|
|
8
|
+
"""AMBER bond term: E = sum k (r - r0)^2"""
|
|
9
|
+
|
|
10
|
+
def __init__(self, i: torch.Tensor, j: torch.Tensor, k: torch.Tensor, r0: torch.Tensor):
|
|
11
|
+
super().__init__()
|
|
12
|
+
self.register_buffer("i", i.long())
|
|
13
|
+
self.register_buffer("j", j.long())
|
|
14
|
+
self.register_buffer("k", k)
|
|
15
|
+
self.register_buffer("r0", r0)
|
|
16
|
+
|
|
17
|
+
def forward(self, coords: torch.Tensor) -> torch.Tensor:
|
|
18
|
+
if self.i.numel() == 0:
|
|
19
|
+
return coords.new_zeros(())
|
|
20
|
+
rij = coords[self.j] - coords[self.i]
|
|
21
|
+
r = torch.linalg.norm(rij, dim=-1)
|
|
22
|
+
return torch.sum(self.k * (r - self.r0) ** 2)
|
|
23
|
+
|
|
24
|
+
def energy_force(self, coords: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
25
|
+
"""Return bond energy and analytical force."""
|
|
26
|
+
force = torch.zeros_like(coords)
|
|
27
|
+
if self.i.numel() == 0:
|
|
28
|
+
return coords.new_zeros(()), force
|
|
29
|
+
|
|
30
|
+
rij = coords[self.j] - coords[self.i]
|
|
31
|
+
r2 = torch.sum(rij * rij, dim=-1).clamp_min(1.0e-24)
|
|
32
|
+
inv_r = torch.rsqrt(r2)
|
|
33
|
+
r = r2 * inv_r
|
|
34
|
+
|
|
35
|
+
dr = r - self.r0
|
|
36
|
+
e = torch.sum(self.k * dr * dr)
|
|
37
|
+
|
|
38
|
+
# E = k (r-r0)^2, dE/dr = 2k(r-r0), F_i = dE/dr * (r_ij/r)
|
|
39
|
+
fscale = 2.0 * self.k * dr * inv_r
|
|
40
|
+
fij = fscale.unsqueeze(-1) * rij
|
|
41
|
+
|
|
42
|
+
force.index_add_(0, self.i, fij)
|
|
43
|
+
force.index_add_(0, self.j, -fij)
|
|
44
|
+
return e, force
|
hessian_ff/terms/cmap.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
from .dihedral import _accumulate_dihedral_forces, _dihedral_angle
|
|
10
|
+
|
|
11
|
+
_TWO_PI = 2.0 * math.pi
|
|
12
|
+
|
|
13
|
+
# References:
|
|
14
|
+
# - OpenMM CMAP coefficient/derivative setup:
|
|
15
|
+
# https://github.com/openmm/openmm/blob/4768436/openmmapi/src/CMAPTorsionForceImpl.cpp
|
|
16
|
+
# - OpenMM periodic spline routines:
|
|
17
|
+
# https://github.com/openmm/openmm/blob/4768436/openmmapi/src/SplineFitter.cpp
|
|
18
|
+
# - OpenMM reference CMAP patch evaluation:
|
|
19
|
+
# https://github.com/openmm/openmm/blob/4768436/platforms/reference/src/SimTKReference/ReferenceCMAPTorsionIxn.cpp
|
|
20
|
+
|
|
21
|
+
# OpenMM CMAP bicubic coefficient matrix (CMAPTorsionForceImpl.cpp, wt[k+16*m]).
|
|
22
|
+
_CMAP_WT = (
|
|
23
|
+
torch.tensor(
|
|
24
|
+
[
|
|
25
|
+
1, 0, -3, 2, 0, 0, 0, 0, -3, 0, 9, -6, 2, 0, -6, 4,
|
|
26
|
+
0, 0, 0, 0, 0, 0, 0, 0, 3, 0, -9, 6, -2, 0, 6, -4,
|
|
27
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, -6, 0, 0, -6, 4,
|
|
28
|
+
0, 0, 3, -2, 0, 0, 0, 0, 0, 0, -9, 6, 0, 0, 6, -4,
|
|
29
|
+
0, 0, 0, 0, 1, 0, -3, 2, -2, 0, 6, -4, 1, 0, -3, 2,
|
|
30
|
+
0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 3, -2, 1, 0, -3, 2,
|
|
31
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, 2, 0, 0, 3, -2,
|
|
32
|
+
0, 0, 0, 0, 0, 0, 3, -2, 0, 0, -6, 4, 0, 0, 3, -2,
|
|
33
|
+
0, 1, -2, 1, 0, 0, 0, 0, 0, -3, 6, -3, 0, 2, -4, 2,
|
|
34
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 3, -6, 3, 0, -2, 4, -2,
|
|
35
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, 3, 0, 0, 2, -2,
|
|
36
|
+
0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 3, -3, 0, 0, -2, 2,
|
|
37
|
+
0, 0, 0, 0, 0, 1, -2, 1, 0, -2, 4, -2, 0, 1, -2, 1,
|
|
38
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 2, -1, 0, 1, -2, 1,
|
|
39
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, -1, 1,
|
|
40
|
+
0, 0, 0, 0, 0, 0, -1, 1, 0, 0, 2, -2, 0, 0, -1, 1,
|
|
41
|
+
],
|
|
42
|
+
dtype=torch.float64,
|
|
43
|
+
).view(16, 16).t()
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _solve_tridiagonal(
|
|
48
|
+
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, rhs: torch.Tensor
|
|
49
|
+
) -> torch.Tensor:
|
|
50
|
+
"""Solve a tridiagonal linear system (Thomas algorithm)."""
|
|
51
|
+
n = int(a.numel())
|
|
52
|
+
if n == 0:
|
|
53
|
+
return rhs.clone()
|
|
54
|
+
if n == 1:
|
|
55
|
+
return rhs / b
|
|
56
|
+
|
|
57
|
+
gamma = torch.zeros(n, dtype=torch.float64)
|
|
58
|
+
sol = torch.zeros(n, dtype=torch.float64)
|
|
59
|
+
|
|
60
|
+
beta = b[0]
|
|
61
|
+
sol[0] = rhs[0] / beta
|
|
62
|
+
for i in range(1, n):
|
|
63
|
+
gamma[i] = c[i - 1] / beta
|
|
64
|
+
beta = b[i] - a[i] * gamma[i]
|
|
65
|
+
sol[i] = (rhs[i] - a[i] * sol[i - 1]) / beta
|
|
66
|
+
for i in range(n - 2, -1, -1):
|
|
67
|
+
sol[i] = sol[i] - gamma[i + 1] * sol[i + 1]
|
|
68
|
+
return sol
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _create_periodic_spline(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
72
|
+
"""Match OpenMM SplineFitter::createPeriodicSpline behavior."""
|
|
73
|
+
n = int(x.numel())
|
|
74
|
+
if n < 3:
|
|
75
|
+
raise ValueError("Periodic spline requires at least 3 points")
|
|
76
|
+
if y.numel() != n:
|
|
77
|
+
raise ValueError("x/y size mismatch in periodic spline setup")
|
|
78
|
+
|
|
79
|
+
a = torch.zeros(n - 1, dtype=torch.float64)
|
|
80
|
+
b = torch.zeros(n - 1, dtype=torch.float64)
|
|
81
|
+
c = torch.zeros(n - 1, dtype=torch.float64)
|
|
82
|
+
rhs = torch.zeros(n - 1, dtype=torch.float64)
|
|
83
|
+
|
|
84
|
+
a[0] = x[n - 1] - x[n - 2]
|
|
85
|
+
b[0] = 2.0 * (x[1] - x[0] + x[n - 1] - x[n - 2])
|
|
86
|
+
c[0] = x[1] - x[0]
|
|
87
|
+
rhs[0] = 6.0 * (
|
|
88
|
+
(y[1] - y[0]) / (x[1] - x[0]) - (y[n - 1] - y[n - 2]) / (x[n - 1] - x[n - 2])
|
|
89
|
+
)
|
|
90
|
+
for i in range(1, n - 1):
|
|
91
|
+
a[i] = x[i] - x[i - 1]
|
|
92
|
+
b[i] = 2.0 * (x[i + 1] - x[i - 1])
|
|
93
|
+
c[i] = x[i + 1] - x[i]
|
|
94
|
+
rhs[i] = 6.0 * (
|
|
95
|
+
(y[i + 1] - y[i]) / (x[i + 1] - x[i]) - (y[i] - y[i - 1]) / (x[i] - x[i - 1])
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
beta = a[0]
|
|
99
|
+
alpha = c[n - 2]
|
|
100
|
+
gamma = -b[0]
|
|
101
|
+
|
|
102
|
+
ntri = n - 1
|
|
103
|
+
b[0] = b[0] - gamma
|
|
104
|
+
b[ntri - 1] = b[ntri - 1] - alpha * beta / gamma
|
|
105
|
+
deriv = _solve_tridiagonal(a, b, c, rhs)
|
|
106
|
+
|
|
107
|
+
u = torch.zeros(ntri, dtype=torch.float64)
|
|
108
|
+
u[0] = gamma
|
|
109
|
+
u[ntri - 1] = alpha
|
|
110
|
+
z = _solve_tridiagonal(a, b, c, u)
|
|
111
|
+
|
|
112
|
+
scale = (deriv[0] + beta * deriv[ntri - 1] / gamma) / (
|
|
113
|
+
1.0 + z[0] + beta * z[ntri - 1] / gamma
|
|
114
|
+
)
|
|
115
|
+
deriv = deriv - scale * z
|
|
116
|
+
|
|
117
|
+
out = torch.zeros(n, dtype=torch.float64)
|
|
118
|
+
out[:ntri] = deriv
|
|
119
|
+
out[ntri] = deriv[0]
|
|
120
|
+
return out
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _eval_spline_derivative(
|
|
124
|
+
x: torch.Tensor, y: torch.Tensor, second_deriv: torch.Tensor, t: torch.Tensor
|
|
125
|
+
) -> torch.Tensor:
|
|
126
|
+
"""Equivalent to OpenMM SplineFitter::evaluateSplineDerivative."""
|
|
127
|
+
lower = 0
|
|
128
|
+
upper = int(x.numel()) - 1
|
|
129
|
+
while upper - lower > 1:
|
|
130
|
+
middle = (upper + lower) // 2
|
|
131
|
+
if x[middle] > t:
|
|
132
|
+
upper = middle
|
|
133
|
+
else:
|
|
134
|
+
lower = middle
|
|
135
|
+
delta = x[upper] - x[lower]
|
|
136
|
+
a = (x[upper] - t) / delta
|
|
137
|
+
b = (t - x[lower]) / delta
|
|
138
|
+
return (-1.0 / delta) * (y[lower] - y[upper]) + (
|
|
139
|
+
(1.0 - 3.0 * a * a) * second_deriv[lower] + (3.0 * b * b - 1.0) * second_deriv[upper]
|
|
140
|
+
) * delta / 6.0
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _calc_map_derivatives(energy: torch.Tensor, size: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
144
|
+
"""Match OpenMM CMAPTorsionForceImpl::calcMapDerivatives."""
|
|
145
|
+
d1 = torch.zeros(size * size, dtype=torch.float64)
|
|
146
|
+
d2 = torch.zeros(size * size, dtype=torch.float64)
|
|
147
|
+
d12 = torch.zeros(size * size, dtype=torch.float64)
|
|
148
|
+
|
|
149
|
+
x = torch.arange(size + 1, dtype=torch.float64) * (_TWO_PI / float(size))
|
|
150
|
+
|
|
151
|
+
# d/dphi
|
|
152
|
+
for i in range(size):
|
|
153
|
+
y = torch.zeros(size + 1, dtype=torch.float64)
|
|
154
|
+
for j in range(size):
|
|
155
|
+
y[j] = energy[j + size * i]
|
|
156
|
+
y[size] = energy[size * i]
|
|
157
|
+
sec = _create_periodic_spline(x, y)
|
|
158
|
+
for j in range(size):
|
|
159
|
+
d1[j + size * i] = _eval_spline_derivative(x, y, sec, x[j])
|
|
160
|
+
|
|
161
|
+
# d/dpsi
|
|
162
|
+
for i in range(size):
|
|
163
|
+
y = torch.zeros(size + 1, dtype=torch.float64)
|
|
164
|
+
for j in range(size):
|
|
165
|
+
y[j] = energy[i + size * j]
|
|
166
|
+
y[size] = energy[i]
|
|
167
|
+
sec = _create_periodic_spline(x, y)
|
|
168
|
+
for j in range(size):
|
|
169
|
+
d2[i + size * j] = _eval_spline_derivative(x, y, sec, x[j])
|
|
170
|
+
|
|
171
|
+
# d2/(dphi dpsi)
|
|
172
|
+
for i in range(size):
|
|
173
|
+
y = torch.zeros(size + 1, dtype=torch.float64)
|
|
174
|
+
for j in range(size):
|
|
175
|
+
y[j] = d2[j + size * i]
|
|
176
|
+
y[size] = d2[size * i]
|
|
177
|
+
sec = _create_periodic_spline(x, y)
|
|
178
|
+
for j in range(size):
|
|
179
|
+
d12[j + size * i] = _eval_spline_derivative(x, y, sec, x[j])
|
|
180
|
+
|
|
181
|
+
return d1, d2, d12
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _calc_map_coefficients(map_values: torch.Tensor, size: int) -> torch.Tensor:
|
|
185
|
+
"""Create bicubic patch coefficients identical to OpenMM CMAP setup."""
|
|
186
|
+
if map_values.numel() != size * size:
|
|
187
|
+
raise ValueError(f"CMAP map size mismatch: expected {size*size}, got {map_values.numel()}")
|
|
188
|
+
energy = map_values.detach().to(dtype=torch.float64, device="cpu").reshape(-1)
|
|
189
|
+
d1, d2, d12 = _calc_map_derivatives(energy, size=size)
|
|
190
|
+
coeff = torch.zeros((size * size, 16), dtype=torch.float64)
|
|
191
|
+
delta = _TWO_PI / float(size)
|
|
192
|
+
|
|
193
|
+
for i in range(size):
|
|
194
|
+
i1 = (i + 1) % size
|
|
195
|
+
for j in range(size):
|
|
196
|
+
j1 = (j + 1) % size
|
|
197
|
+
k = i + size * j
|
|
198
|
+
rhs = torch.tensor(
|
|
199
|
+
[
|
|
200
|
+
energy[k],
|
|
201
|
+
energy[i1 + size * j],
|
|
202
|
+
energy[i1 + size * j1],
|
|
203
|
+
energy[i + size * j1],
|
|
204
|
+
d1[k] * delta,
|
|
205
|
+
d1[i1 + size * j] * delta,
|
|
206
|
+
d1[i1 + size * j1] * delta,
|
|
207
|
+
d1[i + size * j1] * delta,
|
|
208
|
+
d2[k] * delta,
|
|
209
|
+
d2[i1 + size * j] * delta,
|
|
210
|
+
d2[i1 + size * j1] * delta,
|
|
211
|
+
d2[i + size * j1] * delta,
|
|
212
|
+
d12[k] * delta * delta,
|
|
213
|
+
d12[i1 + size * j] * delta * delta,
|
|
214
|
+
d12[i1 + size * j1] * delta * delta,
|
|
215
|
+
d12[i + size * j1] * delta * delta,
|
|
216
|
+
],
|
|
217
|
+
dtype=torch.float64,
|
|
218
|
+
)
|
|
219
|
+
coeff[k] = _CMAP_WT @ rhs
|
|
220
|
+
return coeff
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class CMapTerm(nn.Module):
|
|
224
|
+
"""CMAP torsion correction implemented with Torch bicubic interpolation.
|
|
225
|
+
|
|
226
|
+
The coefficient generation follows OpenMM's CMAP setup logic. The forward
|
|
227
|
+
path is composed of standard Torch ops, so both gradients and Hessians are
|
|
228
|
+
available through autograd.
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
def __init__(
|
|
232
|
+
self,
|
|
233
|
+
natom: int,
|
|
234
|
+
cmap_type: torch.Tensor,
|
|
235
|
+
cmap_i: torch.Tensor,
|
|
236
|
+
cmap_j: torch.Tensor,
|
|
237
|
+
cmap_k: torch.Tensor,
|
|
238
|
+
cmap_l: torch.Tensor,
|
|
239
|
+
cmap_m: torch.Tensor,
|
|
240
|
+
cmap_resolution: Tuple[int, ...],
|
|
241
|
+
cmap_maps: Tuple[torch.Tensor, ...],
|
|
242
|
+
*,
|
|
243
|
+
precomputed_size: torch.Tensor | None = None,
|
|
244
|
+
precomputed_delta: torch.Tensor | None = None,
|
|
245
|
+
precomputed_offset: torch.Tensor | None = None,
|
|
246
|
+
precomputed_coeff: torch.Tensor | None = None,
|
|
247
|
+
):
|
|
248
|
+
super().__init__()
|
|
249
|
+
self.natom = int(natom)
|
|
250
|
+
self.register_buffer("cmap_type", cmap_type.long())
|
|
251
|
+
self.register_buffer("cmap_i", cmap_i.long())
|
|
252
|
+
self.register_buffer("cmap_j", cmap_j.long())
|
|
253
|
+
self.register_buffer("cmap_k", cmap_k.long())
|
|
254
|
+
self.register_buffer("cmap_l", cmap_l.long())
|
|
255
|
+
self.register_buffer("cmap_m", cmap_m.long())
|
|
256
|
+
|
|
257
|
+
self._enabled = self.cmap_type.numel() > 0
|
|
258
|
+
if not self._enabled:
|
|
259
|
+
self.register_buffer("map_size", torch.zeros((0,), dtype=torch.int64))
|
|
260
|
+
self.register_buffer("map_delta", torch.zeros((0,), dtype=torch.float64))
|
|
261
|
+
self.register_buffer("map_offset", torch.zeros((0,), dtype=torch.int64))
|
|
262
|
+
self.register_buffer("map_coeff", torch.zeros((0, 16), dtype=torch.float64))
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
# Use precomputed coefficients if available (avoids redundant computation
|
|
266
|
+
# when AmberSystem already carries cmap_size/delta/offset/coeff).
|
|
267
|
+
if precomputed_coeff is not None:
|
|
268
|
+
self.register_buffer("map_size", precomputed_size)
|
|
269
|
+
self.register_buffer("map_delta", precomputed_delta)
|
|
270
|
+
self.register_buffer("map_offset", precomputed_offset)
|
|
271
|
+
self.register_buffer("map_coeff", precomputed_coeff)
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
if len(cmap_resolution) != len(cmap_maps):
|
|
275
|
+
raise ValueError(
|
|
276
|
+
f"CMAP resolution/map count mismatch: {len(cmap_resolution)} vs {len(cmap_maps)}"
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
map_size = torch.tensor([int(x) for x in cmap_resolution], dtype=torch.int64)
|
|
280
|
+
coeff_parts = []
|
|
281
|
+
offsets = []
|
|
282
|
+
cursor = 0
|
|
283
|
+
for ngrid, table in zip(cmap_resolution, cmap_maps):
|
|
284
|
+
offsets.append(cursor)
|
|
285
|
+
c = _calc_map_coefficients(table, size=int(ngrid))
|
|
286
|
+
coeff_parts.append(c)
|
|
287
|
+
cursor += int(ngrid) * int(ngrid)
|
|
288
|
+
|
|
289
|
+
coeff_cat = torch.cat(coeff_parts, dim=0)
|
|
290
|
+
target_dtype = cmap_maps[0].dtype
|
|
291
|
+
target_device = cmap_maps[0].device
|
|
292
|
+
|
|
293
|
+
self.register_buffer("map_size", map_size.to(device=target_device))
|
|
294
|
+
self.register_buffer(
|
|
295
|
+
"map_delta",
|
|
296
|
+
(torch.full_like(map_size, _TWO_PI, dtype=torch.float64) / map_size.to(torch.float64)).to(
|
|
297
|
+
dtype=target_dtype, device=target_device
|
|
298
|
+
),
|
|
299
|
+
)
|
|
300
|
+
self.register_buffer("map_offset", torch.tensor(offsets, dtype=torch.int64, device=target_device))
|
|
301
|
+
self.register_buffer("map_coeff", coeff_cat.to(dtype=target_dtype, device=target_device))
|
|
302
|
+
|
|
303
|
+
def forward(self, coords: torch.Tensor) -> torch.Tensor:
|
|
304
|
+
if not self._enabled:
|
|
305
|
+
return coords.new_zeros(())
|
|
306
|
+
|
|
307
|
+
p_i = coords[self.cmap_i]
|
|
308
|
+
p_j = coords[self.cmap_j]
|
|
309
|
+
p_k = coords[self.cmap_k]
|
|
310
|
+
p_l = coords[self.cmap_l]
|
|
311
|
+
p_m = coords[self.cmap_m]
|
|
312
|
+
|
|
313
|
+
phi = _dihedral_angle(p_i, p_j, p_k, p_l)
|
|
314
|
+
psi = _dihedral_angle(p_j, p_k, p_l, p_m)
|
|
315
|
+
|
|
316
|
+
# OpenMM CMAP interpolation uses [0, 2pi) wrapped angles.
|
|
317
|
+
ang_a = torch.remainder(phi + _TWO_PI, _TWO_PI)
|
|
318
|
+
ang_b = torch.remainder(psi + _TWO_PI, _TWO_PI)
|
|
319
|
+
|
|
320
|
+
tmap = self.cmap_type
|
|
321
|
+
size = self.map_size[tmap]
|
|
322
|
+
delta = self.map_delta[tmap].to(dtype=coords.dtype)
|
|
323
|
+
|
|
324
|
+
u = ang_a / delta
|
|
325
|
+
v = ang_b / delta
|
|
326
|
+
s = torch.floor(u).to(torch.int64)
|
|
327
|
+
t = torch.floor(v).to(torch.int64)
|
|
328
|
+
s = torch.minimum(s, size - 1)
|
|
329
|
+
t = torch.minimum(t, size - 1)
|
|
330
|
+
|
|
331
|
+
da = u - s.to(dtype=coords.dtype)
|
|
332
|
+
db = v - t.to(dtype=coords.dtype)
|
|
333
|
+
|
|
334
|
+
patch = s + size * t
|
|
335
|
+
coeff_row = self.map_offset[tmap] + patch
|
|
336
|
+
coeff = self.map_coeff[coeff_row]
|
|
337
|
+
coeff = coeff.to(dtype=coords.dtype).reshape(-1, 4, 4)
|
|
338
|
+
|
|
339
|
+
# Horner evaluation in db, then da.
|
|
340
|
+
dbu = db.unsqueeze(-1)
|
|
341
|
+
poly_b = ((coeff[:, :, 3] * dbu + coeff[:, :, 2]) * dbu + coeff[:, :, 1]) * dbu + coeff[:, :, 0]
|
|
342
|
+
e = ((poly_b[:, 3] * da + poly_b[:, 2]) * da + poly_b[:, 1]) * da + poly_b[:, 0]
|
|
343
|
+
return torch.sum(e)
|
|
344
|
+
|
|
345
|
+
def energy_force(self, coords: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
346
|
+
"""Return CMAP energy and analytical force."""
|
|
347
|
+
force = torch.zeros_like(coords)
|
|
348
|
+
if not self._enabled:
|
|
349
|
+
return coords.new_zeros(()), force
|
|
350
|
+
|
|
351
|
+
p_i = coords[self.cmap_i]
|
|
352
|
+
p_j = coords[self.cmap_j]
|
|
353
|
+
p_k = coords[self.cmap_k]
|
|
354
|
+
p_l = coords[self.cmap_l]
|
|
355
|
+
p_m = coords[self.cmap_m]
|
|
356
|
+
|
|
357
|
+
phi = _dihedral_angle(p_i, p_j, p_k, p_l)
|
|
358
|
+
psi = _dihedral_angle(p_j, p_k, p_l, p_m)
|
|
359
|
+
|
|
360
|
+
ang_a = torch.remainder(phi + _TWO_PI, _TWO_PI)
|
|
361
|
+
ang_b = torch.remainder(psi + _TWO_PI, _TWO_PI)
|
|
362
|
+
|
|
363
|
+
tmap = self.cmap_type
|
|
364
|
+
size = self.map_size[tmap]
|
|
365
|
+
delta = self.map_delta[tmap].to(dtype=coords.dtype)
|
|
366
|
+
|
|
367
|
+
u = ang_a / delta
|
|
368
|
+
v = ang_b / delta
|
|
369
|
+
s = torch.floor(u).to(torch.int64)
|
|
370
|
+
t = torch.floor(v).to(torch.int64)
|
|
371
|
+
s = torch.minimum(s, size - 1)
|
|
372
|
+
t = torch.minimum(t, size - 1)
|
|
373
|
+
|
|
374
|
+
da = u - s.to(dtype=coords.dtype)
|
|
375
|
+
db = v - t.to(dtype=coords.dtype)
|
|
376
|
+
|
|
377
|
+
patch = s + size * t
|
|
378
|
+
coeff_row = self.map_offset[tmap] + patch
|
|
379
|
+
coeff = self.map_coeff[coeff_row].to(dtype=coords.dtype).reshape(-1, 4, 4)
|
|
380
|
+
|
|
381
|
+
dbu = db.unsqueeze(-1)
|
|
382
|
+
# Row-wise polynomial in db: Pm(db) (m=0..3)
|
|
383
|
+
poly_b = ((coeff[:, :, 3] * dbu + coeff[:, :, 2]) * dbu + coeff[:, :, 1]) * dbu + coeff[:, :, 0]
|
|
384
|
+
# Derivative wrt db: dPm/ddb
|
|
385
|
+
dpoly_b = ((3.0 * coeff[:, :, 3] * dbu + 2.0 * coeff[:, :, 2]) * dbu + coeff[:, :, 1])
|
|
386
|
+
|
|
387
|
+
# E = sum_m Pm(db) * da^m
|
|
388
|
+
e = ((poly_b[:, 3] * da + poly_b[:, 2]) * da + poly_b[:, 1]) * da + poly_b[:, 0]
|
|
389
|
+
energy = torch.sum(e)
|
|
390
|
+
|
|
391
|
+
# dE/dda and dE/ddb
|
|
392
|
+
dE_dda = (3.0 * poly_b[:, 3] * da + 2.0 * poly_b[:, 2]) * da + poly_b[:, 1]
|
|
393
|
+
dE_ddb = ((dpoly_b[:, 3] * da + dpoly_b[:, 2]) * da + dpoly_b[:, 1]) * da + dpoly_b[:, 0]
|
|
394
|
+
|
|
395
|
+
# da = phi/delta - floor(...), db = psi/delta - floor(...)
|
|
396
|
+
# Away from grid boundaries, d(da)/dphi = 1/delta and d(db)/dpsi = 1/delta.
|
|
397
|
+
dE_dphi = dE_dda / delta
|
|
398
|
+
dE_dpsi = dE_ddb / delta
|
|
399
|
+
|
|
400
|
+
_accumulate_dihedral_forces(
|
|
401
|
+
force, coords, self.cmap_i, self.cmap_j, self.cmap_k, self.cmap_l, dE_dphi
|
|
402
|
+
)
|
|
403
|
+
_accumulate_dihedral_forces(
|
|
404
|
+
force, coords, self.cmap_j, self.cmap_k, self.cmap_l, self.cmap_m, dE_dpsi
|
|
405
|
+
)
|
|
406
|
+
return energy, force
|