mlmm-toolkit 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (372) hide show
  1. hessian_ff/__init__.py +50 -0
  2. hessian_ff/analytical_hessian.py +609 -0
  3. hessian_ff/constants.py +46 -0
  4. hessian_ff/forcefield.py +339 -0
  5. hessian_ff/loaders.py +608 -0
  6. hessian_ff/native/Makefile +8 -0
  7. hessian_ff/native/__init__.py +28 -0
  8. hessian_ff/native/analytical_hessian.py +88 -0
  9. hessian_ff/native/analytical_hessian_ext.cpp +258 -0
  10. hessian_ff/native/bonded.py +82 -0
  11. hessian_ff/native/bonded_ext.cpp +640 -0
  12. hessian_ff/native/loader.py +349 -0
  13. hessian_ff/native/nonbonded.py +118 -0
  14. hessian_ff/native/nonbonded_ext.cpp +1150 -0
  15. hessian_ff/prmtop_parmed.py +23 -0
  16. hessian_ff/system.py +107 -0
  17. hessian_ff/terms/__init__.py +14 -0
  18. hessian_ff/terms/angle.py +73 -0
  19. hessian_ff/terms/bond.py +44 -0
  20. hessian_ff/terms/cmap.py +406 -0
  21. hessian_ff/terms/dihedral.py +141 -0
  22. hessian_ff/terms/nonbonded.py +209 -0
  23. hessian_ff/tests/__init__.py +0 -0
  24. hessian_ff/tests/conftest.py +75 -0
  25. hessian_ff/tests/data/small/complex.parm7 +1346 -0
  26. hessian_ff/tests/data/small/complex.pdb +125 -0
  27. hessian_ff/tests/data/small/complex.rst7 +63 -0
  28. hessian_ff/tests/test_coords_input.py +44 -0
  29. hessian_ff/tests/test_energy_force.py +49 -0
  30. hessian_ff/tests/test_hessian.py +137 -0
  31. hessian_ff/tests/test_smoke.py +18 -0
  32. hessian_ff/tests/test_validation.py +40 -0
  33. hessian_ff/workflows.py +889 -0
  34. mlmm/__init__.py +36 -0
  35. mlmm/__main__.py +7 -0
  36. mlmm/_version.py +34 -0
  37. mlmm/add_elem_info.py +374 -0
  38. mlmm/advanced_help.py +91 -0
  39. mlmm/align_freeze_atoms.py +601 -0
  40. mlmm/all.py +3535 -0
  41. mlmm/bond_changes.py +231 -0
  42. mlmm/bool_compat.py +223 -0
  43. mlmm/cli.py +574 -0
  44. mlmm/cli_utils.py +166 -0
  45. mlmm/default_group.py +337 -0
  46. mlmm/defaults.py +467 -0
  47. mlmm/define_layer.py +526 -0
  48. mlmm/dft.py +1041 -0
  49. mlmm/energy_diagram.py +253 -0
  50. mlmm/extract.py +2213 -0
  51. mlmm/fix_altloc.py +464 -0
  52. mlmm/freq.py +1406 -0
  53. mlmm/harmonic_constraints.py +140 -0
  54. mlmm/hessian_cache.py +44 -0
  55. mlmm/hessian_calc.py +174 -0
  56. mlmm/irc.py +638 -0
  57. mlmm/mlmm_calc.py +2262 -0
  58. mlmm/mm_parm.py +945 -0
  59. mlmm/oniom_export.py +1983 -0
  60. mlmm/oniom_import.py +457 -0
  61. mlmm/opt.py +1742 -0
  62. mlmm/path_opt.py +1353 -0
  63. mlmm/path_search.py +2299 -0
  64. mlmm/preflight.py +88 -0
  65. mlmm/py.typed +1 -0
  66. mlmm/pysis_runner.py +45 -0
  67. mlmm/scan.py +1047 -0
  68. mlmm/scan2d.py +1226 -0
  69. mlmm/scan3d.py +1265 -0
  70. mlmm/scan_common.py +184 -0
  71. mlmm/summary_log.py +736 -0
  72. mlmm/trj2fig.py +448 -0
  73. mlmm/tsopt.py +2871 -0
  74. mlmm/utils.py +2309 -0
  75. mlmm/xtb_embedcharge_correction.py +475 -0
  76. mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
  77. mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
  78. mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
  79. mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
  80. mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
  81. mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
  82. pysisyphus/Geometry.py +1667 -0
  83. pysisyphus/LICENSE +674 -0
  84. pysisyphus/TableFormatter.py +63 -0
  85. pysisyphus/TablePrinter.py +74 -0
  86. pysisyphus/__init__.py +12 -0
  87. pysisyphus/calculators/AFIR.py +452 -0
  88. pysisyphus/calculators/AnaPot.py +20 -0
  89. pysisyphus/calculators/AnaPot2.py +48 -0
  90. pysisyphus/calculators/AnaPot3.py +12 -0
  91. pysisyphus/calculators/AnaPot4.py +20 -0
  92. pysisyphus/calculators/AnaPotBase.py +337 -0
  93. pysisyphus/calculators/AnaPotCBM.py +25 -0
  94. pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
  95. pysisyphus/calculators/CFOUR.py +250 -0
  96. pysisyphus/calculators/Calculator.py +844 -0
  97. pysisyphus/calculators/CerjanMiller.py +24 -0
  98. pysisyphus/calculators/Composite.py +123 -0
  99. pysisyphus/calculators/ConicalIntersection.py +171 -0
  100. pysisyphus/calculators/DFTBp.py +430 -0
  101. pysisyphus/calculators/DFTD3.py +66 -0
  102. pysisyphus/calculators/DFTD4.py +84 -0
  103. pysisyphus/calculators/Dalton.py +61 -0
  104. pysisyphus/calculators/Dimer.py +681 -0
  105. pysisyphus/calculators/Dummy.py +20 -0
  106. pysisyphus/calculators/EGO.py +76 -0
  107. pysisyphus/calculators/EnergyMin.py +224 -0
  108. pysisyphus/calculators/ExternalPotential.py +264 -0
  109. pysisyphus/calculators/FakeASE.py +35 -0
  110. pysisyphus/calculators/FourWellAnaPot.py +28 -0
  111. pysisyphus/calculators/FreeEndNEBPot.py +39 -0
  112. pysisyphus/calculators/Gaussian09.py +18 -0
  113. pysisyphus/calculators/Gaussian16.py +726 -0
  114. pysisyphus/calculators/HardSphere.py +159 -0
  115. pysisyphus/calculators/IDPPCalculator.py +49 -0
  116. pysisyphus/calculators/IPIClient.py +133 -0
  117. pysisyphus/calculators/IPIServer.py +234 -0
  118. pysisyphus/calculators/LEPSBase.py +24 -0
  119. pysisyphus/calculators/LEPSExpr.py +139 -0
  120. pysisyphus/calculators/LennardJones.py +80 -0
  121. pysisyphus/calculators/MOPAC.py +219 -0
  122. pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
  123. pysisyphus/calculators/MultiCalc.py +85 -0
  124. pysisyphus/calculators/NFK.py +45 -0
  125. pysisyphus/calculators/OBabel.py +87 -0
  126. pysisyphus/calculators/ONIOMv2.py +1129 -0
  127. pysisyphus/calculators/ORCA.py +893 -0
  128. pysisyphus/calculators/ORCA5.py +6 -0
  129. pysisyphus/calculators/OpenMM.py +88 -0
  130. pysisyphus/calculators/OpenMolcas.py +281 -0
  131. pysisyphus/calculators/OverlapCalculator.py +908 -0
  132. pysisyphus/calculators/Psi4.py +218 -0
  133. pysisyphus/calculators/PyPsi4.py +37 -0
  134. pysisyphus/calculators/PySCF.py +341 -0
  135. pysisyphus/calculators/PyXTB.py +73 -0
  136. pysisyphus/calculators/QCEngine.py +106 -0
  137. pysisyphus/calculators/Rastrigin.py +22 -0
  138. pysisyphus/calculators/Remote.py +76 -0
  139. pysisyphus/calculators/Rosenbrock.py +15 -0
  140. pysisyphus/calculators/SocketCalc.py +97 -0
  141. pysisyphus/calculators/TIP3P.py +111 -0
  142. pysisyphus/calculators/TransTorque.py +161 -0
  143. pysisyphus/calculators/Turbomole.py +965 -0
  144. pysisyphus/calculators/VRIPot.py +37 -0
  145. pysisyphus/calculators/WFOWrapper.py +333 -0
  146. pysisyphus/calculators/WFOWrapper2.py +341 -0
  147. pysisyphus/calculators/XTB.py +418 -0
  148. pysisyphus/calculators/__init__.py +81 -0
  149. pysisyphus/calculators/cosmo_data.py +139 -0
  150. pysisyphus/calculators/parser.py +150 -0
  151. pysisyphus/color.py +19 -0
  152. pysisyphus/config.py +133 -0
  153. pysisyphus/constants.py +65 -0
  154. pysisyphus/cos/AdaptiveNEB.py +230 -0
  155. pysisyphus/cos/ChainOfStates.py +725 -0
  156. pysisyphus/cos/FreeEndNEB.py +25 -0
  157. pysisyphus/cos/FreezingString.py +103 -0
  158. pysisyphus/cos/GrowingChainOfStates.py +71 -0
  159. pysisyphus/cos/GrowingNT.py +309 -0
  160. pysisyphus/cos/GrowingString.py +508 -0
  161. pysisyphus/cos/NEB.py +189 -0
  162. pysisyphus/cos/SimpleZTS.py +64 -0
  163. pysisyphus/cos/__init__.py +22 -0
  164. pysisyphus/cos/stiffness.py +199 -0
  165. pysisyphus/drivers/__init__.py +17 -0
  166. pysisyphus/drivers/afir.py +855 -0
  167. pysisyphus/drivers/barriers.py +271 -0
  168. pysisyphus/drivers/birkholz.py +138 -0
  169. pysisyphus/drivers/cluster.py +318 -0
  170. pysisyphus/drivers/diabatization.py +133 -0
  171. pysisyphus/drivers/merge.py +368 -0
  172. pysisyphus/drivers/merge_mol2.py +322 -0
  173. pysisyphus/drivers/opt.py +375 -0
  174. pysisyphus/drivers/perf.py +91 -0
  175. pysisyphus/drivers/pka.py +52 -0
  176. pysisyphus/drivers/precon_pos_rot.py +669 -0
  177. pysisyphus/drivers/rates.py +480 -0
  178. pysisyphus/drivers/replace.py +219 -0
  179. pysisyphus/drivers/scan.py +212 -0
  180. pysisyphus/drivers/spectrum.py +166 -0
  181. pysisyphus/drivers/thermo.py +31 -0
  182. pysisyphus/dynamics/Gaussian.py +103 -0
  183. pysisyphus/dynamics/__init__.py +20 -0
  184. pysisyphus/dynamics/colvars.py +136 -0
  185. pysisyphus/dynamics/driver.py +297 -0
  186. pysisyphus/dynamics/helpers.py +256 -0
  187. pysisyphus/dynamics/lincs.py +105 -0
  188. pysisyphus/dynamics/mdp.py +364 -0
  189. pysisyphus/dynamics/rattle.py +121 -0
  190. pysisyphus/dynamics/thermostats.py +128 -0
  191. pysisyphus/dynamics/wigner.py +266 -0
  192. pysisyphus/elem_data.py +3473 -0
  193. pysisyphus/exceptions.py +2 -0
  194. pysisyphus/filtertrj.py +69 -0
  195. pysisyphus/helpers.py +623 -0
  196. pysisyphus/helpers_pure.py +649 -0
  197. pysisyphus/init_logging.py +50 -0
  198. pysisyphus/intcoords/Bend.py +69 -0
  199. pysisyphus/intcoords/Bend2.py +25 -0
  200. pysisyphus/intcoords/BondedFragment.py +32 -0
  201. pysisyphus/intcoords/Cartesian.py +41 -0
  202. pysisyphus/intcoords/CartesianCoords.py +140 -0
  203. pysisyphus/intcoords/Coords.py +56 -0
  204. pysisyphus/intcoords/DLC.py +197 -0
  205. pysisyphus/intcoords/DistanceFunction.py +34 -0
  206. pysisyphus/intcoords/DummyImproper.py +70 -0
  207. pysisyphus/intcoords/DummyTorsion.py +72 -0
  208. pysisyphus/intcoords/LinearBend.py +105 -0
  209. pysisyphus/intcoords/LinearDisplacement.py +80 -0
  210. pysisyphus/intcoords/OutOfPlane.py +59 -0
  211. pysisyphus/intcoords/PrimTypes.py +286 -0
  212. pysisyphus/intcoords/Primitive.py +137 -0
  213. pysisyphus/intcoords/RedundantCoords.py +659 -0
  214. pysisyphus/intcoords/RobustTorsion.py +59 -0
  215. pysisyphus/intcoords/Rotation.py +147 -0
  216. pysisyphus/intcoords/Stretch.py +31 -0
  217. pysisyphus/intcoords/Torsion.py +101 -0
  218. pysisyphus/intcoords/Torsion2.py +25 -0
  219. pysisyphus/intcoords/Translation.py +45 -0
  220. pysisyphus/intcoords/__init__.py +61 -0
  221. pysisyphus/intcoords/augment_bonds.py +126 -0
  222. pysisyphus/intcoords/derivatives.py +10512 -0
  223. pysisyphus/intcoords/eval.py +80 -0
  224. pysisyphus/intcoords/exceptions.py +37 -0
  225. pysisyphus/intcoords/findiffs.py +48 -0
  226. pysisyphus/intcoords/generate_derivatives.py +414 -0
  227. pysisyphus/intcoords/helpers.py +235 -0
  228. pysisyphus/intcoords/logging_conf.py +10 -0
  229. pysisyphus/intcoords/mp_derivatives.py +10836 -0
  230. pysisyphus/intcoords/setup.py +962 -0
  231. pysisyphus/intcoords/setup_fast.py +176 -0
  232. pysisyphus/intcoords/update.py +272 -0
  233. pysisyphus/intcoords/valid.py +89 -0
  234. pysisyphus/interpolate/Geodesic.py +93 -0
  235. pysisyphus/interpolate/IDPP.py +55 -0
  236. pysisyphus/interpolate/Interpolator.py +116 -0
  237. pysisyphus/interpolate/LST.py +70 -0
  238. pysisyphus/interpolate/Redund.py +152 -0
  239. pysisyphus/interpolate/__init__.py +9 -0
  240. pysisyphus/interpolate/helpers.py +34 -0
  241. pysisyphus/io/__init__.py +22 -0
  242. pysisyphus/io/aomix.py +178 -0
  243. pysisyphus/io/cjson.py +24 -0
  244. pysisyphus/io/crd.py +101 -0
  245. pysisyphus/io/cube.py +220 -0
  246. pysisyphus/io/fchk.py +184 -0
  247. pysisyphus/io/hdf5.py +49 -0
  248. pysisyphus/io/hessian.py +72 -0
  249. pysisyphus/io/mol2.py +146 -0
  250. pysisyphus/io/molden.py +293 -0
  251. pysisyphus/io/orca.py +189 -0
  252. pysisyphus/io/pdb.py +269 -0
  253. pysisyphus/io/psf.py +79 -0
  254. pysisyphus/io/pubchem.py +31 -0
  255. pysisyphus/io/qcschema.py +34 -0
  256. pysisyphus/io/sdf.py +29 -0
  257. pysisyphus/io/xyz.py +61 -0
  258. pysisyphus/io/zmat.py +175 -0
  259. pysisyphus/irc/DWI.py +108 -0
  260. pysisyphus/irc/DampedVelocityVerlet.py +134 -0
  261. pysisyphus/irc/Euler.py +22 -0
  262. pysisyphus/irc/EulerPC.py +345 -0
  263. pysisyphus/irc/GonzalezSchlegel.py +187 -0
  264. pysisyphus/irc/IMKMod.py +164 -0
  265. pysisyphus/irc/IRC.py +878 -0
  266. pysisyphus/irc/IRCDummy.py +10 -0
  267. pysisyphus/irc/Instanton.py +307 -0
  268. pysisyphus/irc/LQA.py +53 -0
  269. pysisyphus/irc/ModeKill.py +136 -0
  270. pysisyphus/irc/ParamPlot.py +53 -0
  271. pysisyphus/irc/RK4.py +36 -0
  272. pysisyphus/irc/__init__.py +31 -0
  273. pysisyphus/irc/initial_displ.py +219 -0
  274. pysisyphus/linalg.py +411 -0
  275. pysisyphus/line_searches/Backtracking.py +88 -0
  276. pysisyphus/line_searches/HagerZhang.py +184 -0
  277. pysisyphus/line_searches/LineSearch.py +232 -0
  278. pysisyphus/line_searches/StrongWolfe.py +108 -0
  279. pysisyphus/line_searches/__init__.py +9 -0
  280. pysisyphus/line_searches/interpol.py +15 -0
  281. pysisyphus/modefollow/NormalMode.py +40 -0
  282. pysisyphus/modefollow/__init__.py +10 -0
  283. pysisyphus/modefollow/davidson.py +199 -0
  284. pysisyphus/modefollow/lanczos.py +95 -0
  285. pysisyphus/optimizers/BFGS.py +99 -0
  286. pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
  287. pysisyphus/optimizers/ConjugateGradient.py +98 -0
  288. pysisyphus/optimizers/CubicNewton.py +75 -0
  289. pysisyphus/optimizers/FIRE.py +113 -0
  290. pysisyphus/optimizers/HessianOptimizer.py +1176 -0
  291. pysisyphus/optimizers/LBFGS.py +228 -0
  292. pysisyphus/optimizers/LayerOpt.py +411 -0
  293. pysisyphus/optimizers/MicroOptimizer.py +169 -0
  294. pysisyphus/optimizers/NCOptimizer.py +90 -0
  295. pysisyphus/optimizers/Optimizer.py +1084 -0
  296. pysisyphus/optimizers/PreconLBFGS.py +260 -0
  297. pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
  298. pysisyphus/optimizers/QuickMin.py +74 -0
  299. pysisyphus/optimizers/RFOptimizer.py +181 -0
  300. pysisyphus/optimizers/RSA.py +99 -0
  301. pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
  302. pysisyphus/optimizers/SteepestDescent.py +23 -0
  303. pysisyphus/optimizers/StringOptimizer.py +173 -0
  304. pysisyphus/optimizers/__init__.py +41 -0
  305. pysisyphus/optimizers/closures.py +301 -0
  306. pysisyphus/optimizers/cls_map.py +58 -0
  307. pysisyphus/optimizers/exceptions.py +6 -0
  308. pysisyphus/optimizers/gdiis.py +280 -0
  309. pysisyphus/optimizers/guess_hessians.py +311 -0
  310. pysisyphus/optimizers/hessian_updates.py +355 -0
  311. pysisyphus/optimizers/poly_fit.py +285 -0
  312. pysisyphus/optimizers/precon.py +153 -0
  313. pysisyphus/optimizers/restrict_step.py +24 -0
  314. pysisyphus/pack.py +172 -0
  315. pysisyphus/peakdetect.py +948 -0
  316. pysisyphus/plot.py +1031 -0
  317. pysisyphus/run.py +2106 -0
  318. pysisyphus/socket_helper.py +74 -0
  319. pysisyphus/stocastic/FragmentKick.py +132 -0
  320. pysisyphus/stocastic/Kick.py +81 -0
  321. pysisyphus/stocastic/Pipeline.py +303 -0
  322. pysisyphus/stocastic/__init__.py +21 -0
  323. pysisyphus/stocastic/align.py +127 -0
  324. pysisyphus/testing.py +96 -0
  325. pysisyphus/thermo.py +156 -0
  326. pysisyphus/trj.py +824 -0
  327. pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
  328. pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
  329. pysisyphus/tsoptimizers/TRIM.py +59 -0
  330. pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
  331. pysisyphus/tsoptimizers/__init__.py +23 -0
  332. pysisyphus/wavefunction/Basis.py +239 -0
  333. pysisyphus/wavefunction/DIIS.py +76 -0
  334. pysisyphus/wavefunction/__init__.py +25 -0
  335. pysisyphus/wavefunction/build_ext.py +42 -0
  336. pysisyphus/wavefunction/cart2sph.py +190 -0
  337. pysisyphus/wavefunction/diabatization.py +304 -0
  338. pysisyphus/wavefunction/excited_states.py +435 -0
  339. pysisyphus/wavefunction/gen_ints.py +1811 -0
  340. pysisyphus/wavefunction/helpers.py +104 -0
  341. pysisyphus/wavefunction/ints/__init__.py +0 -0
  342. pysisyphus/wavefunction/ints/boys.py +193 -0
  343. pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
  344. pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
  345. pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
  346. pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
  347. pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
  348. pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
  349. pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
  350. pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
  351. pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
  352. pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
  353. pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
  354. pysisyphus/wavefunction/localization.py +458 -0
  355. pysisyphus/wavefunction/multipole.py +159 -0
  356. pysisyphus/wavefunction/normalization.py +36 -0
  357. pysisyphus/wavefunction/pop_analysis.py +134 -0
  358. pysisyphus/wavefunction/shells.py +1171 -0
  359. pysisyphus/wavefunction/wavefunction.py +504 -0
  360. pysisyphus/wrapper/__init__.py +11 -0
  361. pysisyphus/wrapper/exceptions.py +2 -0
  362. pysisyphus/wrapper/jmol.py +120 -0
  363. pysisyphus/wrapper/mwfn.py +169 -0
  364. pysisyphus/wrapper/packmol.py +71 -0
  365. pysisyphus/xyzloader.py +168 -0
  366. pysisyphus/yaml_mods.py +45 -0
  367. thermoanalysis/LICENSE +674 -0
  368. thermoanalysis/QCData.py +244 -0
  369. thermoanalysis/__init__.py +0 -0
  370. thermoanalysis/config.py +3 -0
  371. thermoanalysis/constants.py +20 -0
  372. thermoanalysis/thermo.py +1011 -0
@@ -0,0 +1,23 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Dict, List, Union
5
+
6
+
7
+ def read_prmtop_with_parmed(prmtop_path: Union[str, Path]) -> Dict[str, List[Any]]:
8
+ """Read a prmtop using ParmEd.
9
+
10
+ Returns a dict mapping prmtop %FLAG names to the underlying raw arrays.
11
+ This package requires ParmEd and always uses this route.
12
+ """
13
+ try:
14
+ from parmed.amber import AmberParm # type: ignore
15
+ except Exception as e: # pragma: no cover
16
+ raise ImportError(
17
+ "ParmEd is required but not installed. Install it (pip install parmed)."
18
+ ) from e
19
+
20
+ prmtop_path = Path(prmtop_path)
21
+ parm = AmberParm(str(prmtop_path))
22
+ # parm.parm_data is a dict: flag -> numpy array/list
23
+ return {k: list(v) for k, v in parm.parm_data.items()}
hessian_ff/system.py ADDED
@@ -0,0 +1,107 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class AmberSystem:
12
+ """Tensorized representation of an AMBER prmtop system.
13
+
14
+ Floating tensors are loaded as torch.float64 by default and can be cast to
15
+ torch.float32 when needed.
16
+ Coordinates are expected in Angstrom.
17
+
18
+ Notes on charge units:
19
+ - This package reads prmtop via ParmEd, which returns atomic charges in units
20
+ of elementary charge (i.e., already de-scaled from raw %FLAG CHARGE).
21
+ - Coulomb energy in kcal/mol is computed as:
22
+ E = 332.0637132991921 * sum_{i<j} q_i q_j / r_ij
23
+ where r_ij is in Angstrom.
24
+ """
25
+
26
+ # ---- atom-level ----
27
+ natom: int
28
+ charge: torch.Tensor # [N] atomic charges in e
29
+ atom_type: torch.Tensor # [N] Lennard-Jones type index (0-based int64)
30
+
31
+ # ---- bonded terms ----
32
+ bond_i: torch.Tensor # [Nb]
33
+ bond_j: torch.Tensor # [Nb]
34
+ bond_k: torch.Tensor # [Nb] force constant
35
+ bond_r0: torch.Tensor # [Nb] equilibrium distance
36
+
37
+ angle_i: torch.Tensor # [Na]
38
+ angle_j: torch.Tensor # [Na]
39
+ angle_k: torch.Tensor # [Na]
40
+ angle_k0: torch.Tensor # [Na] force constant
41
+ angle_t0: torch.Tensor # [Na] equilibrium angle (radians)
42
+
43
+ dihed_i: torch.Tensor # [Nd]
44
+ dihed_j: torch.Tensor # [Nd]
45
+ dihed_k: torch.Tensor # [Nd]
46
+ dihed_l: torch.Tensor # [Nd]
47
+ dihed_force: torch.Tensor # [Nd]
48
+ dihed_period: torch.Tensor # [Nd]
49
+ dihed_phase: torch.Tensor # [Nd] (radians)
50
+
51
+ # ---- nonbonded parameter tables ----
52
+ lj_acoef: torch.Tensor # [Nlj] A in A/r^12
53
+ lj_bcoef: torch.Tensor # [Nlj] B in B/r^6
54
+ nb_index: torch.Tensor # [ntypes, ntypes] raw NONBONDED_PARM_INDEX (Fortran-style signed int)
55
+ hb_acoef: torch.Tensor # [Nhb] HBOND A in A/r^12 (used when nb_index<0)
56
+ hb_bcoef: torch.Tensor # [Nhb] HBOND B in B/r^10 (used when nb_index<0)
57
+
58
+ # ---- pair lists (precomputed for no-PBC O(N^2) evaluation) ----
59
+ pair_i: torch.Tensor # [Np] nonbonded general pairs (i<j), excludes exclusions and 1-4
60
+ pair_j: torch.Tensor # [Np]
61
+
62
+ pair14_i: torch.Tensor # [N14] 1-4 pairs (i<j)
63
+ pair14_j: torch.Tensor # [N14]
64
+ pair14_inv_scee: torch.Tensor # [N14] multiply Coulomb by this
65
+ pair14_inv_scnb: torch.Tensor # [N14] multiply LJ by this
66
+
67
+ # ---- CMAP term (optional; CHARMM-style correction map) ----
68
+ cmap_type: torch.Tensor # [Ncmap] map type index (0-based)
69
+ cmap_i: torch.Tensor # [Ncmap] first torsion atom i
70
+ cmap_j: torch.Tensor # [Ncmap] first torsion atom j
71
+ cmap_k: torch.Tensor # [Ncmap] first torsion atom k
72
+ cmap_l: torch.Tensor # [Ncmap] first torsion atom l
73
+ cmap_m: torch.Tensor # [Ncmap] second torsion terminal atom m
74
+ cmap_resolution: Tuple[int, ...] # one resolution per map type
75
+ cmap_maps: Tuple[torch.Tensor, ...] # flattened map values (kcal/mol) in OpenMM ordering
76
+ cmap_size: torch.Tensor # [Nmap] grid size per map type
77
+ cmap_delta: torch.Tensor # [Nmap] angular grid width (radian)
78
+ cmap_offset: torch.Tensor # [Nmap] flattened patch row offset
79
+ cmap_coeff: torch.Tensor # [sum(size*size),16] bicubic coefficients
80
+
81
+ def to(
82
+ self,
83
+ device: torch.device | str | None = None,
84
+ dtype: Optional[torch.dtype] = None,
85
+ ) -> "AmberSystem":
86
+ """Return a copy of this system moved to a device and/or floating dtype."""
87
+
88
+ target_device = torch.device(device) if device is not None else None
89
+
90
+ def mv(x: torch.Tensor) -> torch.Tensor:
91
+ kw = {}
92
+ if target_device is not None:
93
+ kw["device"] = target_device
94
+ if dtype is not None and x.is_floating_point():
95
+ kw["dtype"] = dtype
96
+ return x.to(**kw) if kw else x
97
+
98
+ kwargs = {}
99
+ for f in dataclasses.fields(self):
100
+ val = getattr(self, f.name)
101
+ if isinstance(val, torch.Tensor):
102
+ kwargs[f.name] = mv(val)
103
+ elif isinstance(val, tuple) and val and isinstance(val[0], torch.Tensor):
104
+ kwargs[f.name] = tuple(mv(x) for x in val)
105
+ else:
106
+ kwargs[f.name] = val
107
+ return AmberSystem(**kwargs)
@@ -0,0 +1,14 @@
1
+ from .bond import BondTerm
2
+ from .angle import AngleTerm
3
+ from .dihedral import DihedralTerm
4
+ from .cmap import CMapTerm
5
+ from .nonbonded import NonbondedTerm, NonbondedEnergies
6
+
7
+ __all__ = [
8
+ "BondTerm",
9
+ "AngleTerm",
10
+ "DihedralTerm",
11
+ "CMapTerm",
12
+ "NonbondedTerm",
13
+ "NonbondedEnergies",
14
+ ]
@@ -0,0 +1,73 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class AngleTerm(nn.Module):
8
+ """AMBER angle term: E = sum k (theta - theta0)^2
9
+
10
+ theta is computed via atan2(|u x v|, u·v) for numerical stability.
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ i: torch.Tensor,
16
+ j: torch.Tensor,
17
+ k: torch.Tensor,
18
+ k_theta: torch.Tensor,
19
+ theta0: torch.Tensor,
20
+ ):
21
+ super().__init__()
22
+ self.register_buffer("i", i.long())
23
+ self.register_buffer("j", j.long())
24
+ self.register_buffer("k", k.long())
25
+ self.register_buffer("k_theta", k_theta)
26
+ self.register_buffer("theta0", theta0)
27
+
28
+ def forward(self, coords: torch.Tensor) -> torch.Tensor:
29
+ if self.i.numel() == 0:
30
+ return coords.new_zeros(())
31
+ u = coords[self.i] - coords[self.j]
32
+ v = coords[self.k] - coords[self.j]
33
+ dot = torch.sum(u * v, dim=-1)
34
+ cross = torch.linalg.norm(torch.cross(u, v, dim=-1), dim=-1)
35
+ theta = torch.atan2(cross, dot)
36
+ return torch.sum(self.k_theta * (theta - self.theta0) ** 2)
37
+
38
+ def energy_force(self, coords: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
39
+ """Return angle energy and analytical force."""
40
+ force = torch.zeros_like(coords)
41
+ if self.i.numel() == 0:
42
+ return coords.new_zeros(()), force
43
+
44
+ # OpenMM-style geometry for numerical stability.
45
+ d0 = coords[self.j] - coords[self.i]
46
+ d1 = coords[self.j] - coords[self.k]
47
+ p = torch.cross(d0, d1, dim=-1)
48
+
49
+ r20 = torch.sum(d0 * d0, dim=-1).clamp_min(1.0e-24)
50
+ r21 = torch.sum(d1 * d1, dim=-1).clamp_min(1.0e-24)
51
+ rp = torch.linalg.norm(p, dim=-1).clamp_min(1.0e-12)
52
+ dot = torch.sum(d0 * d1, dim=-1)
53
+ cos_theta = dot / torch.sqrt(r20 * r21)
54
+ cos_theta = torch.clamp(cos_theta, -1.0, 1.0)
55
+ theta = torch.acos(cos_theta)
56
+
57
+ dtheta = theta - self.theta0
58
+ e = torch.sum(self.k_theta * dtheta * dtheta)
59
+
60
+ # E = k (theta-theta0)^2, dE/dtheta = 2k(theta-theta0)
61
+ dE_dtheta = 2.0 * self.k_theta * dtheta
62
+
63
+ term_i = dE_dtheta / (r20 * rp)
64
+ term_k = -dE_dtheta / (r21 * rp)
65
+
66
+ fi = torch.cross(d0, p, dim=-1) * term_i.unsqueeze(-1)
67
+ fk = torch.cross(d1, p, dim=-1) * term_k.unsqueeze(-1)
68
+ fj = -(fi + fk)
69
+
70
+ force.index_add_(0, self.i, fi)
71
+ force.index_add_(0, self.j, fj)
72
+ force.index_add_(0, self.k, fk)
73
+ return e, force
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class BondTerm(nn.Module):
8
+ """AMBER bond term: E = sum k (r - r0)^2"""
9
+
10
+ def __init__(self, i: torch.Tensor, j: torch.Tensor, k: torch.Tensor, r0: torch.Tensor):
11
+ super().__init__()
12
+ self.register_buffer("i", i.long())
13
+ self.register_buffer("j", j.long())
14
+ self.register_buffer("k", k)
15
+ self.register_buffer("r0", r0)
16
+
17
+ def forward(self, coords: torch.Tensor) -> torch.Tensor:
18
+ if self.i.numel() == 0:
19
+ return coords.new_zeros(())
20
+ rij = coords[self.j] - coords[self.i]
21
+ r = torch.linalg.norm(rij, dim=-1)
22
+ return torch.sum(self.k * (r - self.r0) ** 2)
23
+
24
+ def energy_force(self, coords: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
25
+ """Return bond energy and analytical force."""
26
+ force = torch.zeros_like(coords)
27
+ if self.i.numel() == 0:
28
+ return coords.new_zeros(()), force
29
+
30
+ rij = coords[self.j] - coords[self.i]
31
+ r2 = torch.sum(rij * rij, dim=-1).clamp_min(1.0e-24)
32
+ inv_r = torch.rsqrt(r2)
33
+ r = r2 * inv_r
34
+
35
+ dr = r - self.r0
36
+ e = torch.sum(self.k * dr * dr)
37
+
38
+ # E = k (r-r0)^2, dE/dr = 2k(r-r0), F_i = dE/dr * (r_ij/r)
39
+ fscale = 2.0 * self.k * dr * inv_r
40
+ fij = fscale.unsqueeze(-1) * rij
41
+
42
+ force.index_add_(0, self.i, fij)
43
+ force.index_add_(0, self.j, -fij)
44
+ return e, force
@@ -0,0 +1,406 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from .dihedral import _accumulate_dihedral_forces, _dihedral_angle
10
+
11
+ _TWO_PI = 2.0 * math.pi
12
+
13
+ # References:
14
+ # - OpenMM CMAP coefficient/derivative setup:
15
+ # https://github.com/openmm/openmm/blob/4768436/openmmapi/src/CMAPTorsionForceImpl.cpp
16
+ # - OpenMM periodic spline routines:
17
+ # https://github.com/openmm/openmm/blob/4768436/openmmapi/src/SplineFitter.cpp
18
+ # - OpenMM reference CMAP patch evaluation:
19
+ # https://github.com/openmm/openmm/blob/4768436/platforms/reference/src/SimTKReference/ReferenceCMAPTorsionIxn.cpp
20
+
21
+ # OpenMM CMAP bicubic coefficient matrix (CMAPTorsionForceImpl.cpp, wt[k+16*m]).
22
+ _CMAP_WT = (
23
+ torch.tensor(
24
+ [
25
+ 1, 0, -3, 2, 0, 0, 0, 0, -3, 0, 9, -6, 2, 0, -6, 4,
26
+ 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, -9, 6, -2, 0, 6, -4,
27
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, -6, 0, 0, -6, 4,
28
+ 0, 0, 3, -2, 0, 0, 0, 0, 0, 0, -9, 6, 0, 0, 6, -4,
29
+ 0, 0, 0, 0, 1, 0, -3, 2, -2, 0, 6, -4, 1, 0, -3, 2,
30
+ 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 3, -2, 1, 0, -3, 2,
31
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, 2, 0, 0, 3, -2,
32
+ 0, 0, 0, 0, 0, 0, 3, -2, 0, 0, -6, 4, 0, 0, 3, -2,
33
+ 0, 1, -2, 1, 0, 0, 0, 0, 0, -3, 6, -3, 0, 2, -4, 2,
34
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, -6, 3, 0, -2, 4, -2,
35
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -3, 3, 0, 0, 2, -2,
36
+ 0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 3, -3, 0, 0, -2, 2,
37
+ 0, 0, 0, 0, 0, 1, -2, 1, 0, -2, 4, -2, 0, 1, -2, 1,
38
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 2, -1, 0, 1, -2, 1,
39
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, -1, 1,
40
+ 0, 0, 0, 0, 0, 0, -1, 1, 0, 0, 2, -2, 0, 0, -1, 1,
41
+ ],
42
+ dtype=torch.float64,
43
+ ).view(16, 16).t()
44
+ )
45
+
46
+
47
+ def _solve_tridiagonal(
48
+ a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, rhs: torch.Tensor
49
+ ) -> torch.Tensor:
50
+ """Solve a tridiagonal linear system (Thomas algorithm)."""
51
+ n = int(a.numel())
52
+ if n == 0:
53
+ return rhs.clone()
54
+ if n == 1:
55
+ return rhs / b
56
+
57
+ gamma = torch.zeros(n, dtype=torch.float64)
58
+ sol = torch.zeros(n, dtype=torch.float64)
59
+
60
+ beta = b[0]
61
+ sol[0] = rhs[0] / beta
62
+ for i in range(1, n):
63
+ gamma[i] = c[i - 1] / beta
64
+ beta = b[i] - a[i] * gamma[i]
65
+ sol[i] = (rhs[i] - a[i] * sol[i - 1]) / beta
66
+ for i in range(n - 2, -1, -1):
67
+ sol[i] = sol[i] - gamma[i + 1] * sol[i + 1]
68
+ return sol
69
+
70
+
71
+ def _create_periodic_spline(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
72
+ """Match OpenMM SplineFitter::createPeriodicSpline behavior."""
73
+ n = int(x.numel())
74
+ if n < 3:
75
+ raise ValueError("Periodic spline requires at least 3 points")
76
+ if y.numel() != n:
77
+ raise ValueError("x/y size mismatch in periodic spline setup")
78
+
79
+ a = torch.zeros(n - 1, dtype=torch.float64)
80
+ b = torch.zeros(n - 1, dtype=torch.float64)
81
+ c = torch.zeros(n - 1, dtype=torch.float64)
82
+ rhs = torch.zeros(n - 1, dtype=torch.float64)
83
+
84
+ a[0] = x[n - 1] - x[n - 2]
85
+ b[0] = 2.0 * (x[1] - x[0] + x[n - 1] - x[n - 2])
86
+ c[0] = x[1] - x[0]
87
+ rhs[0] = 6.0 * (
88
+ (y[1] - y[0]) / (x[1] - x[0]) - (y[n - 1] - y[n - 2]) / (x[n - 1] - x[n - 2])
89
+ )
90
+ for i in range(1, n - 1):
91
+ a[i] = x[i] - x[i - 1]
92
+ b[i] = 2.0 * (x[i + 1] - x[i - 1])
93
+ c[i] = x[i + 1] - x[i]
94
+ rhs[i] = 6.0 * (
95
+ (y[i + 1] - y[i]) / (x[i + 1] - x[i]) - (y[i] - y[i - 1]) / (x[i] - x[i - 1])
96
+ )
97
+
98
+ beta = a[0]
99
+ alpha = c[n - 2]
100
+ gamma = -b[0]
101
+
102
+ ntri = n - 1
103
+ b[0] = b[0] - gamma
104
+ b[ntri - 1] = b[ntri - 1] - alpha * beta / gamma
105
+ deriv = _solve_tridiagonal(a, b, c, rhs)
106
+
107
+ u = torch.zeros(ntri, dtype=torch.float64)
108
+ u[0] = gamma
109
+ u[ntri - 1] = alpha
110
+ z = _solve_tridiagonal(a, b, c, u)
111
+
112
+ scale = (deriv[0] + beta * deriv[ntri - 1] / gamma) / (
113
+ 1.0 + z[0] + beta * z[ntri - 1] / gamma
114
+ )
115
+ deriv = deriv - scale * z
116
+
117
+ out = torch.zeros(n, dtype=torch.float64)
118
+ out[:ntri] = deriv
119
+ out[ntri] = deriv[0]
120
+ return out
121
+
122
+
123
+ def _eval_spline_derivative(
124
+ x: torch.Tensor, y: torch.Tensor, second_deriv: torch.Tensor, t: torch.Tensor
125
+ ) -> torch.Tensor:
126
+ """Equivalent to OpenMM SplineFitter::evaluateSplineDerivative."""
127
+ lower = 0
128
+ upper = int(x.numel()) - 1
129
+ while upper - lower > 1:
130
+ middle = (upper + lower) // 2
131
+ if x[middle] > t:
132
+ upper = middle
133
+ else:
134
+ lower = middle
135
+ delta = x[upper] - x[lower]
136
+ a = (x[upper] - t) / delta
137
+ b = (t - x[lower]) / delta
138
+ return (-1.0 / delta) * (y[lower] - y[upper]) + (
139
+ (1.0 - 3.0 * a * a) * second_deriv[lower] + (3.0 * b * b - 1.0) * second_deriv[upper]
140
+ ) * delta / 6.0
141
+
142
+
143
+ def _calc_map_derivatives(energy: torch.Tensor, size: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
144
+ """Match OpenMM CMAPTorsionForceImpl::calcMapDerivatives."""
145
+ d1 = torch.zeros(size * size, dtype=torch.float64)
146
+ d2 = torch.zeros(size * size, dtype=torch.float64)
147
+ d12 = torch.zeros(size * size, dtype=torch.float64)
148
+
149
+ x = torch.arange(size + 1, dtype=torch.float64) * (_TWO_PI / float(size))
150
+
151
+ # d/dphi
152
+ for i in range(size):
153
+ y = torch.zeros(size + 1, dtype=torch.float64)
154
+ for j in range(size):
155
+ y[j] = energy[j + size * i]
156
+ y[size] = energy[size * i]
157
+ sec = _create_periodic_spline(x, y)
158
+ for j in range(size):
159
+ d1[j + size * i] = _eval_spline_derivative(x, y, sec, x[j])
160
+
161
+ # d/dpsi
162
+ for i in range(size):
163
+ y = torch.zeros(size + 1, dtype=torch.float64)
164
+ for j in range(size):
165
+ y[j] = energy[i + size * j]
166
+ y[size] = energy[i]
167
+ sec = _create_periodic_spline(x, y)
168
+ for j in range(size):
169
+ d2[i + size * j] = _eval_spline_derivative(x, y, sec, x[j])
170
+
171
+ # d2/(dphi dpsi)
172
+ for i in range(size):
173
+ y = torch.zeros(size + 1, dtype=torch.float64)
174
+ for j in range(size):
175
+ y[j] = d2[j + size * i]
176
+ y[size] = d2[size * i]
177
+ sec = _create_periodic_spline(x, y)
178
+ for j in range(size):
179
+ d12[j + size * i] = _eval_spline_derivative(x, y, sec, x[j])
180
+
181
+ return d1, d2, d12
182
+
183
+
184
+ def _calc_map_coefficients(map_values: torch.Tensor, size: int) -> torch.Tensor:
185
+ """Create bicubic patch coefficients identical to OpenMM CMAP setup."""
186
+ if map_values.numel() != size * size:
187
+ raise ValueError(f"CMAP map size mismatch: expected {size*size}, got {map_values.numel()}")
188
+ energy = map_values.detach().to(dtype=torch.float64, device="cpu").reshape(-1)
189
+ d1, d2, d12 = _calc_map_derivatives(energy, size=size)
190
+ coeff = torch.zeros((size * size, 16), dtype=torch.float64)
191
+ delta = _TWO_PI / float(size)
192
+
193
+ for i in range(size):
194
+ i1 = (i + 1) % size
195
+ for j in range(size):
196
+ j1 = (j + 1) % size
197
+ k = i + size * j
198
+ rhs = torch.tensor(
199
+ [
200
+ energy[k],
201
+ energy[i1 + size * j],
202
+ energy[i1 + size * j1],
203
+ energy[i + size * j1],
204
+ d1[k] * delta,
205
+ d1[i1 + size * j] * delta,
206
+ d1[i1 + size * j1] * delta,
207
+ d1[i + size * j1] * delta,
208
+ d2[k] * delta,
209
+ d2[i1 + size * j] * delta,
210
+ d2[i1 + size * j1] * delta,
211
+ d2[i + size * j1] * delta,
212
+ d12[k] * delta * delta,
213
+ d12[i1 + size * j] * delta * delta,
214
+ d12[i1 + size * j1] * delta * delta,
215
+ d12[i + size * j1] * delta * delta,
216
+ ],
217
+ dtype=torch.float64,
218
+ )
219
+ coeff[k] = _CMAP_WT @ rhs
220
+ return coeff
221
+
222
+
223
+ class CMapTerm(nn.Module):
224
+ """CMAP torsion correction implemented with Torch bicubic interpolation.
225
+
226
+ The coefficient generation follows OpenMM's CMAP setup logic. The forward
227
+ path is composed of standard Torch ops, so both gradients and Hessians are
228
+ available through autograd.
229
+ """
230
+
231
+ def __init__(
232
+ self,
233
+ natom: int,
234
+ cmap_type: torch.Tensor,
235
+ cmap_i: torch.Tensor,
236
+ cmap_j: torch.Tensor,
237
+ cmap_k: torch.Tensor,
238
+ cmap_l: torch.Tensor,
239
+ cmap_m: torch.Tensor,
240
+ cmap_resolution: Tuple[int, ...],
241
+ cmap_maps: Tuple[torch.Tensor, ...],
242
+ *,
243
+ precomputed_size: torch.Tensor | None = None,
244
+ precomputed_delta: torch.Tensor | None = None,
245
+ precomputed_offset: torch.Tensor | None = None,
246
+ precomputed_coeff: torch.Tensor | None = None,
247
+ ):
248
+ super().__init__()
249
+ self.natom = int(natom)
250
+ self.register_buffer("cmap_type", cmap_type.long())
251
+ self.register_buffer("cmap_i", cmap_i.long())
252
+ self.register_buffer("cmap_j", cmap_j.long())
253
+ self.register_buffer("cmap_k", cmap_k.long())
254
+ self.register_buffer("cmap_l", cmap_l.long())
255
+ self.register_buffer("cmap_m", cmap_m.long())
256
+
257
+ self._enabled = self.cmap_type.numel() > 0
258
+ if not self._enabled:
259
+ self.register_buffer("map_size", torch.zeros((0,), dtype=torch.int64))
260
+ self.register_buffer("map_delta", torch.zeros((0,), dtype=torch.float64))
261
+ self.register_buffer("map_offset", torch.zeros((0,), dtype=torch.int64))
262
+ self.register_buffer("map_coeff", torch.zeros((0, 16), dtype=torch.float64))
263
+ return
264
+
265
+ # Use precomputed coefficients if available (avoids redundant computation
266
+ # when AmberSystem already carries cmap_size/delta/offset/coeff).
267
+ if precomputed_coeff is not None:
268
+ self.register_buffer("map_size", precomputed_size)
269
+ self.register_buffer("map_delta", precomputed_delta)
270
+ self.register_buffer("map_offset", precomputed_offset)
271
+ self.register_buffer("map_coeff", precomputed_coeff)
272
+ return
273
+
274
+ if len(cmap_resolution) != len(cmap_maps):
275
+ raise ValueError(
276
+ f"CMAP resolution/map count mismatch: {len(cmap_resolution)} vs {len(cmap_maps)}"
277
+ )
278
+
279
+ map_size = torch.tensor([int(x) for x in cmap_resolution], dtype=torch.int64)
280
+ coeff_parts = []
281
+ offsets = []
282
+ cursor = 0
283
+ for ngrid, table in zip(cmap_resolution, cmap_maps):
284
+ offsets.append(cursor)
285
+ c = _calc_map_coefficients(table, size=int(ngrid))
286
+ coeff_parts.append(c)
287
+ cursor += int(ngrid) * int(ngrid)
288
+
289
+ coeff_cat = torch.cat(coeff_parts, dim=0)
290
+ target_dtype = cmap_maps[0].dtype
291
+ target_device = cmap_maps[0].device
292
+
293
+ self.register_buffer("map_size", map_size.to(device=target_device))
294
+ self.register_buffer(
295
+ "map_delta",
296
+ (torch.full_like(map_size, _TWO_PI, dtype=torch.float64) / map_size.to(torch.float64)).to(
297
+ dtype=target_dtype, device=target_device
298
+ ),
299
+ )
300
+ self.register_buffer("map_offset", torch.tensor(offsets, dtype=torch.int64, device=target_device))
301
+ self.register_buffer("map_coeff", coeff_cat.to(dtype=target_dtype, device=target_device))
302
+
303
+ def forward(self, coords: torch.Tensor) -> torch.Tensor:
304
+ if not self._enabled:
305
+ return coords.new_zeros(())
306
+
307
+ p_i = coords[self.cmap_i]
308
+ p_j = coords[self.cmap_j]
309
+ p_k = coords[self.cmap_k]
310
+ p_l = coords[self.cmap_l]
311
+ p_m = coords[self.cmap_m]
312
+
313
+ phi = _dihedral_angle(p_i, p_j, p_k, p_l)
314
+ psi = _dihedral_angle(p_j, p_k, p_l, p_m)
315
+
316
+ # OpenMM CMAP interpolation uses [0, 2pi) wrapped angles.
317
+ ang_a = torch.remainder(phi + _TWO_PI, _TWO_PI)
318
+ ang_b = torch.remainder(psi + _TWO_PI, _TWO_PI)
319
+
320
+ tmap = self.cmap_type
321
+ size = self.map_size[tmap]
322
+ delta = self.map_delta[tmap].to(dtype=coords.dtype)
323
+
324
+ u = ang_a / delta
325
+ v = ang_b / delta
326
+ s = torch.floor(u).to(torch.int64)
327
+ t = torch.floor(v).to(torch.int64)
328
+ s = torch.minimum(s, size - 1)
329
+ t = torch.minimum(t, size - 1)
330
+
331
+ da = u - s.to(dtype=coords.dtype)
332
+ db = v - t.to(dtype=coords.dtype)
333
+
334
+ patch = s + size * t
335
+ coeff_row = self.map_offset[tmap] + patch
336
+ coeff = self.map_coeff[coeff_row]
337
+ coeff = coeff.to(dtype=coords.dtype).reshape(-1, 4, 4)
338
+
339
+ # Horner evaluation in db, then da.
340
+ dbu = db.unsqueeze(-1)
341
+ poly_b = ((coeff[:, :, 3] * dbu + coeff[:, :, 2]) * dbu + coeff[:, :, 1]) * dbu + coeff[:, :, 0]
342
+ e = ((poly_b[:, 3] * da + poly_b[:, 2]) * da + poly_b[:, 1]) * da + poly_b[:, 0]
343
+ return torch.sum(e)
344
+
345
+ def energy_force(self, coords: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
346
+ """Return CMAP energy and analytical force."""
347
+ force = torch.zeros_like(coords)
348
+ if not self._enabled:
349
+ return coords.new_zeros(()), force
350
+
351
+ p_i = coords[self.cmap_i]
352
+ p_j = coords[self.cmap_j]
353
+ p_k = coords[self.cmap_k]
354
+ p_l = coords[self.cmap_l]
355
+ p_m = coords[self.cmap_m]
356
+
357
+ phi = _dihedral_angle(p_i, p_j, p_k, p_l)
358
+ psi = _dihedral_angle(p_j, p_k, p_l, p_m)
359
+
360
+ ang_a = torch.remainder(phi + _TWO_PI, _TWO_PI)
361
+ ang_b = torch.remainder(psi + _TWO_PI, _TWO_PI)
362
+
363
+ tmap = self.cmap_type
364
+ size = self.map_size[tmap]
365
+ delta = self.map_delta[tmap].to(dtype=coords.dtype)
366
+
367
+ u = ang_a / delta
368
+ v = ang_b / delta
369
+ s = torch.floor(u).to(torch.int64)
370
+ t = torch.floor(v).to(torch.int64)
371
+ s = torch.minimum(s, size - 1)
372
+ t = torch.minimum(t, size - 1)
373
+
374
+ da = u - s.to(dtype=coords.dtype)
375
+ db = v - t.to(dtype=coords.dtype)
376
+
377
+ patch = s + size * t
378
+ coeff_row = self.map_offset[tmap] + patch
379
+ coeff = self.map_coeff[coeff_row].to(dtype=coords.dtype).reshape(-1, 4, 4)
380
+
381
+ dbu = db.unsqueeze(-1)
382
+ # Row-wise polynomial in db: Pm(db) (m=0..3)
383
+ poly_b = ((coeff[:, :, 3] * dbu + coeff[:, :, 2]) * dbu + coeff[:, :, 1]) * dbu + coeff[:, :, 0]
384
+ # Derivative wrt db: dPm/ddb
385
+ dpoly_b = ((3.0 * coeff[:, :, 3] * dbu + 2.0 * coeff[:, :, 2]) * dbu + coeff[:, :, 1])
386
+
387
+ # E = sum_m Pm(db) * da^m
388
+ e = ((poly_b[:, 3] * da + poly_b[:, 2]) * da + poly_b[:, 1]) * da + poly_b[:, 0]
389
+ energy = torch.sum(e)
390
+
391
+ # dE/dda and dE/ddb
392
+ dE_dda = (3.0 * poly_b[:, 3] * da + 2.0 * poly_b[:, 2]) * da + poly_b[:, 1]
393
+ dE_ddb = ((dpoly_b[:, 3] * da + dpoly_b[:, 2]) * da + dpoly_b[:, 1]) * da + dpoly_b[:, 0]
394
+
395
+ # da = phi/delta - floor(...), db = psi/delta - floor(...)
396
+ # Away from grid boundaries, d(da)/dphi = 1/delta and d(db)/dpsi = 1/delta.
397
+ dE_dphi = dE_dda / delta
398
+ dE_dpsi = dE_ddb / delta
399
+
400
+ _accumulate_dihedral_forces(
401
+ force, coords, self.cmap_i, self.cmap_j, self.cmap_k, self.cmap_l, dE_dphi
402
+ )
403
+ _accumulate_dihedral_forces(
404
+ force, coords, self.cmap_j, self.cmap_k, self.cmap_l, self.cmap_m, dE_dpsi
405
+ )
406
+ return energy, force