mlmm-toolkit 0.2.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hessian_ff/__init__.py +50 -0
- hessian_ff/analytical_hessian.py +609 -0
- hessian_ff/constants.py +46 -0
- hessian_ff/forcefield.py +339 -0
- hessian_ff/loaders.py +608 -0
- hessian_ff/native/Makefile +8 -0
- hessian_ff/native/__init__.py +28 -0
- hessian_ff/native/analytical_hessian.py +88 -0
- hessian_ff/native/analytical_hessian_ext.cpp +258 -0
- hessian_ff/native/bonded.py +82 -0
- hessian_ff/native/bonded_ext.cpp +640 -0
- hessian_ff/native/loader.py +349 -0
- hessian_ff/native/nonbonded.py +118 -0
- hessian_ff/native/nonbonded_ext.cpp +1150 -0
- hessian_ff/prmtop_parmed.py +23 -0
- hessian_ff/system.py +107 -0
- hessian_ff/terms/__init__.py +14 -0
- hessian_ff/terms/angle.py +73 -0
- hessian_ff/terms/bond.py +44 -0
- hessian_ff/terms/cmap.py +406 -0
- hessian_ff/terms/dihedral.py +141 -0
- hessian_ff/terms/nonbonded.py +209 -0
- hessian_ff/tests/__init__.py +0 -0
- hessian_ff/tests/conftest.py +75 -0
- hessian_ff/tests/data/small/complex.parm7 +1346 -0
- hessian_ff/tests/data/small/complex.pdb +125 -0
- hessian_ff/tests/data/small/complex.rst7 +63 -0
- hessian_ff/tests/test_coords_input.py +44 -0
- hessian_ff/tests/test_energy_force.py +49 -0
- hessian_ff/tests/test_hessian.py +137 -0
- hessian_ff/tests/test_smoke.py +18 -0
- hessian_ff/tests/test_validation.py +40 -0
- hessian_ff/workflows.py +889 -0
- mlmm/__init__.py +36 -0
- mlmm/__main__.py +7 -0
- mlmm/_version.py +34 -0
- mlmm/add_elem_info.py +374 -0
- mlmm/advanced_help.py +91 -0
- mlmm/align_freeze_atoms.py +601 -0
- mlmm/all.py +3535 -0
- mlmm/bond_changes.py +231 -0
- mlmm/bool_compat.py +223 -0
- mlmm/cli.py +574 -0
- mlmm/cli_utils.py +166 -0
- mlmm/default_group.py +337 -0
- mlmm/defaults.py +467 -0
- mlmm/define_layer.py +526 -0
- mlmm/dft.py +1041 -0
- mlmm/energy_diagram.py +253 -0
- mlmm/extract.py +2213 -0
- mlmm/fix_altloc.py +464 -0
- mlmm/freq.py +1406 -0
- mlmm/harmonic_constraints.py +140 -0
- mlmm/hessian_cache.py +44 -0
- mlmm/hessian_calc.py +174 -0
- mlmm/irc.py +638 -0
- mlmm/mlmm_calc.py +2262 -0
- mlmm/mm_parm.py +945 -0
- mlmm/oniom_export.py +1983 -0
- mlmm/oniom_import.py +457 -0
- mlmm/opt.py +1742 -0
- mlmm/path_opt.py +1353 -0
- mlmm/path_search.py +2299 -0
- mlmm/preflight.py +88 -0
- mlmm/py.typed +1 -0
- mlmm/pysis_runner.py +45 -0
- mlmm/scan.py +1047 -0
- mlmm/scan2d.py +1226 -0
- mlmm/scan3d.py +1265 -0
- mlmm/scan_common.py +184 -0
- mlmm/summary_log.py +736 -0
- mlmm/trj2fig.py +448 -0
- mlmm/tsopt.py +2871 -0
- mlmm/utils.py +2309 -0
- mlmm/xtb_embedcharge_correction.py +475 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
- pysisyphus/Geometry.py +1667 -0
- pysisyphus/LICENSE +674 -0
- pysisyphus/TableFormatter.py +63 -0
- pysisyphus/TablePrinter.py +74 -0
- pysisyphus/__init__.py +12 -0
- pysisyphus/calculators/AFIR.py +452 -0
- pysisyphus/calculators/AnaPot.py +20 -0
- pysisyphus/calculators/AnaPot2.py +48 -0
- pysisyphus/calculators/AnaPot3.py +12 -0
- pysisyphus/calculators/AnaPot4.py +20 -0
- pysisyphus/calculators/AnaPotBase.py +337 -0
- pysisyphus/calculators/AnaPotCBM.py +25 -0
- pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
- pysisyphus/calculators/CFOUR.py +250 -0
- pysisyphus/calculators/Calculator.py +844 -0
- pysisyphus/calculators/CerjanMiller.py +24 -0
- pysisyphus/calculators/Composite.py +123 -0
- pysisyphus/calculators/ConicalIntersection.py +171 -0
- pysisyphus/calculators/DFTBp.py +430 -0
- pysisyphus/calculators/DFTD3.py +66 -0
- pysisyphus/calculators/DFTD4.py +84 -0
- pysisyphus/calculators/Dalton.py +61 -0
- pysisyphus/calculators/Dimer.py +681 -0
- pysisyphus/calculators/Dummy.py +20 -0
- pysisyphus/calculators/EGO.py +76 -0
- pysisyphus/calculators/EnergyMin.py +224 -0
- pysisyphus/calculators/ExternalPotential.py +264 -0
- pysisyphus/calculators/FakeASE.py +35 -0
- pysisyphus/calculators/FourWellAnaPot.py +28 -0
- pysisyphus/calculators/FreeEndNEBPot.py +39 -0
- pysisyphus/calculators/Gaussian09.py +18 -0
- pysisyphus/calculators/Gaussian16.py +726 -0
- pysisyphus/calculators/HardSphere.py +159 -0
- pysisyphus/calculators/IDPPCalculator.py +49 -0
- pysisyphus/calculators/IPIClient.py +133 -0
- pysisyphus/calculators/IPIServer.py +234 -0
- pysisyphus/calculators/LEPSBase.py +24 -0
- pysisyphus/calculators/LEPSExpr.py +139 -0
- pysisyphus/calculators/LennardJones.py +80 -0
- pysisyphus/calculators/MOPAC.py +219 -0
- pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
- pysisyphus/calculators/MultiCalc.py +85 -0
- pysisyphus/calculators/NFK.py +45 -0
- pysisyphus/calculators/OBabel.py +87 -0
- pysisyphus/calculators/ONIOMv2.py +1129 -0
- pysisyphus/calculators/ORCA.py +893 -0
- pysisyphus/calculators/ORCA5.py +6 -0
- pysisyphus/calculators/OpenMM.py +88 -0
- pysisyphus/calculators/OpenMolcas.py +281 -0
- pysisyphus/calculators/OverlapCalculator.py +908 -0
- pysisyphus/calculators/Psi4.py +218 -0
- pysisyphus/calculators/PyPsi4.py +37 -0
- pysisyphus/calculators/PySCF.py +341 -0
- pysisyphus/calculators/PyXTB.py +73 -0
- pysisyphus/calculators/QCEngine.py +106 -0
- pysisyphus/calculators/Rastrigin.py +22 -0
- pysisyphus/calculators/Remote.py +76 -0
- pysisyphus/calculators/Rosenbrock.py +15 -0
- pysisyphus/calculators/SocketCalc.py +97 -0
- pysisyphus/calculators/TIP3P.py +111 -0
- pysisyphus/calculators/TransTorque.py +161 -0
- pysisyphus/calculators/Turbomole.py +965 -0
- pysisyphus/calculators/VRIPot.py +37 -0
- pysisyphus/calculators/WFOWrapper.py +333 -0
- pysisyphus/calculators/WFOWrapper2.py +341 -0
- pysisyphus/calculators/XTB.py +418 -0
- pysisyphus/calculators/__init__.py +81 -0
- pysisyphus/calculators/cosmo_data.py +139 -0
- pysisyphus/calculators/parser.py +150 -0
- pysisyphus/color.py +19 -0
- pysisyphus/config.py +133 -0
- pysisyphus/constants.py +65 -0
- pysisyphus/cos/AdaptiveNEB.py +230 -0
- pysisyphus/cos/ChainOfStates.py +725 -0
- pysisyphus/cos/FreeEndNEB.py +25 -0
- pysisyphus/cos/FreezingString.py +103 -0
- pysisyphus/cos/GrowingChainOfStates.py +71 -0
- pysisyphus/cos/GrowingNT.py +309 -0
- pysisyphus/cos/GrowingString.py +508 -0
- pysisyphus/cos/NEB.py +189 -0
- pysisyphus/cos/SimpleZTS.py +64 -0
- pysisyphus/cos/__init__.py +22 -0
- pysisyphus/cos/stiffness.py +199 -0
- pysisyphus/drivers/__init__.py +17 -0
- pysisyphus/drivers/afir.py +855 -0
- pysisyphus/drivers/barriers.py +271 -0
- pysisyphus/drivers/birkholz.py +138 -0
- pysisyphus/drivers/cluster.py +318 -0
- pysisyphus/drivers/diabatization.py +133 -0
- pysisyphus/drivers/merge.py +368 -0
- pysisyphus/drivers/merge_mol2.py +322 -0
- pysisyphus/drivers/opt.py +375 -0
- pysisyphus/drivers/perf.py +91 -0
- pysisyphus/drivers/pka.py +52 -0
- pysisyphus/drivers/precon_pos_rot.py +669 -0
- pysisyphus/drivers/rates.py +480 -0
- pysisyphus/drivers/replace.py +219 -0
- pysisyphus/drivers/scan.py +212 -0
- pysisyphus/drivers/spectrum.py +166 -0
- pysisyphus/drivers/thermo.py +31 -0
- pysisyphus/dynamics/Gaussian.py +103 -0
- pysisyphus/dynamics/__init__.py +20 -0
- pysisyphus/dynamics/colvars.py +136 -0
- pysisyphus/dynamics/driver.py +297 -0
- pysisyphus/dynamics/helpers.py +256 -0
- pysisyphus/dynamics/lincs.py +105 -0
- pysisyphus/dynamics/mdp.py +364 -0
- pysisyphus/dynamics/rattle.py +121 -0
- pysisyphus/dynamics/thermostats.py +128 -0
- pysisyphus/dynamics/wigner.py +266 -0
- pysisyphus/elem_data.py +3473 -0
- pysisyphus/exceptions.py +2 -0
- pysisyphus/filtertrj.py +69 -0
- pysisyphus/helpers.py +623 -0
- pysisyphus/helpers_pure.py +649 -0
- pysisyphus/init_logging.py +50 -0
- pysisyphus/intcoords/Bend.py +69 -0
- pysisyphus/intcoords/Bend2.py +25 -0
- pysisyphus/intcoords/BondedFragment.py +32 -0
- pysisyphus/intcoords/Cartesian.py +41 -0
- pysisyphus/intcoords/CartesianCoords.py +140 -0
- pysisyphus/intcoords/Coords.py +56 -0
- pysisyphus/intcoords/DLC.py +197 -0
- pysisyphus/intcoords/DistanceFunction.py +34 -0
- pysisyphus/intcoords/DummyImproper.py +70 -0
- pysisyphus/intcoords/DummyTorsion.py +72 -0
- pysisyphus/intcoords/LinearBend.py +105 -0
- pysisyphus/intcoords/LinearDisplacement.py +80 -0
- pysisyphus/intcoords/OutOfPlane.py +59 -0
- pysisyphus/intcoords/PrimTypes.py +286 -0
- pysisyphus/intcoords/Primitive.py +137 -0
- pysisyphus/intcoords/RedundantCoords.py +659 -0
- pysisyphus/intcoords/RobustTorsion.py +59 -0
- pysisyphus/intcoords/Rotation.py +147 -0
- pysisyphus/intcoords/Stretch.py +31 -0
- pysisyphus/intcoords/Torsion.py +101 -0
- pysisyphus/intcoords/Torsion2.py +25 -0
- pysisyphus/intcoords/Translation.py +45 -0
- pysisyphus/intcoords/__init__.py +61 -0
- pysisyphus/intcoords/augment_bonds.py +126 -0
- pysisyphus/intcoords/derivatives.py +10512 -0
- pysisyphus/intcoords/eval.py +80 -0
- pysisyphus/intcoords/exceptions.py +37 -0
- pysisyphus/intcoords/findiffs.py +48 -0
- pysisyphus/intcoords/generate_derivatives.py +414 -0
- pysisyphus/intcoords/helpers.py +235 -0
- pysisyphus/intcoords/logging_conf.py +10 -0
- pysisyphus/intcoords/mp_derivatives.py +10836 -0
- pysisyphus/intcoords/setup.py +962 -0
- pysisyphus/intcoords/setup_fast.py +176 -0
- pysisyphus/intcoords/update.py +272 -0
- pysisyphus/intcoords/valid.py +89 -0
- pysisyphus/interpolate/Geodesic.py +93 -0
- pysisyphus/interpolate/IDPP.py +55 -0
- pysisyphus/interpolate/Interpolator.py +116 -0
- pysisyphus/interpolate/LST.py +70 -0
- pysisyphus/interpolate/Redund.py +152 -0
- pysisyphus/interpolate/__init__.py +9 -0
- pysisyphus/interpolate/helpers.py +34 -0
- pysisyphus/io/__init__.py +22 -0
- pysisyphus/io/aomix.py +178 -0
- pysisyphus/io/cjson.py +24 -0
- pysisyphus/io/crd.py +101 -0
- pysisyphus/io/cube.py +220 -0
- pysisyphus/io/fchk.py +184 -0
- pysisyphus/io/hdf5.py +49 -0
- pysisyphus/io/hessian.py +72 -0
- pysisyphus/io/mol2.py +146 -0
- pysisyphus/io/molden.py +293 -0
- pysisyphus/io/orca.py +189 -0
- pysisyphus/io/pdb.py +269 -0
- pysisyphus/io/psf.py +79 -0
- pysisyphus/io/pubchem.py +31 -0
- pysisyphus/io/qcschema.py +34 -0
- pysisyphus/io/sdf.py +29 -0
- pysisyphus/io/xyz.py +61 -0
- pysisyphus/io/zmat.py +175 -0
- pysisyphus/irc/DWI.py +108 -0
- pysisyphus/irc/DampedVelocityVerlet.py +134 -0
- pysisyphus/irc/Euler.py +22 -0
- pysisyphus/irc/EulerPC.py +345 -0
- pysisyphus/irc/GonzalezSchlegel.py +187 -0
- pysisyphus/irc/IMKMod.py +164 -0
- pysisyphus/irc/IRC.py +878 -0
- pysisyphus/irc/IRCDummy.py +10 -0
- pysisyphus/irc/Instanton.py +307 -0
- pysisyphus/irc/LQA.py +53 -0
- pysisyphus/irc/ModeKill.py +136 -0
- pysisyphus/irc/ParamPlot.py +53 -0
- pysisyphus/irc/RK4.py +36 -0
- pysisyphus/irc/__init__.py +31 -0
- pysisyphus/irc/initial_displ.py +219 -0
- pysisyphus/linalg.py +411 -0
- pysisyphus/line_searches/Backtracking.py +88 -0
- pysisyphus/line_searches/HagerZhang.py +184 -0
- pysisyphus/line_searches/LineSearch.py +232 -0
- pysisyphus/line_searches/StrongWolfe.py +108 -0
- pysisyphus/line_searches/__init__.py +9 -0
- pysisyphus/line_searches/interpol.py +15 -0
- pysisyphus/modefollow/NormalMode.py +40 -0
- pysisyphus/modefollow/__init__.py +10 -0
- pysisyphus/modefollow/davidson.py +199 -0
- pysisyphus/modefollow/lanczos.py +95 -0
- pysisyphus/optimizers/BFGS.py +99 -0
- pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
- pysisyphus/optimizers/ConjugateGradient.py +98 -0
- pysisyphus/optimizers/CubicNewton.py +75 -0
- pysisyphus/optimizers/FIRE.py +113 -0
- pysisyphus/optimizers/HessianOptimizer.py +1176 -0
- pysisyphus/optimizers/LBFGS.py +228 -0
- pysisyphus/optimizers/LayerOpt.py +411 -0
- pysisyphus/optimizers/MicroOptimizer.py +169 -0
- pysisyphus/optimizers/NCOptimizer.py +90 -0
- pysisyphus/optimizers/Optimizer.py +1084 -0
- pysisyphus/optimizers/PreconLBFGS.py +260 -0
- pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
- pysisyphus/optimizers/QuickMin.py +74 -0
- pysisyphus/optimizers/RFOptimizer.py +181 -0
- pysisyphus/optimizers/RSA.py +99 -0
- pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
- pysisyphus/optimizers/SteepestDescent.py +23 -0
- pysisyphus/optimizers/StringOptimizer.py +173 -0
- pysisyphus/optimizers/__init__.py +41 -0
- pysisyphus/optimizers/closures.py +301 -0
- pysisyphus/optimizers/cls_map.py +58 -0
- pysisyphus/optimizers/exceptions.py +6 -0
- pysisyphus/optimizers/gdiis.py +280 -0
- pysisyphus/optimizers/guess_hessians.py +311 -0
- pysisyphus/optimizers/hessian_updates.py +355 -0
- pysisyphus/optimizers/poly_fit.py +285 -0
- pysisyphus/optimizers/precon.py +153 -0
- pysisyphus/optimizers/restrict_step.py +24 -0
- pysisyphus/pack.py +172 -0
- pysisyphus/peakdetect.py +948 -0
- pysisyphus/plot.py +1031 -0
- pysisyphus/run.py +2106 -0
- pysisyphus/socket_helper.py +74 -0
- pysisyphus/stocastic/FragmentKick.py +132 -0
- pysisyphus/stocastic/Kick.py +81 -0
- pysisyphus/stocastic/Pipeline.py +303 -0
- pysisyphus/stocastic/__init__.py +21 -0
- pysisyphus/stocastic/align.py +127 -0
- pysisyphus/testing.py +96 -0
- pysisyphus/thermo.py +156 -0
- pysisyphus/trj.py +824 -0
- pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
- pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
- pysisyphus/tsoptimizers/TRIM.py +59 -0
- pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
- pysisyphus/tsoptimizers/__init__.py +23 -0
- pysisyphus/wavefunction/Basis.py +239 -0
- pysisyphus/wavefunction/DIIS.py +76 -0
- pysisyphus/wavefunction/__init__.py +25 -0
- pysisyphus/wavefunction/build_ext.py +42 -0
- pysisyphus/wavefunction/cart2sph.py +190 -0
- pysisyphus/wavefunction/diabatization.py +304 -0
- pysisyphus/wavefunction/excited_states.py +435 -0
- pysisyphus/wavefunction/gen_ints.py +1811 -0
- pysisyphus/wavefunction/helpers.py +104 -0
- pysisyphus/wavefunction/ints/__init__.py +0 -0
- pysisyphus/wavefunction/ints/boys.py +193 -0
- pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
- pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
- pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
- pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
- pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
- pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
- pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
- pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
- pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
- pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
- pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
- pysisyphus/wavefunction/localization.py +458 -0
- pysisyphus/wavefunction/multipole.py +159 -0
- pysisyphus/wavefunction/normalization.py +36 -0
- pysisyphus/wavefunction/pop_analysis.py +134 -0
- pysisyphus/wavefunction/shells.py +1171 -0
- pysisyphus/wavefunction/wavefunction.py +504 -0
- pysisyphus/wrapper/__init__.py +11 -0
- pysisyphus/wrapper/exceptions.py +2 -0
- pysisyphus/wrapper/jmol.py +120 -0
- pysisyphus/wrapper/mwfn.py +169 -0
- pysisyphus/wrapper/packmol.py +71 -0
- pysisyphus/xyzloader.py +168 -0
- pysisyphus/yaml_mods.py +45 -0
- thermoanalysis/LICENSE +674 -0
- thermoanalysis/QCData.py +244 -0
- thermoanalysis/__init__.py +0 -0
- thermoanalysis/config.py +3 -0
- thermoanalysis/constants.py +20 -0
- thermoanalysis/thermo.py +1011 -0
|
@@ -0,0 +1,640 @@
|
|
|
1
|
+
#include <torch/extension.h>
|
|
2
|
+
|
|
3
|
+
#include <algorithm>
|
|
4
|
+
#include <cmath>
|
|
5
|
+
#include <cstdint>
|
|
6
|
+
#include <vector>
|
|
7
|
+
|
|
8
|
+
#ifdef _OPENMP
|
|
9
|
+
#include <omp.h>
|
|
10
|
+
#endif
|
|
11
|
+
|
|
12
|
+
// References:
|
|
13
|
+
// - hessian_ff term definitions and analytical force equations:
|
|
14
|
+
// - hessian_ff/terms/bond.py
|
|
15
|
+
// - hessian_ff/terms/angle.py
|
|
16
|
+
// - hessian_ff/terms/dihedral.py
|
|
17
|
+
// - hessian_ff/terms/cmap.py
|
|
18
|
+
// - OpenMM conventions used for compatibility:
|
|
19
|
+
// - Periodic torsion / dihedral sign conventions:
|
|
20
|
+
// https://github.com/openmm/openmm/blob/master/platforms/reference/src/SimTKReference/ReferenceProperDihedralBond.cpp
|
|
21
|
+
// - CMAP bicubic map conventions:
|
|
22
|
+
// https://github.com/openmm/openmm/blob/master/openmmapi/src/CMAPTorsionForceImpl.cpp
|
|
23
|
+
|
|
24
|
+
namespace {
|
|
25
|
+
|
|
26
|
+
constexpr double kTwoPi = 6.283185307179586476925286766559;
|
|
27
|
+
|
|
28
|
+
inline double clamp(double x, double lo, double hi) {
|
|
29
|
+
return std::max(lo, std::min(hi, x));
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
inline double wrap_0_2pi(double x) {
|
|
33
|
+
double y = std::fmod(x, kTwoPi);
|
|
34
|
+
if (y < 0.0) {
|
|
35
|
+
y += kTwoPi;
|
|
36
|
+
}
|
|
37
|
+
return y;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
inline void cross3(
|
|
41
|
+
double ax,
|
|
42
|
+
double ay,
|
|
43
|
+
double az,
|
|
44
|
+
double bx,
|
|
45
|
+
double by,
|
|
46
|
+
double bz,
|
|
47
|
+
double& cx,
|
|
48
|
+
double& cy,
|
|
49
|
+
double& cz) {
|
|
50
|
+
cx = ay * bz - az * by;
|
|
51
|
+
cy = az * bx - ax * bz;
|
|
52
|
+
cz = ax * by - ay * bx;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
inline double dot3(
|
|
56
|
+
double ax,
|
|
57
|
+
double ay,
|
|
58
|
+
double az,
|
|
59
|
+
double bx,
|
|
60
|
+
double by,
|
|
61
|
+
double bz) {
|
|
62
|
+
return ax * bx + ay * by + az * bz;
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
template <typename scalar_t>
|
|
66
|
+
inline void add_force(scalar_t* force, int64_t atom, double fx, double fy, double fz) {
|
|
67
|
+
force[3 * atom + 0] += static_cast<scalar_t>(fx);
|
|
68
|
+
force[3 * atom + 1] += static_cast<scalar_t>(fy);
|
|
69
|
+
force[3 * atom + 2] += static_cast<scalar_t>(fz);
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
template <typename scalar_t>
|
|
73
|
+
inline double dihedral_angle_from_points(
|
|
74
|
+
const scalar_t* coords,
|
|
75
|
+
int64_t i,
|
|
76
|
+
int64_t j,
|
|
77
|
+
int64_t k,
|
|
78
|
+
int64_t l,
|
|
79
|
+
double* out_cross12,
|
|
80
|
+
double* out_cross23,
|
|
81
|
+
double* out_v1,
|
|
82
|
+
double* out_v2,
|
|
83
|
+
double* out_v3) {
|
|
84
|
+
// OpenMM-compatible signed dihedral convention used in hessian_ff/terms/dihedral.py.
|
|
85
|
+
const double p0x = static_cast<double>(coords[3 * i + 0]);
|
|
86
|
+
const double p0y = static_cast<double>(coords[3 * i + 1]);
|
|
87
|
+
const double p0z = static_cast<double>(coords[3 * i + 2]);
|
|
88
|
+
const double p1x = static_cast<double>(coords[3 * j + 0]);
|
|
89
|
+
const double p1y = static_cast<double>(coords[3 * j + 1]);
|
|
90
|
+
const double p1z = static_cast<double>(coords[3 * j + 2]);
|
|
91
|
+
const double p2x = static_cast<double>(coords[3 * k + 0]);
|
|
92
|
+
const double p2y = static_cast<double>(coords[3 * k + 1]);
|
|
93
|
+
const double p2z = static_cast<double>(coords[3 * k + 2]);
|
|
94
|
+
const double p3x = static_cast<double>(coords[3 * l + 0]);
|
|
95
|
+
const double p3y = static_cast<double>(coords[3 * l + 1]);
|
|
96
|
+
const double p3z = static_cast<double>(coords[3 * l + 2]);
|
|
97
|
+
|
|
98
|
+
const double v1x = p0x - p1x;
|
|
99
|
+
const double v1y = p0y - p1y;
|
|
100
|
+
const double v1z = p0z - p1z;
|
|
101
|
+
const double v2x = p2x - p1x;
|
|
102
|
+
const double v2y = p2y - p1y;
|
|
103
|
+
const double v2z = p2z - p1z;
|
|
104
|
+
const double v3x = p2x - p3x;
|
|
105
|
+
const double v3y = p2y - p3y;
|
|
106
|
+
const double v3z = p2z - p3z;
|
|
107
|
+
|
|
108
|
+
double c12x, c12y, c12z;
|
|
109
|
+
double c23x, c23y, c23z;
|
|
110
|
+
cross3(v1x, v1y, v1z, v2x, v2y, v2z, c12x, c12y, c12z);
|
|
111
|
+
cross3(v2x, v2y, v2z, v3x, v3y, v3z, c23x, c23y, c23z);
|
|
112
|
+
|
|
113
|
+
const double n12 = std::max(std::sqrt(dot3(c12x, c12y, c12z, c12x, c12y, c12z)), 1.0e-12);
|
|
114
|
+
const double n23 = std::max(std::sqrt(dot3(c23x, c23y, c23z, c23x, c23y, c23z)), 1.0e-12);
|
|
115
|
+
const double cos_phi = clamp(dot3(c12x, c12y, c12z, c23x, c23y, c23z) / (n12 * n23), -1.0, 1.0);
|
|
116
|
+
double phi = std::acos(cos_phi);
|
|
117
|
+
|
|
118
|
+
const double sign_probe = dot3(v1x, v1y, v1z, c23x, c23y, c23z);
|
|
119
|
+
if (sign_probe < 0.0) {
|
|
120
|
+
phi = -phi;
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
if (out_cross12 != nullptr) {
|
|
124
|
+
out_cross12[0] = c12x;
|
|
125
|
+
out_cross12[1] = c12y;
|
|
126
|
+
out_cross12[2] = c12z;
|
|
127
|
+
}
|
|
128
|
+
if (out_cross23 != nullptr) {
|
|
129
|
+
out_cross23[0] = c23x;
|
|
130
|
+
out_cross23[1] = c23y;
|
|
131
|
+
out_cross23[2] = c23z;
|
|
132
|
+
}
|
|
133
|
+
if (out_v1 != nullptr) {
|
|
134
|
+
out_v1[0] = v1x;
|
|
135
|
+
out_v1[1] = v1y;
|
|
136
|
+
out_v1[2] = v1z;
|
|
137
|
+
}
|
|
138
|
+
if (out_v2 != nullptr) {
|
|
139
|
+
out_v2[0] = v2x;
|
|
140
|
+
out_v2[1] = v2y;
|
|
141
|
+
out_v2[2] = v2z;
|
|
142
|
+
}
|
|
143
|
+
if (out_v3 != nullptr) {
|
|
144
|
+
out_v3[0] = v3x;
|
|
145
|
+
out_v3[1] = v3y;
|
|
146
|
+
out_v3[2] = v3z;
|
|
147
|
+
}
|
|
148
|
+
return phi;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
template <typename scalar_t>
|
|
152
|
+
inline void accumulate_dihedral_force(
|
|
153
|
+
const scalar_t* coords,
|
|
154
|
+
int64_t i,
|
|
155
|
+
int64_t j,
|
|
156
|
+
int64_t k,
|
|
157
|
+
int64_t l,
|
|
158
|
+
double dE_dphi,
|
|
159
|
+
scalar_t* force) {
|
|
160
|
+
double cross12[3], cross23[3], v1[3], v2[3], v3[3];
|
|
161
|
+
(void)dihedral_angle_from_points(
|
|
162
|
+
coords,
|
|
163
|
+
i,
|
|
164
|
+
j,
|
|
165
|
+
k,
|
|
166
|
+
l,
|
|
167
|
+
cross12,
|
|
168
|
+
cross23,
|
|
169
|
+
v1,
|
|
170
|
+
v2,
|
|
171
|
+
v3);
|
|
172
|
+
|
|
173
|
+
const double norm_cross12_sq = std::max(dot3(cross12[0], cross12[1], cross12[2], cross12[0], cross12[1], cross12[2]), 1.0e-24);
|
|
174
|
+
const double norm_cross23_sq = std::max(dot3(cross23[0], cross23[1], cross23[2], cross23[0], cross23[1], cross23[2]), 1.0e-24);
|
|
175
|
+
const double norm_v2_sq = std::max(dot3(v2[0], v2[1], v2[2], v2[0], v2[1], v2[2]), 1.0e-24);
|
|
176
|
+
const double norm_v2 = std::max(std::sqrt(norm_v2_sq), 1.0e-12);
|
|
177
|
+
|
|
178
|
+
const double f0 = (-dE_dphi * norm_v2) / norm_cross12_sq;
|
|
179
|
+
const double f3 = (dE_dphi * norm_v2) / norm_cross23_sq;
|
|
180
|
+
const double f1 = dot3(v1[0], v1[1], v1[2], v2[0], v2[1], v2[2]) / norm_v2_sq;
|
|
181
|
+
const double f2 = dot3(v3[0], v3[1], v3[2], v2[0], v2[1], v2[2]) / norm_v2_sq;
|
|
182
|
+
|
|
183
|
+
const double ff0x = f0 * cross12[0];
|
|
184
|
+
const double ff0y = f0 * cross12[1];
|
|
185
|
+
const double ff0z = f0 * cross12[2];
|
|
186
|
+
const double ff3x = f3 * cross23[0];
|
|
187
|
+
const double ff3y = f3 * cross23[1];
|
|
188
|
+
const double ff3z = f3 * cross23[2];
|
|
189
|
+
|
|
190
|
+
const double sx = f1 * ff0x - f2 * ff3x;
|
|
191
|
+
const double sy = f1 * ff0y - f2 * ff3y;
|
|
192
|
+
const double sz = f1 * ff0z - f2 * ff3z;
|
|
193
|
+
|
|
194
|
+
const double ff1x = ff0x - sx;
|
|
195
|
+
const double ff1y = ff0y - sy;
|
|
196
|
+
const double ff1z = ff0z - sz;
|
|
197
|
+
const double ff2x = ff3x + sx;
|
|
198
|
+
const double ff2y = ff3y + sy;
|
|
199
|
+
const double ff2z = ff3z + sz;
|
|
200
|
+
|
|
201
|
+
add_force(force, i, ff0x, ff0y, ff0z);
|
|
202
|
+
add_force(force, j, -ff1x, -ff1y, -ff1z);
|
|
203
|
+
add_force(force, k, -ff2x, -ff2y, -ff2z);
|
|
204
|
+
add_force(force, l, ff3x, ff3y, ff3z);
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
std::vector<torch::Tensor> bonded_energy_force_cpu(
|
|
208
|
+
const torch::Tensor& coords,
|
|
209
|
+
const torch::Tensor& bond_i,
|
|
210
|
+
const torch::Tensor& bond_j,
|
|
211
|
+
const torch::Tensor& bond_k,
|
|
212
|
+
const torch::Tensor& bond_r0,
|
|
213
|
+
const torch::Tensor& angle_i,
|
|
214
|
+
const torch::Tensor& angle_j,
|
|
215
|
+
const torch::Tensor& angle_k,
|
|
216
|
+
const torch::Tensor& angle_k0,
|
|
217
|
+
const torch::Tensor& angle_t0,
|
|
218
|
+
const torch::Tensor& dihed_i,
|
|
219
|
+
const torch::Tensor& dihed_j,
|
|
220
|
+
const torch::Tensor& dihed_k,
|
|
221
|
+
const torch::Tensor& dihed_l,
|
|
222
|
+
const torch::Tensor& dihed_force,
|
|
223
|
+
const torch::Tensor& dihed_period,
|
|
224
|
+
const torch::Tensor& dihed_phase,
|
|
225
|
+
const torch::Tensor& cmap_type,
|
|
226
|
+
const torch::Tensor& cmap_i,
|
|
227
|
+
const torch::Tensor& cmap_j,
|
|
228
|
+
const torch::Tensor& cmap_k,
|
|
229
|
+
const torch::Tensor& cmap_l,
|
|
230
|
+
const torch::Tensor& cmap_m,
|
|
231
|
+
const torch::Tensor& cmap_size,
|
|
232
|
+
const torch::Tensor& cmap_delta,
|
|
233
|
+
const torch::Tensor& cmap_offset,
|
|
234
|
+
const torch::Tensor& cmap_coeff) {
|
|
235
|
+
auto c = coords.contiguous();
|
|
236
|
+
TORCH_CHECK(c.device().is_cpu(), "bonded_energy_force_cpu expects CPU coords");
|
|
237
|
+
TORCH_CHECK(c.dim() == 2 && c.size(1) == 3, "coords must be [N,3]");
|
|
238
|
+
|
|
239
|
+
auto bi = bond_i.contiguous();
|
|
240
|
+
auto bj = bond_j.contiguous();
|
|
241
|
+
auto bk = bond_k.contiguous();
|
|
242
|
+
auto br0 = bond_r0.contiguous();
|
|
243
|
+
|
|
244
|
+
auto ai = angle_i.contiguous();
|
|
245
|
+
auto aj = angle_j.contiguous();
|
|
246
|
+
auto ak = angle_k.contiguous();
|
|
247
|
+
auto ak0 = angle_k0.contiguous();
|
|
248
|
+
auto at0 = angle_t0.contiguous();
|
|
249
|
+
|
|
250
|
+
auto di = dihed_i.contiguous();
|
|
251
|
+
auto dj = dihed_j.contiguous();
|
|
252
|
+
auto dk = dihed_k.contiguous();
|
|
253
|
+
auto dl = dihed_l.contiguous();
|
|
254
|
+
auto df = dihed_force.contiguous();
|
|
255
|
+
auto dp = dihed_period.contiguous();
|
|
256
|
+
auto dph = dihed_phase.contiguous();
|
|
257
|
+
|
|
258
|
+
auto ct = cmap_type.contiguous();
|
|
259
|
+
auto ci = cmap_i.contiguous();
|
|
260
|
+
auto cj = cmap_j.contiguous();
|
|
261
|
+
auto ck = cmap_k.contiguous();
|
|
262
|
+
auto cl = cmap_l.contiguous();
|
|
263
|
+
auto cm = cmap_m.contiguous();
|
|
264
|
+
auto csize = cmap_size.contiguous();
|
|
265
|
+
auto cdelta = cmap_delta.contiguous();
|
|
266
|
+
auto coff = cmap_offset.contiguous();
|
|
267
|
+
auto ccoef = cmap_coeff.contiguous();
|
|
268
|
+
|
|
269
|
+
TORCH_CHECK(bi.scalar_type() == torch::kInt64, "bond_i must be int64");
|
|
270
|
+
TORCH_CHECK(ai.scalar_type() == torch::kInt64, "angle_i must be int64");
|
|
271
|
+
TORCH_CHECK(di.scalar_type() == torch::kInt64, "dihed_i must be int64");
|
|
272
|
+
TORCH_CHECK(ct.scalar_type() == torch::kInt64, "cmap_type must be int64");
|
|
273
|
+
TORCH_CHECK(csize.scalar_type() == torch::kInt64, "cmap_size must be int64");
|
|
274
|
+
TORCH_CHECK(coff.scalar_type() == torch::kInt64, "cmap_offset must be int64");
|
|
275
|
+
|
|
276
|
+
auto force = torch::zeros_like(c);
|
|
277
|
+
auto e_bond = torch::zeros({}, c.options());
|
|
278
|
+
auto e_angle = torch::zeros({}, c.options());
|
|
279
|
+
auto e_dihed = torch::zeros({}, c.options());
|
|
280
|
+
auto e_cmap = torch::zeros({}, c.options());
|
|
281
|
+
|
|
282
|
+
AT_DISPATCH_FLOATING_TYPES(c.scalar_type(), "bonded_energy_force_cpu", [&] {
|
|
283
|
+
const scalar_t* xyz = c.data_ptr<scalar_t>();
|
|
284
|
+
scalar_t* f = force.data_ptr<scalar_t>();
|
|
285
|
+
const int64_t natom = c.size(0);
|
|
286
|
+
const int64_t stride = 3 * natom;
|
|
287
|
+
const int64_t nbond = bi.numel();
|
|
288
|
+
const int64_t nangle = ai.numel();
|
|
289
|
+
const int64_t ndihed = di.numel();
|
|
290
|
+
const int64_t ncmap = ct.numel();
|
|
291
|
+
|
|
292
|
+
const int64_t* bi_ptr = bi.data_ptr<int64_t>();
|
|
293
|
+
const int64_t* bj_ptr = bj.data_ptr<int64_t>();
|
|
294
|
+
const scalar_t* bk_ptr = bk.data_ptr<scalar_t>();
|
|
295
|
+
const scalar_t* br0_ptr = br0.data_ptr<scalar_t>();
|
|
296
|
+
|
|
297
|
+
const int64_t* ai_ptr = ai.data_ptr<int64_t>();
|
|
298
|
+
const int64_t* aj_ptr = aj.data_ptr<int64_t>();
|
|
299
|
+
const int64_t* ak_ptr = ak.data_ptr<int64_t>();
|
|
300
|
+
const scalar_t* ak0_ptr = ak0.data_ptr<scalar_t>();
|
|
301
|
+
const scalar_t* at0_ptr = at0.data_ptr<scalar_t>();
|
|
302
|
+
|
|
303
|
+
const int64_t* di_ptr = di.data_ptr<int64_t>();
|
|
304
|
+
const int64_t* dj_ptr = dj.data_ptr<int64_t>();
|
|
305
|
+
const int64_t* dk_ptr = dk.data_ptr<int64_t>();
|
|
306
|
+
const int64_t* dl_ptr = dl.data_ptr<int64_t>();
|
|
307
|
+
const scalar_t* df_ptr = df.data_ptr<scalar_t>();
|
|
308
|
+
const scalar_t* dp_ptr = dp.data_ptr<scalar_t>();
|
|
309
|
+
const scalar_t* dph_ptr = dph.data_ptr<scalar_t>();
|
|
310
|
+
|
|
311
|
+
const int64_t* ct_ptr = ct.data_ptr<int64_t>();
|
|
312
|
+
const int64_t* ci_ptr = ci.data_ptr<int64_t>();
|
|
313
|
+
const int64_t* cj_ptr = cj.data_ptr<int64_t>();
|
|
314
|
+
const int64_t* ck_ptr = ck.data_ptr<int64_t>();
|
|
315
|
+
const int64_t* cl_ptr = cl.data_ptr<int64_t>();
|
|
316
|
+
const int64_t* cm_ptr = cm.data_ptr<int64_t>();
|
|
317
|
+
const int64_t* csize_ptr = csize.data_ptr<int64_t>();
|
|
318
|
+
const scalar_t* cdelta_ptr = cdelta.data_ptr<scalar_t>();
|
|
319
|
+
const int64_t* coff_ptr = coff.data_ptr<int64_t>();
|
|
320
|
+
const scalar_t* ccoef_ptr = ccoef.data_ptr<scalar_t>();
|
|
321
|
+
|
|
322
|
+
int nthreads = 1;
|
|
323
|
+
#ifdef _OPENMP
|
|
324
|
+
nthreads = std::max(1, omp_get_max_threads());
|
|
325
|
+
#endif
|
|
326
|
+
std::vector<scalar_t> force_tls(
|
|
327
|
+
static_cast<size_t>(nthreads) * static_cast<size_t>(stride),
|
|
328
|
+
static_cast<scalar_t>(0));
|
|
329
|
+
std::vector<double> eb_tls(static_cast<size_t>(nthreads), 0.0);
|
|
330
|
+
std::vector<double> ea_tls(static_cast<size_t>(nthreads), 0.0);
|
|
331
|
+
std::vector<double> ed_tls(static_cast<size_t>(nthreads), 0.0);
|
|
332
|
+
std::vector<double> ec_tls(static_cast<size_t>(nthreads), 0.0);
|
|
333
|
+
|
|
334
|
+
// Bond
|
|
335
|
+
#ifdef _OPENMP
|
|
336
|
+
#pragma omp parallel num_threads(nthreads)
|
|
337
|
+
#endif
|
|
338
|
+
{
|
|
339
|
+
int tid = 0;
|
|
340
|
+
#ifdef _OPENMP
|
|
341
|
+
tid = omp_get_thread_num();
|
|
342
|
+
#endif
|
|
343
|
+
scalar_t* fl = force_tls.data() + static_cast<size_t>(tid) * static_cast<size_t>(stride);
|
|
344
|
+
double eb_local = 0.0;
|
|
345
|
+
#ifdef _OPENMP
|
|
346
|
+
#pragma omp for schedule(static)
|
|
347
|
+
#endif
|
|
348
|
+
for (int64_t p = 0; p < nbond; ++p) {
|
|
349
|
+
const int64_t i = bi_ptr[p];
|
|
350
|
+
const int64_t j = bj_ptr[p];
|
|
351
|
+
const double k = static_cast<double>(bk_ptr[p]);
|
|
352
|
+
const double r0 = static_cast<double>(br0_ptr[p]);
|
|
353
|
+
|
|
354
|
+
const double dx = static_cast<double>(xyz[3 * j + 0] - xyz[3 * i + 0]);
|
|
355
|
+
const double dy = static_cast<double>(xyz[3 * j + 1] - xyz[3 * i + 1]);
|
|
356
|
+
const double dz = static_cast<double>(xyz[3 * j + 2] - xyz[3 * i + 2]);
|
|
357
|
+
const double r2 = std::max(dx * dx + dy * dy + dz * dz, 1.0e-24);
|
|
358
|
+
const double inv_r = 1.0 / std::sqrt(r2);
|
|
359
|
+
const double r = r2 * inv_r;
|
|
360
|
+
const double dr = r - r0;
|
|
361
|
+
eb_local += k * dr * dr;
|
|
362
|
+
const double fs = 2.0 * k * dr * inv_r;
|
|
363
|
+
add_force(fl, i, fs * dx, fs * dy, fs * dz);
|
|
364
|
+
add_force(fl, j, -fs * dx, -fs * dy, -fs * dz);
|
|
365
|
+
}
|
|
366
|
+
eb_tls[static_cast<size_t>(tid)] += eb_local;
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
// Angle
|
|
370
|
+
#ifdef _OPENMP
|
|
371
|
+
#pragma omp parallel num_threads(nthreads)
|
|
372
|
+
#endif
|
|
373
|
+
{
|
|
374
|
+
int tid = 0;
|
|
375
|
+
#ifdef _OPENMP
|
|
376
|
+
tid = omp_get_thread_num();
|
|
377
|
+
#endif
|
|
378
|
+
scalar_t* fl = force_tls.data() + static_cast<size_t>(tid) * static_cast<size_t>(stride);
|
|
379
|
+
double ea_local = 0.0;
|
|
380
|
+
#ifdef _OPENMP
|
|
381
|
+
#pragma omp for schedule(static)
|
|
382
|
+
#endif
|
|
383
|
+
for (int64_t p = 0; p < nangle; ++p) {
|
|
384
|
+
const int64_t i = ai_ptr[p];
|
|
385
|
+
const int64_t j = aj_ptr[p];
|
|
386
|
+
const int64_t k = ak_ptr[p];
|
|
387
|
+
const double ktheta = static_cast<double>(ak0_ptr[p]);
|
|
388
|
+
const double theta0 = static_cast<double>(at0_ptr[p]);
|
|
389
|
+
|
|
390
|
+
const double d0x = static_cast<double>(xyz[3 * j + 0] - xyz[3 * i + 0]);
|
|
391
|
+
const double d0y = static_cast<double>(xyz[3 * j + 1] - xyz[3 * i + 1]);
|
|
392
|
+
const double d0z = static_cast<double>(xyz[3 * j + 2] - xyz[3 * i + 2]);
|
|
393
|
+
const double d1x = static_cast<double>(xyz[3 * j + 0] - xyz[3 * k + 0]);
|
|
394
|
+
const double d1y = static_cast<double>(xyz[3 * j + 1] - xyz[3 * k + 1]);
|
|
395
|
+
const double d1z = static_cast<double>(xyz[3 * j + 2] - xyz[3 * k + 2]);
|
|
396
|
+
|
|
397
|
+
double px, py, pz;
|
|
398
|
+
cross3(d0x, d0y, d0z, d1x, d1y, d1z, px, py, pz);
|
|
399
|
+
|
|
400
|
+
const double r20 = std::max(dot3(d0x, d0y, d0z, d0x, d0y, d0z), 1.0e-24);
|
|
401
|
+
const double r21 = std::max(dot3(d1x, d1y, d1z, d1x, d1y, d1z), 1.0e-24);
|
|
402
|
+
const double rp = std::max(std::sqrt(dot3(px, py, pz, px, py, pz)), 1.0e-12);
|
|
403
|
+
const double dot = dot3(d0x, d0y, d0z, d1x, d1y, d1z);
|
|
404
|
+
const double cos_theta = clamp(dot / std::sqrt(r20 * r21), -1.0, 1.0);
|
|
405
|
+
const double theta = std::acos(cos_theta);
|
|
406
|
+
|
|
407
|
+
const double dtheta = theta - theta0;
|
|
408
|
+
ea_local += ktheta * dtheta * dtheta;
|
|
409
|
+
const double dE_dtheta = 2.0 * ktheta * dtheta;
|
|
410
|
+
|
|
411
|
+
const double term_i = dE_dtheta / (r20 * rp);
|
|
412
|
+
const double term_k = -dE_dtheta / (r21 * rp);
|
|
413
|
+
|
|
414
|
+
double cix, ciy, ciz;
|
|
415
|
+
double ckx, cky, ckz;
|
|
416
|
+
cross3(d0x, d0y, d0z, px, py, pz, cix, ciy, ciz);
|
|
417
|
+
cross3(d1x, d1y, d1z, px, py, pz, ckx, cky, ckz);
|
|
418
|
+
|
|
419
|
+
const double fix = cix * term_i;
|
|
420
|
+
const double fiy = ciy * term_i;
|
|
421
|
+
const double fiz = ciz * term_i;
|
|
422
|
+
const double fkx = ckx * term_k;
|
|
423
|
+
const double fky = cky * term_k;
|
|
424
|
+
const double fkz = ckz * term_k;
|
|
425
|
+
const double fjx = -(fix + fkx);
|
|
426
|
+
const double fjy = -(fiy + fky);
|
|
427
|
+
const double fjz = -(fiz + fkz);
|
|
428
|
+
|
|
429
|
+
add_force(fl, i, fix, fiy, fiz);
|
|
430
|
+
add_force(fl, j, fjx, fjy, fjz);
|
|
431
|
+
add_force(fl, k, fkx, fky, fkz);
|
|
432
|
+
}
|
|
433
|
+
ea_tls[static_cast<size_t>(tid)] += ea_local;
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
// Dihedral
|
|
437
|
+
#ifdef _OPENMP
|
|
438
|
+
#pragma omp parallel num_threads(nthreads)
|
|
439
|
+
#endif
|
|
440
|
+
{
|
|
441
|
+
int tid = 0;
|
|
442
|
+
#ifdef _OPENMP
|
|
443
|
+
tid = omp_get_thread_num();
|
|
444
|
+
#endif
|
|
445
|
+
scalar_t* fl = force_tls.data() + static_cast<size_t>(tid) * static_cast<size_t>(stride);
|
|
446
|
+
double ed_local = 0.0;
|
|
447
|
+
#ifdef _OPENMP
|
|
448
|
+
#pragma omp for schedule(static)
|
|
449
|
+
#endif
|
|
450
|
+
for (int64_t p = 0; p < ndihed; ++p) {
|
|
451
|
+
const int64_t i = di_ptr[p];
|
|
452
|
+
const int64_t j = dj_ptr[p];
|
|
453
|
+
const int64_t k = dk_ptr[p];
|
|
454
|
+
const int64_t l = dl_ptr[p];
|
|
455
|
+
const double kf = static_cast<double>(df_ptr[p]);
|
|
456
|
+
const double n = std::abs(static_cast<double>(dp_ptr[p]));
|
|
457
|
+
const double phase = static_cast<double>(dph_ptr[p]);
|
|
458
|
+
|
|
459
|
+
const double phi = dihedral_angle_from_points<scalar_t>(
|
|
460
|
+
xyz, i, j, k, l, nullptr, nullptr, nullptr, nullptr, nullptr);
|
|
461
|
+
const double delta = n * phi - phase;
|
|
462
|
+
ed_local += kf * (1.0 + std::cos(delta));
|
|
463
|
+
const double dE_dphi = -kf * n * std::sin(delta);
|
|
464
|
+
accumulate_dihedral_force<scalar_t>(xyz, i, j, k, l, dE_dphi, fl);
|
|
465
|
+
}
|
|
466
|
+
ed_tls[static_cast<size_t>(tid)] += ed_local;
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
// CMAP
|
|
470
|
+
#ifdef _OPENMP
|
|
471
|
+
#pragma omp parallel num_threads(nthreads)
|
|
472
|
+
#endif
|
|
473
|
+
{
|
|
474
|
+
int tid = 0;
|
|
475
|
+
#ifdef _OPENMP
|
|
476
|
+
tid = omp_get_thread_num();
|
|
477
|
+
#endif
|
|
478
|
+
scalar_t* fl = force_tls.data() + static_cast<size_t>(tid) * static_cast<size_t>(stride);
|
|
479
|
+
double ec_local = 0.0;
|
|
480
|
+
#ifdef _OPENMP
|
|
481
|
+
#pragma omp for schedule(static)
|
|
482
|
+
#endif
|
|
483
|
+
for (int64_t p = 0; p < ncmap; ++p) {
|
|
484
|
+
const int64_t tmap = ct_ptr[p];
|
|
485
|
+
const int64_t i = ci_ptr[p];
|
|
486
|
+
const int64_t j = cj_ptr[p];
|
|
487
|
+
const int64_t k = ck_ptr[p];
|
|
488
|
+
const int64_t l = cl_ptr[p];
|
|
489
|
+
const int64_t m = cm_ptr[p];
|
|
490
|
+
|
|
491
|
+
const double phi = dihedral_angle_from_points<scalar_t>(
|
|
492
|
+
xyz, i, j, k, l, nullptr, nullptr, nullptr, nullptr, nullptr);
|
|
493
|
+
const double psi = dihedral_angle_from_points<scalar_t>(
|
|
494
|
+
xyz, j, k, l, m, nullptr, nullptr, nullptr, nullptr, nullptr);
|
|
495
|
+
const double delta = static_cast<double>(cdelta_ptr[tmap]);
|
|
496
|
+
const int64_t size = csize_ptr[tmap];
|
|
497
|
+
|
|
498
|
+
const double ang_a = wrap_0_2pi(phi + kTwoPi);
|
|
499
|
+
const double ang_b = wrap_0_2pi(psi + kTwoPi);
|
|
500
|
+
const double u = ang_a / delta;
|
|
501
|
+
const double v = ang_b / delta;
|
|
502
|
+
int64_t su = static_cast<int64_t>(std::floor(u));
|
|
503
|
+
int64_t sv = static_cast<int64_t>(std::floor(v));
|
|
504
|
+
su = std::min(su, size - 1);
|
|
505
|
+
sv = std::min(sv, size - 1);
|
|
506
|
+
const double da = u - static_cast<double>(su);
|
|
507
|
+
const double db = v - static_cast<double>(sv);
|
|
508
|
+
|
|
509
|
+
const int64_t patch = su + size * sv;
|
|
510
|
+
const int64_t coeff_row = coff_ptr[tmap] + patch;
|
|
511
|
+
const scalar_t* coeff = ccoef_ptr + 16 * coeff_row;
|
|
512
|
+
|
|
513
|
+
const double db2 = db * db;
|
|
514
|
+
const double da2 = da * da;
|
|
515
|
+
const double da3 = da2 * da;
|
|
516
|
+
|
|
517
|
+
double ppoly[4];
|
|
518
|
+
double pd[4];
|
|
519
|
+
for (int r = 0; r < 4; ++r) {
|
|
520
|
+
const double c0 = static_cast<double>(coeff[4 * r + 0]);
|
|
521
|
+
const double c1 = static_cast<double>(coeff[4 * r + 1]);
|
|
522
|
+
const double c2 = static_cast<double>(coeff[4 * r + 2]);
|
|
523
|
+
const double c3 = static_cast<double>(coeff[4 * r + 3]);
|
|
524
|
+
ppoly[r] = c0 + c1 * db + c2 * db2 + c3 * db2 * db;
|
|
525
|
+
pd[r] = c1 + 2.0 * c2 * db + 3.0 * c3 * db2;
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
const double e_val = ppoly[0] + ppoly[1] * da + ppoly[2] * da2 + ppoly[3] * da3;
|
|
529
|
+
ec_local += e_val;
|
|
530
|
+
|
|
531
|
+
const double dE_dda = ppoly[1] + 2.0 * ppoly[2] * da + 3.0 * ppoly[3] * da2;
|
|
532
|
+
const double dE_ddb = pd[0] + pd[1] * da + pd[2] * da2 + pd[3] * da3;
|
|
533
|
+
const double dE_dphi = dE_dda / delta;
|
|
534
|
+
const double dE_dpsi = dE_ddb / delta;
|
|
535
|
+
|
|
536
|
+
accumulate_dihedral_force<scalar_t>(xyz, i, j, k, l, dE_dphi, fl);
|
|
537
|
+
accumulate_dihedral_force<scalar_t>(xyz, j, k, l, m, dE_dpsi, fl);
|
|
538
|
+
}
|
|
539
|
+
ec_tls[static_cast<size_t>(tid)] += ec_local;
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
double eb = 0.0;
|
|
543
|
+
double ea = 0.0;
|
|
544
|
+
double ed = 0.0;
|
|
545
|
+
double ec = 0.0;
|
|
546
|
+
for (int tid = 0; tid < nthreads; ++tid) {
|
|
547
|
+
eb += eb_tls[static_cast<size_t>(tid)];
|
|
548
|
+
ea += ea_tls[static_cast<size_t>(tid)];
|
|
549
|
+
ed += ed_tls[static_cast<size_t>(tid)];
|
|
550
|
+
ec += ec_tls[static_cast<size_t>(tid)];
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
for (int64_t idx = 0; idx < stride; ++idx) {
|
|
554
|
+
double acc = 0.0;
|
|
555
|
+
for (int tid = 0; tid < nthreads; ++tid) {
|
|
556
|
+
acc += static_cast<double>(
|
|
557
|
+
force_tls[static_cast<size_t>(tid) * static_cast<size_t>(stride) + static_cast<size_t>(idx)]);
|
|
558
|
+
}
|
|
559
|
+
f[idx] = static_cast<scalar_t>(acc);
|
|
560
|
+
}
|
|
561
|
+
|
|
562
|
+
e_bond.fill_(static_cast<scalar_t>(eb));
|
|
563
|
+
e_angle.fill_(static_cast<scalar_t>(ea));
|
|
564
|
+
e_dihed.fill_(static_cast<scalar_t>(ed));
|
|
565
|
+
e_cmap.fill_(static_cast<scalar_t>(ec));
|
|
566
|
+
});
|
|
567
|
+
|
|
568
|
+
return {e_bond, e_angle, e_dihed, e_cmap, force};
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
} // namespace
|
|
572
|
+
|
|
573
|
+
std::vector<torch::Tensor> bonded_energy_force(
|
|
574
|
+
const torch::Tensor& coords,
|
|
575
|
+
const torch::Tensor& bond_i,
|
|
576
|
+
const torch::Tensor& bond_j,
|
|
577
|
+
const torch::Tensor& bond_k,
|
|
578
|
+
const torch::Tensor& bond_r0,
|
|
579
|
+
const torch::Tensor& angle_i,
|
|
580
|
+
const torch::Tensor& angle_j,
|
|
581
|
+
const torch::Tensor& angle_k,
|
|
582
|
+
const torch::Tensor& angle_k0,
|
|
583
|
+
const torch::Tensor& angle_t0,
|
|
584
|
+
const torch::Tensor& dihed_i,
|
|
585
|
+
const torch::Tensor& dihed_j,
|
|
586
|
+
const torch::Tensor& dihed_k,
|
|
587
|
+
const torch::Tensor& dihed_l,
|
|
588
|
+
const torch::Tensor& dihed_force,
|
|
589
|
+
const torch::Tensor& dihed_period,
|
|
590
|
+
const torch::Tensor& dihed_phase,
|
|
591
|
+
const torch::Tensor& cmap_type,
|
|
592
|
+
const torch::Tensor& cmap_i,
|
|
593
|
+
const torch::Tensor& cmap_j,
|
|
594
|
+
const torch::Tensor& cmap_k,
|
|
595
|
+
const torch::Tensor& cmap_l,
|
|
596
|
+
const torch::Tensor& cmap_m,
|
|
597
|
+
const torch::Tensor& cmap_size,
|
|
598
|
+
const torch::Tensor& cmap_delta,
|
|
599
|
+
const torch::Tensor& cmap_offset,
|
|
600
|
+
const torch::Tensor& cmap_coeff) {
|
|
601
|
+
if (!coords.device().is_cpu()) {
|
|
602
|
+
TORCH_CHECK(false, "bonded_energy_force currently supports CPU tensors only");
|
|
603
|
+
}
|
|
604
|
+
return bonded_energy_force_cpu(
|
|
605
|
+
coords,
|
|
606
|
+
bond_i,
|
|
607
|
+
bond_j,
|
|
608
|
+
bond_k,
|
|
609
|
+
bond_r0,
|
|
610
|
+
angle_i,
|
|
611
|
+
angle_j,
|
|
612
|
+
angle_k,
|
|
613
|
+
angle_k0,
|
|
614
|
+
angle_t0,
|
|
615
|
+
dihed_i,
|
|
616
|
+
dihed_j,
|
|
617
|
+
dihed_k,
|
|
618
|
+
dihed_l,
|
|
619
|
+
dihed_force,
|
|
620
|
+
dihed_period,
|
|
621
|
+
dihed_phase,
|
|
622
|
+
cmap_type,
|
|
623
|
+
cmap_i,
|
|
624
|
+
cmap_j,
|
|
625
|
+
cmap_k,
|
|
626
|
+
cmap_l,
|
|
627
|
+
cmap_m,
|
|
628
|
+
cmap_size,
|
|
629
|
+
cmap_delta,
|
|
630
|
+
cmap_offset,
|
|
631
|
+
cmap_coeff);
|
|
632
|
+
}
|
|
633
|
+
|
|
634
|
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
635
|
+
m.doc() = "hessian_ff native bonded extension";
|
|
636
|
+
m.def(
|
|
637
|
+
"bonded_energy_force",
|
|
638
|
+
&bonded_energy_force,
|
|
639
|
+
"Compute bonded energy/force (bond/angle/dihedral/cmap)");
|
|
640
|
+
}
|