mlmm-toolkit 0.2.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hessian_ff/__init__.py +50 -0
- hessian_ff/analytical_hessian.py +609 -0
- hessian_ff/constants.py +46 -0
- hessian_ff/forcefield.py +339 -0
- hessian_ff/loaders.py +608 -0
- hessian_ff/native/Makefile +8 -0
- hessian_ff/native/__init__.py +28 -0
- hessian_ff/native/analytical_hessian.py +88 -0
- hessian_ff/native/analytical_hessian_ext.cpp +258 -0
- hessian_ff/native/bonded.py +82 -0
- hessian_ff/native/bonded_ext.cpp +640 -0
- hessian_ff/native/loader.py +349 -0
- hessian_ff/native/nonbonded.py +118 -0
- hessian_ff/native/nonbonded_ext.cpp +1150 -0
- hessian_ff/prmtop_parmed.py +23 -0
- hessian_ff/system.py +107 -0
- hessian_ff/terms/__init__.py +14 -0
- hessian_ff/terms/angle.py +73 -0
- hessian_ff/terms/bond.py +44 -0
- hessian_ff/terms/cmap.py +406 -0
- hessian_ff/terms/dihedral.py +141 -0
- hessian_ff/terms/nonbonded.py +209 -0
- hessian_ff/tests/__init__.py +0 -0
- hessian_ff/tests/conftest.py +75 -0
- hessian_ff/tests/data/small/complex.parm7 +1346 -0
- hessian_ff/tests/data/small/complex.pdb +125 -0
- hessian_ff/tests/data/small/complex.rst7 +63 -0
- hessian_ff/tests/test_coords_input.py +44 -0
- hessian_ff/tests/test_energy_force.py +49 -0
- hessian_ff/tests/test_hessian.py +137 -0
- hessian_ff/tests/test_smoke.py +18 -0
- hessian_ff/tests/test_validation.py +40 -0
- hessian_ff/workflows.py +889 -0
- mlmm/__init__.py +36 -0
- mlmm/__main__.py +7 -0
- mlmm/_version.py +34 -0
- mlmm/add_elem_info.py +374 -0
- mlmm/advanced_help.py +91 -0
- mlmm/align_freeze_atoms.py +601 -0
- mlmm/all.py +3535 -0
- mlmm/bond_changes.py +231 -0
- mlmm/bool_compat.py +223 -0
- mlmm/cli.py +574 -0
- mlmm/cli_utils.py +166 -0
- mlmm/default_group.py +337 -0
- mlmm/defaults.py +467 -0
- mlmm/define_layer.py +526 -0
- mlmm/dft.py +1041 -0
- mlmm/energy_diagram.py +253 -0
- mlmm/extract.py +2213 -0
- mlmm/fix_altloc.py +464 -0
- mlmm/freq.py +1406 -0
- mlmm/harmonic_constraints.py +140 -0
- mlmm/hessian_cache.py +44 -0
- mlmm/hessian_calc.py +174 -0
- mlmm/irc.py +638 -0
- mlmm/mlmm_calc.py +2262 -0
- mlmm/mm_parm.py +945 -0
- mlmm/oniom_export.py +1983 -0
- mlmm/oniom_import.py +457 -0
- mlmm/opt.py +1742 -0
- mlmm/path_opt.py +1353 -0
- mlmm/path_search.py +2299 -0
- mlmm/preflight.py +88 -0
- mlmm/py.typed +1 -0
- mlmm/pysis_runner.py +45 -0
- mlmm/scan.py +1047 -0
- mlmm/scan2d.py +1226 -0
- mlmm/scan3d.py +1265 -0
- mlmm/scan_common.py +184 -0
- mlmm/summary_log.py +736 -0
- mlmm/trj2fig.py +448 -0
- mlmm/tsopt.py +2871 -0
- mlmm/utils.py +2309 -0
- mlmm/xtb_embedcharge_correction.py +475 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
- pysisyphus/Geometry.py +1667 -0
- pysisyphus/LICENSE +674 -0
- pysisyphus/TableFormatter.py +63 -0
- pysisyphus/TablePrinter.py +74 -0
- pysisyphus/__init__.py +12 -0
- pysisyphus/calculators/AFIR.py +452 -0
- pysisyphus/calculators/AnaPot.py +20 -0
- pysisyphus/calculators/AnaPot2.py +48 -0
- pysisyphus/calculators/AnaPot3.py +12 -0
- pysisyphus/calculators/AnaPot4.py +20 -0
- pysisyphus/calculators/AnaPotBase.py +337 -0
- pysisyphus/calculators/AnaPotCBM.py +25 -0
- pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
- pysisyphus/calculators/CFOUR.py +250 -0
- pysisyphus/calculators/Calculator.py +844 -0
- pysisyphus/calculators/CerjanMiller.py +24 -0
- pysisyphus/calculators/Composite.py +123 -0
- pysisyphus/calculators/ConicalIntersection.py +171 -0
- pysisyphus/calculators/DFTBp.py +430 -0
- pysisyphus/calculators/DFTD3.py +66 -0
- pysisyphus/calculators/DFTD4.py +84 -0
- pysisyphus/calculators/Dalton.py +61 -0
- pysisyphus/calculators/Dimer.py +681 -0
- pysisyphus/calculators/Dummy.py +20 -0
- pysisyphus/calculators/EGO.py +76 -0
- pysisyphus/calculators/EnergyMin.py +224 -0
- pysisyphus/calculators/ExternalPotential.py +264 -0
- pysisyphus/calculators/FakeASE.py +35 -0
- pysisyphus/calculators/FourWellAnaPot.py +28 -0
- pysisyphus/calculators/FreeEndNEBPot.py +39 -0
- pysisyphus/calculators/Gaussian09.py +18 -0
- pysisyphus/calculators/Gaussian16.py +726 -0
- pysisyphus/calculators/HardSphere.py +159 -0
- pysisyphus/calculators/IDPPCalculator.py +49 -0
- pysisyphus/calculators/IPIClient.py +133 -0
- pysisyphus/calculators/IPIServer.py +234 -0
- pysisyphus/calculators/LEPSBase.py +24 -0
- pysisyphus/calculators/LEPSExpr.py +139 -0
- pysisyphus/calculators/LennardJones.py +80 -0
- pysisyphus/calculators/MOPAC.py +219 -0
- pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
- pysisyphus/calculators/MultiCalc.py +85 -0
- pysisyphus/calculators/NFK.py +45 -0
- pysisyphus/calculators/OBabel.py +87 -0
- pysisyphus/calculators/ONIOMv2.py +1129 -0
- pysisyphus/calculators/ORCA.py +893 -0
- pysisyphus/calculators/ORCA5.py +6 -0
- pysisyphus/calculators/OpenMM.py +88 -0
- pysisyphus/calculators/OpenMolcas.py +281 -0
- pysisyphus/calculators/OverlapCalculator.py +908 -0
- pysisyphus/calculators/Psi4.py +218 -0
- pysisyphus/calculators/PyPsi4.py +37 -0
- pysisyphus/calculators/PySCF.py +341 -0
- pysisyphus/calculators/PyXTB.py +73 -0
- pysisyphus/calculators/QCEngine.py +106 -0
- pysisyphus/calculators/Rastrigin.py +22 -0
- pysisyphus/calculators/Remote.py +76 -0
- pysisyphus/calculators/Rosenbrock.py +15 -0
- pysisyphus/calculators/SocketCalc.py +97 -0
- pysisyphus/calculators/TIP3P.py +111 -0
- pysisyphus/calculators/TransTorque.py +161 -0
- pysisyphus/calculators/Turbomole.py +965 -0
- pysisyphus/calculators/VRIPot.py +37 -0
- pysisyphus/calculators/WFOWrapper.py +333 -0
- pysisyphus/calculators/WFOWrapper2.py +341 -0
- pysisyphus/calculators/XTB.py +418 -0
- pysisyphus/calculators/__init__.py +81 -0
- pysisyphus/calculators/cosmo_data.py +139 -0
- pysisyphus/calculators/parser.py +150 -0
- pysisyphus/color.py +19 -0
- pysisyphus/config.py +133 -0
- pysisyphus/constants.py +65 -0
- pysisyphus/cos/AdaptiveNEB.py +230 -0
- pysisyphus/cos/ChainOfStates.py +725 -0
- pysisyphus/cos/FreeEndNEB.py +25 -0
- pysisyphus/cos/FreezingString.py +103 -0
- pysisyphus/cos/GrowingChainOfStates.py +71 -0
- pysisyphus/cos/GrowingNT.py +309 -0
- pysisyphus/cos/GrowingString.py +508 -0
- pysisyphus/cos/NEB.py +189 -0
- pysisyphus/cos/SimpleZTS.py +64 -0
- pysisyphus/cos/__init__.py +22 -0
- pysisyphus/cos/stiffness.py +199 -0
- pysisyphus/drivers/__init__.py +17 -0
- pysisyphus/drivers/afir.py +855 -0
- pysisyphus/drivers/barriers.py +271 -0
- pysisyphus/drivers/birkholz.py +138 -0
- pysisyphus/drivers/cluster.py +318 -0
- pysisyphus/drivers/diabatization.py +133 -0
- pysisyphus/drivers/merge.py +368 -0
- pysisyphus/drivers/merge_mol2.py +322 -0
- pysisyphus/drivers/opt.py +375 -0
- pysisyphus/drivers/perf.py +91 -0
- pysisyphus/drivers/pka.py +52 -0
- pysisyphus/drivers/precon_pos_rot.py +669 -0
- pysisyphus/drivers/rates.py +480 -0
- pysisyphus/drivers/replace.py +219 -0
- pysisyphus/drivers/scan.py +212 -0
- pysisyphus/drivers/spectrum.py +166 -0
- pysisyphus/drivers/thermo.py +31 -0
- pysisyphus/dynamics/Gaussian.py +103 -0
- pysisyphus/dynamics/__init__.py +20 -0
- pysisyphus/dynamics/colvars.py +136 -0
- pysisyphus/dynamics/driver.py +297 -0
- pysisyphus/dynamics/helpers.py +256 -0
- pysisyphus/dynamics/lincs.py +105 -0
- pysisyphus/dynamics/mdp.py +364 -0
- pysisyphus/dynamics/rattle.py +121 -0
- pysisyphus/dynamics/thermostats.py +128 -0
- pysisyphus/dynamics/wigner.py +266 -0
- pysisyphus/elem_data.py +3473 -0
- pysisyphus/exceptions.py +2 -0
- pysisyphus/filtertrj.py +69 -0
- pysisyphus/helpers.py +623 -0
- pysisyphus/helpers_pure.py +649 -0
- pysisyphus/init_logging.py +50 -0
- pysisyphus/intcoords/Bend.py +69 -0
- pysisyphus/intcoords/Bend2.py +25 -0
- pysisyphus/intcoords/BondedFragment.py +32 -0
- pysisyphus/intcoords/Cartesian.py +41 -0
- pysisyphus/intcoords/CartesianCoords.py +140 -0
- pysisyphus/intcoords/Coords.py +56 -0
- pysisyphus/intcoords/DLC.py +197 -0
- pysisyphus/intcoords/DistanceFunction.py +34 -0
- pysisyphus/intcoords/DummyImproper.py +70 -0
- pysisyphus/intcoords/DummyTorsion.py +72 -0
- pysisyphus/intcoords/LinearBend.py +105 -0
- pysisyphus/intcoords/LinearDisplacement.py +80 -0
- pysisyphus/intcoords/OutOfPlane.py +59 -0
- pysisyphus/intcoords/PrimTypes.py +286 -0
- pysisyphus/intcoords/Primitive.py +137 -0
- pysisyphus/intcoords/RedundantCoords.py +659 -0
- pysisyphus/intcoords/RobustTorsion.py +59 -0
- pysisyphus/intcoords/Rotation.py +147 -0
- pysisyphus/intcoords/Stretch.py +31 -0
- pysisyphus/intcoords/Torsion.py +101 -0
- pysisyphus/intcoords/Torsion2.py +25 -0
- pysisyphus/intcoords/Translation.py +45 -0
- pysisyphus/intcoords/__init__.py +61 -0
- pysisyphus/intcoords/augment_bonds.py +126 -0
- pysisyphus/intcoords/derivatives.py +10512 -0
- pysisyphus/intcoords/eval.py +80 -0
- pysisyphus/intcoords/exceptions.py +37 -0
- pysisyphus/intcoords/findiffs.py +48 -0
- pysisyphus/intcoords/generate_derivatives.py +414 -0
- pysisyphus/intcoords/helpers.py +235 -0
- pysisyphus/intcoords/logging_conf.py +10 -0
- pysisyphus/intcoords/mp_derivatives.py +10836 -0
- pysisyphus/intcoords/setup.py +962 -0
- pysisyphus/intcoords/setup_fast.py +176 -0
- pysisyphus/intcoords/update.py +272 -0
- pysisyphus/intcoords/valid.py +89 -0
- pysisyphus/interpolate/Geodesic.py +93 -0
- pysisyphus/interpolate/IDPP.py +55 -0
- pysisyphus/interpolate/Interpolator.py +116 -0
- pysisyphus/interpolate/LST.py +70 -0
- pysisyphus/interpolate/Redund.py +152 -0
- pysisyphus/interpolate/__init__.py +9 -0
- pysisyphus/interpolate/helpers.py +34 -0
- pysisyphus/io/__init__.py +22 -0
- pysisyphus/io/aomix.py +178 -0
- pysisyphus/io/cjson.py +24 -0
- pysisyphus/io/crd.py +101 -0
- pysisyphus/io/cube.py +220 -0
- pysisyphus/io/fchk.py +184 -0
- pysisyphus/io/hdf5.py +49 -0
- pysisyphus/io/hessian.py +72 -0
- pysisyphus/io/mol2.py +146 -0
- pysisyphus/io/molden.py +293 -0
- pysisyphus/io/orca.py +189 -0
- pysisyphus/io/pdb.py +269 -0
- pysisyphus/io/psf.py +79 -0
- pysisyphus/io/pubchem.py +31 -0
- pysisyphus/io/qcschema.py +34 -0
- pysisyphus/io/sdf.py +29 -0
- pysisyphus/io/xyz.py +61 -0
- pysisyphus/io/zmat.py +175 -0
- pysisyphus/irc/DWI.py +108 -0
- pysisyphus/irc/DampedVelocityVerlet.py +134 -0
- pysisyphus/irc/Euler.py +22 -0
- pysisyphus/irc/EulerPC.py +345 -0
- pysisyphus/irc/GonzalezSchlegel.py +187 -0
- pysisyphus/irc/IMKMod.py +164 -0
- pysisyphus/irc/IRC.py +878 -0
- pysisyphus/irc/IRCDummy.py +10 -0
- pysisyphus/irc/Instanton.py +307 -0
- pysisyphus/irc/LQA.py +53 -0
- pysisyphus/irc/ModeKill.py +136 -0
- pysisyphus/irc/ParamPlot.py +53 -0
- pysisyphus/irc/RK4.py +36 -0
- pysisyphus/irc/__init__.py +31 -0
- pysisyphus/irc/initial_displ.py +219 -0
- pysisyphus/linalg.py +411 -0
- pysisyphus/line_searches/Backtracking.py +88 -0
- pysisyphus/line_searches/HagerZhang.py +184 -0
- pysisyphus/line_searches/LineSearch.py +232 -0
- pysisyphus/line_searches/StrongWolfe.py +108 -0
- pysisyphus/line_searches/__init__.py +9 -0
- pysisyphus/line_searches/interpol.py +15 -0
- pysisyphus/modefollow/NormalMode.py +40 -0
- pysisyphus/modefollow/__init__.py +10 -0
- pysisyphus/modefollow/davidson.py +199 -0
- pysisyphus/modefollow/lanczos.py +95 -0
- pysisyphus/optimizers/BFGS.py +99 -0
- pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
- pysisyphus/optimizers/ConjugateGradient.py +98 -0
- pysisyphus/optimizers/CubicNewton.py +75 -0
- pysisyphus/optimizers/FIRE.py +113 -0
- pysisyphus/optimizers/HessianOptimizer.py +1176 -0
- pysisyphus/optimizers/LBFGS.py +228 -0
- pysisyphus/optimizers/LayerOpt.py +411 -0
- pysisyphus/optimizers/MicroOptimizer.py +169 -0
- pysisyphus/optimizers/NCOptimizer.py +90 -0
- pysisyphus/optimizers/Optimizer.py +1084 -0
- pysisyphus/optimizers/PreconLBFGS.py +260 -0
- pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
- pysisyphus/optimizers/QuickMin.py +74 -0
- pysisyphus/optimizers/RFOptimizer.py +181 -0
- pysisyphus/optimizers/RSA.py +99 -0
- pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
- pysisyphus/optimizers/SteepestDescent.py +23 -0
- pysisyphus/optimizers/StringOptimizer.py +173 -0
- pysisyphus/optimizers/__init__.py +41 -0
- pysisyphus/optimizers/closures.py +301 -0
- pysisyphus/optimizers/cls_map.py +58 -0
- pysisyphus/optimizers/exceptions.py +6 -0
- pysisyphus/optimizers/gdiis.py +280 -0
- pysisyphus/optimizers/guess_hessians.py +311 -0
- pysisyphus/optimizers/hessian_updates.py +355 -0
- pysisyphus/optimizers/poly_fit.py +285 -0
- pysisyphus/optimizers/precon.py +153 -0
- pysisyphus/optimizers/restrict_step.py +24 -0
- pysisyphus/pack.py +172 -0
- pysisyphus/peakdetect.py +948 -0
- pysisyphus/plot.py +1031 -0
- pysisyphus/run.py +2106 -0
- pysisyphus/socket_helper.py +74 -0
- pysisyphus/stocastic/FragmentKick.py +132 -0
- pysisyphus/stocastic/Kick.py +81 -0
- pysisyphus/stocastic/Pipeline.py +303 -0
- pysisyphus/stocastic/__init__.py +21 -0
- pysisyphus/stocastic/align.py +127 -0
- pysisyphus/testing.py +96 -0
- pysisyphus/thermo.py +156 -0
- pysisyphus/trj.py +824 -0
- pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
- pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
- pysisyphus/tsoptimizers/TRIM.py +59 -0
- pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
- pysisyphus/tsoptimizers/__init__.py +23 -0
- pysisyphus/wavefunction/Basis.py +239 -0
- pysisyphus/wavefunction/DIIS.py +76 -0
- pysisyphus/wavefunction/__init__.py +25 -0
- pysisyphus/wavefunction/build_ext.py +42 -0
- pysisyphus/wavefunction/cart2sph.py +190 -0
- pysisyphus/wavefunction/diabatization.py +304 -0
- pysisyphus/wavefunction/excited_states.py +435 -0
- pysisyphus/wavefunction/gen_ints.py +1811 -0
- pysisyphus/wavefunction/helpers.py +104 -0
- pysisyphus/wavefunction/ints/__init__.py +0 -0
- pysisyphus/wavefunction/ints/boys.py +193 -0
- pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
- pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
- pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
- pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
- pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
- pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
- pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
- pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
- pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
- pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
- pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
- pysisyphus/wavefunction/localization.py +458 -0
- pysisyphus/wavefunction/multipole.py +159 -0
- pysisyphus/wavefunction/normalization.py +36 -0
- pysisyphus/wavefunction/pop_analysis.py +134 -0
- pysisyphus/wavefunction/shells.py +1171 -0
- pysisyphus/wavefunction/wavefunction.py +504 -0
- pysisyphus/wrapper/__init__.py +11 -0
- pysisyphus/wrapper/exceptions.py +2 -0
- pysisyphus/wrapper/jmol.py +120 -0
- pysisyphus/wrapper/mwfn.py +169 -0
- pysisyphus/wrapper/packmol.py +71 -0
- pysisyphus/xyzloader.py +168 -0
- pysisyphus/yaml_mods.py +45 -0
- thermoanalysis/LICENSE +674 -0
- thermoanalysis/QCData.py +244 -0
- thermoanalysis/__init__.py +0 -0
- thermoanalysis/config.py +3 -0
- thermoanalysis/constants.py +20 -0
- thermoanalysis/thermo.py +1011 -0
|
@@ -0,0 +1,1150 @@
|
|
|
1
|
+
#include <torch/extension.h>
|
|
2
|
+
|
|
3
|
+
#include <algorithm>
|
|
4
|
+
#include <cmath>
|
|
5
|
+
#include <cstdint>
|
|
6
|
+
#include <cstdlib>
|
|
7
|
+
#include <vector>
|
|
8
|
+
|
|
9
|
+
#ifdef _OPENMP
|
|
10
|
+
#include <omp.h>
|
|
11
|
+
#endif
|
|
12
|
+
|
|
13
|
+
// References:
|
|
14
|
+
// - hessian_ff/terms/nonbonded.py formulas in this repository.
|
|
15
|
+
// - OpenMM NonbondedForce conventions (Coulomb + LJ/HB style terms).
|
|
16
|
+
// - OpenMM CPU threading model/property (reference for thread-count behavior):
|
|
17
|
+
// https://github.com/openmm/openmm/blob/master/platforms/cpu/src/CpuPlatform.cpp
|
|
18
|
+
// https://github.com/openmm/openmm/blob/master/platforms/cpu/src/CpuPlatform.h
|
|
19
|
+
// Note:
|
|
20
|
+
// - The adaptive pair-count-based thread selection implemented below is NOT a
|
|
21
|
+
// byte-for-byte OpenMM policy. It is a hessian_ff-specific optimization
|
|
22
|
+
// inspired by OpenMM's CPU-threaded execution model.
|
|
23
|
+
|
|
24
|
+
namespace {
|
|
25
|
+
|
|
26
|
+
constexpr double kCoulomb = 332.0637132991921;
|
|
27
|
+
|
|
28
|
+
int64_t min_pairs_per_thread_threshold(int64_t requested) {
|
|
29
|
+
if (requested > 0) {
|
|
30
|
+
return requested;
|
|
31
|
+
}
|
|
32
|
+
// Tuned via benchmark sweep on bundled small+large data.
|
|
33
|
+
// Current default prioritizes aggregate throughput across representative sizes.
|
|
34
|
+
constexpr int64_t kDefault = 100000;
|
|
35
|
+
const char* env = std::getenv("HESSIAN_FF_MIN_PAIRS_PER_THREAD");
|
|
36
|
+
if (env == nullptr || env[0] == '\0') {
|
|
37
|
+
return kDefault;
|
|
38
|
+
}
|
|
39
|
+
char* end = nullptr;
|
|
40
|
+
const long long v = std::strtoll(env, &end, 10);
|
|
41
|
+
if (end == env || v <= 0) {
|
|
42
|
+
return kDefault;
|
|
43
|
+
}
|
|
44
|
+
return static_cast<int64_t>(v);
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
int effective_max_threads() {
|
|
48
|
+
#ifdef _OPENMP
|
|
49
|
+
int max_threads = std::max(1, omp_get_max_threads());
|
|
50
|
+
const int nprocs = std::max(1, omp_get_num_procs());
|
|
51
|
+
|
|
52
|
+
const char* env = std::getenv("HESSIAN_FF_NONBONDED_MAX_THREADS");
|
|
53
|
+
if (env != nullptr && env[0] != '\0') {
|
|
54
|
+
char* end = nullptr;
|
|
55
|
+
const long long v = std::strtoll(env, &end, 10);
|
|
56
|
+
if (end != env && v > 0) {
|
|
57
|
+
max_threads = static_cast<int>(v);
|
|
58
|
+
}
|
|
59
|
+
} else {
|
|
60
|
+
// Prefer all OpenMP-visible processors by default. This avoids being
|
|
61
|
+
// unintentionally capped by unrelated intra-op thread settings.
|
|
62
|
+
max_threads = std::max(max_threads, nprocs);
|
|
63
|
+
}
|
|
64
|
+
return std::max(1, max_threads);
|
|
65
|
+
#else
|
|
66
|
+
return 1;
|
|
67
|
+
#endif
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> pair_energy_force_aten(
|
|
71
|
+
const torch::Tensor& coords,
|
|
72
|
+
const torch::Tensor& charge,
|
|
73
|
+
const torch::Tensor& atom_type,
|
|
74
|
+
const torch::Tensor& lj_acoef,
|
|
75
|
+
const torch::Tensor& lj_bcoef,
|
|
76
|
+
const torch::Tensor& hb_acoef,
|
|
77
|
+
const torch::Tensor& hb_bcoef,
|
|
78
|
+
const torch::Tensor& nb_index,
|
|
79
|
+
const torch::Tensor& ii,
|
|
80
|
+
const torch::Tensor& jj,
|
|
81
|
+
const c10::optional<torch::Tensor>& inv_scee_opt,
|
|
82
|
+
const c10::optional<torch::Tensor>& inv_scnb_opt) {
|
|
83
|
+
auto rij = coords.index_select(0, jj) - coords.index_select(0, ii);
|
|
84
|
+
auto r2 = (rij * rij).sum(-1).clamp_min(1.0e-24);
|
|
85
|
+
auto inv_r = torch::rsqrt(r2);
|
|
86
|
+
auto inv_r2 = inv_r * inv_r;
|
|
87
|
+
|
|
88
|
+
auto qq = charge.index_select(0, ii) * charge.index_select(0, jj);
|
|
89
|
+
auto e_coul = qq * (double)kCoulomb * inv_r;
|
|
90
|
+
auto fscale_coul = -(qq * (double)kCoulomb) * inv_r2 * inv_r;
|
|
91
|
+
|
|
92
|
+
auto ti = atom_type.index_select(0, ii);
|
|
93
|
+
auto tj = atom_type.index_select(0, jj);
|
|
94
|
+
auto raw_idx = nb_index.index({ti, tj});
|
|
95
|
+
|
|
96
|
+
auto inv_r4 = inv_r2 * inv_r2;
|
|
97
|
+
auto inv_r6 = inv_r4 * inv_r2;
|
|
98
|
+
auto inv_r8 = inv_r6 * inv_r2;
|
|
99
|
+
auto inv_r10 = inv_r8 * inv_r2;
|
|
100
|
+
auto inv_r12 = inv_r6 * inv_r6;
|
|
101
|
+
auto inv_r14 = inv_r12 * inv_r2;
|
|
102
|
+
|
|
103
|
+
auto e_lj = torch::zeros_like(inv_r);
|
|
104
|
+
auto fscale_lj = torch::zeros_like(inv_r);
|
|
105
|
+
|
|
106
|
+
auto lj_mask = raw_idx > 0;
|
|
107
|
+
if (lj_mask.any().item<bool>()) {
|
|
108
|
+
auto lj_idx = raw_idx.index({lj_mask}) - 1;
|
|
109
|
+
auto a = lj_acoef.index_select(0, lj_idx);
|
|
110
|
+
auto b = lj_bcoef.index_select(0, lj_idx);
|
|
111
|
+
e_lj.index_put_({lj_mask}, a * inv_r12.index({lj_mask}) - b * inv_r6.index({lj_mask}));
|
|
112
|
+
fscale_lj.index_put_(
|
|
113
|
+
{lj_mask},
|
|
114
|
+
-12.0 * a * inv_r14.index({lj_mask}) + 6.0 * b * inv_r8.index({lj_mask}));
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
auto hb_mask = raw_idx < 0;
|
|
118
|
+
if (hb_mask.any().item<bool>() && hb_acoef.numel() > 0 && hb_bcoef.numel() > 0) {
|
|
119
|
+
auto hb_idx_full = (-raw_idx.index({hb_mask})) - 1;
|
|
120
|
+
auto valid = (hb_idx_full >= 0) & (hb_idx_full < hb_acoef.numel()) &
|
|
121
|
+
(hb_idx_full < hb_bcoef.numel());
|
|
122
|
+
if (valid.any().item<bool>()) {
|
|
123
|
+
auto hb_idx = hb_idx_full.index({valid});
|
|
124
|
+
auto hb_pos = torch::nonzero(hb_mask).reshape(-1).index({valid});
|
|
125
|
+
auto a = hb_acoef.index_select(0, hb_idx);
|
|
126
|
+
auto b = hb_bcoef.index_select(0, hb_idx);
|
|
127
|
+
e_lj.index_put_({hb_pos}, a * inv_r12.index({hb_pos}) - b * inv_r10.index({hb_pos}));
|
|
128
|
+
fscale_lj.index_put_(
|
|
129
|
+
{hb_pos},
|
|
130
|
+
-12.0 * a * inv_r14.index({hb_pos}) + 10.0 * b * inv_r12.index({hb_pos}));
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
if (inv_scee_opt.has_value()) {
|
|
135
|
+
auto inv_scee = inv_scee_opt.value();
|
|
136
|
+
e_coul = e_coul * inv_scee;
|
|
137
|
+
fscale_coul = fscale_coul * inv_scee;
|
|
138
|
+
}
|
|
139
|
+
if (inv_scnb_opt.has_value()) {
|
|
140
|
+
auto inv_scnb = inv_scnb_opt.value();
|
|
141
|
+
e_lj = e_lj * inv_scnb;
|
|
142
|
+
fscale_lj = fscale_lj * inv_scnb;
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
auto fij = (fscale_coul + fscale_lj).unsqueeze(-1) * rij;
|
|
146
|
+
return {e_coul.sum(), e_lj.sum(), fij};
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> pair_energy_force_preparam_aten(
|
|
150
|
+
const torch::Tensor& coords,
|
|
151
|
+
const torch::Tensor& ii,
|
|
152
|
+
const torch::Tensor& jj,
|
|
153
|
+
const torch::Tensor& coul_coeff,
|
|
154
|
+
const torch::Tensor& a12_coeff,
|
|
155
|
+
const torch::Tensor& b6_coeff,
|
|
156
|
+
const torch::Tensor& b10_coeff) {
|
|
157
|
+
auto rij = coords.index_select(0, jj) - coords.index_select(0, ii);
|
|
158
|
+
auto r2 = (rij * rij).sum(-1).clamp_min(1.0e-24);
|
|
159
|
+
auto inv_r = torch::rsqrt(r2);
|
|
160
|
+
auto inv_r2 = inv_r * inv_r;
|
|
161
|
+
auto inv_r4 = inv_r2 * inv_r2;
|
|
162
|
+
auto inv_r6 = inv_r4 * inv_r2;
|
|
163
|
+
auto inv_r8 = inv_r6 * inv_r2;
|
|
164
|
+
auto inv_r10 = inv_r8 * inv_r2;
|
|
165
|
+
auto inv_r12 = inv_r6 * inv_r6;
|
|
166
|
+
auto inv_r14 = inv_r12 * inv_r2;
|
|
167
|
+
|
|
168
|
+
auto e_coul = coul_coeff * inv_r;
|
|
169
|
+
auto e_lj = a12_coeff * inv_r12 - b6_coeff * inv_r6 - b10_coeff * inv_r10;
|
|
170
|
+
auto fscale = -coul_coeff * inv_r2 * inv_r - 12.0 * a12_coeff * inv_r14 +
|
|
171
|
+
6.0 * b6_coeff * inv_r8 + 10.0 * b10_coeff * inv_r12;
|
|
172
|
+
auto fij = fscale.unsqueeze(-1) * rij;
|
|
173
|
+
return {e_coul.sum(), e_lj.sum(), fij};
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
std::vector<torch::Tensor> nonbonded_energy_force_aten(
|
|
177
|
+
const torch::Tensor& coords,
|
|
178
|
+
const torch::Tensor& charge,
|
|
179
|
+
const torch::Tensor& atom_type,
|
|
180
|
+
const torch::Tensor& lj_acoef,
|
|
181
|
+
const torch::Tensor& lj_bcoef,
|
|
182
|
+
const torch::Tensor& hb_acoef,
|
|
183
|
+
const torch::Tensor& hb_bcoef,
|
|
184
|
+
const torch::Tensor& nb_index,
|
|
185
|
+
const torch::Tensor& pair_i,
|
|
186
|
+
const torch::Tensor& pair_j,
|
|
187
|
+
const torch::Tensor& pair14_i,
|
|
188
|
+
const torch::Tensor& pair14_j,
|
|
189
|
+
const torch::Tensor& pair14_inv_scee,
|
|
190
|
+
const torch::Tensor& pair14_inv_scnb,
|
|
191
|
+
int64_t chunk_size) {
|
|
192
|
+
TORCH_CHECK(chunk_size > 0, "chunk_size must be > 0");
|
|
193
|
+
TORCH_CHECK(coords.dim() == 2 && coords.size(1) == 3, "coords must be [N,3]");
|
|
194
|
+
TORCH_CHECK(coords.is_floating_point(), "coords must be floating tensor");
|
|
195
|
+
TORCH_CHECK(charge.is_floating_point(), "charge must be floating tensor");
|
|
196
|
+
TORCH_CHECK(
|
|
197
|
+
coords.scalar_type() == charge.scalar_type(),
|
|
198
|
+
"coords/charge dtype mismatch in nonbonded_energy_force");
|
|
199
|
+
TORCH_CHECK(
|
|
200
|
+
coords.scalar_type() == lj_acoef.scalar_type() && coords.scalar_type() == lj_bcoef.scalar_type(),
|
|
201
|
+
"coords/LJ coeff dtype mismatch in nonbonded_energy_force");
|
|
202
|
+
TORCH_CHECK(
|
|
203
|
+
coords.scalar_type() == hb_acoef.scalar_type() && coords.scalar_type() == hb_bcoef.scalar_type(),
|
|
204
|
+
"coords/HB coeff dtype mismatch in nonbonded_energy_force");
|
|
205
|
+
TORCH_CHECK(atom_type.scalar_type() == torch::kInt64, "atom_type must be int64");
|
|
206
|
+
TORCH_CHECK(nb_index.scalar_type() == torch::kInt64, "nb_index must be int64");
|
|
207
|
+
TORCH_CHECK(pair_i.scalar_type() == torch::kInt64, "pair_i must be int64");
|
|
208
|
+
TORCH_CHECK(pair_j.scalar_type() == torch::kInt64, "pair_j must be int64");
|
|
209
|
+
TORCH_CHECK(pair14_i.scalar_type() == torch::kInt64, "pair14_i must be int64");
|
|
210
|
+
TORCH_CHECK(pair14_j.scalar_type() == torch::kInt64, "pair14_j must be int64");
|
|
211
|
+
TORCH_CHECK(
|
|
212
|
+
pair_i.device() == coords.device() && pair_j.device() == coords.device() &&
|
|
213
|
+
pair14_i.device() == coords.device() && pair14_j.device() == coords.device(),
|
|
214
|
+
"pair index tensors must be on same device as coords");
|
|
215
|
+
TORCH_CHECK(
|
|
216
|
+
charge.device() == coords.device() && atom_type.device() == coords.device() &&
|
|
217
|
+
lj_acoef.device() == coords.device() && lj_bcoef.device() == coords.device() &&
|
|
218
|
+
hb_acoef.device() == coords.device() && hb_bcoef.device() == coords.device() &&
|
|
219
|
+
nb_index.device() == coords.device() && pair14_inv_scee.device() == coords.device() &&
|
|
220
|
+
pair14_inv_scnb.device() == coords.device(),
|
|
221
|
+
"nonbonded parameter tensors must be on same device as coords");
|
|
222
|
+
TORCH_CHECK(coords.size(0) == charge.numel(), "coords/charge atom count mismatch");
|
|
223
|
+
TORCH_CHECK(atom_type.numel() == charge.numel(), "atom_type/charge atom count mismatch");
|
|
224
|
+
TORCH_CHECK(pair_i.numel() == pair_j.numel(), "pair_i/pair_j size mismatch");
|
|
225
|
+
TORCH_CHECK(pair14_i.numel() == pair14_j.numel(), "pair14_i/pair14_j size mismatch");
|
|
226
|
+
TORCH_CHECK(
|
|
227
|
+
pair14_i.numel() == pair14_inv_scee.numel() && pair14_i.numel() == pair14_inv_scnb.numel(),
|
|
228
|
+
"1-4 pair/scale size mismatch");
|
|
229
|
+
|
|
230
|
+
auto e_coul = torch::zeros({}, coords.options());
|
|
231
|
+
auto e_lj = torch::zeros({}, coords.options());
|
|
232
|
+
auto e_coul14 = torch::zeros({}, coords.options());
|
|
233
|
+
auto e_lj14 = torch::zeros({}, coords.options());
|
|
234
|
+
auto force = torch::zeros_like(coords);
|
|
235
|
+
|
|
236
|
+
const int64_t n = pair_i.numel();
|
|
237
|
+
for (int64_t start = 0; start < n; start += chunk_size) {
|
|
238
|
+
int64_t end = std::min(start + chunk_size, n);
|
|
239
|
+
auto ii = pair_i.slice(0, start, end);
|
|
240
|
+
auto jj = pair_j.slice(0, start, end);
|
|
241
|
+
auto out = pair_energy_force_aten(
|
|
242
|
+
coords,
|
|
243
|
+
charge,
|
|
244
|
+
atom_type,
|
|
245
|
+
lj_acoef,
|
|
246
|
+
lj_bcoef,
|
|
247
|
+
hb_acoef,
|
|
248
|
+
hb_bcoef,
|
|
249
|
+
nb_index,
|
|
250
|
+
ii,
|
|
251
|
+
jj,
|
|
252
|
+
c10::nullopt,
|
|
253
|
+
c10::nullopt);
|
|
254
|
+
auto ce = std::get<0>(out);
|
|
255
|
+
auto le = std::get<1>(out);
|
|
256
|
+
auto fij = std::get<2>(out);
|
|
257
|
+
e_coul = e_coul + ce;
|
|
258
|
+
e_lj = e_lj + le;
|
|
259
|
+
force.index_add_(0, ii, fij);
|
|
260
|
+
force.index_add_(0, jj, -fij);
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
const int64_t n14 = pair14_i.numel();
|
|
264
|
+
for (int64_t start = 0; start < n14; start += chunk_size) {
|
|
265
|
+
int64_t end = std::min(start + chunk_size, n14);
|
|
266
|
+
auto ii = pair14_i.slice(0, start, end);
|
|
267
|
+
auto jj = pair14_j.slice(0, start, end);
|
|
268
|
+
auto inv_scee = pair14_inv_scee.slice(0, start, end);
|
|
269
|
+
auto inv_scnb = pair14_inv_scnb.slice(0, start, end);
|
|
270
|
+
auto out = pair_energy_force_aten(
|
|
271
|
+
coords,
|
|
272
|
+
charge,
|
|
273
|
+
atom_type,
|
|
274
|
+
lj_acoef,
|
|
275
|
+
lj_bcoef,
|
|
276
|
+
hb_acoef,
|
|
277
|
+
hb_bcoef,
|
|
278
|
+
nb_index,
|
|
279
|
+
ii,
|
|
280
|
+
jj,
|
|
281
|
+
inv_scee,
|
|
282
|
+
inv_scnb);
|
|
283
|
+
auto ce = std::get<0>(out);
|
|
284
|
+
auto le = std::get<1>(out);
|
|
285
|
+
auto fij = std::get<2>(out);
|
|
286
|
+
e_coul14 = e_coul14 + ce;
|
|
287
|
+
e_lj14 = e_lj14 + le;
|
|
288
|
+
force.index_add_(0, ii, fij);
|
|
289
|
+
force.index_add_(0, jj, -fij);
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
return {e_coul, e_lj, e_coul14, e_lj14, force};
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
std::vector<torch::Tensor> nonbonded_energy_force_preparam_aten(
|
|
296
|
+
const torch::Tensor& coords,
|
|
297
|
+
const torch::Tensor& pair_i,
|
|
298
|
+
const torch::Tensor& pair_j,
|
|
299
|
+
const torch::Tensor& pair_coul_coeff,
|
|
300
|
+
const torch::Tensor& pair_a12_coeff,
|
|
301
|
+
const torch::Tensor& pair_b6_coeff,
|
|
302
|
+
const torch::Tensor& pair_b10_coeff,
|
|
303
|
+
const torch::Tensor& pair14_i,
|
|
304
|
+
const torch::Tensor& pair14_j,
|
|
305
|
+
const torch::Tensor& pair14_coul_coeff,
|
|
306
|
+
const torch::Tensor& pair14_a12_coeff,
|
|
307
|
+
const torch::Tensor& pair14_b6_coeff,
|
|
308
|
+
const torch::Tensor& pair14_b10_coeff,
|
|
309
|
+
int64_t chunk_size) {
|
|
310
|
+
TORCH_CHECK(chunk_size > 0, "chunk_size must be > 0");
|
|
311
|
+
TORCH_CHECK(coords.dim() == 2 && coords.size(1) == 3, "coords must be [N,3]");
|
|
312
|
+
TORCH_CHECK(coords.is_floating_point(), "coords must be floating tensor");
|
|
313
|
+
TORCH_CHECK(pair_i.scalar_type() == torch::kInt64, "pair_i must be int64");
|
|
314
|
+
TORCH_CHECK(pair_j.scalar_type() == torch::kInt64, "pair_j must be int64");
|
|
315
|
+
TORCH_CHECK(pair14_i.scalar_type() == torch::kInt64, "pair14_i must be int64");
|
|
316
|
+
TORCH_CHECK(pair14_j.scalar_type() == torch::kInt64, "pair14_j must be int64");
|
|
317
|
+
TORCH_CHECK(
|
|
318
|
+
pair_i.device() == coords.device() && pair_j.device() == coords.device() &&
|
|
319
|
+
pair14_i.device() == coords.device() && pair14_j.device() == coords.device(),
|
|
320
|
+
"pair index tensors must be on same device as coords");
|
|
321
|
+
TORCH_CHECK(
|
|
322
|
+
pair_coul_coeff.device() == coords.device() && pair_a12_coeff.device() == coords.device() &&
|
|
323
|
+
pair_b6_coeff.device() == coords.device() && pair_b10_coeff.device() == coords.device() &&
|
|
324
|
+
pair14_coul_coeff.device() == coords.device() && pair14_a12_coeff.device() == coords.device() &&
|
|
325
|
+
pair14_b6_coeff.device() == coords.device() && pair14_b10_coeff.device() == coords.device(),
|
|
326
|
+
"precomputed pair coeff tensors must be on same device as coords");
|
|
327
|
+
TORCH_CHECK(
|
|
328
|
+
pair_coul_coeff.scalar_type() == coords.scalar_type() &&
|
|
329
|
+
pair_a12_coeff.scalar_type() == coords.scalar_type() &&
|
|
330
|
+
pair_b6_coeff.scalar_type() == coords.scalar_type() &&
|
|
331
|
+
pair_b10_coeff.scalar_type() == coords.scalar_type() &&
|
|
332
|
+
pair14_coul_coeff.scalar_type() == coords.scalar_type() &&
|
|
333
|
+
pair14_a12_coeff.scalar_type() == coords.scalar_type() &&
|
|
334
|
+
pair14_b6_coeff.scalar_type() == coords.scalar_type() &&
|
|
335
|
+
pair14_b10_coeff.scalar_type() == coords.scalar_type(),
|
|
336
|
+
"precomputed pair coeff dtype mismatch with coords");
|
|
337
|
+
TORCH_CHECK(pair_i.numel() == pair_j.numel(), "pair_i/pair_j size mismatch");
|
|
338
|
+
TORCH_CHECK(pair14_i.numel() == pair14_j.numel(), "pair14_i/pair14_j size mismatch");
|
|
339
|
+
TORCH_CHECK(
|
|
340
|
+
pair_i.numel() == pair_coul_coeff.numel() && pair_i.numel() == pair_a12_coeff.numel() &&
|
|
341
|
+
pair_i.numel() == pair_b6_coeff.numel() && pair_i.numel() == pair_b10_coeff.numel(),
|
|
342
|
+
"general pair/coeff size mismatch");
|
|
343
|
+
TORCH_CHECK(
|
|
344
|
+
pair14_i.numel() == pair14_coul_coeff.numel() && pair14_i.numel() == pair14_a12_coeff.numel() &&
|
|
345
|
+
pair14_i.numel() == pair14_b6_coeff.numel() && pair14_i.numel() == pair14_b10_coeff.numel(),
|
|
346
|
+
"1-4 pair/coeff size mismatch");
|
|
347
|
+
|
|
348
|
+
auto e_coul = torch::zeros({}, coords.options());
|
|
349
|
+
auto e_lj = torch::zeros({}, coords.options());
|
|
350
|
+
auto e_coul14 = torch::zeros({}, coords.options());
|
|
351
|
+
auto e_lj14 = torch::zeros({}, coords.options());
|
|
352
|
+
auto force = torch::zeros_like(coords);
|
|
353
|
+
|
|
354
|
+
const int64_t n = pair_i.numel();
|
|
355
|
+
for (int64_t start = 0; start < n; start += chunk_size) {
|
|
356
|
+
int64_t end = std::min(start + chunk_size, n);
|
|
357
|
+
auto ii = pair_i.slice(0, start, end);
|
|
358
|
+
auto jj = pair_j.slice(0, start, end);
|
|
359
|
+
auto cc = pair_coul_coeff.slice(0, start, end);
|
|
360
|
+
auto a12 = pair_a12_coeff.slice(0, start, end);
|
|
361
|
+
auto b6 = pair_b6_coeff.slice(0, start, end);
|
|
362
|
+
auto b10 = pair_b10_coeff.slice(0, start, end);
|
|
363
|
+
auto out = pair_energy_force_preparam_aten(coords, ii, jj, cc, a12, b6, b10);
|
|
364
|
+
auto ce = std::get<0>(out);
|
|
365
|
+
auto le = std::get<1>(out);
|
|
366
|
+
auto fij = std::get<2>(out);
|
|
367
|
+
e_coul = e_coul + ce;
|
|
368
|
+
e_lj = e_lj + le;
|
|
369
|
+
force.index_add_(0, ii, fij);
|
|
370
|
+
force.index_add_(0, jj, -fij);
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
const int64_t n14 = pair14_i.numel();
|
|
374
|
+
for (int64_t start = 0; start < n14; start += chunk_size) {
|
|
375
|
+
int64_t end = std::min(start + chunk_size, n14);
|
|
376
|
+
auto ii = pair14_i.slice(0, start, end);
|
|
377
|
+
auto jj = pair14_j.slice(0, start, end);
|
|
378
|
+
auto cc = pair14_coul_coeff.slice(0, start, end);
|
|
379
|
+
auto a12 = pair14_a12_coeff.slice(0, start, end);
|
|
380
|
+
auto b6 = pair14_b6_coeff.slice(0, start, end);
|
|
381
|
+
auto b10 = pair14_b10_coeff.slice(0, start, end);
|
|
382
|
+
auto out = pair_energy_force_preparam_aten(coords, ii, jj, cc, a12, b6, b10);
|
|
383
|
+
auto ce = std::get<0>(out);
|
|
384
|
+
auto le = std::get<1>(out);
|
|
385
|
+
auto fij = std::get<2>(out);
|
|
386
|
+
e_coul14 = e_coul14 + ce;
|
|
387
|
+
e_lj14 = e_lj14 + le;
|
|
388
|
+
force.index_add_(0, ii, fij);
|
|
389
|
+
force.index_add_(0, jj, -fij);
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
return {e_coul, e_lj, e_coul14, e_lj14, force};
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
std::vector<torch::Tensor> nonbonded_energy_force_preparam_cpu(
|
|
396
|
+
const torch::Tensor& coords,
|
|
397
|
+
const torch::Tensor& pair_i,
|
|
398
|
+
const torch::Tensor& pair_j,
|
|
399
|
+
const torch::Tensor& pair_coul_coeff,
|
|
400
|
+
const torch::Tensor& pair_a12_coeff,
|
|
401
|
+
const torch::Tensor& pair_b6_coeff,
|
|
402
|
+
const torch::Tensor& pair_b10_coeff,
|
|
403
|
+
const torch::Tensor& pair14_i,
|
|
404
|
+
const torch::Tensor& pair14_j,
|
|
405
|
+
const torch::Tensor& pair14_coul_coeff,
|
|
406
|
+
const torch::Tensor& pair14_a12_coeff,
|
|
407
|
+
const torch::Tensor& pair14_b6_coeff,
|
|
408
|
+
const torch::Tensor& pair14_b10_coeff,
|
|
409
|
+
int64_t min_pairs_per_thread);
|
|
410
|
+
|
|
411
|
+
std::vector<torch::Tensor> nonbonded_energy_force_preparam(
|
|
412
|
+
const torch::Tensor& coords,
|
|
413
|
+
const torch::Tensor& pair_i,
|
|
414
|
+
const torch::Tensor& pair_j,
|
|
415
|
+
const torch::Tensor& pair_coul_coeff,
|
|
416
|
+
const torch::Tensor& pair_a12_coeff,
|
|
417
|
+
const torch::Tensor& pair_b6_coeff,
|
|
418
|
+
const torch::Tensor& pair_b10_coeff,
|
|
419
|
+
const torch::Tensor& pair14_i,
|
|
420
|
+
const torch::Tensor& pair14_j,
|
|
421
|
+
const torch::Tensor& pair14_coul_coeff,
|
|
422
|
+
const torch::Tensor& pair14_a12_coeff,
|
|
423
|
+
const torch::Tensor& pair14_b6_coeff,
|
|
424
|
+
const torch::Tensor& pair14_b10_coeff,
|
|
425
|
+
int64_t chunk_size,
|
|
426
|
+
bool cpu_fast,
|
|
427
|
+
int64_t min_pairs_per_thread) {
|
|
428
|
+
if (coords.device().is_cpu() && cpu_fast) {
|
|
429
|
+
return nonbonded_energy_force_preparam_cpu(
|
|
430
|
+
coords,
|
|
431
|
+
pair_i,
|
|
432
|
+
pair_j,
|
|
433
|
+
pair_coul_coeff,
|
|
434
|
+
pair_a12_coeff,
|
|
435
|
+
pair_b6_coeff,
|
|
436
|
+
pair_b10_coeff,
|
|
437
|
+
pair14_i,
|
|
438
|
+
pair14_j,
|
|
439
|
+
pair14_coul_coeff,
|
|
440
|
+
pair14_a12_coeff,
|
|
441
|
+
pair14_b6_coeff,
|
|
442
|
+
pair14_b10_coeff,
|
|
443
|
+
min_pairs_per_thread);
|
|
444
|
+
}
|
|
445
|
+
return nonbonded_energy_force_preparam_aten(
|
|
446
|
+
coords,
|
|
447
|
+
pair_i,
|
|
448
|
+
pair_j,
|
|
449
|
+
pair_coul_coeff,
|
|
450
|
+
pair_a12_coeff,
|
|
451
|
+
pair_b6_coeff,
|
|
452
|
+
pair_b10_coeff,
|
|
453
|
+
pair14_i,
|
|
454
|
+
pair14_j,
|
|
455
|
+
pair14_coul_coeff,
|
|
456
|
+
pair14_a12_coeff,
|
|
457
|
+
pair14_b6_coeff,
|
|
458
|
+
pair14_b10_coeff,
|
|
459
|
+
chunk_size);
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
template <typename scalar_t>
|
|
463
|
+
inline void accumulate_pair_one(
|
|
464
|
+
const scalar_t* coords,
|
|
465
|
+
const scalar_t* charge,
|
|
466
|
+
const int64_t* atom_type,
|
|
467
|
+
const scalar_t* lj_acoef,
|
|
468
|
+
const scalar_t* lj_bcoef,
|
|
469
|
+
int64_t n_lj,
|
|
470
|
+
const scalar_t* hb_acoef,
|
|
471
|
+
const scalar_t* hb_bcoef,
|
|
472
|
+
int64_t n_hb,
|
|
473
|
+
const int64_t* nb_index,
|
|
474
|
+
int64_t ntypes,
|
|
475
|
+
int64_t i,
|
|
476
|
+
int64_t j,
|
|
477
|
+
const scalar_t* inv_scee,
|
|
478
|
+
const scalar_t* inv_scnb,
|
|
479
|
+
int64_t p,
|
|
480
|
+
scalar_t* force,
|
|
481
|
+
double& e_coul,
|
|
482
|
+
double& e_lj) {
|
|
483
|
+
const double dx = static_cast<double>(coords[3 * j + 0] - coords[3 * i + 0]);
|
|
484
|
+
const double dy = static_cast<double>(coords[3 * j + 1] - coords[3 * i + 1]);
|
|
485
|
+
const double dz = static_cast<double>(coords[3 * j + 2] - coords[3 * i + 2]);
|
|
486
|
+
const double r2 = std::max(dx * dx + dy * dy + dz * dz, 1.0e-24);
|
|
487
|
+
const double inv_r = 1.0 / std::sqrt(r2);
|
|
488
|
+
const double inv_r2 = inv_r * inv_r;
|
|
489
|
+
|
|
490
|
+
const double qq = static_cast<double>(charge[i]) * static_cast<double>(charge[j]);
|
|
491
|
+
double ec = qq * kCoulomb * inv_r;
|
|
492
|
+
double f_coul = -(qq * kCoulomb) * inv_r2 * inv_r;
|
|
493
|
+
|
|
494
|
+
const int64_t ti = atom_type[i];
|
|
495
|
+
const int64_t tj = atom_type[j];
|
|
496
|
+
const int64_t raw = nb_index[ti * ntypes + tj];
|
|
497
|
+
|
|
498
|
+
const double inv_r4 = inv_r2 * inv_r2;
|
|
499
|
+
const double inv_r6 = inv_r4 * inv_r2;
|
|
500
|
+
const double inv_r8 = inv_r6 * inv_r2;
|
|
501
|
+
const double inv_r10 = inv_r8 * inv_r2;
|
|
502
|
+
const double inv_r12 = inv_r6 * inv_r6;
|
|
503
|
+
const double inv_r14 = inv_r12 * inv_r2;
|
|
504
|
+
|
|
505
|
+
double el = 0.0;
|
|
506
|
+
double f_lj = 0.0;
|
|
507
|
+
if (raw > 0) {
|
|
508
|
+
const int64_t idx = raw - 1;
|
|
509
|
+
if (idx >= 0 && idx < n_lj) {
|
|
510
|
+
const double a = static_cast<double>(lj_acoef[idx]);
|
|
511
|
+
const double b = static_cast<double>(lj_bcoef[idx]);
|
|
512
|
+
el = a * inv_r12 - b * inv_r6;
|
|
513
|
+
f_lj = -12.0 * a * inv_r14 + 6.0 * b * inv_r8;
|
|
514
|
+
}
|
|
515
|
+
} else if (raw < 0 && n_hb > 0) {
|
|
516
|
+
const int64_t idx = -raw - 1;
|
|
517
|
+
if (idx >= 0 && idx < n_hb) {
|
|
518
|
+
const double a = static_cast<double>(hb_acoef[idx]);
|
|
519
|
+
const double b = static_cast<double>(hb_bcoef[idx]);
|
|
520
|
+
el = a * inv_r12 - b * inv_r10;
|
|
521
|
+
f_lj = -12.0 * a * inv_r14 + 10.0 * b * inv_r12;
|
|
522
|
+
}
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
if (inv_scee != nullptr) {
|
|
526
|
+
const double s = static_cast<double>(inv_scee[p]);
|
|
527
|
+
ec *= s;
|
|
528
|
+
f_coul *= s;
|
|
529
|
+
}
|
|
530
|
+
if (inv_scnb != nullptr) {
|
|
531
|
+
const double s = static_cast<double>(inv_scnb[p]);
|
|
532
|
+
el *= s;
|
|
533
|
+
f_lj *= s;
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
const double fs = f_coul + f_lj;
|
|
537
|
+
const double fx = fs * dx;
|
|
538
|
+
const double fy = fs * dy;
|
|
539
|
+
const double fz = fs * dz;
|
|
540
|
+
|
|
541
|
+
force[3 * i + 0] += static_cast<scalar_t>(fx);
|
|
542
|
+
force[3 * i + 1] += static_cast<scalar_t>(fy);
|
|
543
|
+
force[3 * i + 2] += static_cast<scalar_t>(fz);
|
|
544
|
+
force[3 * j + 0] -= static_cast<scalar_t>(fx);
|
|
545
|
+
force[3 * j + 1] -= static_cast<scalar_t>(fy);
|
|
546
|
+
force[3 * j + 2] -= static_cast<scalar_t>(fz);
|
|
547
|
+
|
|
548
|
+
e_coul += ec;
|
|
549
|
+
e_lj += el;
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
template <typename scalar_t>
|
|
553
|
+
void accumulate_pairs_cpu(
|
|
554
|
+
const scalar_t* coords,
|
|
555
|
+
const scalar_t* charge,
|
|
556
|
+
const int64_t* atom_type,
|
|
557
|
+
const scalar_t* lj_acoef,
|
|
558
|
+
const scalar_t* lj_bcoef,
|
|
559
|
+
int64_t n_lj,
|
|
560
|
+
const scalar_t* hb_acoef,
|
|
561
|
+
const scalar_t* hb_bcoef,
|
|
562
|
+
int64_t n_hb,
|
|
563
|
+
const int64_t* nb_index,
|
|
564
|
+
int64_t ntypes,
|
|
565
|
+
const int64_t* ii,
|
|
566
|
+
const int64_t* jj,
|
|
567
|
+
const scalar_t* inv_scee,
|
|
568
|
+
const scalar_t* inv_scnb,
|
|
569
|
+
int64_t npairs,
|
|
570
|
+
int64_t natom,
|
|
571
|
+
int64_t min_pairs_per_thread,
|
|
572
|
+
std::vector<scalar_t>& force_out,
|
|
573
|
+
double& e_coul_out,
|
|
574
|
+
double& e_lj_out) {
|
|
575
|
+
const int64_t stride = natom * 3;
|
|
576
|
+
if (npairs <= 0) {
|
|
577
|
+
return;
|
|
578
|
+
}
|
|
579
|
+
|
|
580
|
+
int nthreads = 1;
|
|
581
|
+
#ifdef _OPENMP
|
|
582
|
+
const int max_threads = effective_max_threads();
|
|
583
|
+
// Adaptive threading:
|
|
584
|
+
// For small pair counts, OpenMP launch/reduction overhead dominates.
|
|
585
|
+
// Use fewer threads (or 1) to keep latency low on toy/small systems.
|
|
586
|
+
// This threshold rule is an hessian_ff heuristic (not OpenMM-identical).
|
|
587
|
+
// Priority: API argument -> env var -> internal default.
|
|
588
|
+
const int64_t kMinPairsPerThread = min_pairs_per_thread_threshold(min_pairs_per_thread);
|
|
589
|
+
nthreads = std::min<int>(
|
|
590
|
+
max_threads,
|
|
591
|
+
std::max<int64_t>(1, npairs / kMinPairsPerThread));
|
|
592
|
+
#endif
|
|
593
|
+
|
|
594
|
+
if (nthreads <= 1) {
|
|
595
|
+
scalar_t* f = force_out.data();
|
|
596
|
+
for (int64_t p = 0; p < npairs; ++p) {
|
|
597
|
+
accumulate_pair_one<scalar_t>(
|
|
598
|
+
coords,
|
|
599
|
+
charge,
|
|
600
|
+
atom_type,
|
|
601
|
+
lj_acoef,
|
|
602
|
+
lj_bcoef,
|
|
603
|
+
n_lj,
|
|
604
|
+
hb_acoef,
|
|
605
|
+
hb_bcoef,
|
|
606
|
+
n_hb,
|
|
607
|
+
nb_index,
|
|
608
|
+
ntypes,
|
|
609
|
+
ii[p],
|
|
610
|
+
jj[p],
|
|
611
|
+
inv_scee,
|
|
612
|
+
inv_scnb,
|
|
613
|
+
p,
|
|
614
|
+
f,
|
|
615
|
+
e_coul_out,
|
|
616
|
+
e_lj_out);
|
|
617
|
+
}
|
|
618
|
+
return;
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
std::vector<scalar_t> force_tls(static_cast<size_t>(nthreads) * static_cast<size_t>(stride), static_cast<scalar_t>(0));
|
|
622
|
+
std::vector<double> ec_tls(static_cast<size_t>(nthreads), 0.0);
|
|
623
|
+
std::vector<double> el_tls(static_cast<size_t>(nthreads), 0.0);
|
|
624
|
+
|
|
625
|
+
#ifdef _OPENMP
|
|
626
|
+
#pragma omp parallel num_threads(nthreads)
|
|
627
|
+
#endif
|
|
628
|
+
{
|
|
629
|
+
int tid = 0;
|
|
630
|
+
#ifdef _OPENMP
|
|
631
|
+
tid = omp_get_thread_num();
|
|
632
|
+
#endif
|
|
633
|
+
auto* f = force_tls.data() + static_cast<size_t>(tid) * static_cast<size_t>(stride);
|
|
634
|
+
double ec_loc = 0.0;
|
|
635
|
+
double el_loc = 0.0;
|
|
636
|
+
#ifdef _OPENMP
|
|
637
|
+
#pragma omp for schedule(static)
|
|
638
|
+
#endif
|
|
639
|
+
for (int64_t p = 0; p < npairs; ++p) {
|
|
640
|
+
accumulate_pair_one<scalar_t>(
|
|
641
|
+
coords,
|
|
642
|
+
charge,
|
|
643
|
+
atom_type,
|
|
644
|
+
lj_acoef,
|
|
645
|
+
lj_bcoef,
|
|
646
|
+
n_lj,
|
|
647
|
+
hb_acoef,
|
|
648
|
+
hb_bcoef,
|
|
649
|
+
n_hb,
|
|
650
|
+
nb_index,
|
|
651
|
+
ntypes,
|
|
652
|
+
ii[p],
|
|
653
|
+
jj[p],
|
|
654
|
+
inv_scee,
|
|
655
|
+
inv_scnb,
|
|
656
|
+
p,
|
|
657
|
+
f,
|
|
658
|
+
ec_loc,
|
|
659
|
+
el_loc);
|
|
660
|
+
}
|
|
661
|
+
ec_tls[static_cast<size_t>(tid)] = ec_loc;
|
|
662
|
+
el_tls[static_cast<size_t>(tid)] = el_loc;
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
for (int tid = 0; tid < nthreads; ++tid) {
|
|
666
|
+
const auto* f = force_tls.data() + static_cast<size_t>(tid) * static_cast<size_t>(stride);
|
|
667
|
+
for (int64_t k = 0; k < stride; ++k) {
|
|
668
|
+
force_out[static_cast<size_t>(k)] += f[static_cast<size_t>(k)];
|
|
669
|
+
}
|
|
670
|
+
e_coul_out += ec_tls[static_cast<size_t>(tid)];
|
|
671
|
+
e_lj_out += el_tls[static_cast<size_t>(tid)];
|
|
672
|
+
}
|
|
673
|
+
}
|
|
674
|
+
|
|
675
|
+
template <typename scalar_t>
|
|
676
|
+
inline void accumulate_pair_preparam_one(
|
|
677
|
+
const scalar_t* coords,
|
|
678
|
+
int64_t i,
|
|
679
|
+
int64_t j,
|
|
680
|
+
const scalar_t* coul_coeff,
|
|
681
|
+
const scalar_t* a12_coeff,
|
|
682
|
+
const scalar_t* b6_coeff,
|
|
683
|
+
const scalar_t* b10_coeff,
|
|
684
|
+
int64_t p,
|
|
685
|
+
scalar_t* force,
|
|
686
|
+
double& e_coul,
|
|
687
|
+
double& e_lj) {
|
|
688
|
+
const double dx = static_cast<double>(coords[3 * j + 0] - coords[3 * i + 0]);
|
|
689
|
+
const double dy = static_cast<double>(coords[3 * j + 1] - coords[3 * i + 1]);
|
|
690
|
+
const double dz = static_cast<double>(coords[3 * j + 2] - coords[3 * i + 2]);
|
|
691
|
+
const double r2 = std::max(dx * dx + dy * dy + dz * dz, 1.0e-24);
|
|
692
|
+
const double inv_r = 1.0 / std::sqrt(r2);
|
|
693
|
+
const double inv_r2 = inv_r * inv_r;
|
|
694
|
+
const double inv_r4 = inv_r2 * inv_r2;
|
|
695
|
+
const double inv_r6 = inv_r4 * inv_r2;
|
|
696
|
+
const double inv_r8 = inv_r6 * inv_r2;
|
|
697
|
+
const double inv_r10 = inv_r8 * inv_r2;
|
|
698
|
+
const double inv_r12 = inv_r6 * inv_r6;
|
|
699
|
+
const double inv_r14 = inv_r12 * inv_r2;
|
|
700
|
+
|
|
701
|
+
const double c = static_cast<double>(coul_coeff[p]);
|
|
702
|
+
const double a12 = static_cast<double>(a12_coeff[p]);
|
|
703
|
+
const double b6 = static_cast<double>(b6_coeff[p]);
|
|
704
|
+
const double b10 = static_cast<double>(b10_coeff[p]);
|
|
705
|
+
|
|
706
|
+
const double ec = c * inv_r;
|
|
707
|
+
const double el = a12 * inv_r12 - b6 * inv_r6 - b10 * inv_r10;
|
|
708
|
+
const double fs = -c * inv_r2 * inv_r - 12.0 * a12 * inv_r14 + 6.0 * b6 * inv_r8 + 10.0 * b10 * inv_r12;
|
|
709
|
+
|
|
710
|
+
const double fx = fs * dx;
|
|
711
|
+
const double fy = fs * dy;
|
|
712
|
+
const double fz = fs * dz;
|
|
713
|
+
|
|
714
|
+
force[3 * i + 0] += static_cast<scalar_t>(fx);
|
|
715
|
+
force[3 * i + 1] += static_cast<scalar_t>(fy);
|
|
716
|
+
force[3 * i + 2] += static_cast<scalar_t>(fz);
|
|
717
|
+
force[3 * j + 0] -= static_cast<scalar_t>(fx);
|
|
718
|
+
force[3 * j + 1] -= static_cast<scalar_t>(fy);
|
|
719
|
+
force[3 * j + 2] -= static_cast<scalar_t>(fz);
|
|
720
|
+
|
|
721
|
+
e_coul += ec;
|
|
722
|
+
e_lj += el;
|
|
723
|
+
}
|
|
724
|
+
|
|
725
|
+
template <typename scalar_t>
|
|
726
|
+
void accumulate_pairs_preparam_cpu(
|
|
727
|
+
const scalar_t* coords,
|
|
728
|
+
const int64_t* ii,
|
|
729
|
+
const int64_t* jj,
|
|
730
|
+
const scalar_t* coul_coeff,
|
|
731
|
+
const scalar_t* a12_coeff,
|
|
732
|
+
const scalar_t* b6_coeff,
|
|
733
|
+
const scalar_t* b10_coeff,
|
|
734
|
+
int64_t npairs,
|
|
735
|
+
int64_t natom,
|
|
736
|
+
int64_t min_pairs_per_thread,
|
|
737
|
+
std::vector<scalar_t>& force_out,
|
|
738
|
+
double& e_coul_out,
|
|
739
|
+
double& e_lj_out) {
|
|
740
|
+
const int64_t stride = natom * 3;
|
|
741
|
+
if (npairs <= 0) {
|
|
742
|
+
return;
|
|
743
|
+
}
|
|
744
|
+
|
|
745
|
+
int nthreads = 1;
|
|
746
|
+
#ifdef _OPENMP
|
|
747
|
+
const int max_threads = effective_max_threads();
|
|
748
|
+
const int64_t kMinPairsPerThread = min_pairs_per_thread_threshold(min_pairs_per_thread);
|
|
749
|
+
nthreads = std::min<int>(
|
|
750
|
+
max_threads,
|
|
751
|
+
std::max<int64_t>(1, npairs / kMinPairsPerThread));
|
|
752
|
+
#endif
|
|
753
|
+
|
|
754
|
+
if (nthreads <= 1) {
|
|
755
|
+
scalar_t* f = force_out.data();
|
|
756
|
+
for (int64_t p = 0; p < npairs; ++p) {
|
|
757
|
+
accumulate_pair_preparam_one<scalar_t>(
|
|
758
|
+
coords,
|
|
759
|
+
ii[p],
|
|
760
|
+
jj[p],
|
|
761
|
+
coul_coeff,
|
|
762
|
+
a12_coeff,
|
|
763
|
+
b6_coeff,
|
|
764
|
+
b10_coeff,
|
|
765
|
+
p,
|
|
766
|
+
f,
|
|
767
|
+
e_coul_out,
|
|
768
|
+
e_lj_out);
|
|
769
|
+
}
|
|
770
|
+
return;
|
|
771
|
+
}
|
|
772
|
+
|
|
773
|
+
std::vector<scalar_t> force_tls(
|
|
774
|
+
static_cast<size_t>(nthreads) * static_cast<size_t>(stride),
|
|
775
|
+
static_cast<scalar_t>(0));
|
|
776
|
+
std::vector<double> ec_tls(static_cast<size_t>(nthreads), 0.0);
|
|
777
|
+
std::vector<double> el_tls(static_cast<size_t>(nthreads), 0.0);
|
|
778
|
+
|
|
779
|
+
#ifdef _OPENMP
|
|
780
|
+
#pragma omp parallel num_threads(nthreads)
|
|
781
|
+
#endif
|
|
782
|
+
{
|
|
783
|
+
int tid = 0;
|
|
784
|
+
#ifdef _OPENMP
|
|
785
|
+
tid = omp_get_thread_num();
|
|
786
|
+
#endif
|
|
787
|
+
auto* f = force_tls.data() + static_cast<size_t>(tid) * static_cast<size_t>(stride);
|
|
788
|
+
double ec_loc = 0.0;
|
|
789
|
+
double el_loc = 0.0;
|
|
790
|
+
#ifdef _OPENMP
|
|
791
|
+
#pragma omp for schedule(static)
|
|
792
|
+
#endif
|
|
793
|
+
for (int64_t p = 0; p < npairs; ++p) {
|
|
794
|
+
accumulate_pair_preparam_one<scalar_t>(
|
|
795
|
+
coords,
|
|
796
|
+
ii[p],
|
|
797
|
+
jj[p],
|
|
798
|
+
coul_coeff,
|
|
799
|
+
a12_coeff,
|
|
800
|
+
b6_coeff,
|
|
801
|
+
b10_coeff,
|
|
802
|
+
p,
|
|
803
|
+
f,
|
|
804
|
+
ec_loc,
|
|
805
|
+
el_loc);
|
|
806
|
+
}
|
|
807
|
+
ec_tls[static_cast<size_t>(tid)] = ec_loc;
|
|
808
|
+
el_tls[static_cast<size_t>(tid)] = el_loc;
|
|
809
|
+
}
|
|
810
|
+
|
|
811
|
+
for (int tid = 0; tid < nthreads; ++tid) {
|
|
812
|
+
const auto* f = force_tls.data() + static_cast<size_t>(tid) * static_cast<size_t>(stride);
|
|
813
|
+
for (int64_t k = 0; k < stride; ++k) {
|
|
814
|
+
force_out[static_cast<size_t>(k)] += f[static_cast<size_t>(k)];
|
|
815
|
+
}
|
|
816
|
+
e_coul_out += ec_tls[static_cast<size_t>(tid)];
|
|
817
|
+
e_lj_out += el_tls[static_cast<size_t>(tid)];
|
|
818
|
+
}
|
|
819
|
+
}
|
|
820
|
+
|
|
821
|
+
std::vector<torch::Tensor> nonbonded_energy_force_cpu(
|
|
822
|
+
const torch::Tensor& coords,
|
|
823
|
+
const torch::Tensor& charge,
|
|
824
|
+
const torch::Tensor& atom_type,
|
|
825
|
+
const torch::Tensor& lj_acoef,
|
|
826
|
+
const torch::Tensor& lj_bcoef,
|
|
827
|
+
const torch::Tensor& hb_acoef,
|
|
828
|
+
const torch::Tensor& hb_bcoef,
|
|
829
|
+
const torch::Tensor& nb_index,
|
|
830
|
+
const torch::Tensor& pair_i,
|
|
831
|
+
const torch::Tensor& pair_j,
|
|
832
|
+
const torch::Tensor& pair14_i,
|
|
833
|
+
const torch::Tensor& pair14_j,
|
|
834
|
+
const torch::Tensor& pair14_inv_scee,
|
|
835
|
+
const torch::Tensor& pair14_inv_scnb,
|
|
836
|
+
int64_t min_pairs_per_thread) {
|
|
837
|
+
auto c = coords.contiguous();
|
|
838
|
+
auto q = charge.contiguous();
|
|
839
|
+
auto at = atom_type.contiguous();
|
|
840
|
+
auto acoef = lj_acoef.contiguous();
|
|
841
|
+
auto bcoef = lj_bcoef.contiguous();
|
|
842
|
+
auto ha = hb_acoef.contiguous();
|
|
843
|
+
auto hb = hb_bcoef.contiguous();
|
|
844
|
+
auto nbi = nb_index.contiguous();
|
|
845
|
+
auto pi = pair_i.contiguous();
|
|
846
|
+
auto pj = pair_j.contiguous();
|
|
847
|
+
auto p14i = pair14_i.contiguous();
|
|
848
|
+
auto p14j = pair14_j.contiguous();
|
|
849
|
+
auto scee = pair14_inv_scee.contiguous();
|
|
850
|
+
auto scnb = pair14_inv_scnb.contiguous();
|
|
851
|
+
|
|
852
|
+
TORCH_CHECK(c.device().is_cpu(), "CPU kernel expects CPU coords tensor");
|
|
853
|
+
TORCH_CHECK(c.dim() == 2 && c.size(1) == 3, "coords must be [N,3]");
|
|
854
|
+
TORCH_CHECK(c.scalar_type() == q.scalar_type(), "coords/charge dtype mismatch");
|
|
855
|
+
TORCH_CHECK(c.scalar_type() == acoef.scalar_type(), "coords/lj dtype mismatch");
|
|
856
|
+
TORCH_CHECK(c.scalar_type() == bcoef.scalar_type(), "coords/lj dtype mismatch");
|
|
857
|
+
TORCH_CHECK(nbi.scalar_type() == torch::kInt64, "nb_index must be int64");
|
|
858
|
+
TORCH_CHECK(pi.scalar_type() == torch::kInt64, "pair_i must be int64");
|
|
859
|
+
TORCH_CHECK(pj.scalar_type() == torch::kInt64, "pair_j must be int64");
|
|
860
|
+
TORCH_CHECK(p14i.scalar_type() == torch::kInt64, "pair14_i must be int64");
|
|
861
|
+
TORCH_CHECK(p14j.scalar_type() == torch::kInt64, "pair14_j must be int64");
|
|
862
|
+
|
|
863
|
+
const int64_t natom = c.size(0);
|
|
864
|
+
const int64_t ntypes = nbi.size(0);
|
|
865
|
+
auto e_coul = torch::zeros({}, c.options());
|
|
866
|
+
auto e_lj = torch::zeros({}, c.options());
|
|
867
|
+
auto e_coul14 = torch::zeros({}, c.options());
|
|
868
|
+
auto e_lj14 = torch::zeros({}, c.options());
|
|
869
|
+
auto force = torch::zeros_like(c);
|
|
870
|
+
|
|
871
|
+
AT_DISPATCH_FLOATING_TYPES(c.scalar_type(), "nonbonded_energy_force_cpu", [&] {
|
|
872
|
+
std::vector<scalar_t> force_acc(static_cast<size_t>(natom) * 3, static_cast<scalar_t>(0));
|
|
873
|
+
double ec = 0.0;
|
|
874
|
+
double el = 0.0;
|
|
875
|
+
double ec14 = 0.0;
|
|
876
|
+
double el14 = 0.0;
|
|
877
|
+
|
|
878
|
+
const scalar_t* hb_a_ptr = ha.numel() > 0 ? ha.data_ptr<scalar_t>() : nullptr;
|
|
879
|
+
const scalar_t* hb_b_ptr = hb.numel() > 0 ? hb.data_ptr<scalar_t>() : nullptr;
|
|
880
|
+
|
|
881
|
+
accumulate_pairs_cpu<scalar_t>(
|
|
882
|
+
c.data_ptr<scalar_t>(),
|
|
883
|
+
q.data_ptr<scalar_t>(),
|
|
884
|
+
at.data_ptr<int64_t>(),
|
|
885
|
+
acoef.data_ptr<scalar_t>(),
|
|
886
|
+
bcoef.data_ptr<scalar_t>(),
|
|
887
|
+
acoef.numel(),
|
|
888
|
+
hb_a_ptr,
|
|
889
|
+
hb_b_ptr,
|
|
890
|
+
ha.numel(),
|
|
891
|
+
nbi.data_ptr<int64_t>(),
|
|
892
|
+
ntypes,
|
|
893
|
+
pi.data_ptr<int64_t>(),
|
|
894
|
+
pj.data_ptr<int64_t>(),
|
|
895
|
+
nullptr,
|
|
896
|
+
nullptr,
|
|
897
|
+
pi.numel(),
|
|
898
|
+
natom,
|
|
899
|
+
min_pairs_per_thread,
|
|
900
|
+
force_acc,
|
|
901
|
+
ec,
|
|
902
|
+
el);
|
|
903
|
+
|
|
904
|
+
const scalar_t* scee_ptr = scee.numel() > 0 ? scee.data_ptr<scalar_t>() : nullptr;
|
|
905
|
+
const scalar_t* scnb_ptr = scnb.numel() > 0 ? scnb.data_ptr<scalar_t>() : nullptr;
|
|
906
|
+
accumulate_pairs_cpu<scalar_t>(
|
|
907
|
+
c.data_ptr<scalar_t>(),
|
|
908
|
+
q.data_ptr<scalar_t>(),
|
|
909
|
+
at.data_ptr<int64_t>(),
|
|
910
|
+
acoef.data_ptr<scalar_t>(),
|
|
911
|
+
bcoef.data_ptr<scalar_t>(),
|
|
912
|
+
acoef.numel(),
|
|
913
|
+
hb_a_ptr,
|
|
914
|
+
hb_b_ptr,
|
|
915
|
+
ha.numel(),
|
|
916
|
+
nbi.data_ptr<int64_t>(),
|
|
917
|
+
ntypes,
|
|
918
|
+
p14i.data_ptr<int64_t>(),
|
|
919
|
+
p14j.data_ptr<int64_t>(),
|
|
920
|
+
scee_ptr,
|
|
921
|
+
scnb_ptr,
|
|
922
|
+
p14i.numel(),
|
|
923
|
+
natom,
|
|
924
|
+
min_pairs_per_thread,
|
|
925
|
+
force_acc,
|
|
926
|
+
ec14,
|
|
927
|
+
el14);
|
|
928
|
+
|
|
929
|
+
auto* force_ptr = force.data_ptr<scalar_t>();
|
|
930
|
+
std::copy(force_acc.begin(), force_acc.end(), force_ptr);
|
|
931
|
+
e_coul.fill_(static_cast<scalar_t>(ec));
|
|
932
|
+
e_lj.fill_(static_cast<scalar_t>(el));
|
|
933
|
+
e_coul14.fill_(static_cast<scalar_t>(ec14));
|
|
934
|
+
e_lj14.fill_(static_cast<scalar_t>(el14));
|
|
935
|
+
});
|
|
936
|
+
|
|
937
|
+
return {e_coul, e_lj, e_coul14, e_lj14, force};
|
|
938
|
+
}
|
|
939
|
+
|
|
940
|
+
std::vector<torch::Tensor> nonbonded_energy_force_preparam_cpu(
|
|
941
|
+
const torch::Tensor& coords,
|
|
942
|
+
const torch::Tensor& pair_i,
|
|
943
|
+
const torch::Tensor& pair_j,
|
|
944
|
+
const torch::Tensor& pair_coul_coeff,
|
|
945
|
+
const torch::Tensor& pair_a12_coeff,
|
|
946
|
+
const torch::Tensor& pair_b6_coeff,
|
|
947
|
+
const torch::Tensor& pair_b10_coeff,
|
|
948
|
+
const torch::Tensor& pair14_i,
|
|
949
|
+
const torch::Tensor& pair14_j,
|
|
950
|
+
const torch::Tensor& pair14_coul_coeff,
|
|
951
|
+
const torch::Tensor& pair14_a12_coeff,
|
|
952
|
+
const torch::Tensor& pair14_b6_coeff,
|
|
953
|
+
const torch::Tensor& pair14_b10_coeff,
|
|
954
|
+
int64_t min_pairs_per_thread) {
|
|
955
|
+
auto c = coords.contiguous();
|
|
956
|
+
auto pi = pair_i.contiguous();
|
|
957
|
+
auto pj = pair_j.contiguous();
|
|
958
|
+
auto cc = pair_coul_coeff.contiguous();
|
|
959
|
+
auto a12 = pair_a12_coeff.contiguous();
|
|
960
|
+
auto b6 = pair_b6_coeff.contiguous();
|
|
961
|
+
auto b10 = pair_b10_coeff.contiguous();
|
|
962
|
+
auto p14i = pair14_i.contiguous();
|
|
963
|
+
auto p14j = pair14_j.contiguous();
|
|
964
|
+
auto cc14 = pair14_coul_coeff.contiguous();
|
|
965
|
+
auto a1214 = pair14_a12_coeff.contiguous();
|
|
966
|
+
auto b614 = pair14_b6_coeff.contiguous();
|
|
967
|
+
auto b1014 = pair14_b10_coeff.contiguous();
|
|
968
|
+
|
|
969
|
+
TORCH_CHECK(c.device().is_cpu(), "CPU kernel expects CPU coords tensor");
|
|
970
|
+
TORCH_CHECK(c.dim() == 2 && c.size(1) == 3, "coords must be [N,3]");
|
|
971
|
+
TORCH_CHECK(pi.scalar_type() == torch::kInt64, "pair_i must be int64");
|
|
972
|
+
TORCH_CHECK(pj.scalar_type() == torch::kInt64, "pair_j must be int64");
|
|
973
|
+
TORCH_CHECK(p14i.scalar_type() == torch::kInt64, "pair14_i must be int64");
|
|
974
|
+
TORCH_CHECK(p14j.scalar_type() == torch::kInt64, "pair14_j must be int64");
|
|
975
|
+
TORCH_CHECK(cc.scalar_type() == c.scalar_type(), "pair coeff dtype mismatch");
|
|
976
|
+
TORCH_CHECK(a12.scalar_type() == c.scalar_type(), "pair coeff dtype mismatch");
|
|
977
|
+
TORCH_CHECK(b6.scalar_type() == c.scalar_type(), "pair coeff dtype mismatch");
|
|
978
|
+
TORCH_CHECK(b10.scalar_type() == c.scalar_type(), "pair coeff dtype mismatch");
|
|
979
|
+
TORCH_CHECK(cc14.scalar_type() == c.scalar_type(), "pair14 coeff dtype mismatch");
|
|
980
|
+
TORCH_CHECK(a1214.scalar_type() == c.scalar_type(), "pair14 coeff dtype mismatch");
|
|
981
|
+
TORCH_CHECK(b614.scalar_type() == c.scalar_type(), "pair14 coeff dtype mismatch");
|
|
982
|
+
TORCH_CHECK(b1014.scalar_type() == c.scalar_type(), "pair14 coeff dtype mismatch");
|
|
983
|
+
TORCH_CHECK(pi.numel() == pj.numel(), "pair_i/pair_j size mismatch");
|
|
984
|
+
TORCH_CHECK(p14i.numel() == p14j.numel(), "pair14_i/pair14_j size mismatch");
|
|
985
|
+
TORCH_CHECK(
|
|
986
|
+
pi.numel() == cc.numel() && pi.numel() == a12.numel() && pi.numel() == b6.numel() &&
|
|
987
|
+
pi.numel() == b10.numel(),
|
|
988
|
+
"pair/coeff size mismatch");
|
|
989
|
+
TORCH_CHECK(
|
|
990
|
+
p14i.numel() == cc14.numel() && p14i.numel() == a1214.numel() && p14i.numel() == b614.numel() &&
|
|
991
|
+
p14i.numel() == b1014.numel(),
|
|
992
|
+
"pair14/coeff size mismatch");
|
|
993
|
+
|
|
994
|
+
const int64_t natom = c.size(0);
|
|
995
|
+
auto e_coul = torch::zeros({}, c.options());
|
|
996
|
+
auto e_lj = torch::zeros({}, c.options());
|
|
997
|
+
auto e_coul14 = torch::zeros({}, c.options());
|
|
998
|
+
auto e_lj14 = torch::zeros({}, c.options());
|
|
999
|
+
auto force = torch::zeros_like(c);
|
|
1000
|
+
|
|
1001
|
+
AT_DISPATCH_FLOATING_TYPES(c.scalar_type(), "nonbonded_energy_force_preparam_cpu", [&] {
|
|
1002
|
+
std::vector<scalar_t> force_acc(static_cast<size_t>(natom) * 3, static_cast<scalar_t>(0));
|
|
1003
|
+
double ec = 0.0;
|
|
1004
|
+
double el = 0.0;
|
|
1005
|
+
double ec14 = 0.0;
|
|
1006
|
+
double el14 = 0.0;
|
|
1007
|
+
|
|
1008
|
+
accumulate_pairs_preparam_cpu<scalar_t>(
|
|
1009
|
+
c.data_ptr<scalar_t>(),
|
|
1010
|
+
pi.data_ptr<int64_t>(),
|
|
1011
|
+
pj.data_ptr<int64_t>(),
|
|
1012
|
+
cc.data_ptr<scalar_t>(),
|
|
1013
|
+
a12.data_ptr<scalar_t>(),
|
|
1014
|
+
b6.data_ptr<scalar_t>(),
|
|
1015
|
+
b10.data_ptr<scalar_t>(),
|
|
1016
|
+
pi.numel(),
|
|
1017
|
+
natom,
|
|
1018
|
+
min_pairs_per_thread,
|
|
1019
|
+
force_acc,
|
|
1020
|
+
ec,
|
|
1021
|
+
el);
|
|
1022
|
+
|
|
1023
|
+
accumulate_pairs_preparam_cpu<scalar_t>(
|
|
1024
|
+
c.data_ptr<scalar_t>(),
|
|
1025
|
+
p14i.data_ptr<int64_t>(),
|
|
1026
|
+
p14j.data_ptr<int64_t>(),
|
|
1027
|
+
cc14.data_ptr<scalar_t>(),
|
|
1028
|
+
a1214.data_ptr<scalar_t>(),
|
|
1029
|
+
b614.data_ptr<scalar_t>(),
|
|
1030
|
+
b1014.data_ptr<scalar_t>(),
|
|
1031
|
+
p14i.numel(),
|
|
1032
|
+
natom,
|
|
1033
|
+
min_pairs_per_thread,
|
|
1034
|
+
force_acc,
|
|
1035
|
+
ec14,
|
|
1036
|
+
el14);
|
|
1037
|
+
|
|
1038
|
+
auto* force_ptr = force.data_ptr<scalar_t>();
|
|
1039
|
+
std::copy(force_acc.begin(), force_acc.end(), force_ptr);
|
|
1040
|
+
e_coul.fill_(static_cast<scalar_t>(ec));
|
|
1041
|
+
e_lj.fill_(static_cast<scalar_t>(el));
|
|
1042
|
+
e_coul14.fill_(static_cast<scalar_t>(ec14));
|
|
1043
|
+
e_lj14.fill_(static_cast<scalar_t>(el14));
|
|
1044
|
+
});
|
|
1045
|
+
|
|
1046
|
+
return {e_coul, e_lj, e_coul14, e_lj14, force};
|
|
1047
|
+
}
|
|
1048
|
+
|
|
1049
|
+
} // namespace
|
|
1050
|
+
|
|
1051
|
+
std::vector<torch::Tensor> nonbonded_energy_force(
|
|
1052
|
+
const torch::Tensor& coords,
|
|
1053
|
+
const torch::Tensor& charge,
|
|
1054
|
+
const torch::Tensor& atom_type,
|
|
1055
|
+
const torch::Tensor& lj_acoef,
|
|
1056
|
+
const torch::Tensor& lj_bcoef,
|
|
1057
|
+
const torch::Tensor& hb_acoef,
|
|
1058
|
+
const torch::Tensor& hb_bcoef,
|
|
1059
|
+
const torch::Tensor& nb_index,
|
|
1060
|
+
const torch::Tensor& pair_i,
|
|
1061
|
+
const torch::Tensor& pair_j,
|
|
1062
|
+
const torch::Tensor& pair14_i,
|
|
1063
|
+
const torch::Tensor& pair14_j,
|
|
1064
|
+
const torch::Tensor& pair14_inv_scee,
|
|
1065
|
+
const torch::Tensor& pair14_inv_scnb,
|
|
1066
|
+
int64_t chunk_size,
|
|
1067
|
+
bool cpu_fast,
|
|
1068
|
+
int64_t min_pairs_per_thread) {
|
|
1069
|
+
// CPU-specialized path: hand-written loop kernel with OpenMP threading.
|
|
1070
|
+
if (coords.device().is_cpu() && cpu_fast) {
|
|
1071
|
+
return nonbonded_energy_force_cpu(
|
|
1072
|
+
coords,
|
|
1073
|
+
charge,
|
|
1074
|
+
atom_type,
|
|
1075
|
+
lj_acoef,
|
|
1076
|
+
lj_bcoef,
|
|
1077
|
+
hb_acoef,
|
|
1078
|
+
hb_bcoef,
|
|
1079
|
+
nb_index,
|
|
1080
|
+
pair_i,
|
|
1081
|
+
pair_j,
|
|
1082
|
+
pair14_i,
|
|
1083
|
+
pair14_j,
|
|
1084
|
+
pair14_inv_scee,
|
|
1085
|
+
pair14_inv_scnb,
|
|
1086
|
+
min_pairs_per_thread);
|
|
1087
|
+
}
|
|
1088
|
+
// ATen path for autograd-friendly CPU execution.
|
|
1089
|
+
return nonbonded_energy_force_aten(
|
|
1090
|
+
coords,
|
|
1091
|
+
charge,
|
|
1092
|
+
atom_type,
|
|
1093
|
+
lj_acoef,
|
|
1094
|
+
lj_bcoef,
|
|
1095
|
+
hb_acoef,
|
|
1096
|
+
hb_bcoef,
|
|
1097
|
+
nb_index,
|
|
1098
|
+
pair_i,
|
|
1099
|
+
pair_j,
|
|
1100
|
+
pair14_i,
|
|
1101
|
+
pair14_j,
|
|
1102
|
+
pair14_inv_scee,
|
|
1103
|
+
pair14_inv_scnb,
|
|
1104
|
+
chunk_size);
|
|
1105
|
+
}
|
|
1106
|
+
|
|
1107
|
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
1108
|
+
m.doc() = "hessian_ff native nonbonded extension";
|
|
1109
|
+
m.def(
|
|
1110
|
+
"nonbonded_energy_force",
|
|
1111
|
+
&nonbonded_energy_force,
|
|
1112
|
+
"Compute nonbonded energies and forces (CPU-only backend)",
|
|
1113
|
+
pybind11::arg("coords"),
|
|
1114
|
+
pybind11::arg("charge"),
|
|
1115
|
+
pybind11::arg("atom_type"),
|
|
1116
|
+
pybind11::arg("lj_acoef"),
|
|
1117
|
+
pybind11::arg("lj_bcoef"),
|
|
1118
|
+
pybind11::arg("hb_acoef"),
|
|
1119
|
+
pybind11::arg("hb_bcoef"),
|
|
1120
|
+
pybind11::arg("nb_index"),
|
|
1121
|
+
pybind11::arg("pair_i"),
|
|
1122
|
+
pybind11::arg("pair_j"),
|
|
1123
|
+
pybind11::arg("pair14_i"),
|
|
1124
|
+
pybind11::arg("pair14_j"),
|
|
1125
|
+
pybind11::arg("pair14_inv_scee"),
|
|
1126
|
+
pybind11::arg("pair14_inv_scnb"),
|
|
1127
|
+
pybind11::arg("chunk_size"),
|
|
1128
|
+
pybind11::arg("cpu_fast") = true,
|
|
1129
|
+
pybind11::arg("min_pairs_per_thread") = -1);
|
|
1130
|
+
m.def(
|
|
1131
|
+
"nonbonded_energy_force_preparam",
|
|
1132
|
+
&nonbonded_energy_force_preparam,
|
|
1133
|
+
"Compute nonbonded energies and forces from precomputed pair coefficients (CPU-only)",
|
|
1134
|
+
pybind11::arg("coords"),
|
|
1135
|
+
pybind11::arg("pair_i"),
|
|
1136
|
+
pybind11::arg("pair_j"),
|
|
1137
|
+
pybind11::arg("pair_coul_coeff"),
|
|
1138
|
+
pybind11::arg("pair_a12_coeff"),
|
|
1139
|
+
pybind11::arg("pair_b6_coeff"),
|
|
1140
|
+
pybind11::arg("pair_b10_coeff"),
|
|
1141
|
+
pybind11::arg("pair14_i"),
|
|
1142
|
+
pybind11::arg("pair14_j"),
|
|
1143
|
+
pybind11::arg("pair14_coul_coeff"),
|
|
1144
|
+
pybind11::arg("pair14_a12_coeff"),
|
|
1145
|
+
pybind11::arg("pair14_b6_coeff"),
|
|
1146
|
+
pybind11::arg("pair14_b10_coeff"),
|
|
1147
|
+
pybind11::arg("chunk_size"),
|
|
1148
|
+
pybind11::arg("cpu_fast") = true,
|
|
1149
|
+
pybind11::arg("min_pairs_per_thread") = -1);
|
|
1150
|
+
}
|