mlmm-toolkit 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (372) hide show
  1. hessian_ff/__init__.py +50 -0
  2. hessian_ff/analytical_hessian.py +609 -0
  3. hessian_ff/constants.py +46 -0
  4. hessian_ff/forcefield.py +339 -0
  5. hessian_ff/loaders.py +608 -0
  6. hessian_ff/native/Makefile +8 -0
  7. hessian_ff/native/__init__.py +28 -0
  8. hessian_ff/native/analytical_hessian.py +88 -0
  9. hessian_ff/native/analytical_hessian_ext.cpp +258 -0
  10. hessian_ff/native/bonded.py +82 -0
  11. hessian_ff/native/bonded_ext.cpp +640 -0
  12. hessian_ff/native/loader.py +349 -0
  13. hessian_ff/native/nonbonded.py +118 -0
  14. hessian_ff/native/nonbonded_ext.cpp +1150 -0
  15. hessian_ff/prmtop_parmed.py +23 -0
  16. hessian_ff/system.py +107 -0
  17. hessian_ff/terms/__init__.py +14 -0
  18. hessian_ff/terms/angle.py +73 -0
  19. hessian_ff/terms/bond.py +44 -0
  20. hessian_ff/terms/cmap.py +406 -0
  21. hessian_ff/terms/dihedral.py +141 -0
  22. hessian_ff/terms/nonbonded.py +209 -0
  23. hessian_ff/tests/__init__.py +0 -0
  24. hessian_ff/tests/conftest.py +75 -0
  25. hessian_ff/tests/data/small/complex.parm7 +1346 -0
  26. hessian_ff/tests/data/small/complex.pdb +125 -0
  27. hessian_ff/tests/data/small/complex.rst7 +63 -0
  28. hessian_ff/tests/test_coords_input.py +44 -0
  29. hessian_ff/tests/test_energy_force.py +49 -0
  30. hessian_ff/tests/test_hessian.py +137 -0
  31. hessian_ff/tests/test_smoke.py +18 -0
  32. hessian_ff/tests/test_validation.py +40 -0
  33. hessian_ff/workflows.py +889 -0
  34. mlmm/__init__.py +36 -0
  35. mlmm/__main__.py +7 -0
  36. mlmm/_version.py +34 -0
  37. mlmm/add_elem_info.py +374 -0
  38. mlmm/advanced_help.py +91 -0
  39. mlmm/align_freeze_atoms.py +601 -0
  40. mlmm/all.py +3535 -0
  41. mlmm/bond_changes.py +231 -0
  42. mlmm/bool_compat.py +223 -0
  43. mlmm/cli.py +574 -0
  44. mlmm/cli_utils.py +166 -0
  45. mlmm/default_group.py +337 -0
  46. mlmm/defaults.py +467 -0
  47. mlmm/define_layer.py +526 -0
  48. mlmm/dft.py +1041 -0
  49. mlmm/energy_diagram.py +253 -0
  50. mlmm/extract.py +2213 -0
  51. mlmm/fix_altloc.py +464 -0
  52. mlmm/freq.py +1406 -0
  53. mlmm/harmonic_constraints.py +140 -0
  54. mlmm/hessian_cache.py +44 -0
  55. mlmm/hessian_calc.py +174 -0
  56. mlmm/irc.py +638 -0
  57. mlmm/mlmm_calc.py +2262 -0
  58. mlmm/mm_parm.py +945 -0
  59. mlmm/oniom_export.py +1983 -0
  60. mlmm/oniom_import.py +457 -0
  61. mlmm/opt.py +1742 -0
  62. mlmm/path_opt.py +1353 -0
  63. mlmm/path_search.py +2299 -0
  64. mlmm/preflight.py +88 -0
  65. mlmm/py.typed +1 -0
  66. mlmm/pysis_runner.py +45 -0
  67. mlmm/scan.py +1047 -0
  68. mlmm/scan2d.py +1226 -0
  69. mlmm/scan3d.py +1265 -0
  70. mlmm/scan_common.py +184 -0
  71. mlmm/summary_log.py +736 -0
  72. mlmm/trj2fig.py +448 -0
  73. mlmm/tsopt.py +2871 -0
  74. mlmm/utils.py +2309 -0
  75. mlmm/xtb_embedcharge_correction.py +475 -0
  76. mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
  77. mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
  78. mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
  79. mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
  80. mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
  81. mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
  82. pysisyphus/Geometry.py +1667 -0
  83. pysisyphus/LICENSE +674 -0
  84. pysisyphus/TableFormatter.py +63 -0
  85. pysisyphus/TablePrinter.py +74 -0
  86. pysisyphus/__init__.py +12 -0
  87. pysisyphus/calculators/AFIR.py +452 -0
  88. pysisyphus/calculators/AnaPot.py +20 -0
  89. pysisyphus/calculators/AnaPot2.py +48 -0
  90. pysisyphus/calculators/AnaPot3.py +12 -0
  91. pysisyphus/calculators/AnaPot4.py +20 -0
  92. pysisyphus/calculators/AnaPotBase.py +337 -0
  93. pysisyphus/calculators/AnaPotCBM.py +25 -0
  94. pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
  95. pysisyphus/calculators/CFOUR.py +250 -0
  96. pysisyphus/calculators/Calculator.py +844 -0
  97. pysisyphus/calculators/CerjanMiller.py +24 -0
  98. pysisyphus/calculators/Composite.py +123 -0
  99. pysisyphus/calculators/ConicalIntersection.py +171 -0
  100. pysisyphus/calculators/DFTBp.py +430 -0
  101. pysisyphus/calculators/DFTD3.py +66 -0
  102. pysisyphus/calculators/DFTD4.py +84 -0
  103. pysisyphus/calculators/Dalton.py +61 -0
  104. pysisyphus/calculators/Dimer.py +681 -0
  105. pysisyphus/calculators/Dummy.py +20 -0
  106. pysisyphus/calculators/EGO.py +76 -0
  107. pysisyphus/calculators/EnergyMin.py +224 -0
  108. pysisyphus/calculators/ExternalPotential.py +264 -0
  109. pysisyphus/calculators/FakeASE.py +35 -0
  110. pysisyphus/calculators/FourWellAnaPot.py +28 -0
  111. pysisyphus/calculators/FreeEndNEBPot.py +39 -0
  112. pysisyphus/calculators/Gaussian09.py +18 -0
  113. pysisyphus/calculators/Gaussian16.py +726 -0
  114. pysisyphus/calculators/HardSphere.py +159 -0
  115. pysisyphus/calculators/IDPPCalculator.py +49 -0
  116. pysisyphus/calculators/IPIClient.py +133 -0
  117. pysisyphus/calculators/IPIServer.py +234 -0
  118. pysisyphus/calculators/LEPSBase.py +24 -0
  119. pysisyphus/calculators/LEPSExpr.py +139 -0
  120. pysisyphus/calculators/LennardJones.py +80 -0
  121. pysisyphus/calculators/MOPAC.py +219 -0
  122. pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
  123. pysisyphus/calculators/MultiCalc.py +85 -0
  124. pysisyphus/calculators/NFK.py +45 -0
  125. pysisyphus/calculators/OBabel.py +87 -0
  126. pysisyphus/calculators/ONIOMv2.py +1129 -0
  127. pysisyphus/calculators/ORCA.py +893 -0
  128. pysisyphus/calculators/ORCA5.py +6 -0
  129. pysisyphus/calculators/OpenMM.py +88 -0
  130. pysisyphus/calculators/OpenMolcas.py +281 -0
  131. pysisyphus/calculators/OverlapCalculator.py +908 -0
  132. pysisyphus/calculators/Psi4.py +218 -0
  133. pysisyphus/calculators/PyPsi4.py +37 -0
  134. pysisyphus/calculators/PySCF.py +341 -0
  135. pysisyphus/calculators/PyXTB.py +73 -0
  136. pysisyphus/calculators/QCEngine.py +106 -0
  137. pysisyphus/calculators/Rastrigin.py +22 -0
  138. pysisyphus/calculators/Remote.py +76 -0
  139. pysisyphus/calculators/Rosenbrock.py +15 -0
  140. pysisyphus/calculators/SocketCalc.py +97 -0
  141. pysisyphus/calculators/TIP3P.py +111 -0
  142. pysisyphus/calculators/TransTorque.py +161 -0
  143. pysisyphus/calculators/Turbomole.py +965 -0
  144. pysisyphus/calculators/VRIPot.py +37 -0
  145. pysisyphus/calculators/WFOWrapper.py +333 -0
  146. pysisyphus/calculators/WFOWrapper2.py +341 -0
  147. pysisyphus/calculators/XTB.py +418 -0
  148. pysisyphus/calculators/__init__.py +81 -0
  149. pysisyphus/calculators/cosmo_data.py +139 -0
  150. pysisyphus/calculators/parser.py +150 -0
  151. pysisyphus/color.py +19 -0
  152. pysisyphus/config.py +133 -0
  153. pysisyphus/constants.py +65 -0
  154. pysisyphus/cos/AdaptiveNEB.py +230 -0
  155. pysisyphus/cos/ChainOfStates.py +725 -0
  156. pysisyphus/cos/FreeEndNEB.py +25 -0
  157. pysisyphus/cos/FreezingString.py +103 -0
  158. pysisyphus/cos/GrowingChainOfStates.py +71 -0
  159. pysisyphus/cos/GrowingNT.py +309 -0
  160. pysisyphus/cos/GrowingString.py +508 -0
  161. pysisyphus/cos/NEB.py +189 -0
  162. pysisyphus/cos/SimpleZTS.py +64 -0
  163. pysisyphus/cos/__init__.py +22 -0
  164. pysisyphus/cos/stiffness.py +199 -0
  165. pysisyphus/drivers/__init__.py +17 -0
  166. pysisyphus/drivers/afir.py +855 -0
  167. pysisyphus/drivers/barriers.py +271 -0
  168. pysisyphus/drivers/birkholz.py +138 -0
  169. pysisyphus/drivers/cluster.py +318 -0
  170. pysisyphus/drivers/diabatization.py +133 -0
  171. pysisyphus/drivers/merge.py +368 -0
  172. pysisyphus/drivers/merge_mol2.py +322 -0
  173. pysisyphus/drivers/opt.py +375 -0
  174. pysisyphus/drivers/perf.py +91 -0
  175. pysisyphus/drivers/pka.py +52 -0
  176. pysisyphus/drivers/precon_pos_rot.py +669 -0
  177. pysisyphus/drivers/rates.py +480 -0
  178. pysisyphus/drivers/replace.py +219 -0
  179. pysisyphus/drivers/scan.py +212 -0
  180. pysisyphus/drivers/spectrum.py +166 -0
  181. pysisyphus/drivers/thermo.py +31 -0
  182. pysisyphus/dynamics/Gaussian.py +103 -0
  183. pysisyphus/dynamics/__init__.py +20 -0
  184. pysisyphus/dynamics/colvars.py +136 -0
  185. pysisyphus/dynamics/driver.py +297 -0
  186. pysisyphus/dynamics/helpers.py +256 -0
  187. pysisyphus/dynamics/lincs.py +105 -0
  188. pysisyphus/dynamics/mdp.py +364 -0
  189. pysisyphus/dynamics/rattle.py +121 -0
  190. pysisyphus/dynamics/thermostats.py +128 -0
  191. pysisyphus/dynamics/wigner.py +266 -0
  192. pysisyphus/elem_data.py +3473 -0
  193. pysisyphus/exceptions.py +2 -0
  194. pysisyphus/filtertrj.py +69 -0
  195. pysisyphus/helpers.py +623 -0
  196. pysisyphus/helpers_pure.py +649 -0
  197. pysisyphus/init_logging.py +50 -0
  198. pysisyphus/intcoords/Bend.py +69 -0
  199. pysisyphus/intcoords/Bend2.py +25 -0
  200. pysisyphus/intcoords/BondedFragment.py +32 -0
  201. pysisyphus/intcoords/Cartesian.py +41 -0
  202. pysisyphus/intcoords/CartesianCoords.py +140 -0
  203. pysisyphus/intcoords/Coords.py +56 -0
  204. pysisyphus/intcoords/DLC.py +197 -0
  205. pysisyphus/intcoords/DistanceFunction.py +34 -0
  206. pysisyphus/intcoords/DummyImproper.py +70 -0
  207. pysisyphus/intcoords/DummyTorsion.py +72 -0
  208. pysisyphus/intcoords/LinearBend.py +105 -0
  209. pysisyphus/intcoords/LinearDisplacement.py +80 -0
  210. pysisyphus/intcoords/OutOfPlane.py +59 -0
  211. pysisyphus/intcoords/PrimTypes.py +286 -0
  212. pysisyphus/intcoords/Primitive.py +137 -0
  213. pysisyphus/intcoords/RedundantCoords.py +659 -0
  214. pysisyphus/intcoords/RobustTorsion.py +59 -0
  215. pysisyphus/intcoords/Rotation.py +147 -0
  216. pysisyphus/intcoords/Stretch.py +31 -0
  217. pysisyphus/intcoords/Torsion.py +101 -0
  218. pysisyphus/intcoords/Torsion2.py +25 -0
  219. pysisyphus/intcoords/Translation.py +45 -0
  220. pysisyphus/intcoords/__init__.py +61 -0
  221. pysisyphus/intcoords/augment_bonds.py +126 -0
  222. pysisyphus/intcoords/derivatives.py +10512 -0
  223. pysisyphus/intcoords/eval.py +80 -0
  224. pysisyphus/intcoords/exceptions.py +37 -0
  225. pysisyphus/intcoords/findiffs.py +48 -0
  226. pysisyphus/intcoords/generate_derivatives.py +414 -0
  227. pysisyphus/intcoords/helpers.py +235 -0
  228. pysisyphus/intcoords/logging_conf.py +10 -0
  229. pysisyphus/intcoords/mp_derivatives.py +10836 -0
  230. pysisyphus/intcoords/setup.py +962 -0
  231. pysisyphus/intcoords/setup_fast.py +176 -0
  232. pysisyphus/intcoords/update.py +272 -0
  233. pysisyphus/intcoords/valid.py +89 -0
  234. pysisyphus/interpolate/Geodesic.py +93 -0
  235. pysisyphus/interpolate/IDPP.py +55 -0
  236. pysisyphus/interpolate/Interpolator.py +116 -0
  237. pysisyphus/interpolate/LST.py +70 -0
  238. pysisyphus/interpolate/Redund.py +152 -0
  239. pysisyphus/interpolate/__init__.py +9 -0
  240. pysisyphus/interpolate/helpers.py +34 -0
  241. pysisyphus/io/__init__.py +22 -0
  242. pysisyphus/io/aomix.py +178 -0
  243. pysisyphus/io/cjson.py +24 -0
  244. pysisyphus/io/crd.py +101 -0
  245. pysisyphus/io/cube.py +220 -0
  246. pysisyphus/io/fchk.py +184 -0
  247. pysisyphus/io/hdf5.py +49 -0
  248. pysisyphus/io/hessian.py +72 -0
  249. pysisyphus/io/mol2.py +146 -0
  250. pysisyphus/io/molden.py +293 -0
  251. pysisyphus/io/orca.py +189 -0
  252. pysisyphus/io/pdb.py +269 -0
  253. pysisyphus/io/psf.py +79 -0
  254. pysisyphus/io/pubchem.py +31 -0
  255. pysisyphus/io/qcschema.py +34 -0
  256. pysisyphus/io/sdf.py +29 -0
  257. pysisyphus/io/xyz.py +61 -0
  258. pysisyphus/io/zmat.py +175 -0
  259. pysisyphus/irc/DWI.py +108 -0
  260. pysisyphus/irc/DampedVelocityVerlet.py +134 -0
  261. pysisyphus/irc/Euler.py +22 -0
  262. pysisyphus/irc/EulerPC.py +345 -0
  263. pysisyphus/irc/GonzalezSchlegel.py +187 -0
  264. pysisyphus/irc/IMKMod.py +164 -0
  265. pysisyphus/irc/IRC.py +878 -0
  266. pysisyphus/irc/IRCDummy.py +10 -0
  267. pysisyphus/irc/Instanton.py +307 -0
  268. pysisyphus/irc/LQA.py +53 -0
  269. pysisyphus/irc/ModeKill.py +136 -0
  270. pysisyphus/irc/ParamPlot.py +53 -0
  271. pysisyphus/irc/RK4.py +36 -0
  272. pysisyphus/irc/__init__.py +31 -0
  273. pysisyphus/irc/initial_displ.py +219 -0
  274. pysisyphus/linalg.py +411 -0
  275. pysisyphus/line_searches/Backtracking.py +88 -0
  276. pysisyphus/line_searches/HagerZhang.py +184 -0
  277. pysisyphus/line_searches/LineSearch.py +232 -0
  278. pysisyphus/line_searches/StrongWolfe.py +108 -0
  279. pysisyphus/line_searches/__init__.py +9 -0
  280. pysisyphus/line_searches/interpol.py +15 -0
  281. pysisyphus/modefollow/NormalMode.py +40 -0
  282. pysisyphus/modefollow/__init__.py +10 -0
  283. pysisyphus/modefollow/davidson.py +199 -0
  284. pysisyphus/modefollow/lanczos.py +95 -0
  285. pysisyphus/optimizers/BFGS.py +99 -0
  286. pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
  287. pysisyphus/optimizers/ConjugateGradient.py +98 -0
  288. pysisyphus/optimizers/CubicNewton.py +75 -0
  289. pysisyphus/optimizers/FIRE.py +113 -0
  290. pysisyphus/optimizers/HessianOptimizer.py +1176 -0
  291. pysisyphus/optimizers/LBFGS.py +228 -0
  292. pysisyphus/optimizers/LayerOpt.py +411 -0
  293. pysisyphus/optimizers/MicroOptimizer.py +169 -0
  294. pysisyphus/optimizers/NCOptimizer.py +90 -0
  295. pysisyphus/optimizers/Optimizer.py +1084 -0
  296. pysisyphus/optimizers/PreconLBFGS.py +260 -0
  297. pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
  298. pysisyphus/optimizers/QuickMin.py +74 -0
  299. pysisyphus/optimizers/RFOptimizer.py +181 -0
  300. pysisyphus/optimizers/RSA.py +99 -0
  301. pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
  302. pysisyphus/optimizers/SteepestDescent.py +23 -0
  303. pysisyphus/optimizers/StringOptimizer.py +173 -0
  304. pysisyphus/optimizers/__init__.py +41 -0
  305. pysisyphus/optimizers/closures.py +301 -0
  306. pysisyphus/optimizers/cls_map.py +58 -0
  307. pysisyphus/optimizers/exceptions.py +6 -0
  308. pysisyphus/optimizers/gdiis.py +280 -0
  309. pysisyphus/optimizers/guess_hessians.py +311 -0
  310. pysisyphus/optimizers/hessian_updates.py +355 -0
  311. pysisyphus/optimizers/poly_fit.py +285 -0
  312. pysisyphus/optimizers/precon.py +153 -0
  313. pysisyphus/optimizers/restrict_step.py +24 -0
  314. pysisyphus/pack.py +172 -0
  315. pysisyphus/peakdetect.py +948 -0
  316. pysisyphus/plot.py +1031 -0
  317. pysisyphus/run.py +2106 -0
  318. pysisyphus/socket_helper.py +74 -0
  319. pysisyphus/stocastic/FragmentKick.py +132 -0
  320. pysisyphus/stocastic/Kick.py +81 -0
  321. pysisyphus/stocastic/Pipeline.py +303 -0
  322. pysisyphus/stocastic/__init__.py +21 -0
  323. pysisyphus/stocastic/align.py +127 -0
  324. pysisyphus/testing.py +96 -0
  325. pysisyphus/thermo.py +156 -0
  326. pysisyphus/trj.py +824 -0
  327. pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
  328. pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
  329. pysisyphus/tsoptimizers/TRIM.py +59 -0
  330. pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
  331. pysisyphus/tsoptimizers/__init__.py +23 -0
  332. pysisyphus/wavefunction/Basis.py +239 -0
  333. pysisyphus/wavefunction/DIIS.py +76 -0
  334. pysisyphus/wavefunction/__init__.py +25 -0
  335. pysisyphus/wavefunction/build_ext.py +42 -0
  336. pysisyphus/wavefunction/cart2sph.py +190 -0
  337. pysisyphus/wavefunction/diabatization.py +304 -0
  338. pysisyphus/wavefunction/excited_states.py +435 -0
  339. pysisyphus/wavefunction/gen_ints.py +1811 -0
  340. pysisyphus/wavefunction/helpers.py +104 -0
  341. pysisyphus/wavefunction/ints/__init__.py +0 -0
  342. pysisyphus/wavefunction/ints/boys.py +193 -0
  343. pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
  344. pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
  345. pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
  346. pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
  347. pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
  348. pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
  349. pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
  350. pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
  351. pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
  352. pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
  353. pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
  354. pysisyphus/wavefunction/localization.py +458 -0
  355. pysisyphus/wavefunction/multipole.py +159 -0
  356. pysisyphus/wavefunction/normalization.py +36 -0
  357. pysisyphus/wavefunction/pop_analysis.py +134 -0
  358. pysisyphus/wavefunction/shells.py +1171 -0
  359. pysisyphus/wavefunction/wavefunction.py +504 -0
  360. pysisyphus/wrapper/__init__.py +11 -0
  361. pysisyphus/wrapper/exceptions.py +2 -0
  362. pysisyphus/wrapper/jmol.py +120 -0
  363. pysisyphus/wrapper/mwfn.py +169 -0
  364. pysisyphus/wrapper/packmol.py +71 -0
  365. pysisyphus/xyzloader.py +168 -0
  366. pysisyphus/yaml_mods.py +45 -0
  367. thermoanalysis/LICENSE +674 -0
  368. thermoanalysis/QCData.py +244 -0
  369. thermoanalysis/__init__.py +0 -0
  370. thermoanalysis/config.py +3 -0
  371. thermoanalysis/constants.py +20 -0
  372. thermoanalysis/thermo.py +1011 -0
@@ -0,0 +1,640 @@
1
+ #include <torch/extension.h>
2
+
3
+ #include <algorithm>
4
+ #include <cmath>
5
+ #include <cstdint>
6
+ #include <vector>
7
+
8
+ #ifdef _OPENMP
9
+ #include <omp.h>
10
+ #endif
11
+
12
+ // References:
13
+ // - hessian_ff term definitions and analytical force equations:
14
+ // - hessian_ff/terms/bond.py
15
+ // - hessian_ff/terms/angle.py
16
+ // - hessian_ff/terms/dihedral.py
17
+ // - hessian_ff/terms/cmap.py
18
+ // - OpenMM conventions used for compatibility:
19
+ // - Periodic torsion / dihedral sign conventions:
20
+ // https://github.com/openmm/openmm/blob/master/platforms/reference/src/SimTKReference/ReferenceProperDihedralBond.cpp
21
+ // - CMAP bicubic map conventions:
22
+ // https://github.com/openmm/openmm/blob/master/openmmapi/src/CMAPTorsionForceImpl.cpp
23
+
24
+ namespace {
25
+
26
+ constexpr double kTwoPi = 6.283185307179586476925286766559;
27
+
28
+ inline double clamp(double x, double lo, double hi) {
29
+ return std::max(lo, std::min(hi, x));
30
+ }
31
+
32
+ inline double wrap_0_2pi(double x) {
33
+ double y = std::fmod(x, kTwoPi);
34
+ if (y < 0.0) {
35
+ y += kTwoPi;
36
+ }
37
+ return y;
38
+ }
39
+
40
+ inline void cross3(
41
+ double ax,
42
+ double ay,
43
+ double az,
44
+ double bx,
45
+ double by,
46
+ double bz,
47
+ double& cx,
48
+ double& cy,
49
+ double& cz) {
50
+ cx = ay * bz - az * by;
51
+ cy = az * bx - ax * bz;
52
+ cz = ax * by - ay * bx;
53
+ }
54
+
55
+ inline double dot3(
56
+ double ax,
57
+ double ay,
58
+ double az,
59
+ double bx,
60
+ double by,
61
+ double bz) {
62
+ return ax * bx + ay * by + az * bz;
63
+ }
64
+
65
+ template <typename scalar_t>
66
+ inline void add_force(scalar_t* force, int64_t atom, double fx, double fy, double fz) {
67
+ force[3 * atom + 0] += static_cast<scalar_t>(fx);
68
+ force[3 * atom + 1] += static_cast<scalar_t>(fy);
69
+ force[3 * atom + 2] += static_cast<scalar_t>(fz);
70
+ }
71
+
72
+ template <typename scalar_t>
73
+ inline double dihedral_angle_from_points(
74
+ const scalar_t* coords,
75
+ int64_t i,
76
+ int64_t j,
77
+ int64_t k,
78
+ int64_t l,
79
+ double* out_cross12,
80
+ double* out_cross23,
81
+ double* out_v1,
82
+ double* out_v2,
83
+ double* out_v3) {
84
+ // OpenMM-compatible signed dihedral convention used in hessian_ff/terms/dihedral.py.
85
+ const double p0x = static_cast<double>(coords[3 * i + 0]);
86
+ const double p0y = static_cast<double>(coords[3 * i + 1]);
87
+ const double p0z = static_cast<double>(coords[3 * i + 2]);
88
+ const double p1x = static_cast<double>(coords[3 * j + 0]);
89
+ const double p1y = static_cast<double>(coords[3 * j + 1]);
90
+ const double p1z = static_cast<double>(coords[3 * j + 2]);
91
+ const double p2x = static_cast<double>(coords[3 * k + 0]);
92
+ const double p2y = static_cast<double>(coords[3 * k + 1]);
93
+ const double p2z = static_cast<double>(coords[3 * k + 2]);
94
+ const double p3x = static_cast<double>(coords[3 * l + 0]);
95
+ const double p3y = static_cast<double>(coords[3 * l + 1]);
96
+ const double p3z = static_cast<double>(coords[3 * l + 2]);
97
+
98
+ const double v1x = p0x - p1x;
99
+ const double v1y = p0y - p1y;
100
+ const double v1z = p0z - p1z;
101
+ const double v2x = p2x - p1x;
102
+ const double v2y = p2y - p1y;
103
+ const double v2z = p2z - p1z;
104
+ const double v3x = p2x - p3x;
105
+ const double v3y = p2y - p3y;
106
+ const double v3z = p2z - p3z;
107
+
108
+ double c12x, c12y, c12z;
109
+ double c23x, c23y, c23z;
110
+ cross3(v1x, v1y, v1z, v2x, v2y, v2z, c12x, c12y, c12z);
111
+ cross3(v2x, v2y, v2z, v3x, v3y, v3z, c23x, c23y, c23z);
112
+
113
+ const double n12 = std::max(std::sqrt(dot3(c12x, c12y, c12z, c12x, c12y, c12z)), 1.0e-12);
114
+ const double n23 = std::max(std::sqrt(dot3(c23x, c23y, c23z, c23x, c23y, c23z)), 1.0e-12);
115
+ const double cos_phi = clamp(dot3(c12x, c12y, c12z, c23x, c23y, c23z) / (n12 * n23), -1.0, 1.0);
116
+ double phi = std::acos(cos_phi);
117
+
118
+ const double sign_probe = dot3(v1x, v1y, v1z, c23x, c23y, c23z);
119
+ if (sign_probe < 0.0) {
120
+ phi = -phi;
121
+ }
122
+
123
+ if (out_cross12 != nullptr) {
124
+ out_cross12[0] = c12x;
125
+ out_cross12[1] = c12y;
126
+ out_cross12[2] = c12z;
127
+ }
128
+ if (out_cross23 != nullptr) {
129
+ out_cross23[0] = c23x;
130
+ out_cross23[1] = c23y;
131
+ out_cross23[2] = c23z;
132
+ }
133
+ if (out_v1 != nullptr) {
134
+ out_v1[0] = v1x;
135
+ out_v1[1] = v1y;
136
+ out_v1[2] = v1z;
137
+ }
138
+ if (out_v2 != nullptr) {
139
+ out_v2[0] = v2x;
140
+ out_v2[1] = v2y;
141
+ out_v2[2] = v2z;
142
+ }
143
+ if (out_v3 != nullptr) {
144
+ out_v3[0] = v3x;
145
+ out_v3[1] = v3y;
146
+ out_v3[2] = v3z;
147
+ }
148
+ return phi;
149
+ }
150
+
151
+ template <typename scalar_t>
152
+ inline void accumulate_dihedral_force(
153
+ const scalar_t* coords,
154
+ int64_t i,
155
+ int64_t j,
156
+ int64_t k,
157
+ int64_t l,
158
+ double dE_dphi,
159
+ scalar_t* force) {
160
+ double cross12[3], cross23[3], v1[3], v2[3], v3[3];
161
+ (void)dihedral_angle_from_points(
162
+ coords,
163
+ i,
164
+ j,
165
+ k,
166
+ l,
167
+ cross12,
168
+ cross23,
169
+ v1,
170
+ v2,
171
+ v3);
172
+
173
+ const double norm_cross12_sq = std::max(dot3(cross12[0], cross12[1], cross12[2], cross12[0], cross12[1], cross12[2]), 1.0e-24);
174
+ const double norm_cross23_sq = std::max(dot3(cross23[0], cross23[1], cross23[2], cross23[0], cross23[1], cross23[2]), 1.0e-24);
175
+ const double norm_v2_sq = std::max(dot3(v2[0], v2[1], v2[2], v2[0], v2[1], v2[2]), 1.0e-24);
176
+ const double norm_v2 = std::max(std::sqrt(norm_v2_sq), 1.0e-12);
177
+
178
+ const double f0 = (-dE_dphi * norm_v2) / norm_cross12_sq;
179
+ const double f3 = (dE_dphi * norm_v2) / norm_cross23_sq;
180
+ const double f1 = dot3(v1[0], v1[1], v1[2], v2[0], v2[1], v2[2]) / norm_v2_sq;
181
+ const double f2 = dot3(v3[0], v3[1], v3[2], v2[0], v2[1], v2[2]) / norm_v2_sq;
182
+
183
+ const double ff0x = f0 * cross12[0];
184
+ const double ff0y = f0 * cross12[1];
185
+ const double ff0z = f0 * cross12[2];
186
+ const double ff3x = f3 * cross23[0];
187
+ const double ff3y = f3 * cross23[1];
188
+ const double ff3z = f3 * cross23[2];
189
+
190
+ const double sx = f1 * ff0x - f2 * ff3x;
191
+ const double sy = f1 * ff0y - f2 * ff3y;
192
+ const double sz = f1 * ff0z - f2 * ff3z;
193
+
194
+ const double ff1x = ff0x - sx;
195
+ const double ff1y = ff0y - sy;
196
+ const double ff1z = ff0z - sz;
197
+ const double ff2x = ff3x + sx;
198
+ const double ff2y = ff3y + sy;
199
+ const double ff2z = ff3z + sz;
200
+
201
+ add_force(force, i, ff0x, ff0y, ff0z);
202
+ add_force(force, j, -ff1x, -ff1y, -ff1z);
203
+ add_force(force, k, -ff2x, -ff2y, -ff2z);
204
+ add_force(force, l, ff3x, ff3y, ff3z);
205
+ }
206
+
207
+ std::vector<torch::Tensor> bonded_energy_force_cpu(
208
+ const torch::Tensor& coords,
209
+ const torch::Tensor& bond_i,
210
+ const torch::Tensor& bond_j,
211
+ const torch::Tensor& bond_k,
212
+ const torch::Tensor& bond_r0,
213
+ const torch::Tensor& angle_i,
214
+ const torch::Tensor& angle_j,
215
+ const torch::Tensor& angle_k,
216
+ const torch::Tensor& angle_k0,
217
+ const torch::Tensor& angle_t0,
218
+ const torch::Tensor& dihed_i,
219
+ const torch::Tensor& dihed_j,
220
+ const torch::Tensor& dihed_k,
221
+ const torch::Tensor& dihed_l,
222
+ const torch::Tensor& dihed_force,
223
+ const torch::Tensor& dihed_period,
224
+ const torch::Tensor& dihed_phase,
225
+ const torch::Tensor& cmap_type,
226
+ const torch::Tensor& cmap_i,
227
+ const torch::Tensor& cmap_j,
228
+ const torch::Tensor& cmap_k,
229
+ const torch::Tensor& cmap_l,
230
+ const torch::Tensor& cmap_m,
231
+ const torch::Tensor& cmap_size,
232
+ const torch::Tensor& cmap_delta,
233
+ const torch::Tensor& cmap_offset,
234
+ const torch::Tensor& cmap_coeff) {
235
+ auto c = coords.contiguous();
236
+ TORCH_CHECK(c.device().is_cpu(), "bonded_energy_force_cpu expects CPU coords");
237
+ TORCH_CHECK(c.dim() == 2 && c.size(1) == 3, "coords must be [N,3]");
238
+
239
+ auto bi = bond_i.contiguous();
240
+ auto bj = bond_j.contiguous();
241
+ auto bk = bond_k.contiguous();
242
+ auto br0 = bond_r0.contiguous();
243
+
244
+ auto ai = angle_i.contiguous();
245
+ auto aj = angle_j.contiguous();
246
+ auto ak = angle_k.contiguous();
247
+ auto ak0 = angle_k0.contiguous();
248
+ auto at0 = angle_t0.contiguous();
249
+
250
+ auto di = dihed_i.contiguous();
251
+ auto dj = dihed_j.contiguous();
252
+ auto dk = dihed_k.contiguous();
253
+ auto dl = dihed_l.contiguous();
254
+ auto df = dihed_force.contiguous();
255
+ auto dp = dihed_period.contiguous();
256
+ auto dph = dihed_phase.contiguous();
257
+
258
+ auto ct = cmap_type.contiguous();
259
+ auto ci = cmap_i.contiguous();
260
+ auto cj = cmap_j.contiguous();
261
+ auto ck = cmap_k.contiguous();
262
+ auto cl = cmap_l.contiguous();
263
+ auto cm = cmap_m.contiguous();
264
+ auto csize = cmap_size.contiguous();
265
+ auto cdelta = cmap_delta.contiguous();
266
+ auto coff = cmap_offset.contiguous();
267
+ auto ccoef = cmap_coeff.contiguous();
268
+
269
+ TORCH_CHECK(bi.scalar_type() == torch::kInt64, "bond_i must be int64");
270
+ TORCH_CHECK(ai.scalar_type() == torch::kInt64, "angle_i must be int64");
271
+ TORCH_CHECK(di.scalar_type() == torch::kInt64, "dihed_i must be int64");
272
+ TORCH_CHECK(ct.scalar_type() == torch::kInt64, "cmap_type must be int64");
273
+ TORCH_CHECK(csize.scalar_type() == torch::kInt64, "cmap_size must be int64");
274
+ TORCH_CHECK(coff.scalar_type() == torch::kInt64, "cmap_offset must be int64");
275
+
276
+ auto force = torch::zeros_like(c);
277
+ auto e_bond = torch::zeros({}, c.options());
278
+ auto e_angle = torch::zeros({}, c.options());
279
+ auto e_dihed = torch::zeros({}, c.options());
280
+ auto e_cmap = torch::zeros({}, c.options());
281
+
282
+ AT_DISPATCH_FLOATING_TYPES(c.scalar_type(), "bonded_energy_force_cpu", [&] {
283
+ const scalar_t* xyz = c.data_ptr<scalar_t>();
284
+ scalar_t* f = force.data_ptr<scalar_t>();
285
+ const int64_t natom = c.size(0);
286
+ const int64_t stride = 3 * natom;
287
+ const int64_t nbond = bi.numel();
288
+ const int64_t nangle = ai.numel();
289
+ const int64_t ndihed = di.numel();
290
+ const int64_t ncmap = ct.numel();
291
+
292
+ const int64_t* bi_ptr = bi.data_ptr<int64_t>();
293
+ const int64_t* bj_ptr = bj.data_ptr<int64_t>();
294
+ const scalar_t* bk_ptr = bk.data_ptr<scalar_t>();
295
+ const scalar_t* br0_ptr = br0.data_ptr<scalar_t>();
296
+
297
+ const int64_t* ai_ptr = ai.data_ptr<int64_t>();
298
+ const int64_t* aj_ptr = aj.data_ptr<int64_t>();
299
+ const int64_t* ak_ptr = ak.data_ptr<int64_t>();
300
+ const scalar_t* ak0_ptr = ak0.data_ptr<scalar_t>();
301
+ const scalar_t* at0_ptr = at0.data_ptr<scalar_t>();
302
+
303
+ const int64_t* di_ptr = di.data_ptr<int64_t>();
304
+ const int64_t* dj_ptr = dj.data_ptr<int64_t>();
305
+ const int64_t* dk_ptr = dk.data_ptr<int64_t>();
306
+ const int64_t* dl_ptr = dl.data_ptr<int64_t>();
307
+ const scalar_t* df_ptr = df.data_ptr<scalar_t>();
308
+ const scalar_t* dp_ptr = dp.data_ptr<scalar_t>();
309
+ const scalar_t* dph_ptr = dph.data_ptr<scalar_t>();
310
+
311
+ const int64_t* ct_ptr = ct.data_ptr<int64_t>();
312
+ const int64_t* ci_ptr = ci.data_ptr<int64_t>();
313
+ const int64_t* cj_ptr = cj.data_ptr<int64_t>();
314
+ const int64_t* ck_ptr = ck.data_ptr<int64_t>();
315
+ const int64_t* cl_ptr = cl.data_ptr<int64_t>();
316
+ const int64_t* cm_ptr = cm.data_ptr<int64_t>();
317
+ const int64_t* csize_ptr = csize.data_ptr<int64_t>();
318
+ const scalar_t* cdelta_ptr = cdelta.data_ptr<scalar_t>();
319
+ const int64_t* coff_ptr = coff.data_ptr<int64_t>();
320
+ const scalar_t* ccoef_ptr = ccoef.data_ptr<scalar_t>();
321
+
322
+ int nthreads = 1;
323
+ #ifdef _OPENMP
324
+ nthreads = std::max(1, omp_get_max_threads());
325
+ #endif
326
+ std::vector<scalar_t> force_tls(
327
+ static_cast<size_t>(nthreads) * static_cast<size_t>(stride),
328
+ static_cast<scalar_t>(0));
329
+ std::vector<double> eb_tls(static_cast<size_t>(nthreads), 0.0);
330
+ std::vector<double> ea_tls(static_cast<size_t>(nthreads), 0.0);
331
+ std::vector<double> ed_tls(static_cast<size_t>(nthreads), 0.0);
332
+ std::vector<double> ec_tls(static_cast<size_t>(nthreads), 0.0);
333
+
334
+ // Bond
335
+ #ifdef _OPENMP
336
+ #pragma omp parallel num_threads(nthreads)
337
+ #endif
338
+ {
339
+ int tid = 0;
340
+ #ifdef _OPENMP
341
+ tid = omp_get_thread_num();
342
+ #endif
343
+ scalar_t* fl = force_tls.data() + static_cast<size_t>(tid) * static_cast<size_t>(stride);
344
+ double eb_local = 0.0;
345
+ #ifdef _OPENMP
346
+ #pragma omp for schedule(static)
347
+ #endif
348
+ for (int64_t p = 0; p < nbond; ++p) {
349
+ const int64_t i = bi_ptr[p];
350
+ const int64_t j = bj_ptr[p];
351
+ const double k = static_cast<double>(bk_ptr[p]);
352
+ const double r0 = static_cast<double>(br0_ptr[p]);
353
+
354
+ const double dx = static_cast<double>(xyz[3 * j + 0] - xyz[3 * i + 0]);
355
+ const double dy = static_cast<double>(xyz[3 * j + 1] - xyz[3 * i + 1]);
356
+ const double dz = static_cast<double>(xyz[3 * j + 2] - xyz[3 * i + 2]);
357
+ const double r2 = std::max(dx * dx + dy * dy + dz * dz, 1.0e-24);
358
+ const double inv_r = 1.0 / std::sqrt(r2);
359
+ const double r = r2 * inv_r;
360
+ const double dr = r - r0;
361
+ eb_local += k * dr * dr;
362
+ const double fs = 2.0 * k * dr * inv_r;
363
+ add_force(fl, i, fs * dx, fs * dy, fs * dz);
364
+ add_force(fl, j, -fs * dx, -fs * dy, -fs * dz);
365
+ }
366
+ eb_tls[static_cast<size_t>(tid)] += eb_local;
367
+ }
368
+
369
+ // Angle
370
+ #ifdef _OPENMP
371
+ #pragma omp parallel num_threads(nthreads)
372
+ #endif
373
+ {
374
+ int tid = 0;
375
+ #ifdef _OPENMP
376
+ tid = omp_get_thread_num();
377
+ #endif
378
+ scalar_t* fl = force_tls.data() + static_cast<size_t>(tid) * static_cast<size_t>(stride);
379
+ double ea_local = 0.0;
380
+ #ifdef _OPENMP
381
+ #pragma omp for schedule(static)
382
+ #endif
383
+ for (int64_t p = 0; p < nangle; ++p) {
384
+ const int64_t i = ai_ptr[p];
385
+ const int64_t j = aj_ptr[p];
386
+ const int64_t k = ak_ptr[p];
387
+ const double ktheta = static_cast<double>(ak0_ptr[p]);
388
+ const double theta0 = static_cast<double>(at0_ptr[p]);
389
+
390
+ const double d0x = static_cast<double>(xyz[3 * j + 0] - xyz[3 * i + 0]);
391
+ const double d0y = static_cast<double>(xyz[3 * j + 1] - xyz[3 * i + 1]);
392
+ const double d0z = static_cast<double>(xyz[3 * j + 2] - xyz[3 * i + 2]);
393
+ const double d1x = static_cast<double>(xyz[3 * j + 0] - xyz[3 * k + 0]);
394
+ const double d1y = static_cast<double>(xyz[3 * j + 1] - xyz[3 * k + 1]);
395
+ const double d1z = static_cast<double>(xyz[3 * j + 2] - xyz[3 * k + 2]);
396
+
397
+ double px, py, pz;
398
+ cross3(d0x, d0y, d0z, d1x, d1y, d1z, px, py, pz);
399
+
400
+ const double r20 = std::max(dot3(d0x, d0y, d0z, d0x, d0y, d0z), 1.0e-24);
401
+ const double r21 = std::max(dot3(d1x, d1y, d1z, d1x, d1y, d1z), 1.0e-24);
402
+ const double rp = std::max(std::sqrt(dot3(px, py, pz, px, py, pz)), 1.0e-12);
403
+ const double dot = dot3(d0x, d0y, d0z, d1x, d1y, d1z);
404
+ const double cos_theta = clamp(dot / std::sqrt(r20 * r21), -1.0, 1.0);
405
+ const double theta = std::acos(cos_theta);
406
+
407
+ const double dtheta = theta - theta0;
408
+ ea_local += ktheta * dtheta * dtheta;
409
+ const double dE_dtheta = 2.0 * ktheta * dtheta;
410
+
411
+ const double term_i = dE_dtheta / (r20 * rp);
412
+ const double term_k = -dE_dtheta / (r21 * rp);
413
+
414
+ double cix, ciy, ciz;
415
+ double ckx, cky, ckz;
416
+ cross3(d0x, d0y, d0z, px, py, pz, cix, ciy, ciz);
417
+ cross3(d1x, d1y, d1z, px, py, pz, ckx, cky, ckz);
418
+
419
+ const double fix = cix * term_i;
420
+ const double fiy = ciy * term_i;
421
+ const double fiz = ciz * term_i;
422
+ const double fkx = ckx * term_k;
423
+ const double fky = cky * term_k;
424
+ const double fkz = ckz * term_k;
425
+ const double fjx = -(fix + fkx);
426
+ const double fjy = -(fiy + fky);
427
+ const double fjz = -(fiz + fkz);
428
+
429
+ add_force(fl, i, fix, fiy, fiz);
430
+ add_force(fl, j, fjx, fjy, fjz);
431
+ add_force(fl, k, fkx, fky, fkz);
432
+ }
433
+ ea_tls[static_cast<size_t>(tid)] += ea_local;
434
+ }
435
+
436
+ // Dihedral
437
+ #ifdef _OPENMP
438
+ #pragma omp parallel num_threads(nthreads)
439
+ #endif
440
+ {
441
+ int tid = 0;
442
+ #ifdef _OPENMP
443
+ tid = omp_get_thread_num();
444
+ #endif
445
+ scalar_t* fl = force_tls.data() + static_cast<size_t>(tid) * static_cast<size_t>(stride);
446
+ double ed_local = 0.0;
447
+ #ifdef _OPENMP
448
+ #pragma omp for schedule(static)
449
+ #endif
450
+ for (int64_t p = 0; p < ndihed; ++p) {
451
+ const int64_t i = di_ptr[p];
452
+ const int64_t j = dj_ptr[p];
453
+ const int64_t k = dk_ptr[p];
454
+ const int64_t l = dl_ptr[p];
455
+ const double kf = static_cast<double>(df_ptr[p]);
456
+ const double n = std::abs(static_cast<double>(dp_ptr[p]));
457
+ const double phase = static_cast<double>(dph_ptr[p]);
458
+
459
+ const double phi = dihedral_angle_from_points<scalar_t>(
460
+ xyz, i, j, k, l, nullptr, nullptr, nullptr, nullptr, nullptr);
461
+ const double delta = n * phi - phase;
462
+ ed_local += kf * (1.0 + std::cos(delta));
463
+ const double dE_dphi = -kf * n * std::sin(delta);
464
+ accumulate_dihedral_force<scalar_t>(xyz, i, j, k, l, dE_dphi, fl);
465
+ }
466
+ ed_tls[static_cast<size_t>(tid)] += ed_local;
467
+ }
468
+
469
+ // CMAP
470
+ #ifdef _OPENMP
471
+ #pragma omp parallel num_threads(nthreads)
472
+ #endif
473
+ {
474
+ int tid = 0;
475
+ #ifdef _OPENMP
476
+ tid = omp_get_thread_num();
477
+ #endif
478
+ scalar_t* fl = force_tls.data() + static_cast<size_t>(tid) * static_cast<size_t>(stride);
479
+ double ec_local = 0.0;
480
+ #ifdef _OPENMP
481
+ #pragma omp for schedule(static)
482
+ #endif
483
+ for (int64_t p = 0; p < ncmap; ++p) {
484
+ const int64_t tmap = ct_ptr[p];
485
+ const int64_t i = ci_ptr[p];
486
+ const int64_t j = cj_ptr[p];
487
+ const int64_t k = ck_ptr[p];
488
+ const int64_t l = cl_ptr[p];
489
+ const int64_t m = cm_ptr[p];
490
+
491
+ const double phi = dihedral_angle_from_points<scalar_t>(
492
+ xyz, i, j, k, l, nullptr, nullptr, nullptr, nullptr, nullptr);
493
+ const double psi = dihedral_angle_from_points<scalar_t>(
494
+ xyz, j, k, l, m, nullptr, nullptr, nullptr, nullptr, nullptr);
495
+ const double delta = static_cast<double>(cdelta_ptr[tmap]);
496
+ const int64_t size = csize_ptr[tmap];
497
+
498
+ const double ang_a = wrap_0_2pi(phi + kTwoPi);
499
+ const double ang_b = wrap_0_2pi(psi + kTwoPi);
500
+ const double u = ang_a / delta;
501
+ const double v = ang_b / delta;
502
+ int64_t su = static_cast<int64_t>(std::floor(u));
503
+ int64_t sv = static_cast<int64_t>(std::floor(v));
504
+ su = std::min(su, size - 1);
505
+ sv = std::min(sv, size - 1);
506
+ const double da = u - static_cast<double>(su);
507
+ const double db = v - static_cast<double>(sv);
508
+
509
+ const int64_t patch = su + size * sv;
510
+ const int64_t coeff_row = coff_ptr[tmap] + patch;
511
+ const scalar_t* coeff = ccoef_ptr + 16 * coeff_row;
512
+
513
+ const double db2 = db * db;
514
+ const double da2 = da * da;
515
+ const double da3 = da2 * da;
516
+
517
+ double ppoly[4];
518
+ double pd[4];
519
+ for (int r = 0; r < 4; ++r) {
520
+ const double c0 = static_cast<double>(coeff[4 * r + 0]);
521
+ const double c1 = static_cast<double>(coeff[4 * r + 1]);
522
+ const double c2 = static_cast<double>(coeff[4 * r + 2]);
523
+ const double c3 = static_cast<double>(coeff[4 * r + 3]);
524
+ ppoly[r] = c0 + c1 * db + c2 * db2 + c3 * db2 * db;
525
+ pd[r] = c1 + 2.0 * c2 * db + 3.0 * c3 * db2;
526
+ }
527
+
528
+ const double e_val = ppoly[0] + ppoly[1] * da + ppoly[2] * da2 + ppoly[3] * da3;
529
+ ec_local += e_val;
530
+
531
+ const double dE_dda = ppoly[1] + 2.0 * ppoly[2] * da + 3.0 * ppoly[3] * da2;
532
+ const double dE_ddb = pd[0] + pd[1] * da + pd[2] * da2 + pd[3] * da3;
533
+ const double dE_dphi = dE_dda / delta;
534
+ const double dE_dpsi = dE_ddb / delta;
535
+
536
+ accumulate_dihedral_force<scalar_t>(xyz, i, j, k, l, dE_dphi, fl);
537
+ accumulate_dihedral_force<scalar_t>(xyz, j, k, l, m, dE_dpsi, fl);
538
+ }
539
+ ec_tls[static_cast<size_t>(tid)] += ec_local;
540
+ }
541
+
542
+ double eb = 0.0;
543
+ double ea = 0.0;
544
+ double ed = 0.0;
545
+ double ec = 0.0;
546
+ for (int tid = 0; tid < nthreads; ++tid) {
547
+ eb += eb_tls[static_cast<size_t>(tid)];
548
+ ea += ea_tls[static_cast<size_t>(tid)];
549
+ ed += ed_tls[static_cast<size_t>(tid)];
550
+ ec += ec_tls[static_cast<size_t>(tid)];
551
+ }
552
+
553
+ for (int64_t idx = 0; idx < stride; ++idx) {
554
+ double acc = 0.0;
555
+ for (int tid = 0; tid < nthreads; ++tid) {
556
+ acc += static_cast<double>(
557
+ force_tls[static_cast<size_t>(tid) * static_cast<size_t>(stride) + static_cast<size_t>(idx)]);
558
+ }
559
+ f[idx] = static_cast<scalar_t>(acc);
560
+ }
561
+
562
+ e_bond.fill_(static_cast<scalar_t>(eb));
563
+ e_angle.fill_(static_cast<scalar_t>(ea));
564
+ e_dihed.fill_(static_cast<scalar_t>(ed));
565
+ e_cmap.fill_(static_cast<scalar_t>(ec));
566
+ });
567
+
568
+ return {e_bond, e_angle, e_dihed, e_cmap, force};
569
+ }
570
+
571
+ } // namespace
572
+
573
+ std::vector<torch::Tensor> bonded_energy_force(
574
+ const torch::Tensor& coords,
575
+ const torch::Tensor& bond_i,
576
+ const torch::Tensor& bond_j,
577
+ const torch::Tensor& bond_k,
578
+ const torch::Tensor& bond_r0,
579
+ const torch::Tensor& angle_i,
580
+ const torch::Tensor& angle_j,
581
+ const torch::Tensor& angle_k,
582
+ const torch::Tensor& angle_k0,
583
+ const torch::Tensor& angle_t0,
584
+ const torch::Tensor& dihed_i,
585
+ const torch::Tensor& dihed_j,
586
+ const torch::Tensor& dihed_k,
587
+ const torch::Tensor& dihed_l,
588
+ const torch::Tensor& dihed_force,
589
+ const torch::Tensor& dihed_period,
590
+ const torch::Tensor& dihed_phase,
591
+ const torch::Tensor& cmap_type,
592
+ const torch::Tensor& cmap_i,
593
+ const torch::Tensor& cmap_j,
594
+ const torch::Tensor& cmap_k,
595
+ const torch::Tensor& cmap_l,
596
+ const torch::Tensor& cmap_m,
597
+ const torch::Tensor& cmap_size,
598
+ const torch::Tensor& cmap_delta,
599
+ const torch::Tensor& cmap_offset,
600
+ const torch::Tensor& cmap_coeff) {
601
+ if (!coords.device().is_cpu()) {
602
+ TORCH_CHECK(false, "bonded_energy_force currently supports CPU tensors only");
603
+ }
604
+ return bonded_energy_force_cpu(
605
+ coords,
606
+ bond_i,
607
+ bond_j,
608
+ bond_k,
609
+ bond_r0,
610
+ angle_i,
611
+ angle_j,
612
+ angle_k,
613
+ angle_k0,
614
+ angle_t0,
615
+ dihed_i,
616
+ dihed_j,
617
+ dihed_k,
618
+ dihed_l,
619
+ dihed_force,
620
+ dihed_period,
621
+ dihed_phase,
622
+ cmap_type,
623
+ cmap_i,
624
+ cmap_j,
625
+ cmap_k,
626
+ cmap_l,
627
+ cmap_m,
628
+ cmap_size,
629
+ cmap_delta,
630
+ cmap_offset,
631
+ cmap_coeff);
632
+ }
633
+
634
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
635
+ m.doc() = "hessian_ff native bonded extension";
636
+ m.def(
637
+ "bonded_energy_force",
638
+ &bonded_energy_force,
639
+ "Compute bonded energy/force (bond/angle/dihedral/cmap)");
640
+ }