mlmm-toolkit 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (372) hide show
  1. hessian_ff/__init__.py +50 -0
  2. hessian_ff/analytical_hessian.py +609 -0
  3. hessian_ff/constants.py +46 -0
  4. hessian_ff/forcefield.py +339 -0
  5. hessian_ff/loaders.py +608 -0
  6. hessian_ff/native/Makefile +8 -0
  7. hessian_ff/native/__init__.py +28 -0
  8. hessian_ff/native/analytical_hessian.py +88 -0
  9. hessian_ff/native/analytical_hessian_ext.cpp +258 -0
  10. hessian_ff/native/bonded.py +82 -0
  11. hessian_ff/native/bonded_ext.cpp +640 -0
  12. hessian_ff/native/loader.py +349 -0
  13. hessian_ff/native/nonbonded.py +118 -0
  14. hessian_ff/native/nonbonded_ext.cpp +1150 -0
  15. hessian_ff/prmtop_parmed.py +23 -0
  16. hessian_ff/system.py +107 -0
  17. hessian_ff/terms/__init__.py +14 -0
  18. hessian_ff/terms/angle.py +73 -0
  19. hessian_ff/terms/bond.py +44 -0
  20. hessian_ff/terms/cmap.py +406 -0
  21. hessian_ff/terms/dihedral.py +141 -0
  22. hessian_ff/terms/nonbonded.py +209 -0
  23. hessian_ff/tests/__init__.py +0 -0
  24. hessian_ff/tests/conftest.py +75 -0
  25. hessian_ff/tests/data/small/complex.parm7 +1346 -0
  26. hessian_ff/tests/data/small/complex.pdb +125 -0
  27. hessian_ff/tests/data/small/complex.rst7 +63 -0
  28. hessian_ff/tests/test_coords_input.py +44 -0
  29. hessian_ff/tests/test_energy_force.py +49 -0
  30. hessian_ff/tests/test_hessian.py +137 -0
  31. hessian_ff/tests/test_smoke.py +18 -0
  32. hessian_ff/tests/test_validation.py +40 -0
  33. hessian_ff/workflows.py +889 -0
  34. mlmm/__init__.py +36 -0
  35. mlmm/__main__.py +7 -0
  36. mlmm/_version.py +34 -0
  37. mlmm/add_elem_info.py +374 -0
  38. mlmm/advanced_help.py +91 -0
  39. mlmm/align_freeze_atoms.py +601 -0
  40. mlmm/all.py +3535 -0
  41. mlmm/bond_changes.py +231 -0
  42. mlmm/bool_compat.py +223 -0
  43. mlmm/cli.py +574 -0
  44. mlmm/cli_utils.py +166 -0
  45. mlmm/default_group.py +337 -0
  46. mlmm/defaults.py +467 -0
  47. mlmm/define_layer.py +526 -0
  48. mlmm/dft.py +1041 -0
  49. mlmm/energy_diagram.py +253 -0
  50. mlmm/extract.py +2213 -0
  51. mlmm/fix_altloc.py +464 -0
  52. mlmm/freq.py +1406 -0
  53. mlmm/harmonic_constraints.py +140 -0
  54. mlmm/hessian_cache.py +44 -0
  55. mlmm/hessian_calc.py +174 -0
  56. mlmm/irc.py +638 -0
  57. mlmm/mlmm_calc.py +2262 -0
  58. mlmm/mm_parm.py +945 -0
  59. mlmm/oniom_export.py +1983 -0
  60. mlmm/oniom_import.py +457 -0
  61. mlmm/opt.py +1742 -0
  62. mlmm/path_opt.py +1353 -0
  63. mlmm/path_search.py +2299 -0
  64. mlmm/preflight.py +88 -0
  65. mlmm/py.typed +1 -0
  66. mlmm/pysis_runner.py +45 -0
  67. mlmm/scan.py +1047 -0
  68. mlmm/scan2d.py +1226 -0
  69. mlmm/scan3d.py +1265 -0
  70. mlmm/scan_common.py +184 -0
  71. mlmm/summary_log.py +736 -0
  72. mlmm/trj2fig.py +448 -0
  73. mlmm/tsopt.py +2871 -0
  74. mlmm/utils.py +2309 -0
  75. mlmm/xtb_embedcharge_correction.py +475 -0
  76. mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
  77. mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
  78. mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
  79. mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
  80. mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
  81. mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
  82. pysisyphus/Geometry.py +1667 -0
  83. pysisyphus/LICENSE +674 -0
  84. pysisyphus/TableFormatter.py +63 -0
  85. pysisyphus/TablePrinter.py +74 -0
  86. pysisyphus/__init__.py +12 -0
  87. pysisyphus/calculators/AFIR.py +452 -0
  88. pysisyphus/calculators/AnaPot.py +20 -0
  89. pysisyphus/calculators/AnaPot2.py +48 -0
  90. pysisyphus/calculators/AnaPot3.py +12 -0
  91. pysisyphus/calculators/AnaPot4.py +20 -0
  92. pysisyphus/calculators/AnaPotBase.py +337 -0
  93. pysisyphus/calculators/AnaPotCBM.py +25 -0
  94. pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
  95. pysisyphus/calculators/CFOUR.py +250 -0
  96. pysisyphus/calculators/Calculator.py +844 -0
  97. pysisyphus/calculators/CerjanMiller.py +24 -0
  98. pysisyphus/calculators/Composite.py +123 -0
  99. pysisyphus/calculators/ConicalIntersection.py +171 -0
  100. pysisyphus/calculators/DFTBp.py +430 -0
  101. pysisyphus/calculators/DFTD3.py +66 -0
  102. pysisyphus/calculators/DFTD4.py +84 -0
  103. pysisyphus/calculators/Dalton.py +61 -0
  104. pysisyphus/calculators/Dimer.py +681 -0
  105. pysisyphus/calculators/Dummy.py +20 -0
  106. pysisyphus/calculators/EGO.py +76 -0
  107. pysisyphus/calculators/EnergyMin.py +224 -0
  108. pysisyphus/calculators/ExternalPotential.py +264 -0
  109. pysisyphus/calculators/FakeASE.py +35 -0
  110. pysisyphus/calculators/FourWellAnaPot.py +28 -0
  111. pysisyphus/calculators/FreeEndNEBPot.py +39 -0
  112. pysisyphus/calculators/Gaussian09.py +18 -0
  113. pysisyphus/calculators/Gaussian16.py +726 -0
  114. pysisyphus/calculators/HardSphere.py +159 -0
  115. pysisyphus/calculators/IDPPCalculator.py +49 -0
  116. pysisyphus/calculators/IPIClient.py +133 -0
  117. pysisyphus/calculators/IPIServer.py +234 -0
  118. pysisyphus/calculators/LEPSBase.py +24 -0
  119. pysisyphus/calculators/LEPSExpr.py +139 -0
  120. pysisyphus/calculators/LennardJones.py +80 -0
  121. pysisyphus/calculators/MOPAC.py +219 -0
  122. pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
  123. pysisyphus/calculators/MultiCalc.py +85 -0
  124. pysisyphus/calculators/NFK.py +45 -0
  125. pysisyphus/calculators/OBabel.py +87 -0
  126. pysisyphus/calculators/ONIOMv2.py +1129 -0
  127. pysisyphus/calculators/ORCA.py +893 -0
  128. pysisyphus/calculators/ORCA5.py +6 -0
  129. pysisyphus/calculators/OpenMM.py +88 -0
  130. pysisyphus/calculators/OpenMolcas.py +281 -0
  131. pysisyphus/calculators/OverlapCalculator.py +908 -0
  132. pysisyphus/calculators/Psi4.py +218 -0
  133. pysisyphus/calculators/PyPsi4.py +37 -0
  134. pysisyphus/calculators/PySCF.py +341 -0
  135. pysisyphus/calculators/PyXTB.py +73 -0
  136. pysisyphus/calculators/QCEngine.py +106 -0
  137. pysisyphus/calculators/Rastrigin.py +22 -0
  138. pysisyphus/calculators/Remote.py +76 -0
  139. pysisyphus/calculators/Rosenbrock.py +15 -0
  140. pysisyphus/calculators/SocketCalc.py +97 -0
  141. pysisyphus/calculators/TIP3P.py +111 -0
  142. pysisyphus/calculators/TransTorque.py +161 -0
  143. pysisyphus/calculators/Turbomole.py +965 -0
  144. pysisyphus/calculators/VRIPot.py +37 -0
  145. pysisyphus/calculators/WFOWrapper.py +333 -0
  146. pysisyphus/calculators/WFOWrapper2.py +341 -0
  147. pysisyphus/calculators/XTB.py +418 -0
  148. pysisyphus/calculators/__init__.py +81 -0
  149. pysisyphus/calculators/cosmo_data.py +139 -0
  150. pysisyphus/calculators/parser.py +150 -0
  151. pysisyphus/color.py +19 -0
  152. pysisyphus/config.py +133 -0
  153. pysisyphus/constants.py +65 -0
  154. pysisyphus/cos/AdaptiveNEB.py +230 -0
  155. pysisyphus/cos/ChainOfStates.py +725 -0
  156. pysisyphus/cos/FreeEndNEB.py +25 -0
  157. pysisyphus/cos/FreezingString.py +103 -0
  158. pysisyphus/cos/GrowingChainOfStates.py +71 -0
  159. pysisyphus/cos/GrowingNT.py +309 -0
  160. pysisyphus/cos/GrowingString.py +508 -0
  161. pysisyphus/cos/NEB.py +189 -0
  162. pysisyphus/cos/SimpleZTS.py +64 -0
  163. pysisyphus/cos/__init__.py +22 -0
  164. pysisyphus/cos/stiffness.py +199 -0
  165. pysisyphus/drivers/__init__.py +17 -0
  166. pysisyphus/drivers/afir.py +855 -0
  167. pysisyphus/drivers/barriers.py +271 -0
  168. pysisyphus/drivers/birkholz.py +138 -0
  169. pysisyphus/drivers/cluster.py +318 -0
  170. pysisyphus/drivers/diabatization.py +133 -0
  171. pysisyphus/drivers/merge.py +368 -0
  172. pysisyphus/drivers/merge_mol2.py +322 -0
  173. pysisyphus/drivers/opt.py +375 -0
  174. pysisyphus/drivers/perf.py +91 -0
  175. pysisyphus/drivers/pka.py +52 -0
  176. pysisyphus/drivers/precon_pos_rot.py +669 -0
  177. pysisyphus/drivers/rates.py +480 -0
  178. pysisyphus/drivers/replace.py +219 -0
  179. pysisyphus/drivers/scan.py +212 -0
  180. pysisyphus/drivers/spectrum.py +166 -0
  181. pysisyphus/drivers/thermo.py +31 -0
  182. pysisyphus/dynamics/Gaussian.py +103 -0
  183. pysisyphus/dynamics/__init__.py +20 -0
  184. pysisyphus/dynamics/colvars.py +136 -0
  185. pysisyphus/dynamics/driver.py +297 -0
  186. pysisyphus/dynamics/helpers.py +256 -0
  187. pysisyphus/dynamics/lincs.py +105 -0
  188. pysisyphus/dynamics/mdp.py +364 -0
  189. pysisyphus/dynamics/rattle.py +121 -0
  190. pysisyphus/dynamics/thermostats.py +128 -0
  191. pysisyphus/dynamics/wigner.py +266 -0
  192. pysisyphus/elem_data.py +3473 -0
  193. pysisyphus/exceptions.py +2 -0
  194. pysisyphus/filtertrj.py +69 -0
  195. pysisyphus/helpers.py +623 -0
  196. pysisyphus/helpers_pure.py +649 -0
  197. pysisyphus/init_logging.py +50 -0
  198. pysisyphus/intcoords/Bend.py +69 -0
  199. pysisyphus/intcoords/Bend2.py +25 -0
  200. pysisyphus/intcoords/BondedFragment.py +32 -0
  201. pysisyphus/intcoords/Cartesian.py +41 -0
  202. pysisyphus/intcoords/CartesianCoords.py +140 -0
  203. pysisyphus/intcoords/Coords.py +56 -0
  204. pysisyphus/intcoords/DLC.py +197 -0
  205. pysisyphus/intcoords/DistanceFunction.py +34 -0
  206. pysisyphus/intcoords/DummyImproper.py +70 -0
  207. pysisyphus/intcoords/DummyTorsion.py +72 -0
  208. pysisyphus/intcoords/LinearBend.py +105 -0
  209. pysisyphus/intcoords/LinearDisplacement.py +80 -0
  210. pysisyphus/intcoords/OutOfPlane.py +59 -0
  211. pysisyphus/intcoords/PrimTypes.py +286 -0
  212. pysisyphus/intcoords/Primitive.py +137 -0
  213. pysisyphus/intcoords/RedundantCoords.py +659 -0
  214. pysisyphus/intcoords/RobustTorsion.py +59 -0
  215. pysisyphus/intcoords/Rotation.py +147 -0
  216. pysisyphus/intcoords/Stretch.py +31 -0
  217. pysisyphus/intcoords/Torsion.py +101 -0
  218. pysisyphus/intcoords/Torsion2.py +25 -0
  219. pysisyphus/intcoords/Translation.py +45 -0
  220. pysisyphus/intcoords/__init__.py +61 -0
  221. pysisyphus/intcoords/augment_bonds.py +126 -0
  222. pysisyphus/intcoords/derivatives.py +10512 -0
  223. pysisyphus/intcoords/eval.py +80 -0
  224. pysisyphus/intcoords/exceptions.py +37 -0
  225. pysisyphus/intcoords/findiffs.py +48 -0
  226. pysisyphus/intcoords/generate_derivatives.py +414 -0
  227. pysisyphus/intcoords/helpers.py +235 -0
  228. pysisyphus/intcoords/logging_conf.py +10 -0
  229. pysisyphus/intcoords/mp_derivatives.py +10836 -0
  230. pysisyphus/intcoords/setup.py +962 -0
  231. pysisyphus/intcoords/setup_fast.py +176 -0
  232. pysisyphus/intcoords/update.py +272 -0
  233. pysisyphus/intcoords/valid.py +89 -0
  234. pysisyphus/interpolate/Geodesic.py +93 -0
  235. pysisyphus/interpolate/IDPP.py +55 -0
  236. pysisyphus/interpolate/Interpolator.py +116 -0
  237. pysisyphus/interpolate/LST.py +70 -0
  238. pysisyphus/interpolate/Redund.py +152 -0
  239. pysisyphus/interpolate/__init__.py +9 -0
  240. pysisyphus/interpolate/helpers.py +34 -0
  241. pysisyphus/io/__init__.py +22 -0
  242. pysisyphus/io/aomix.py +178 -0
  243. pysisyphus/io/cjson.py +24 -0
  244. pysisyphus/io/crd.py +101 -0
  245. pysisyphus/io/cube.py +220 -0
  246. pysisyphus/io/fchk.py +184 -0
  247. pysisyphus/io/hdf5.py +49 -0
  248. pysisyphus/io/hessian.py +72 -0
  249. pysisyphus/io/mol2.py +146 -0
  250. pysisyphus/io/molden.py +293 -0
  251. pysisyphus/io/orca.py +189 -0
  252. pysisyphus/io/pdb.py +269 -0
  253. pysisyphus/io/psf.py +79 -0
  254. pysisyphus/io/pubchem.py +31 -0
  255. pysisyphus/io/qcschema.py +34 -0
  256. pysisyphus/io/sdf.py +29 -0
  257. pysisyphus/io/xyz.py +61 -0
  258. pysisyphus/io/zmat.py +175 -0
  259. pysisyphus/irc/DWI.py +108 -0
  260. pysisyphus/irc/DampedVelocityVerlet.py +134 -0
  261. pysisyphus/irc/Euler.py +22 -0
  262. pysisyphus/irc/EulerPC.py +345 -0
  263. pysisyphus/irc/GonzalezSchlegel.py +187 -0
  264. pysisyphus/irc/IMKMod.py +164 -0
  265. pysisyphus/irc/IRC.py +878 -0
  266. pysisyphus/irc/IRCDummy.py +10 -0
  267. pysisyphus/irc/Instanton.py +307 -0
  268. pysisyphus/irc/LQA.py +53 -0
  269. pysisyphus/irc/ModeKill.py +136 -0
  270. pysisyphus/irc/ParamPlot.py +53 -0
  271. pysisyphus/irc/RK4.py +36 -0
  272. pysisyphus/irc/__init__.py +31 -0
  273. pysisyphus/irc/initial_displ.py +219 -0
  274. pysisyphus/linalg.py +411 -0
  275. pysisyphus/line_searches/Backtracking.py +88 -0
  276. pysisyphus/line_searches/HagerZhang.py +184 -0
  277. pysisyphus/line_searches/LineSearch.py +232 -0
  278. pysisyphus/line_searches/StrongWolfe.py +108 -0
  279. pysisyphus/line_searches/__init__.py +9 -0
  280. pysisyphus/line_searches/interpol.py +15 -0
  281. pysisyphus/modefollow/NormalMode.py +40 -0
  282. pysisyphus/modefollow/__init__.py +10 -0
  283. pysisyphus/modefollow/davidson.py +199 -0
  284. pysisyphus/modefollow/lanczos.py +95 -0
  285. pysisyphus/optimizers/BFGS.py +99 -0
  286. pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
  287. pysisyphus/optimizers/ConjugateGradient.py +98 -0
  288. pysisyphus/optimizers/CubicNewton.py +75 -0
  289. pysisyphus/optimizers/FIRE.py +113 -0
  290. pysisyphus/optimizers/HessianOptimizer.py +1176 -0
  291. pysisyphus/optimizers/LBFGS.py +228 -0
  292. pysisyphus/optimizers/LayerOpt.py +411 -0
  293. pysisyphus/optimizers/MicroOptimizer.py +169 -0
  294. pysisyphus/optimizers/NCOptimizer.py +90 -0
  295. pysisyphus/optimizers/Optimizer.py +1084 -0
  296. pysisyphus/optimizers/PreconLBFGS.py +260 -0
  297. pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
  298. pysisyphus/optimizers/QuickMin.py +74 -0
  299. pysisyphus/optimizers/RFOptimizer.py +181 -0
  300. pysisyphus/optimizers/RSA.py +99 -0
  301. pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
  302. pysisyphus/optimizers/SteepestDescent.py +23 -0
  303. pysisyphus/optimizers/StringOptimizer.py +173 -0
  304. pysisyphus/optimizers/__init__.py +41 -0
  305. pysisyphus/optimizers/closures.py +301 -0
  306. pysisyphus/optimizers/cls_map.py +58 -0
  307. pysisyphus/optimizers/exceptions.py +6 -0
  308. pysisyphus/optimizers/gdiis.py +280 -0
  309. pysisyphus/optimizers/guess_hessians.py +311 -0
  310. pysisyphus/optimizers/hessian_updates.py +355 -0
  311. pysisyphus/optimizers/poly_fit.py +285 -0
  312. pysisyphus/optimizers/precon.py +153 -0
  313. pysisyphus/optimizers/restrict_step.py +24 -0
  314. pysisyphus/pack.py +172 -0
  315. pysisyphus/peakdetect.py +948 -0
  316. pysisyphus/plot.py +1031 -0
  317. pysisyphus/run.py +2106 -0
  318. pysisyphus/socket_helper.py +74 -0
  319. pysisyphus/stocastic/FragmentKick.py +132 -0
  320. pysisyphus/stocastic/Kick.py +81 -0
  321. pysisyphus/stocastic/Pipeline.py +303 -0
  322. pysisyphus/stocastic/__init__.py +21 -0
  323. pysisyphus/stocastic/align.py +127 -0
  324. pysisyphus/testing.py +96 -0
  325. pysisyphus/thermo.py +156 -0
  326. pysisyphus/trj.py +824 -0
  327. pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
  328. pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
  329. pysisyphus/tsoptimizers/TRIM.py +59 -0
  330. pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
  331. pysisyphus/tsoptimizers/__init__.py +23 -0
  332. pysisyphus/wavefunction/Basis.py +239 -0
  333. pysisyphus/wavefunction/DIIS.py +76 -0
  334. pysisyphus/wavefunction/__init__.py +25 -0
  335. pysisyphus/wavefunction/build_ext.py +42 -0
  336. pysisyphus/wavefunction/cart2sph.py +190 -0
  337. pysisyphus/wavefunction/diabatization.py +304 -0
  338. pysisyphus/wavefunction/excited_states.py +435 -0
  339. pysisyphus/wavefunction/gen_ints.py +1811 -0
  340. pysisyphus/wavefunction/helpers.py +104 -0
  341. pysisyphus/wavefunction/ints/__init__.py +0 -0
  342. pysisyphus/wavefunction/ints/boys.py +193 -0
  343. pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
  344. pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
  345. pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
  346. pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
  347. pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
  348. pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
  349. pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
  350. pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
  351. pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
  352. pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
  353. pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
  354. pysisyphus/wavefunction/localization.py +458 -0
  355. pysisyphus/wavefunction/multipole.py +159 -0
  356. pysisyphus/wavefunction/normalization.py +36 -0
  357. pysisyphus/wavefunction/pop_analysis.py +134 -0
  358. pysisyphus/wavefunction/shells.py +1171 -0
  359. pysisyphus/wavefunction/wavefunction.py +504 -0
  360. pysisyphus/wrapper/__init__.py +11 -0
  361. pysisyphus/wrapper/exceptions.py +2 -0
  362. pysisyphus/wrapper/jmol.py +120 -0
  363. pysisyphus/wrapper/mwfn.py +169 -0
  364. pysisyphus/wrapper/packmol.py +71 -0
  365. pysisyphus/xyzloader.py +168 -0
  366. pysisyphus/yaml_mods.py +45 -0
  367. thermoanalysis/LICENSE +674 -0
  368. thermoanalysis/QCData.py +244 -0
  369. thermoanalysis/__init__.py +0 -0
  370. thermoanalysis/config.py +3 -0
  371. thermoanalysis/constants.py +20 -0
  372. thermoanalysis/thermo.py +1011 -0
@@ -0,0 +1,889 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Optional, Sequence, Union
7
+
8
+ import torch
9
+
10
+ from .analytical_hessian import build_analytical_hessian
11
+ from .forcefield import ForceFieldTorch
12
+ from .loaders import load_coords, load_system
13
+ from .system import AmberSystem
14
+ from .terms.nonbonded import NonbondedTerm
15
+
16
+ PathLike = Union[str, Path]
17
+ CoordsLike = Union[PathLike, torch.Tensor, Any]
18
+ CoordsBatchLike = Union[torch.Tensor, Sequence[CoordsLike]]
19
+
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # Hessian mode synonym map
23
+ # ---------------------------------------------------------------------------
24
+ _HESSIAN_MODE_MAP = {
25
+ "analytical": "analytical",
26
+ "forcejacobian": "analytical",
27
+ "force_jacobian": "analytical",
28
+ "customautograd": "analytical",
29
+ "custom_autograd": "analytical",
30
+ "autograd": "autograd",
31
+ "finitedifferenceforce": "fd",
32
+ "finite_difference_force": "fd",
33
+ "finitedifference": "fd",
34
+ "finite_difference": "fd",
35
+ "fd": "fd",
36
+ }
37
+
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Runtime cache
41
+ # ---------------------------------------------------------------------------
42
+ @dataclass
43
+ class _RuntimeEntry:
44
+ system: AmberSystem
45
+ ff: ForceFieldTorch
46
+ coords_buffer: Optional[torch.Tensor] = None
47
+
48
+
49
+ _RUNTIME_CACHE: Dict[tuple[str, str, str, int, bool], _RuntimeEntry] = {}
50
+
51
+
52
+ def clear_runtime_cache() -> None:
53
+ """Clear cached runtime objects (system/force-field/coord buffer)."""
54
+ _RUNTIME_CACHE.clear()
55
+
56
+
57
+ # ---------------------------------------------------------------------------
58
+ # Small helpers
59
+ # ---------------------------------------------------------------------------
60
+ def _dtype_from_double(double: bool) -> torch.dtype:
61
+ return torch.float64 if bool(double) else torch.float32
62
+
63
+
64
+ def _precision_info(double: bool) -> Dict[str, Any]:
65
+ dtype = _dtype_from_double(double)
66
+ if dtype == torch.float64:
67
+ dtype_name = "float64"
68
+ elif dtype == torch.float32:
69
+ dtype_name = "float32"
70
+ else:
71
+ dtype_name = str(dtype)
72
+ return {"double": bool(double), "torch_dtype": dtype_name}
73
+
74
+
75
+ def _configure_torch_threads(num_threads: Optional[int]) -> Optional[int]:
76
+ if num_threads is None:
77
+ return None
78
+ n = int(num_threads)
79
+ if n < 1:
80
+ raise ValueError(f"num_threads must be >=1, got {num_threads}")
81
+ import os
82
+
83
+ for env_key in ("OMP_NUM_THREADS", "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS"):
84
+ os.environ[env_key] = str(n)
85
+ torch.set_num_threads(n)
86
+ return n
87
+
88
+
89
+ def _energy_terms_to_float(energy_terms: Dict[str, torch.Tensor]) -> Dict[str, float]:
90
+ return {k: float(v.detach().cpu()) for k, v in energy_terms.items() if k.startswith("E_")}
91
+
92
+ # ---------------------------------------------------------------------------
93
+ # MPI helpers (unified)
94
+ # ---------------------------------------------------------------------------
95
+ def _get_mpi_context(
96
+ mpi: bool,
97
+ ) -> tuple[Optional[Any], Optional[Any], int, int]:
98
+ if not mpi:
99
+ return None, None, 0, 1
100
+ try:
101
+ from mpi4py import MPI # type: ignore
102
+ except Exception as e:
103
+ raise RuntimeError(
104
+ "mpi=True requires mpi4py and a working MPI runtime (libmpi). "
105
+ f"Original import error: {e}"
106
+ ) from e
107
+ comm = MPI.COMM_WORLD
108
+ return MPI, comm, int(comm.Get_rank()), int(comm.Get_size())
109
+
110
+
111
+ def _mpi_reduce_tensor(local: torch.Tensor, mpi_obj: Any, comm: Any) -> torch.Tensor:
112
+ """Allreduce a tensor (scalar, vector, or matrix) with MPI SUM."""
113
+ if comm is None:
114
+ return local
115
+ arr_local = local.detach().cpu().contiguous().reshape(-1).numpy()
116
+ arr_global = arr_local.copy()
117
+ comm.Allreduce(arr_local, arr_global, op=mpi_obj.SUM)
118
+ return torch.from_numpy(arr_global).reshape(local.shape).to(dtype=local.dtype)
119
+
120
+
121
+ # ---------------------------------------------------------------------------
122
+ # Result dict builder (eliminates repeated metadata assembly)
123
+ # ---------------------------------------------------------------------------
124
+ def _build_result_meta(
125
+ *,
126
+ double: bool,
127
+ num_threads: Optional[int] = None,
128
+ mpi_rank: int = 0,
129
+ mpi_size: int = 1,
130
+ **extra: Any,
131
+ ) -> Dict[str, Any]:
132
+ result: Dict[str, Any] = dict(extra)
133
+ result["eval_backend"] = "torch"
134
+ if num_threads is not None:
135
+ result["num_threads"] = int(num_threads)
136
+ result["mpi_enabled"] = bool(mpi_size > 1)
137
+ result["mpi_rank"] = int(mpi_rank)
138
+ result["mpi_size"] = int(mpi_size)
139
+ result.update(_precision_info(double))
140
+ return result
141
+
142
+
143
+ # ---------------------------------------------------------------------------
144
+ # Runtime loading
145
+ # ---------------------------------------------------------------------------
146
+ def _load_runtime(
147
+ prmtop: PathLike,
148
+ coords: CoordsLike,
149
+ device: Union[str, torch.device],
150
+ double: bool,
151
+ requires_grad: bool = False,
152
+ nonbonded_cpu_fast: bool = True,
153
+ ) -> tuple[AmberSystem, torch.Tensor, ForceFieldTorch]:
154
+ dtype = _dtype_from_double(double)
155
+ dev = str(torch.device(device))
156
+ cache_key = (
157
+ str(Path(prmtop).resolve()),
158
+ dev,
159
+ str(dtype),
160
+ bool(nonbonded_cpu_fast),
161
+ )
162
+
163
+ cached = _RUNTIME_CACHE.get(cache_key)
164
+ if cached is None:
165
+ system = load_system(prmtop, device=device).to(dtype=dtype)
166
+ ff = ForceFieldTorch(
167
+ system,
168
+ nonbonded_cpu_fast=nonbonded_cpu_fast,
169
+ )
170
+ entry = _RuntimeEntry(system=system, ff=ff, coords_buffer=None)
171
+ _RUNTIME_CACHE[cache_key] = entry
172
+ else:
173
+ entry = cached
174
+ system = entry.system
175
+ ff = entry.ff
176
+
177
+ xyz_loaded = load_coords(coords, natom=system.natom, device=device, dtype=dtype)
178
+ if requires_grad:
179
+ xyz = xyz_loaded.clone().detach().requires_grad_(True)
180
+ return system, xyz, ff
181
+
182
+ buf = entry.coords_buffer
183
+ if (
184
+ buf is None
185
+ or buf.shape != xyz_loaded.shape
186
+ or buf.dtype != xyz_loaded.dtype
187
+ or buf.device != xyz_loaded.device
188
+ ):
189
+ entry.coords_buffer = xyz_loaded.clone().detach()
190
+ else:
191
+ entry.coords_buffer.copy_(xyz_loaded)
192
+ xyz = entry.coords_buffer
193
+ return system, xyz, ff
194
+
195
+
196
+ # ---------------------------------------------------------------------------
197
+ # MPI-distributed energy/force on CPU
198
+ # ---------------------------------------------------------------------------
199
+ def _dist_energy_force_cpu(
200
+ system: AmberSystem,
201
+ coords: torch.Tensor,
202
+ *,
203
+ force_calc_mode: str,
204
+ mpi_obj: Any,
205
+ mpi_comm: Any,
206
+ mpi_rank: int,
207
+ mpi_size: int,
208
+ ) -> tuple[Dict[str, torch.Tensor], torch.Tensor]:
209
+ if str(force_calc_mode).strip().lower() != "analytical":
210
+ raise ValueError("mpi=True for E/F currently requires force_calc_mode='Analytical'")
211
+ if coords.device.type != "cpu":
212
+ raise ValueError("Distributed CPU E/F path requires CPU tensor")
213
+
214
+ e_bond = coords.new_zeros(())
215
+ e_angle = coords.new_zeros(())
216
+ e_dihed = coords.new_zeros(())
217
+ e_cmap = coords.new_zeros(())
218
+ force_bonded = torch.zeros_like(coords)
219
+
220
+ if mpi_rank == 0:
221
+ ff_root = ForceFieldTorch(system)
222
+ e_bond, f_bond = ff_root.bond.energy_force(coords)
223
+ e_angle, f_angle = ff_root.angle.energy_force(coords)
224
+ e_dihed, f_dihed = ff_root.dihedral.energy_force(coords)
225
+ e_cmap, f_cmap = ff_root.cmap.energy_force(coords)
226
+ force_bonded = f_bond + f_angle + f_dihed + f_cmap
227
+
228
+ def _shard(x: torch.Tensor) -> torch.Tensor:
229
+ if mpi_size <= 1 or x.numel() == 0:
230
+ return x
231
+ return x[mpi_rank::mpi_size]
232
+
233
+ nb_term = NonbondedTerm(
234
+ natom=system.natom,
235
+ charge=system.charge,
236
+ atom_type=system.atom_type,
237
+ lj_acoef=system.lj_acoef,
238
+ lj_bcoef=system.lj_bcoef,
239
+ hb_acoef=system.hb_acoef,
240
+ hb_bcoef=system.hb_bcoef,
241
+ nb_index=system.nb_index,
242
+ pair_i=_shard(system.pair_i),
243
+ pair_j=_shard(system.pair_j),
244
+ pair14_i=_shard(system.pair14_i),
245
+ pair14_j=_shard(system.pair14_j),
246
+ pair14_inv_scee=_shard(system.pair14_inv_scee),
247
+ pair14_inv_scnb=_shard(system.pair14_inv_scnb),
248
+ )
249
+ nb_e, f_nb = nb_term.energy_force(coords)
250
+
251
+ local_force = force_bonded + f_nb
252
+ local_terms: Dict[str, torch.Tensor] = {
253
+ "E_bond": e_bond,
254
+ "E_angle": e_angle,
255
+ "E_dihedral": e_dihed,
256
+ "E_cmap": e_cmap,
257
+ "E_coul": nb_e.coulomb,
258
+ "E_lj": nb_e.lj,
259
+ "E_coul14": nb_e.coulomb14,
260
+ "E_lj14": nb_e.lj14,
261
+ }
262
+
263
+ if mpi_size > 1:
264
+ force_global = _mpi_reduce_tensor(local_force, mpi_obj=mpi_obj, comm=mpi_comm)
265
+ terms_global = {
266
+ k: _mpi_reduce_tensor(v, mpi_obj=mpi_obj, comm=mpi_comm)
267
+ for k, v in local_terms.items()
268
+ }
269
+ else:
270
+ force_global = local_force
271
+ terms_global = local_terms
272
+
273
+ e_nb_total = terms_global["E_coul"] + terms_global["E_lj"] + terms_global["E_coul14"] + terms_global["E_lj14"]
274
+ e_total = terms_global["E_bond"] + terms_global["E_angle"] + terms_global["E_dihedral"] + terms_global["E_cmap"] + e_nb_total
275
+ out = dict(terms_global)
276
+ out["E_total"] = e_total
277
+ out["E_nonbonded_total"] = e_nb_total
278
+ return out, force_global
279
+
280
+
281
+ # ---------------------------------------------------------------------------
282
+ # Core energy/force computation (shared by torch_energy and torch_force)
283
+ # ---------------------------------------------------------------------------
284
+ def _compute_energy_force(
285
+ prmtop: PathLike,
286
+ coords: CoordsLike,
287
+ *,
288
+ device: Union[str, torch.device],
289
+ double: bool,
290
+ force_calc_mode: str,
291
+ with_force: bool,
292
+ num_threads: Optional[int],
293
+ mpi: bool,
294
+ ) -> tuple[Dict[str, Any], Optional[torch.Tensor], Optional[int], int, int]:
295
+ """Shared computation for torch_energy/torch_force.
296
+
297
+ Returns (energy_terms_float, force_or_None, used_threads, mpi_rank, mpi_size).
298
+ """
299
+ used_threads = _configure_torch_threads(num_threads)
300
+ mpi_obj, mpi_comm, mpi_rank, mpi_size = _get_mpi_context(mpi)
301
+
302
+ dev = torch.device(device)
303
+ if mpi_size > 1 and dev.type != "cpu":
304
+ raise ValueError("mpi=True currently supports CPU execution only")
305
+
306
+ force_mode_lower = str(force_calc_mode).strip().lower()
307
+ system, xyz, ff = _load_runtime(
308
+ prmtop=prmtop,
309
+ coords=coords,
310
+ device=device,
311
+ double=double,
312
+ requires_grad=(with_force and force_mode_lower == "autograd"),
313
+ nonbonded_cpu_fast=not (torch.device(device).type == "cpu" and force_mode_lower == "autograd"),
314
+ )
315
+
316
+ if mpi_size > 1:
317
+ out, force_tensor = _dist_energy_force_cpu(
318
+ system=system,
319
+ coords=xyz,
320
+ force_calc_mode=force_calc_mode,
321
+ mpi_obj=mpi_obj,
322
+ mpi_comm=mpi_comm,
323
+ mpi_rank=mpi_rank,
324
+ mpi_size=mpi_size,
325
+ )
326
+ elif with_force:
327
+ out, force_tensor = ff.energy_force(coords=xyz, force_calc_mode=force_calc_mode)
328
+ else:
329
+ out = ff(xyz)
330
+ force_tensor = None
331
+
332
+ energy_dict = _energy_terms_to_float(out)
333
+ return energy_dict, force_tensor, used_threads, mpi_rank, mpi_size
334
+
335
+
336
+ # ---------------------------------------------------------------------------
337
+ # Public API: system_summary
338
+ # ---------------------------------------------------------------------------
339
+ def system_summary(prmtop: PathLike, device: Union[str, torch.device] = "cpu") -> Dict[str, int]:
340
+ """Return basic term/pair counts for a prmtop."""
341
+ system = load_system(prmtop, device=device)
342
+ return {
343
+ "n_atom": int(system.natom),
344
+ "n_bond": int(system.bond_i.numel()),
345
+ "n_angle": int(system.angle_i.numel()),
346
+ "n_dihedral": int(system.dihed_i.numel()),
347
+ "n_cmap": int(system.cmap_type.numel()),
348
+ "n_pair_general": int(system.pair_i.numel()),
349
+ "n_pair_14": int(system.pair14_i.numel()),
350
+ }
351
+
352
+
353
+ # ---------------------------------------------------------------------------
354
+ # Public API: torch_energy
355
+ # ---------------------------------------------------------------------------
356
+ def torch_energy(
357
+ prmtop: PathLike,
358
+ coords: CoordsLike,
359
+ device: Union[str, torch.device] = "cpu",
360
+ with_grad: bool = False,
361
+ double: bool = True,
362
+ force_calc_mode: str = "Analytical",
363
+ num_threads: Optional[int] = None,
364
+ mpi: bool = False,
365
+ ) -> Dict[str, Any]:
366
+ """Evaluate Torch MM energies, with optional gradient norm output."""
367
+ energy_dict, force_tensor, used_threads, mpi_rank, mpi_size = _compute_energy_force(
368
+ prmtop=prmtop,
369
+ coords=coords,
370
+ device=device,
371
+ double=double,
372
+ force_calc_mode=force_calc_mode,
373
+ with_force=with_grad,
374
+ num_threads=num_threads,
375
+ mpi=mpi,
376
+ )
377
+ result = _build_result_meta(
378
+ double=double,
379
+ num_threads=used_threads,
380
+ mpi_rank=mpi_rank,
381
+ mpi_size=mpi_size,
382
+ **energy_dict,
383
+ )
384
+ if with_grad and force_tensor is not None:
385
+ grad = (-force_tensor).detach().cpu()
386
+ result["grad_shape"] = [int(grad.shape[0]), int(grad.shape[1])]
387
+ result["grad_norm"] = float(torch.linalg.norm(grad))
388
+ result["force_calc_mode"] = str(force_calc_mode)
389
+ return result
390
+
391
+
392
+ # ---------------------------------------------------------------------------
393
+ # Public API: torch_force
394
+ # ---------------------------------------------------------------------------
395
+ def torch_force(
396
+ prmtop: PathLike,
397
+ coords: CoordsLike,
398
+ device: Union[str, torch.device] = "cpu",
399
+ double: bool = True,
400
+ force_calc_mode: str = "Analytical",
401
+ save_force: Optional[PathLike] = None,
402
+ num_threads: Optional[int] = None,
403
+ mpi: bool = False,
404
+ ) -> Dict[str, Any]:
405
+ """Compute force (negative gradient) from Torch energy."""
406
+ energy_dict, force_tensor, used_threads, mpi_rank, mpi_size = _compute_energy_force(
407
+ prmtop=prmtop,
408
+ coords=coords,
409
+ device=device,
410
+ double=double,
411
+ force_calc_mode=force_calc_mode,
412
+ with_force=True,
413
+ num_threads=num_threads,
414
+ mpi=mpi,
415
+ )
416
+ force = force_tensor.detach().cpu()
417
+ e_total = energy_dict.get("E_total", energy_dict.get("E_total_kcalmol", 0.0))
418
+
419
+ result = _build_result_meta(
420
+ double=double,
421
+ num_threads=used_threads,
422
+ mpi_rank=mpi_rank,
423
+ mpi_size=mpi_size,
424
+ E_total_kcalmol=float(e_total) if not isinstance(e_total, float) else e_total,
425
+ force_shape=[int(force.shape[0]), int(force.shape[1])],
426
+ force_norm=float(torch.linalg.norm(force)),
427
+ force_maxabs=float(torch.max(torch.abs(force))),
428
+ force_calc_mode=str(force_calc_mode),
429
+ )
430
+ if save_force is not None:
431
+ save_path = Path(save_force)
432
+ torch.save(force, save_path)
433
+ result["force_file"] = str(save_path)
434
+ return result
435
+
436
+
437
+ # ---------------------------------------------------------------------------
438
+ # Core batch computation (shared by torch_energy_batch / torch_force_batch)
439
+ # ---------------------------------------------------------------------------
440
+ def _prepare_batch_ff(
441
+ prmtop: PathLike,
442
+ coords_batch: CoordsBatchLike,
443
+ *,
444
+ device: Union[str, torch.device],
445
+ double: bool,
446
+ force_calc_mode: str,
447
+ with_force: bool,
448
+ ) -> tuple[AmberSystem, torch.Tensor, ForceFieldTorch]:
449
+ dtype = _dtype_from_double(double)
450
+ system = load_system(prmtop, device=device).to(dtype=dtype)
451
+ if isinstance(coords_batch, torch.Tensor):
452
+ xb = coords_batch.to(device=device, dtype=dtype)
453
+ else:
454
+ frames = [load_coords(p, natom=system.natom, device=device, dtype=dtype) for p in coords_batch]
455
+ if not frames:
456
+ raise ValueError("coords_batch is empty")
457
+ xb = torch.stack(frames, dim=0)
458
+ if xb.ndim != 3:
459
+ raise ValueError(f"coords_batch must have shape [B,N,3], got {tuple(xb.shape)}")
460
+ if int(xb.shape[1]) != int(system.natom) or int(xb.shape[2]) != 3:
461
+ raise ValueError(
462
+ f"coords_batch shape mismatch: expected [B,{system.natom},3], got {tuple(xb.shape)}"
463
+ )
464
+ ff = ForceFieldTorch(
465
+ system,
466
+ nonbonded_cpu_fast=not (
467
+ torch.device(device).type == "cpu"
468
+ and with_force
469
+ and str(force_calc_mode).strip().lower() == "autograd"
470
+ ),
471
+ )
472
+ return system, xb, ff
473
+
474
+
475
+ # ---------------------------------------------------------------------------
476
+ # Public API: torch_energy_batch
477
+ # ---------------------------------------------------------------------------
478
+ def torch_energy_batch(
479
+ prmtop: PathLike,
480
+ coords_batch: CoordsBatchLike,
481
+ device: Union[str, torch.device] = "cpu",
482
+ with_grad: bool = False,
483
+ double: bool = True,
484
+ force_calc_mode: str = "Analytical",
485
+ batch_mode: str = "vmap",
486
+ microbatch_size: Optional[int] = None,
487
+ num_threads: Optional[int] = None,
488
+ ) -> Dict[str, Any]:
489
+ """Evaluate energies for batched coordinates [B,N,3]."""
490
+ used_threads = _configure_torch_threads(num_threads)
491
+ system, xb, ff = _prepare_batch_ff(
492
+ prmtop, coords_batch,
493
+ device=device, double=double,
494
+ force_calc_mode=force_calc_mode, with_force=with_grad,
495
+ )
496
+
497
+ t0 = time.perf_counter()
498
+ if with_grad:
499
+ out, force = ff.energy_force_batch(
500
+ xb, force_calc_mode=force_calc_mode,
501
+ batch_mode=batch_mode, microbatch_size=microbatch_size,
502
+ )
503
+ grad = (-force).detach().cpu()
504
+ else:
505
+ out = ff.forward_batch(xb, batch_mode=batch_mode, microbatch_size=microbatch_size)
506
+ elapsed = time.perf_counter() - t0
507
+
508
+ result: Dict[str, Any] = {
509
+ "batch_size": int(xb.shape[0]),
510
+ "n_atom": int(system.natom),
511
+ "batch_mode": str(batch_mode),
512
+ "elapsed_s": float(elapsed),
513
+ "energy_terms_kcalmol": {k: out[k].detach().cpu().tolist() for k in out},
514
+ }
515
+ if with_grad:
516
+ result["grad_shape"] = [int(grad.shape[0]), int(grad.shape[1]), int(grad.shape[2])]
517
+ result["grad_norm"] = float(torch.linalg.norm(grad))
518
+ result["force_calc_mode"] = str(force_calc_mode)
519
+ if microbatch_size is not None:
520
+ result["microbatch_size"] = int(microbatch_size)
521
+ if used_threads is not None:
522
+ result["num_threads"] = int(used_threads)
523
+ result.update(_precision_info(double))
524
+ return result
525
+
526
+
527
+ # ---------------------------------------------------------------------------
528
+ # Public API: torch_force_batch
529
+ # ---------------------------------------------------------------------------
530
+ def torch_force_batch(
531
+ prmtop: PathLike,
532
+ coords_batch: CoordsBatchLike,
533
+ device: Union[str, torch.device] = "cpu",
534
+ double: bool = True,
535
+ force_calc_mode: str = "Analytical",
536
+ batch_mode: str = "vmap",
537
+ microbatch_size: Optional[int] = None,
538
+ save_force: Optional[PathLike] = None,
539
+ num_threads: Optional[int] = None,
540
+ ) -> Dict[str, Any]:
541
+ """Compute batched force tensors for coordinates [B,N,3]."""
542
+ used_threads = _configure_torch_threads(num_threads)
543
+ system, xb, ff = _prepare_batch_ff(
544
+ prmtop, coords_batch,
545
+ device=device, double=double,
546
+ force_calc_mode=force_calc_mode, with_force=True,
547
+ )
548
+
549
+ t0 = time.perf_counter()
550
+ out, force_tensor = ff.energy_force_batch(
551
+ xb, force_calc_mode=force_calc_mode,
552
+ batch_mode=batch_mode, microbatch_size=microbatch_size,
553
+ )
554
+ elapsed = time.perf_counter() - t0
555
+ force = force_tensor.detach().cpu()
556
+
557
+ result: Dict[str, Any] = {
558
+ "batch_size": int(xb.shape[0]),
559
+ "n_atom": int(system.natom),
560
+ "batch_mode": str(batch_mode),
561
+ "elapsed_s": float(elapsed),
562
+ "E_total_kcalmol": out["E_total"].detach().cpu().tolist(),
563
+ "force_shape": [int(force.shape[0]), int(force.shape[1]), int(force.shape[2])],
564
+ "force_norm": float(torch.linalg.norm(force)),
565
+ "force_maxabs": float(torch.max(torch.abs(force))),
566
+ "force_calc_mode": str(force_calc_mode),
567
+ }
568
+ if microbatch_size is not None:
569
+ result["microbatch_size"] = int(microbatch_size)
570
+ if save_force is not None:
571
+ save_path = Path(save_force)
572
+ torch.save(force, save_path)
573
+ result["force_file"] = str(save_path)
574
+ if used_threads is not None:
575
+ result["num_threads"] = int(used_threads)
576
+ result.update(_precision_info(double))
577
+ return result
578
+
579
+
580
+ # ---------------------------------------------------------------------------
581
+ # Public API: torch_hessian
582
+ # ---------------------------------------------------------------------------
583
+ def _normalize_active_atoms(natom: int, active_atoms: Sequence[int]) -> list[int]:
584
+ out: list[int] = []
585
+ seen: set[int] = set()
586
+ for a in active_atoms:
587
+ ia = int(a)
588
+ if ia < 0 or ia >= natom:
589
+ raise ValueError(f"active atom index out of range: {ia} (natom={natom})")
590
+ if ia in seen:
591
+ continue
592
+ seen.add(ia)
593
+ out.append(ia)
594
+ if not out:
595
+ raise ValueError("active atom list is empty")
596
+ return out
597
+
598
+
599
+ def torch_hessian(
600
+ prmtop: PathLike,
601
+ coords: CoordsLike,
602
+ device: Union[str, torch.device] = "cpu",
603
+ double: bool = True,
604
+ hessian_calc_mode: str = "FiniteDifferenceForce",
605
+ force_calc_mode: str = "Analytical",
606
+ hessian_delta: float = 1.0e-4,
607
+ fd_column_batch: int = 16,
608
+ partial_hessian: bool = True,
609
+ active_atoms: Optional[Sequence[int]] = None,
610
+ save_hessian: Optional[PathLike] = None,
611
+ num_threads: Optional[int] = None,
612
+ mpi: bool = False,
613
+ ) -> Dict[str, Any]:
614
+ """Compute Hessian by one of three modes: Analytical, Autograd, or FiniteDifferenceForce."""
615
+ used_threads = _configure_torch_threads(num_threads)
616
+ mpi_obj, mpi_comm, mpi_rank, mpi_size = _get_mpi_context(mpi)
617
+
618
+ dev = torch.device(device)
619
+ if mpi_size > 1 and dev.type != "cpu":
620
+ raise ValueError("mpi=True currently supports CPU execution only")
621
+
622
+ mode = _HESSIAN_MODE_MAP.get(str(hessian_calc_mode).strip().lower())
623
+ if mode is None:
624
+ raise ValueError(f"Unknown hessian_calc_mode: {hessian_calc_mode!r}")
625
+
626
+ force_mode = str(force_calc_mode).strip().lower()
627
+ need_diff_nonbonded = dev.type == "cpu" and (mode == "autograd" or force_mode == "autograd")
628
+
629
+ system, xyz, ff = _load_runtime(
630
+ prmtop=prmtop, coords=coords, device=device,
631
+ double=double, requires_grad=False,
632
+ nonbonded_cpu_fast=not need_diff_nonbonded,
633
+ )
634
+
635
+ natom = int(system.natom)
636
+ if partial_hessian:
637
+ if active_atoms is None:
638
+ raise ValueError("partial_hessian=true requires active_atoms")
639
+ active_list = _normalize_active_atoms(natom, active_atoms)
640
+ else:
641
+ if active_atoms is not None:
642
+ raise ValueError("active_atoms is only valid when partial_hessian=true")
643
+ active_list = list(range(natom))
644
+
645
+ active_idx = torch.tensor(active_list, dtype=torch.int64, device=xyz.device)
646
+ x_active_base = xyz.index_select(0, active_idx).clone().detach()
647
+ ndof = int(x_active_base.numel())
648
+
649
+ n_force_eval: Optional[int] = None
650
+ analytical_meta: Optional[Dict[str, int]] = None
651
+
652
+ if mode == "analytical":
653
+ if force_mode != "analytical":
654
+ raise ValueError("hessian_calc_mode='Analytical' requires force_calc_mode='Analytical'")
655
+ t0 = time.perf_counter()
656
+ h_local, analytical_meta_local = build_analytical_hessian(
657
+ system=system, coords=xyz, active_atoms=active_list,
658
+ mpi_rank=int(mpi_rank), mpi_size=int(mpi_size),
659
+ )
660
+ elapsed_local = time.perf_counter() - t0
661
+ if mpi_size > 1:
662
+ h2 = _mpi_reduce_tensor(h_local, mpi_obj=mpi_obj, comm=mpi_comm)
663
+ elapsed = float(mpi_comm.allreduce(elapsed_local, op=mpi_obj.MAX))
664
+ meta_sum_keys = {
665
+ "bond_pairs_used", "angle_terms_used", "dihedral_terms_used",
666
+ "cmap_terms_used", "nonbond_pairs_used", "nonbond14_pairs_used",
667
+ "analytical_force_evals",
668
+ }
669
+ analytical_meta = {}
670
+ for k, v in analytical_meta_local.items():
671
+ op = mpi_obj.SUM if k in meta_sum_keys else mpi_obj.MAX
672
+ analytical_meta[k] = int(mpi_comm.allreduce(int(v), op=op))
673
+ else:
674
+ h2 = h_local
675
+ elapsed = elapsed_local
676
+ analytical_meta = analytical_meta_local
677
+ h2 = (0.5 * (h2 + h2.T)).detach().cpu()
678
+ n_force_eval = int(analytical_meta.get("analytical_force_evals", 0))
679
+
680
+ elif mode == "autograd":
681
+ if mpi_size > 1:
682
+ raise ValueError("hessian_calc_mode='Autograd' does not support mpi=True")
683
+
684
+ def _energy_fn(x_sub: torch.Tensor) -> torch.Tensor:
685
+ x_full = xyz.index_copy(0, active_idx, x_sub)
686
+ return ff(x_full)["E_total"]
687
+
688
+ x_active = x_active_base.requires_grad_(True)
689
+ t0 = time.perf_counter()
690
+ h4 = torch.autograd.functional.hessian(_energy_fn, x_active, vectorize=False)
691
+ elapsed = time.perf_counter() - t0
692
+ h2 = h4.reshape(ndof, ndof).detach().cpu()
693
+
694
+ else: # mode == "fd"
695
+ delta = float(hessian_delta)
696
+ fd_batch_cols = max(1, int(fd_column_batch))
697
+ t0 = time.perf_counter()
698
+ if mpi_size > 1:
699
+ h2 = torch.zeros((ndof, ndof), dtype=xyz.dtype, device="cpu")
700
+ col_iter = range(mpi_rank, ndof, mpi_size)
701
+ else:
702
+ h2 = torch.empty((ndof, ndof), dtype=xyz.dtype, device="cpu")
703
+ col_iter = range(ndof)
704
+ cols = list(col_iter)
705
+ n_force_eval_local = 0
706
+ for start in range(0, len(cols), fd_batch_cols):
707
+ sub_cols = cols[start : start + fd_batch_cols]
708
+ bsz = len(sub_cols)
709
+ xb = xyz.unsqueeze(0).repeat(2 * bsz, 1, 1).clone()
710
+ for bi, col in enumerate(sub_cols):
711
+ a = col // 3
712
+ c = col % 3
713
+ atom_idx = int(active_list[a])
714
+ xb[2 * bi, atom_idx, c] = xb[2 * bi, atom_idx, c] + delta
715
+ xb[2 * bi + 1, atom_idx, c] = xb[2 * bi + 1, atom_idx, c] - delta
716
+ _, force_b = ff.energy_force_batch(
717
+ xb, force_calc_mode=force_calc_mode, batch_mode="loop",
718
+ microbatch_size=max(1, min(2 * bsz, 64)),
719
+ )
720
+ for bi, col in enumerate(sub_cols):
721
+ fp = force_b[2 * bi].index_select(0, active_idx).reshape(-1)
722
+ fm = force_b[2 * bi + 1].index_select(0, active_idx).reshape(-1)
723
+ h2[:, col] = (-(fp - fm) / (2.0 * delta)).detach().cpu()
724
+ n_force_eval_local += 2 * bsz
725
+ elapsed_local = time.perf_counter() - t0
726
+ if mpi_size > 1:
727
+ h2 = _mpi_reduce_tensor(h2, mpi_obj=mpi_obj, comm=mpi_comm)
728
+ elapsed = float(mpi_comm.allreduce(elapsed_local, op=mpi_obj.MAX))
729
+ n_force_eval = int(mpi_comm.allreduce(n_force_eval_local, op=mpi_obj.SUM))
730
+ else:
731
+ elapsed = elapsed_local
732
+ n_force_eval = n_force_eval_local
733
+
734
+ hasym = torch.max(torch.abs(h2 - h2.T))
735
+
736
+ out: Dict[str, Any] = {
737
+ "hessian_calc_mode": str(hessian_calc_mode),
738
+ "force_calc_mode": str(force_calc_mode),
739
+ "partial_hessian": bool(partial_hessian),
740
+ "n_atom_total": natom,
741
+ "active_atom_count": len(active_list),
742
+ "active_dof": ndof,
743
+ "hessian_shape": [ndof, ndof],
744
+ "hessian_maxabs": float(torch.max(torch.abs(h2))),
745
+ "hessian_max_asym": float(hasym),
746
+ "hessian_elapsed_s": float(elapsed),
747
+ "mpi_enabled": bool(mpi_size > 1),
748
+ "mpi_rank": int(mpi_rank),
749
+ "mpi_size": int(mpi_size),
750
+ }
751
+ if used_threads is not None:
752
+ out["num_threads"] = int(used_threads)
753
+ if mode == "fd":
754
+ out["hessian_delta_A"] = float(hessian_delta)
755
+ out["fd_column_batch"] = int(fd_column_batch)
756
+ out["force_evals"] = int(n_force_eval) if n_force_eval is not None else None
757
+ elif mode == "analytical":
758
+ out["hessian_delta_A"] = float(hessian_delta)
759
+ if n_force_eval is not None:
760
+ out["force_evals"] = int(n_force_eval)
761
+ if analytical_meta is not None:
762
+ for k, v in analytical_meta.items():
763
+ out[k] = int(v)
764
+ if partial_hessian:
765
+ out["active_atoms"] = active_list
766
+
767
+ if save_hessian is not None:
768
+ save_path = Path(save_hessian)
769
+ if mpi_size == 1 or mpi_rank == 0:
770
+ torch.save(
771
+ {
772
+ "hessian": h2,
773
+ "active_atoms": torch.tensor(active_list, dtype=torch.int64),
774
+ "partial_hessian": bool(partial_hessian),
775
+ "dtype": _precision_info(double)["torch_dtype"],
776
+ },
777
+ save_path,
778
+ )
779
+ out["hessian_file"] = str(save_path)
780
+
781
+ out.update(_precision_info(double))
782
+ return out
783
+
784
+
785
+ # ---------------------------------------------------------------------------
786
+ # OpenMM compatibility (lazy import)
787
+ # ---------------------------------------------------------------------------
788
+ def _openmm_positions_from_coords(coords: CoordsLike, natom: int):
789
+ from openmm import app, unit
790
+
791
+ if isinstance(coords, (str, Path)):
792
+ coords_path = Path(coords)
793
+ suffix = coords_path.suffix.lower()
794
+ if suffix in {".rst7", ".inpcrd", ".crd", ".restrt"}:
795
+ return app.AmberInpcrdFile(str(coords_path)).getPositions(asNumpy=True)
796
+ if suffix == ".pdb":
797
+ return app.PDBFile(str(coords_path)).getPositions(asNumpy=True)
798
+ if suffix == ".xyz":
799
+ xyz = load_coords(coords_path, natom=natom, device="cpu", dtype=torch.float64)
800
+ return xyz.detach().cpu().numpy() * unit.angstrom
801
+ raise ValueError(f"Unsupported coords file for OpenMM: {coords_path}")
802
+ xyz = load_coords(coords, natom=natom, device="cpu", dtype=torch.float64)
803
+ return xyz.detach().cpu().numpy() * unit.angstrom
804
+
805
+
806
+ def verify_openmm(
807
+ prmtop: PathLike,
808
+ coords: CoordsLike,
809
+ device: Union[str, torch.device] = "cpu",
810
+ openmm_platform: Optional[str] = None,
811
+ double: bool = True,
812
+ ) -> Dict[str, Any]:
813
+ """Validate Torch energies/gradients against OpenMM."""
814
+ try:
815
+ from openmm import app, openmm, unit
816
+ except Exception as e:
817
+ raise RuntimeError(
818
+ "OpenMM is not installed. Install it (pip install openmm) and retry."
819
+ ) from e
820
+
821
+ prmtop_path = Path(prmtop)
822
+ _, xyz, ff = _load_runtime(
823
+ prmtop=prmtop_path, coords=coords, device=device,
824
+ double=double, requires_grad=True,
825
+ nonbonded_cpu_fast=torch.device(device).type != "cpu",
826
+ )
827
+ e_torch = ff(xyz)
828
+ e_total = e_torch["E_total"]
829
+ e_total.backward()
830
+ grad_torch = xyz.grad.detach().cpu().to(torch.float64)
831
+
832
+ prmtop_omm = app.AmberPrmtopFile(str(prmtop_path))
833
+ omm_system = prmtop_omm.createSystem(
834
+ nonbondedMethod=app.NoCutoff, constraints=None, rigidWater=False,
835
+ )
836
+ for gi, force in enumerate(omm_system.getForces()):
837
+ force.setForceGroup(gi)
838
+
839
+ integrator = openmm.VerletIntegrator(1.0 * unit.femtoseconds)
840
+ preferred = openmm_platform or "CPU"
841
+ if str(preferred).strip().upper() != "CPU":
842
+ raise ValueError(
843
+ "Only OpenMM CPU platform is allowed in this workflow "
844
+ f"(got openmm_platform={openmm_platform!r})."
845
+ )
846
+ names = [openmm.Platform.getPlatform(i).getName() for i in range(openmm.Platform.getNumPlatforms())]
847
+ if "CPU" not in names:
848
+ raise RuntimeError(f"OpenMM CPU platform is not available. Platforms: {names}")
849
+ selected_platform = "CPU"
850
+ platform = openmm.Platform.getPlatformByName(selected_platform)
851
+ context = openmm.Context(omm_system, integrator, platform)
852
+
853
+ natom = int(omm_system.getNumParticles())
854
+ context.setPositions(_openmm_positions_from_coords(coords, natom=natom))
855
+
856
+ state = context.getState(getEnergy=True, getForces=True)
857
+ e_openmm = state.getPotentialEnergy().value_in_unit(unit.kilocalories_per_mole)
858
+ f_openmm = state.getForces(asNumpy=True).value_in_unit(
859
+ unit.kilocalories_per_mole / unit.angstrom
860
+ )
861
+ grad_openmm = -torch.tensor(f_openmm, dtype=torch.float64)
862
+
863
+ omm_terms: Dict[str, float] = {}
864
+ for gi, force in enumerate(omm_system.getForces()):
865
+ st = context.getState(getEnergy=True, groups=1 << gi)
866
+ omm_terms[type(force).__name__] = float(
867
+ st.getPotentialEnergy().value_in_unit(unit.kilocalories_per_mole)
868
+ )
869
+
870
+ e_t = float(e_total.detach().cpu())
871
+ gdiff = grad_torch - grad_openmm
872
+
873
+ out: Dict[str, Any] = {
874
+ "E_total_torch_kcalmol": e_t,
875
+ "E_total_openmm_kcalmol": float(e_openmm),
876
+ "E_total_diff_kcalmol": float(e_t - float(e_openmm)),
877
+ "grad_rms_kcalmolA": float(torch.sqrt(torch.mean(gdiff**2))),
878
+ "grad_maxabs_kcalmolA": float(torch.max(torch.abs(gdiff))),
879
+ "torch_terms_kcalmol": _energy_terms_to_float(e_torch),
880
+ "openmm_force_terms_kcalmol": omm_terms,
881
+ "notes": [
882
+ "OpenMM NonbondedForce energy includes 1-4 exceptions; Torch prints E_coul14/E_lj14 separately.",
883
+ "Agreement depends on matching coordinate units (Angstrom) and using NoCutoff + constraints=None + rigidWater=False.",
884
+ ],
885
+ }
886
+ out["openmm_platform"] = platform.getName()
887
+
888
+ out.update(_precision_info(double))
889
+ return out