mlmm-toolkit 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (372) hide show
  1. hessian_ff/__init__.py +50 -0
  2. hessian_ff/analytical_hessian.py +609 -0
  3. hessian_ff/constants.py +46 -0
  4. hessian_ff/forcefield.py +339 -0
  5. hessian_ff/loaders.py +608 -0
  6. hessian_ff/native/Makefile +8 -0
  7. hessian_ff/native/__init__.py +28 -0
  8. hessian_ff/native/analytical_hessian.py +88 -0
  9. hessian_ff/native/analytical_hessian_ext.cpp +258 -0
  10. hessian_ff/native/bonded.py +82 -0
  11. hessian_ff/native/bonded_ext.cpp +640 -0
  12. hessian_ff/native/loader.py +349 -0
  13. hessian_ff/native/nonbonded.py +118 -0
  14. hessian_ff/native/nonbonded_ext.cpp +1150 -0
  15. hessian_ff/prmtop_parmed.py +23 -0
  16. hessian_ff/system.py +107 -0
  17. hessian_ff/terms/__init__.py +14 -0
  18. hessian_ff/terms/angle.py +73 -0
  19. hessian_ff/terms/bond.py +44 -0
  20. hessian_ff/terms/cmap.py +406 -0
  21. hessian_ff/terms/dihedral.py +141 -0
  22. hessian_ff/terms/nonbonded.py +209 -0
  23. hessian_ff/tests/__init__.py +0 -0
  24. hessian_ff/tests/conftest.py +75 -0
  25. hessian_ff/tests/data/small/complex.parm7 +1346 -0
  26. hessian_ff/tests/data/small/complex.pdb +125 -0
  27. hessian_ff/tests/data/small/complex.rst7 +63 -0
  28. hessian_ff/tests/test_coords_input.py +44 -0
  29. hessian_ff/tests/test_energy_force.py +49 -0
  30. hessian_ff/tests/test_hessian.py +137 -0
  31. hessian_ff/tests/test_smoke.py +18 -0
  32. hessian_ff/tests/test_validation.py +40 -0
  33. hessian_ff/workflows.py +889 -0
  34. mlmm/__init__.py +36 -0
  35. mlmm/__main__.py +7 -0
  36. mlmm/_version.py +34 -0
  37. mlmm/add_elem_info.py +374 -0
  38. mlmm/advanced_help.py +91 -0
  39. mlmm/align_freeze_atoms.py +601 -0
  40. mlmm/all.py +3535 -0
  41. mlmm/bond_changes.py +231 -0
  42. mlmm/bool_compat.py +223 -0
  43. mlmm/cli.py +574 -0
  44. mlmm/cli_utils.py +166 -0
  45. mlmm/default_group.py +337 -0
  46. mlmm/defaults.py +467 -0
  47. mlmm/define_layer.py +526 -0
  48. mlmm/dft.py +1041 -0
  49. mlmm/energy_diagram.py +253 -0
  50. mlmm/extract.py +2213 -0
  51. mlmm/fix_altloc.py +464 -0
  52. mlmm/freq.py +1406 -0
  53. mlmm/harmonic_constraints.py +140 -0
  54. mlmm/hessian_cache.py +44 -0
  55. mlmm/hessian_calc.py +174 -0
  56. mlmm/irc.py +638 -0
  57. mlmm/mlmm_calc.py +2262 -0
  58. mlmm/mm_parm.py +945 -0
  59. mlmm/oniom_export.py +1983 -0
  60. mlmm/oniom_import.py +457 -0
  61. mlmm/opt.py +1742 -0
  62. mlmm/path_opt.py +1353 -0
  63. mlmm/path_search.py +2299 -0
  64. mlmm/preflight.py +88 -0
  65. mlmm/py.typed +1 -0
  66. mlmm/pysis_runner.py +45 -0
  67. mlmm/scan.py +1047 -0
  68. mlmm/scan2d.py +1226 -0
  69. mlmm/scan3d.py +1265 -0
  70. mlmm/scan_common.py +184 -0
  71. mlmm/summary_log.py +736 -0
  72. mlmm/trj2fig.py +448 -0
  73. mlmm/tsopt.py +2871 -0
  74. mlmm/utils.py +2309 -0
  75. mlmm/xtb_embedcharge_correction.py +475 -0
  76. mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
  77. mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
  78. mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
  79. mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
  80. mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
  81. mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
  82. pysisyphus/Geometry.py +1667 -0
  83. pysisyphus/LICENSE +674 -0
  84. pysisyphus/TableFormatter.py +63 -0
  85. pysisyphus/TablePrinter.py +74 -0
  86. pysisyphus/__init__.py +12 -0
  87. pysisyphus/calculators/AFIR.py +452 -0
  88. pysisyphus/calculators/AnaPot.py +20 -0
  89. pysisyphus/calculators/AnaPot2.py +48 -0
  90. pysisyphus/calculators/AnaPot3.py +12 -0
  91. pysisyphus/calculators/AnaPot4.py +20 -0
  92. pysisyphus/calculators/AnaPotBase.py +337 -0
  93. pysisyphus/calculators/AnaPotCBM.py +25 -0
  94. pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
  95. pysisyphus/calculators/CFOUR.py +250 -0
  96. pysisyphus/calculators/Calculator.py +844 -0
  97. pysisyphus/calculators/CerjanMiller.py +24 -0
  98. pysisyphus/calculators/Composite.py +123 -0
  99. pysisyphus/calculators/ConicalIntersection.py +171 -0
  100. pysisyphus/calculators/DFTBp.py +430 -0
  101. pysisyphus/calculators/DFTD3.py +66 -0
  102. pysisyphus/calculators/DFTD4.py +84 -0
  103. pysisyphus/calculators/Dalton.py +61 -0
  104. pysisyphus/calculators/Dimer.py +681 -0
  105. pysisyphus/calculators/Dummy.py +20 -0
  106. pysisyphus/calculators/EGO.py +76 -0
  107. pysisyphus/calculators/EnergyMin.py +224 -0
  108. pysisyphus/calculators/ExternalPotential.py +264 -0
  109. pysisyphus/calculators/FakeASE.py +35 -0
  110. pysisyphus/calculators/FourWellAnaPot.py +28 -0
  111. pysisyphus/calculators/FreeEndNEBPot.py +39 -0
  112. pysisyphus/calculators/Gaussian09.py +18 -0
  113. pysisyphus/calculators/Gaussian16.py +726 -0
  114. pysisyphus/calculators/HardSphere.py +159 -0
  115. pysisyphus/calculators/IDPPCalculator.py +49 -0
  116. pysisyphus/calculators/IPIClient.py +133 -0
  117. pysisyphus/calculators/IPIServer.py +234 -0
  118. pysisyphus/calculators/LEPSBase.py +24 -0
  119. pysisyphus/calculators/LEPSExpr.py +139 -0
  120. pysisyphus/calculators/LennardJones.py +80 -0
  121. pysisyphus/calculators/MOPAC.py +219 -0
  122. pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
  123. pysisyphus/calculators/MultiCalc.py +85 -0
  124. pysisyphus/calculators/NFK.py +45 -0
  125. pysisyphus/calculators/OBabel.py +87 -0
  126. pysisyphus/calculators/ONIOMv2.py +1129 -0
  127. pysisyphus/calculators/ORCA.py +893 -0
  128. pysisyphus/calculators/ORCA5.py +6 -0
  129. pysisyphus/calculators/OpenMM.py +88 -0
  130. pysisyphus/calculators/OpenMolcas.py +281 -0
  131. pysisyphus/calculators/OverlapCalculator.py +908 -0
  132. pysisyphus/calculators/Psi4.py +218 -0
  133. pysisyphus/calculators/PyPsi4.py +37 -0
  134. pysisyphus/calculators/PySCF.py +341 -0
  135. pysisyphus/calculators/PyXTB.py +73 -0
  136. pysisyphus/calculators/QCEngine.py +106 -0
  137. pysisyphus/calculators/Rastrigin.py +22 -0
  138. pysisyphus/calculators/Remote.py +76 -0
  139. pysisyphus/calculators/Rosenbrock.py +15 -0
  140. pysisyphus/calculators/SocketCalc.py +97 -0
  141. pysisyphus/calculators/TIP3P.py +111 -0
  142. pysisyphus/calculators/TransTorque.py +161 -0
  143. pysisyphus/calculators/Turbomole.py +965 -0
  144. pysisyphus/calculators/VRIPot.py +37 -0
  145. pysisyphus/calculators/WFOWrapper.py +333 -0
  146. pysisyphus/calculators/WFOWrapper2.py +341 -0
  147. pysisyphus/calculators/XTB.py +418 -0
  148. pysisyphus/calculators/__init__.py +81 -0
  149. pysisyphus/calculators/cosmo_data.py +139 -0
  150. pysisyphus/calculators/parser.py +150 -0
  151. pysisyphus/color.py +19 -0
  152. pysisyphus/config.py +133 -0
  153. pysisyphus/constants.py +65 -0
  154. pysisyphus/cos/AdaptiveNEB.py +230 -0
  155. pysisyphus/cos/ChainOfStates.py +725 -0
  156. pysisyphus/cos/FreeEndNEB.py +25 -0
  157. pysisyphus/cos/FreezingString.py +103 -0
  158. pysisyphus/cos/GrowingChainOfStates.py +71 -0
  159. pysisyphus/cos/GrowingNT.py +309 -0
  160. pysisyphus/cos/GrowingString.py +508 -0
  161. pysisyphus/cos/NEB.py +189 -0
  162. pysisyphus/cos/SimpleZTS.py +64 -0
  163. pysisyphus/cos/__init__.py +22 -0
  164. pysisyphus/cos/stiffness.py +199 -0
  165. pysisyphus/drivers/__init__.py +17 -0
  166. pysisyphus/drivers/afir.py +855 -0
  167. pysisyphus/drivers/barriers.py +271 -0
  168. pysisyphus/drivers/birkholz.py +138 -0
  169. pysisyphus/drivers/cluster.py +318 -0
  170. pysisyphus/drivers/diabatization.py +133 -0
  171. pysisyphus/drivers/merge.py +368 -0
  172. pysisyphus/drivers/merge_mol2.py +322 -0
  173. pysisyphus/drivers/opt.py +375 -0
  174. pysisyphus/drivers/perf.py +91 -0
  175. pysisyphus/drivers/pka.py +52 -0
  176. pysisyphus/drivers/precon_pos_rot.py +669 -0
  177. pysisyphus/drivers/rates.py +480 -0
  178. pysisyphus/drivers/replace.py +219 -0
  179. pysisyphus/drivers/scan.py +212 -0
  180. pysisyphus/drivers/spectrum.py +166 -0
  181. pysisyphus/drivers/thermo.py +31 -0
  182. pysisyphus/dynamics/Gaussian.py +103 -0
  183. pysisyphus/dynamics/__init__.py +20 -0
  184. pysisyphus/dynamics/colvars.py +136 -0
  185. pysisyphus/dynamics/driver.py +297 -0
  186. pysisyphus/dynamics/helpers.py +256 -0
  187. pysisyphus/dynamics/lincs.py +105 -0
  188. pysisyphus/dynamics/mdp.py +364 -0
  189. pysisyphus/dynamics/rattle.py +121 -0
  190. pysisyphus/dynamics/thermostats.py +128 -0
  191. pysisyphus/dynamics/wigner.py +266 -0
  192. pysisyphus/elem_data.py +3473 -0
  193. pysisyphus/exceptions.py +2 -0
  194. pysisyphus/filtertrj.py +69 -0
  195. pysisyphus/helpers.py +623 -0
  196. pysisyphus/helpers_pure.py +649 -0
  197. pysisyphus/init_logging.py +50 -0
  198. pysisyphus/intcoords/Bend.py +69 -0
  199. pysisyphus/intcoords/Bend2.py +25 -0
  200. pysisyphus/intcoords/BondedFragment.py +32 -0
  201. pysisyphus/intcoords/Cartesian.py +41 -0
  202. pysisyphus/intcoords/CartesianCoords.py +140 -0
  203. pysisyphus/intcoords/Coords.py +56 -0
  204. pysisyphus/intcoords/DLC.py +197 -0
  205. pysisyphus/intcoords/DistanceFunction.py +34 -0
  206. pysisyphus/intcoords/DummyImproper.py +70 -0
  207. pysisyphus/intcoords/DummyTorsion.py +72 -0
  208. pysisyphus/intcoords/LinearBend.py +105 -0
  209. pysisyphus/intcoords/LinearDisplacement.py +80 -0
  210. pysisyphus/intcoords/OutOfPlane.py +59 -0
  211. pysisyphus/intcoords/PrimTypes.py +286 -0
  212. pysisyphus/intcoords/Primitive.py +137 -0
  213. pysisyphus/intcoords/RedundantCoords.py +659 -0
  214. pysisyphus/intcoords/RobustTorsion.py +59 -0
  215. pysisyphus/intcoords/Rotation.py +147 -0
  216. pysisyphus/intcoords/Stretch.py +31 -0
  217. pysisyphus/intcoords/Torsion.py +101 -0
  218. pysisyphus/intcoords/Torsion2.py +25 -0
  219. pysisyphus/intcoords/Translation.py +45 -0
  220. pysisyphus/intcoords/__init__.py +61 -0
  221. pysisyphus/intcoords/augment_bonds.py +126 -0
  222. pysisyphus/intcoords/derivatives.py +10512 -0
  223. pysisyphus/intcoords/eval.py +80 -0
  224. pysisyphus/intcoords/exceptions.py +37 -0
  225. pysisyphus/intcoords/findiffs.py +48 -0
  226. pysisyphus/intcoords/generate_derivatives.py +414 -0
  227. pysisyphus/intcoords/helpers.py +235 -0
  228. pysisyphus/intcoords/logging_conf.py +10 -0
  229. pysisyphus/intcoords/mp_derivatives.py +10836 -0
  230. pysisyphus/intcoords/setup.py +962 -0
  231. pysisyphus/intcoords/setup_fast.py +176 -0
  232. pysisyphus/intcoords/update.py +272 -0
  233. pysisyphus/intcoords/valid.py +89 -0
  234. pysisyphus/interpolate/Geodesic.py +93 -0
  235. pysisyphus/interpolate/IDPP.py +55 -0
  236. pysisyphus/interpolate/Interpolator.py +116 -0
  237. pysisyphus/interpolate/LST.py +70 -0
  238. pysisyphus/interpolate/Redund.py +152 -0
  239. pysisyphus/interpolate/__init__.py +9 -0
  240. pysisyphus/interpolate/helpers.py +34 -0
  241. pysisyphus/io/__init__.py +22 -0
  242. pysisyphus/io/aomix.py +178 -0
  243. pysisyphus/io/cjson.py +24 -0
  244. pysisyphus/io/crd.py +101 -0
  245. pysisyphus/io/cube.py +220 -0
  246. pysisyphus/io/fchk.py +184 -0
  247. pysisyphus/io/hdf5.py +49 -0
  248. pysisyphus/io/hessian.py +72 -0
  249. pysisyphus/io/mol2.py +146 -0
  250. pysisyphus/io/molden.py +293 -0
  251. pysisyphus/io/orca.py +189 -0
  252. pysisyphus/io/pdb.py +269 -0
  253. pysisyphus/io/psf.py +79 -0
  254. pysisyphus/io/pubchem.py +31 -0
  255. pysisyphus/io/qcschema.py +34 -0
  256. pysisyphus/io/sdf.py +29 -0
  257. pysisyphus/io/xyz.py +61 -0
  258. pysisyphus/io/zmat.py +175 -0
  259. pysisyphus/irc/DWI.py +108 -0
  260. pysisyphus/irc/DampedVelocityVerlet.py +134 -0
  261. pysisyphus/irc/Euler.py +22 -0
  262. pysisyphus/irc/EulerPC.py +345 -0
  263. pysisyphus/irc/GonzalezSchlegel.py +187 -0
  264. pysisyphus/irc/IMKMod.py +164 -0
  265. pysisyphus/irc/IRC.py +878 -0
  266. pysisyphus/irc/IRCDummy.py +10 -0
  267. pysisyphus/irc/Instanton.py +307 -0
  268. pysisyphus/irc/LQA.py +53 -0
  269. pysisyphus/irc/ModeKill.py +136 -0
  270. pysisyphus/irc/ParamPlot.py +53 -0
  271. pysisyphus/irc/RK4.py +36 -0
  272. pysisyphus/irc/__init__.py +31 -0
  273. pysisyphus/irc/initial_displ.py +219 -0
  274. pysisyphus/linalg.py +411 -0
  275. pysisyphus/line_searches/Backtracking.py +88 -0
  276. pysisyphus/line_searches/HagerZhang.py +184 -0
  277. pysisyphus/line_searches/LineSearch.py +232 -0
  278. pysisyphus/line_searches/StrongWolfe.py +108 -0
  279. pysisyphus/line_searches/__init__.py +9 -0
  280. pysisyphus/line_searches/interpol.py +15 -0
  281. pysisyphus/modefollow/NormalMode.py +40 -0
  282. pysisyphus/modefollow/__init__.py +10 -0
  283. pysisyphus/modefollow/davidson.py +199 -0
  284. pysisyphus/modefollow/lanczos.py +95 -0
  285. pysisyphus/optimizers/BFGS.py +99 -0
  286. pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
  287. pysisyphus/optimizers/ConjugateGradient.py +98 -0
  288. pysisyphus/optimizers/CubicNewton.py +75 -0
  289. pysisyphus/optimizers/FIRE.py +113 -0
  290. pysisyphus/optimizers/HessianOptimizer.py +1176 -0
  291. pysisyphus/optimizers/LBFGS.py +228 -0
  292. pysisyphus/optimizers/LayerOpt.py +411 -0
  293. pysisyphus/optimizers/MicroOptimizer.py +169 -0
  294. pysisyphus/optimizers/NCOptimizer.py +90 -0
  295. pysisyphus/optimizers/Optimizer.py +1084 -0
  296. pysisyphus/optimizers/PreconLBFGS.py +260 -0
  297. pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
  298. pysisyphus/optimizers/QuickMin.py +74 -0
  299. pysisyphus/optimizers/RFOptimizer.py +181 -0
  300. pysisyphus/optimizers/RSA.py +99 -0
  301. pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
  302. pysisyphus/optimizers/SteepestDescent.py +23 -0
  303. pysisyphus/optimizers/StringOptimizer.py +173 -0
  304. pysisyphus/optimizers/__init__.py +41 -0
  305. pysisyphus/optimizers/closures.py +301 -0
  306. pysisyphus/optimizers/cls_map.py +58 -0
  307. pysisyphus/optimizers/exceptions.py +6 -0
  308. pysisyphus/optimizers/gdiis.py +280 -0
  309. pysisyphus/optimizers/guess_hessians.py +311 -0
  310. pysisyphus/optimizers/hessian_updates.py +355 -0
  311. pysisyphus/optimizers/poly_fit.py +285 -0
  312. pysisyphus/optimizers/precon.py +153 -0
  313. pysisyphus/optimizers/restrict_step.py +24 -0
  314. pysisyphus/pack.py +172 -0
  315. pysisyphus/peakdetect.py +948 -0
  316. pysisyphus/plot.py +1031 -0
  317. pysisyphus/run.py +2106 -0
  318. pysisyphus/socket_helper.py +74 -0
  319. pysisyphus/stocastic/FragmentKick.py +132 -0
  320. pysisyphus/stocastic/Kick.py +81 -0
  321. pysisyphus/stocastic/Pipeline.py +303 -0
  322. pysisyphus/stocastic/__init__.py +21 -0
  323. pysisyphus/stocastic/align.py +127 -0
  324. pysisyphus/testing.py +96 -0
  325. pysisyphus/thermo.py +156 -0
  326. pysisyphus/trj.py +824 -0
  327. pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
  328. pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
  329. pysisyphus/tsoptimizers/TRIM.py +59 -0
  330. pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
  331. pysisyphus/tsoptimizers/__init__.py +23 -0
  332. pysisyphus/wavefunction/Basis.py +239 -0
  333. pysisyphus/wavefunction/DIIS.py +76 -0
  334. pysisyphus/wavefunction/__init__.py +25 -0
  335. pysisyphus/wavefunction/build_ext.py +42 -0
  336. pysisyphus/wavefunction/cart2sph.py +190 -0
  337. pysisyphus/wavefunction/diabatization.py +304 -0
  338. pysisyphus/wavefunction/excited_states.py +435 -0
  339. pysisyphus/wavefunction/gen_ints.py +1811 -0
  340. pysisyphus/wavefunction/helpers.py +104 -0
  341. pysisyphus/wavefunction/ints/__init__.py +0 -0
  342. pysisyphus/wavefunction/ints/boys.py +193 -0
  343. pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
  344. pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
  345. pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
  346. pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
  347. pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
  348. pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
  349. pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
  350. pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
  351. pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
  352. pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
  353. pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
  354. pysisyphus/wavefunction/localization.py +458 -0
  355. pysisyphus/wavefunction/multipole.py +159 -0
  356. pysisyphus/wavefunction/normalization.py +36 -0
  357. pysisyphus/wavefunction/pop_analysis.py +134 -0
  358. pysisyphus/wavefunction/shells.py +1171 -0
  359. pysisyphus/wavefunction/wavefunction.py +504 -0
  360. pysisyphus/wrapper/__init__.py +11 -0
  361. pysisyphus/wrapper/exceptions.py +2 -0
  362. pysisyphus/wrapper/jmol.py +120 -0
  363. pysisyphus/wrapper/mwfn.py +169 -0
  364. pysisyphus/wrapper/packmol.py +71 -0
  365. pysisyphus/xyzloader.py +168 -0
  366. pysisyphus/yaml_mods.py +45 -0
  367. thermoanalysis/LICENSE +674 -0
  368. thermoanalysis/QCData.py +244 -0
  369. thermoanalysis/__init__.py +0 -0
  370. thermoanalysis/config.py +3 -0
  371. thermoanalysis/constants.py +20 -0
  372. thermoanalysis/thermo.py +1011 -0
mlmm/mlmm_calc.py ADDED
@@ -0,0 +1,2262 @@
1
+ """
2
+ ONIOM-like ML/MM calculator coupling MLIP backends (UMA, ORB, MACE, AIMNet2)
3
+ with hessian_ff (MM).
4
+
5
+ Example:
6
+ calc = mlmm(input_pdb="input.pdb", real_parm7="real.parm7", model_pdb="model.pdb", charge=0)
7
+
8
+ For detailed documentation, see: docs/mlmm_calc.md
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import abc
14
+ import logging
15
+ import os
16
+ import warnings
17
+ import shutil
18
+ import tempfile
19
+ import time
20
+ from dataclasses import dataclass
21
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
22
+ from concurrent.futures import ThreadPoolExecutor
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ import click
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn as nn
30
+
31
+ from ase import Atoms
32
+ from ase.io import read
33
+ from ase.calculators.calculator import Calculator, all_changes
34
+ from ase.constraints import FixAtoms
35
+
36
+ import parmed as pmd
37
+ from hessian_ff import ForceFieldTorch, load_coords, load_system
38
+ from hessian_ff.analytical_hessian import build_analytical_hessian
39
+
40
+ # Optional OpenMM import
41
+ try:
42
+ import openmm as mm
43
+ from openmm import app, unit, Platform
44
+ from openmm.unit import ScaledUnit, joule
45
+ HAS_OPENMM = True
46
+ except ImportError:
47
+ HAS_OPENMM = False
48
+
49
+ # Optional fairchem import (UMA backend)
50
+ try:
51
+ from fairchem.core import pretrained_mlip
52
+ from fairchem.core.datasets.atomic_data import AtomicData
53
+ from fairchem.core.datasets import data_list_collater
54
+ HAS_FAIRCHEM = True
55
+ except ImportError:
56
+ HAS_FAIRCHEM = False
57
+
58
+ # Optional ORB backend
59
+ try:
60
+ import orb_models # noqa: F401
61
+ HAS_ORB = True
62
+ except ImportError:
63
+ HAS_ORB = False
64
+
65
+ # Optional MACE backend
66
+ try:
67
+ import mace # noqa: F401
68
+ HAS_MACE = True
69
+ except ImportError:
70
+ HAS_MACE = False
71
+
72
+ # Optional AIMNet2 backend
73
+ try:
74
+ import aimnet # noqa: F401
75
+ HAS_AIMNET2 = True
76
+ except ImportError:
77
+ HAS_AIMNET2 = False
78
+
79
+ # ---------- PySisyphus unit constants ----------
80
+ from pysisyphus.constants import BOHR2ANG, ANG2BOHR, AU2EV, AU2KCALPERMOL
81
+ EV2AU = 1.0 / AU2EV # eV → Hartree
82
+ KCALMOL2EV = AU2EV / AU2KCALPERMOL # kcal/mol -> eV
83
+
84
+
85
+ # ======================================================================
86
+ # ML Backend Abstraction
87
+ # ======================================================================
88
+
89
+
90
+ class _MLBackend(abc.ABC):
91
+ """Internal abstraction for the ML part of the ONIOM ML/MM coupling.
92
+
93
+ Each backend must provide energy/force evaluation and Hessian computation.
94
+ All quantities are in eV and Angstrom.
95
+ """
96
+
97
+ @abc.abstractmethod
98
+ def eval(
99
+ self, atoms: Atoms, need_grad: bool = True
100
+ ) -> Tuple[float, np.ndarray, Any]:
101
+ """Evaluate energy and forces.
102
+
103
+ Returns
104
+ -------
105
+ E : float
106
+ Energy in eV.
107
+ F : ndarray (N, 3)
108
+ Forces in eV/Å.
109
+ opaque : Any
110
+ Backend-specific data needed for analytical Hessian (e.g., batch).
111
+ """
112
+
113
+ @abc.abstractmethod
114
+ def hessian_analytical(self, opaque: Any, n_atoms: int, *, dtype: torch.dtype) -> torch.Tensor:
115
+ """Compute analytical Hessian from the opaque batch returned by eval().
116
+
117
+ Returns Hessian as a (N, 3, N, 3) torch Tensor in eV/Ų.
118
+ """
119
+
120
+ def hessian_fd(
121
+ self,
122
+ atoms: Atoms,
123
+ freeze_model: Sequence[int],
124
+ *,
125
+ eps_ang: float = 1.0e-3,
126
+ dtype: torch.dtype = torch.float32,
127
+ device: torch.device = torch.device("cpu"),
128
+ ) -> torch.Tensor:
129
+ """Compute Hessian via finite differences (central difference).
130
+
131
+ Generic implementation that works for all backends.
132
+ """
133
+ n_atoms = len(atoms)
134
+ dof = n_atoms * 3
135
+
136
+ frozen_set = set(int(i) for i in freeze_model)
137
+ active_atoms = [i for i in range(n_atoms) if i not in frozen_set]
138
+ active_dof_idx = [3 * i + j for i in active_atoms for j in range(3)]
139
+
140
+ H = torch.zeros((dof, dof), device=device, dtype=dtype)
141
+ coord0 = atoms.get_positions().copy()
142
+ for k in active_dof_idx:
143
+ a = k // 3
144
+ c = k % 3
145
+
146
+ atoms.positions = coord0.copy()
147
+ atoms.positions[a, c] = coord0[a, c] + eps_ang
148
+ _, Fp, _ = self.eval(atoms, need_grad=False)
149
+
150
+ atoms.positions = coord0.copy()
151
+ atoms.positions[a, c] = coord0[a, c] - eps_ang
152
+ _, Fm, _ = self.eval(atoms, need_grad=False)
153
+
154
+ col = -(torch.from_numpy(Fp.reshape(-1)) - torch.from_numpy(Fm.reshape(-1))) / (2.0 * eps_ang)
155
+ H[:, k] = col.to(device, dtype=dtype)
156
+
157
+ atoms.positions = coord0
158
+ return H.view(n_atoms, 3, n_atoms, 3)
159
+
160
+ @property
161
+ @abc.abstractmethod
162
+ def supports_analytical_hessian(self) -> bool:
163
+ """Whether this backend supports analytical Hessian."""
164
+
165
+ @property
166
+ @abc.abstractmethod
167
+ def device(self) -> torch.device:
168
+ """The torch device this backend uses."""
169
+
170
+
171
+ class _UMABackend(_MLBackend):
172
+ """UMA (FAIR-Chem) ML backend."""
173
+
174
+ def __init__(
175
+ self,
176
+ *,
177
+ uma_model: str = "uma-s-1p1",
178
+ uma_task_name: str = "omol",
179
+ model_charge: int = 0,
180
+ model_mult: int = 1,
181
+ ml_device: torch.device,
182
+ ):
183
+ if not HAS_FAIRCHEM:
184
+ raise ImportError(
185
+ "fairchem-core is required for the UMA backend. "
186
+ "Install with `pip install fairchem-core` "
187
+ "and ensure Hugging Face authentication is configured."
188
+ )
189
+ self._device = ml_device
190
+ device_str = "cuda" if ml_device.type == "cuda" else "cpu"
191
+ self._AtomicData = AtomicData
192
+ self._data_list_collater = data_list_collater
193
+ self.predictor = pretrained_mlip.get_predict_unit(uma_model, device=device_str)
194
+ self.predictor.model.eval()
195
+ for m in self.predictor.model.modules():
196
+ if isinstance(m, nn.Dropout):
197
+ m.p = 0.0
198
+ self.uma_task_name = uma_task_name
199
+ self.model_charge = model_charge
200
+ self.model_mult = model_mult
201
+ backbone = getattr(self.predictor.model, "module", self.predictor.model).backbone
202
+ self._uma_max_neigh = getattr(backbone, "max_neighbors", None)
203
+ self._uma_radius = getattr(backbone, "cutoff", None)
204
+
205
+ @property
206
+ def supports_analytical_hessian(self) -> bool:
207
+ return True
208
+
209
+ @property
210
+ def device(self) -> torch.device:
211
+ return self._device
212
+
213
+ def eval(self, atoms: Atoms, need_grad: bool = True) -> Tuple[float, np.ndarray, Any]:
214
+ atoms.info.update({"charge": self.model_charge, "spin": self.model_mult - 1})
215
+ data = self._AtomicData.from_ase(
216
+ atoms,
217
+ max_neigh=self._uma_max_neigh,
218
+ radius=self._uma_radius,
219
+ r_edges=False,
220
+ ).to(self._device)
221
+ data.dataset = self.uma_task_name
222
+ batch = self._data_list_collater([data], otf_graph=True).to(self._device)
223
+ pos = batch.pos.detach().clone().to(self._device)
224
+ pos.requires_grad_(need_grad)
225
+ batch.pos = pos
226
+ if need_grad:
227
+ res = self.predictor.predict(batch)
228
+ else:
229
+ with torch.no_grad():
230
+ res = self.predictor.predict(batch)
231
+ E = float(res["energy"].squeeze().detach().item())
232
+ F = res["forces"].detach().cpu().numpy()
233
+ return E, F, batch
234
+
235
+ def hessian_analytical(self, opaque: Any, n_atoms: int, *, dtype: torch.dtype) -> torch.Tensor:
236
+ batch = opaque
237
+ p_flags = [p.requires_grad for p in self.predictor.model.parameters()]
238
+ for p in self.predictor.model.parameters():
239
+ p.requires_grad_(False)
240
+
241
+ self.predictor.model.train()
242
+ try:
243
+ pos = batch.pos
244
+
245
+ def energy_fn(flat_pos: torch.Tensor):
246
+ batch.pos = flat_pos.view(-1, 3)
247
+ return self.predictor.predict(batch)["energy"].squeeze()
248
+
249
+ H_flat = torch.autograd.functional.hessian(energy_fn, pos.view(-1), vectorize=False)
250
+ H = H_flat.view(n_atoms, 3, n_atoms, 3).to(dtype).detach()
251
+ finally:
252
+ self.predictor.model.eval()
253
+ for p, flag in zip(self.predictor.model.parameters(), p_flags):
254
+ p.requires_grad_(flag)
255
+ if self._device.type == "cuda":
256
+ torch.cuda.empty_cache()
257
+ return H
258
+
259
+
260
+ class _ASEMLBackend(_MLBackend):
261
+ """Base class for ASE-calculator-based ML backends (ORB, MACE, AIMNet2).
262
+
263
+ Subclasses must set ``self._ase_calc`` (an ASE Calculator) and
264
+ ``self._device``.
265
+ """
266
+
267
+ _ase_calc: Calculator
268
+ _device: torch.device
269
+ _model_charge: int = 0
270
+ _model_mult: int = 1
271
+
272
+ @property
273
+ def supports_analytical_hessian(self) -> bool:
274
+ return False
275
+
276
+ @property
277
+ def device(self) -> torch.device:
278
+ return self._device
279
+
280
+ def eval(self, atoms: Atoms, need_grad: bool = True) -> Tuple[float, np.ndarray, Any]:
281
+ atoms_copy = atoms.copy()
282
+ atoms_copy.calc = self._ase_calc
283
+ # Propagate charge/spin to ASE Atoms info for backends that use them
284
+ # (e.g. AIMNet2 reads atoms.info['charge'] and atoms.info['mult'])
285
+ atoms_copy.info["charge"] = self._model_charge
286
+ atoms_copy.info["mult"] = self._model_mult
287
+ E = float(atoms_copy.get_potential_energy())
288
+ F = np.array(atoms_copy.get_forces(), dtype=np.float64)
289
+ return E, F, None
290
+
291
+ def hessian_analytical(self, opaque: Any, n_atoms: int, *, dtype: torch.dtype) -> torch.Tensor:
292
+ raise NotImplementedError(
293
+ f"Analytical Hessian is not supported by {self.__class__.__name__}. "
294
+ "Use hessian_calc_mode='FiniteDifference'."
295
+ )
296
+
297
+
298
+ class _OrbBackend(_ASEMLBackend):
299
+ """ORB (Orbital Materials) ML backend."""
300
+
301
+ def __init__(
302
+ self,
303
+ *,
304
+ orb_model: str = "orb_v3_conservative_omol",
305
+ model_charge: int = 0,
306
+ model_mult: int = 1,
307
+ ml_device: torch.device,
308
+ **_kwargs, # absorb unused keys (e.g. orb_precision)
309
+ ):
310
+ if not HAS_ORB:
311
+ raise ImportError(
312
+ "orb-models is required for the ORB backend. "
313
+ "Install with `pip install orb-models`."
314
+ )
315
+ from orb_models.forcefield import pretrained
316
+ from orb_models.forcefield.calculator import ORBCalculator
317
+
318
+ device_str = "cuda" if ml_device.type == "cuda" else "cpu"
319
+ orbff = getattr(pretrained, orb_model)(device=device_str)
320
+ self._ase_calc = ORBCalculator(orbff, device=device_str)
321
+ self._device = ml_device
322
+ self._model_charge = model_charge
323
+ self._model_mult = model_mult
324
+
325
+
326
+ class _MACEBackend(_ASEMLBackend):
327
+ """MACE ML backend."""
328
+
329
+ def __init__(
330
+ self,
331
+ *,
332
+ mace_model: str = "MACE-OMOL-0",
333
+ mace_dtype: str = "float64",
334
+ model_charge: int = 0,
335
+ model_mult: int = 1,
336
+ ml_device: torch.device,
337
+ ):
338
+ if not HAS_MACE:
339
+ raise ImportError(
340
+ "mace-torch is required for the MACE backend. "
341
+ "Install with `pip install mace-torch`."
342
+ )
343
+ from mace.calculators import mace_off, mace_mp, mace_anicc
344
+
345
+ device_str = "cuda" if ml_device.type == "cuda" else "cpu"
346
+ model_lower = mace_model.lower()
347
+
348
+ # Resolve model name to the appropriate factory
349
+ if model_lower.startswith("mp:") or model_lower.startswith("mace-mp"):
350
+ model_name = mace_model.split(":", 1)[-1] if ":" in mace_model else mace_model
351
+ self._ase_calc = mace_mp(
352
+ model=model_name, device=device_str, default_dtype=mace_dtype
353
+ )
354
+ elif model_lower.startswith("off:") or model_lower.startswith("mace-off"):
355
+ model_name = mace_model.split(":", 1)[-1] if ":" in mace_model else mace_model
356
+ self._ase_calc = mace_off(
357
+ model=model_name, device=device_str, default_dtype=mace_dtype
358
+ )
359
+ elif model_lower.startswith("anicc") or model_lower.startswith("mace-anicc"):
360
+ self._ase_calc = mace_anicc(device=device_str, default_dtype=mace_dtype)
361
+ elif model_lower.startswith("omol") or model_lower.startswith("mace-omol"):
362
+ # MACE-OMOL uses mace_off with the omol model
363
+ self._ase_calc = mace_off(
364
+ model=mace_model, device=device_str, default_dtype=mace_dtype
365
+ )
366
+ else:
367
+ # Treat as a local model file or direct mace_off model
368
+ self._ase_calc = mace_off(
369
+ model=mace_model, device=device_str, default_dtype=mace_dtype
370
+ )
371
+
372
+ self._device = ml_device
373
+ self._model_charge = model_charge
374
+ self._model_mult = model_mult
375
+
376
+
377
+ class _AIMNet2Backend(_ASEMLBackend):
378
+ """AIMNet2 ML backend."""
379
+
380
+ def __init__(
381
+ self,
382
+ *,
383
+ aimnet2_model: str = "aimnet2",
384
+ model_charge: int = 0,
385
+ model_mult: int = 1,
386
+ ml_device: torch.device,
387
+ ):
388
+ if not HAS_AIMNET2:
389
+ raise ImportError(
390
+ "aimnet is required for the AIMNet2 backend. "
391
+ "Install with `pip install aimnet`."
392
+ )
393
+ from aimnet.calculators import AIMNet2Calculator
394
+
395
+ device_str = "cuda" if ml_device.type == "cuda" else "cpu"
396
+ self._ase_calc = AIMNet2Calculator(model=aimnet2_model, device=device_str)
397
+ self._device = ml_device
398
+ self._model_charge = model_charge
399
+ self._model_mult = model_mult
400
+
401
+
402
+ def _create_ml_backend(
403
+ backend: str,
404
+ *,
405
+ uma_model: str = "uma-s-1p1",
406
+ uma_task_name: str = "omol",
407
+ orb_model: str = "orb_v3_conservative_omol",
408
+ mace_model: str = "MACE-OMOL-0",
409
+ mace_dtype: str = "float64",
410
+ aimnet2_model: str = "aimnet2",
411
+ model_charge: int = 0,
412
+ model_mult: int = 1,
413
+ ml_device: torch.device,
414
+ ) -> _MLBackend:
415
+ """Factory function to create the appropriate ML backend."""
416
+ backend = backend.strip().lower()
417
+ if backend == "uma":
418
+ return _UMABackend(
419
+ uma_model=uma_model,
420
+ uma_task_name=uma_task_name,
421
+ model_charge=model_charge,
422
+ model_mult=model_mult,
423
+ ml_device=ml_device,
424
+ )
425
+ elif backend == "orb":
426
+ return _OrbBackend(
427
+ orb_model=orb_model,
428
+ model_charge=model_charge,
429
+ model_mult=model_mult,
430
+ ml_device=ml_device,
431
+ )
432
+ elif backend == "mace":
433
+ return _MACEBackend(
434
+ mace_model=mace_model,
435
+ mace_dtype=mace_dtype,
436
+ model_charge=model_charge,
437
+ model_mult=model_mult,
438
+ ml_device=ml_device,
439
+ )
440
+ elif backend == "aimnet2":
441
+ return _AIMNet2Backend(
442
+ aimnet2_model=aimnet2_model,
443
+ model_charge=model_charge,
444
+ model_mult=model_mult,
445
+ ml_device=ml_device,
446
+ )
447
+ else:
448
+ raise ValueError(
449
+ f"Unknown ML backend '{backend}'. Choose from: uma, orb, mace, aimnet2."
450
+ )
451
+
452
+
453
+ # ======================================================================
454
+ # xTB Point-Charge Embedding Correction
455
+ # ======================================================================
456
+
457
+
458
+ class _EmbedChargeCorrection:
459
+ """xTB-based point-charge embedding correction for ONIOM ML/MM.
460
+
461
+ Computes the electrostatic interaction between the ML region and
462
+ the MM point charges via xTB:
463
+
464
+ dE = E_xTB(ML + MM_charges) - E_xTB(ML_only)
465
+ dF = F_xTB(ML + MM_charges) - F_xTB(ML_only)
466
+
467
+ This accounts for the environmental electrostatic effect of MM
468
+ atoms on the ML region, which is not captured by the subtractive
469
+ ONIOM scheme alone.
470
+ """
471
+
472
+ def __init__(
473
+ self,
474
+ *,
475
+ xtb_cmd: str = "xtb",
476
+ xtb_acc: float = 0.2,
477
+ xtb_workdir: str = "tmp",
478
+ xtb_keep_files: bool = False,
479
+ xtb_ncores: int = 4,
480
+ hessian_step: float = 1.0e-3,
481
+ ):
482
+ self.xtb_cmd = xtb_cmd
483
+ self.xtb_acc = xtb_acc
484
+ self.xtb_workdir = xtb_workdir
485
+ self.xtb_keep_files = xtb_keep_files
486
+ self.xtb_ncores = xtb_ncores
487
+ self.hessian_step = hessian_step
488
+
489
+ def compute_correction(
490
+ self,
491
+ symbols: List[str],
492
+ coords_ml_ang: np.ndarray,
493
+ mm_coords_ang: np.ndarray,
494
+ mm_charges: np.ndarray,
495
+ charge: int,
496
+ multiplicity: int,
497
+ *,
498
+ need_forces: bool = False,
499
+ need_hessian: bool = False,
500
+ ) -> Tuple[float, Optional[np.ndarray], Optional[np.ndarray]]:
501
+ """Compute point-charge embedding correction.
502
+
503
+ Parameters
504
+ ----------
505
+ symbols : list of str
506
+ Element symbols for ML atoms.
507
+ coords_ml_ang : ndarray (N_ML, 3)
508
+ Coordinates of ML atoms in Angstrom.
509
+ mm_coords_ang : ndarray (N_MM, 3)
510
+ Coordinates of MM point charges in Angstrom.
511
+ mm_charges : ndarray (N_MM,)
512
+ Charges of MM point charges in atomic units.
513
+ charge : int
514
+ Total charge of the ML region.
515
+ multiplicity : int
516
+ Spin multiplicity of the ML region.
517
+ need_forces : bool
518
+ Whether to compute force corrections.
519
+ need_hessian : bool
520
+ Whether to compute Hessian corrections.
521
+
522
+ Returns
523
+ -------
524
+ dE : float
525
+ Energy correction in eV.
526
+ dF_ml : ndarray (N_ML, 3) or None
527
+ Force corrections for ML atoms in eV/Å.
528
+ dH_ml : ndarray (3*N_ML, 3*N_ML) or None
529
+ Hessian correction for ML atoms in eV/Ų.
530
+ """
531
+ from .xtb_embedcharge_correction import delta_embedcharge_minus_noembed
532
+
533
+ n_ml = len(symbols)
534
+ mm_coords = np.asarray(mm_coords_ang, dtype=np.float64).reshape(-1, 3)
535
+ mm_q = np.asarray(mm_charges, dtype=np.float64).reshape(-1)
536
+ n_mm = mm_q.shape[0]
537
+
538
+ if n_mm == 0:
539
+ dF = np.zeros((n_ml, 3), dtype=np.float64) if need_forces else None
540
+ dH = np.zeros((3 * n_ml, 3 * n_ml), dtype=np.float64) if need_hessian else None
541
+ return 0.0, dF, dH
542
+
543
+ dE_ev, dF_full_ev, dH_full_ev = delta_embedcharge_minus_noembed(
544
+ symbols=symbols,
545
+ coords_q_ang=np.asarray(coords_ml_ang, dtype=np.float64).reshape(-1, 3),
546
+ mm_coords_ang=mm_coords,
547
+ mm_charges=mm_q,
548
+ charge=charge,
549
+ multiplicity=multiplicity,
550
+ need_forces=need_forces or need_hessian,
551
+ need_hessian=need_hessian,
552
+ xtb_cmd=self.xtb_cmd,
553
+ xtb_acc=self.xtb_acc,
554
+ xtb_workdir=self.xtb_workdir,
555
+ xtb_keep_files=self.xtb_keep_files,
556
+ ncores=self.xtb_ncores,
557
+ hessian_step=self.hessian_step,
558
+ )
559
+
560
+ dF_ml = None
561
+ if dF_full_ev is not None:
562
+ # Extract only the ML-atom forces (first n_ml rows)
563
+ dF_ml = np.asarray(dF_full_ev, dtype=np.float64).reshape(-1, 3)[:n_ml]
564
+
565
+ dH_ml = None
566
+ if dH_full_ev is not None:
567
+ # Extract only the ML-atom Hessian block
568
+ dof_ml = 3 * n_ml
569
+ dH_full = np.asarray(dH_full_ev, dtype=np.float64)
570
+ dH_ml = dH_full[:dof_ml, :dof_ml]
571
+
572
+ return float(dE_ev), dF_ml, dH_ml
573
+
574
+
575
+ # ======================================================================
576
+ # Utilities
577
+ # ======================================================================
578
+
579
+ def _fixed_indices_from_constraints(atoms: Atoms) -> set[int]:
580
+ fixed: set[int] = set()
581
+ for c in atoms.constraints or []:
582
+ if isinstance(c, FixAtoms):
583
+ fixed.update(int(i) for i in c.get_indices())
584
+ return fixed
585
+
586
+
587
+ def _normalize_prmtop_lj_tables(parm7_path: str) -> None:
588
+ """Normalize LJ table lengths in parm7 files generated from sliced structures.
589
+
590
+ ParmEd slicing can leave ``LENNARD_JONES_*COEF`` longer than the ``POINTERS``
591
+ ``NTYPES`` expectation. Trim only the trailing unused tail when detected.
592
+ """
593
+ from parmed.amber import AmberFormat, AmberParm
594
+
595
+ try:
596
+ AmberParm(parm7_path)
597
+ return
598
+ except Exception as exc:
599
+ msg = str(exc)
600
+ if (
601
+ "FLAG LENNARD_JONES_ACOEF" not in msg
602
+ and "FLAG LENNARD_JONES_BCOEF" not in msg
603
+ ):
604
+ raise
605
+
606
+ af = AmberFormat(parm7_path)
607
+ pointers = list(af.parm_data.get("POINTERS", []))
608
+ if len(pointers) < 2:
609
+ raise ValueError(f"Invalid POINTERS section in parm7: {parm7_path}")
610
+ ntypes = int(pointers[1])
611
+ expected = ntypes * (ntypes + 1) // 2
612
+
613
+ changed = False
614
+ for key in ("LENNARD_JONES_ACOEF", "LENNARD_JONES_BCOEF"):
615
+ values = list(af.parm_data.get(key, []))
616
+ if len(values) == expected:
617
+ continue
618
+ if len(values) < expected:
619
+ raise ValueError(
620
+ f"{key} has {len(values)} entries but expected at least {expected} "
621
+ f"from NTYPES={ntypes} in {parm7_path}."
622
+ )
623
+ af.parm_data[key] = values[:expected]
624
+ changed = True
625
+
626
+ if changed:
627
+ af.write_parm(parm7_path)
628
+
629
+ # Validate normalized topology immediately.
630
+ AmberParm(parm7_path)
631
+
632
+
633
+ # ======================================================================
634
+ # hessian_ff (MM) -> ASE calculator
635
+ # ======================================================================
636
+
637
+ def _expand_partial_hessian(
638
+ h_sub: np.ndarray,
639
+ active_atoms: np.ndarray,
640
+ n_atoms: int,
641
+ *,
642
+ dtype: np.dtype,
643
+ ) -> np.ndarray:
644
+ h_full = np.zeros((3 * n_atoms, 3 * n_atoms), dtype=dtype)
645
+ for i_local, i_atom in enumerate(active_atoms):
646
+ i0 = 3 * int(i_atom)
647
+ for j_local, j_atom in enumerate(active_atoms):
648
+ j0 = 3 * int(j_atom)
649
+ h_full[i0:i0 + 3, j0:j0 + 3] = h_sub[
650
+ 3 * i_local:3 * i_local + 3,
651
+ 3 * j_local:3 * j_local + 3,
652
+ ]
653
+ return h_full
654
+
655
+
656
+ class hessianffCalculator(Calculator):
657
+ """Calculator for MM. hessian_ff-backed."""
658
+
659
+ implemented_properties = ["energy", "forces"]
660
+
661
+ def __init__(
662
+ self,
663
+ parm7: str,
664
+ rst7: Optional[str] = None,
665
+ *,
666
+ device: str = "auto",
667
+ cuda_idx: int = 0,
668
+ threads: int = 16,
669
+ **kwargs,
670
+ ):
671
+ super().__init__(**kwargs)
672
+
673
+ requested = str(device).lower()
674
+ if requested not in {"auto", "cpu"}:
675
+ raise ValueError(
676
+ "MM backend 'hessian_ff' is CPU-only. "
677
+ f"Got device={device!r}. Use mm_device='cpu' or 'auto'."
678
+ )
679
+
680
+ self.device = "cpu"
681
+ self.cuda_idx = int(cuda_idx)
682
+ self.threads = int(threads)
683
+ if self.threads > 0 and torch.get_num_threads() != self.threads:
684
+ torch.set_num_threads(self.threads)
685
+
686
+ self.system = load_system(parm7, device="cpu").to(dtype=torch.float64)
687
+ self.ff = ForceFieldTorch(self.system)
688
+ self.natom = int(self.system.natom)
689
+ self._coords_dtype = torch.float64
690
+ self._coords_device = torch.device("cpu")
691
+ self._coord_buf = torch.empty((self.natom, 3), dtype=self._coords_dtype, device=self._coords_device)
692
+
693
+ if rst7 is not None:
694
+ xyz = load_coords(rst7, natom=self.natom, device=self._coords_device, dtype=self._coords_dtype)
695
+ self._coord_buf.copy_(xyz)
696
+
697
+ def _positions_to_tensor(self, positions_ang: np.ndarray) -> torch.Tensor:
698
+ arr = np.asarray(positions_ang, dtype=np.float64)
699
+ if arr.shape != (self.natom, 3):
700
+ raise ValueError(
701
+ f"Coordinate shape mismatch for '{type(self).__name__}': "
702
+ f"got {arr.shape}, expected ({self.natom}, 3)."
703
+ )
704
+ self._coord_buf.copy_(torch.as_tensor(arr, dtype=self._coords_dtype, device=self._coords_device))
705
+ return self._coord_buf
706
+
707
+ def _energy_forces_from_positions(self, positions_ang: np.ndarray) -> Tuple[float, np.ndarray]:
708
+ xyz = self._positions_to_tensor(positions_ang)
709
+ out, force = self.ff.energy_force(xyz, force_calc_mode="Analytical")
710
+ energy_ev = float(out["E_total"].detach().cpu()) * KCALMOL2EV
711
+ forces_ev = force.detach().cpu().numpy().astype(np.float64, copy=False) * KCALMOL2EV
712
+ return energy_ev, forces_ev
713
+
714
+ def calculate(self, atoms: Atoms = None, properties=None, system_changes=all_changes):
715
+ super().calculate(atoms, properties, system_changes)
716
+ if atoms is None:
717
+ raise ValueError("ASE Atoms is required for MM evaluation.")
718
+ energy_ev, forces_ev = self._energy_forces_from_positions(atoms.get_positions())
719
+ self.results = {"energy": energy_ev, "forces": forces_ev}
720
+
721
+ def analytical_hessian(
722
+ self,
723
+ atoms: Atoms,
724
+ *,
725
+ info_path: Optional[str] = None,
726
+ dtype: np.dtype = np.float64,
727
+ return_partial_hessian: bool = False,
728
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
729
+ fixed = _fixed_indices_from_constraints(atoms)
730
+ active_atoms = np.asarray([i for i in range(len(atoms)) if i not in fixed], dtype=int)
731
+
732
+ if active_atoms.size == 0:
733
+ if return_partial_hessian:
734
+ return np.zeros((0, 0), dtype=dtype), active_atoms
735
+ return np.zeros((3 * len(atoms), 3 * len(atoms)), dtype=dtype), None
736
+
737
+ if info_path is not None:
738
+ dir_ = os.path.dirname(info_path)
739
+ if dir_:
740
+ os.makedirs(dir_, exist_ok=True)
741
+ with open(info_path, "w", encoding="utf-8") as log:
742
+ log.write("Analytical Hessian (hessian_ff)\n")
743
+ log.write("--------------------------------\n")
744
+ log.write(f"n_active_atoms = {active_atoms.size}\n")
745
+ log.flush()
746
+
747
+ xyz = self._positions_to_tensor(atoms.get_positions())
748
+ h_local, _ = build_analytical_hessian(
749
+ system=self.system,
750
+ coords=xyz,
751
+ active_atoms=active_atoms.tolist(),
752
+ )
753
+ h_sub = h_local.detach().cpu().numpy().astype(np.float64, copy=False) * KCALMOL2EV
754
+ h_sub = np.asarray(h_sub, dtype=dtype)
755
+
756
+ if return_partial_hessian:
757
+ return h_sub, active_atoms
758
+
759
+ h_full = _expand_partial_hessian(h_sub, active_atoms, len(atoms), dtype=dtype)
760
+ return h_full, None
761
+
762
+ def finite_difference_hessian(
763
+ self,
764
+ atoms: Atoms,
765
+ *,
766
+ delta: float = 1e-3,
767
+ info_path: Optional[str] = None,
768
+ dtype: np.dtype = np.float64,
769
+ return_partial_hessian: bool = False,
770
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
771
+ _ = float(delta) # Kept for backward-compatible signature.
772
+ return self.analytical_hessian(
773
+ atoms,
774
+ info_path=info_path,
775
+ dtype=dtype,
776
+ return_partial_hessian=return_partial_hessian,
777
+ )
778
+
779
+
780
+
781
+ # ======================================================================
782
+ # OpenMM Calculator
783
+ # ======================================================================
784
+ class OpenMMCalculator(Calculator):
785
+ """
786
+ ASE Calculator wrapper for OpenMM backend (finite-difference Hessian).
787
+
788
+ This calculator uses OpenMM for MM force field evaluation and supports
789
+ CUDA/CPU platforms. Unlike hessianffCalculator, it computes Hessians
790
+ via numerical finite differences.
791
+
792
+ Parameters
793
+ ----------
794
+ parm7 : str
795
+ Path to Amber parm7 topology file.
796
+ rst7 : str
797
+ Path to Amber rst7 coordinate file.
798
+ device : str, default "auto"
799
+ Platform selection: "auto", "cuda", or "cpu".
800
+ cuda_idx : int, default 0
801
+ CUDA device index when device="cuda".
802
+ threads : int, default 16
803
+ Number of CPU threads when device="cpu".
804
+ """
805
+
806
+ implemented_properties = ["energy", "forces"]
807
+
808
+ def __init__(
809
+ self,
810
+ parm7: str,
811
+ rst7: str,
812
+ *,
813
+ device: str = "auto",
814
+ cuda_idx: int = 0,
815
+ threads: int = 16,
816
+ **kwargs,
817
+ ):
818
+ super().__init__(**kwargs)
819
+
820
+ if not HAS_OPENMM:
821
+ raise ImportError(
822
+ "OpenMM is required for OpenMMCalculator. "
823
+ "Install with: conda install -c conda-forge openmm"
824
+ )
825
+
826
+ # Auto-detect device
827
+ if device == "auto":
828
+ device = "cuda" if torch.cuda.is_available() else "cpu"
829
+
830
+ # Platform selection
831
+ if device == "cuda":
832
+ platform = Platform.getPlatformByName("CUDA")
833
+ properties = {
834
+ "CudaDeviceIndex": str(cuda_idx),
835
+ "CudaPrecision": "double",
836
+ "DeterministicForces": "true",
837
+ "CudaUseBlockingSync": "true",
838
+ }
839
+ else:
840
+ platform = Platform.getPlatformByName("CPU")
841
+ properties = {"Threads": str(threads)}
842
+
843
+ # Load Amber topology and coordinates
844
+ self.prmtop = app.AmberPrmtopFile(parm7)
845
+ inpcrd = app.AmberInpcrdFile(rst7)
846
+
847
+ # Create OpenMM system and context
848
+ self.system = self.prmtop.createSystem(
849
+ nonbondedMethod=app.NoCutoff,
850
+ rigidWater=False
851
+ )
852
+ self.integrator = mm.VerletIntegrator(0 * unit.femtoseconds)
853
+ self.context = mm.Context(self.system, self.integrator, platform, properties)
854
+ self.context.setPositions(inpcrd.positions)
855
+
856
+ def calculate(self, atoms: Atoms = None, properties=None, system_changes=all_changes):
857
+ """Compute energy and forces for the given atoms."""
858
+ super().calculate(atoms, properties, system_changes)
859
+
860
+ # Define eV unit for OpenMM
861
+ ev_base_unit = ScaledUnit(1.602176634e-19, joule, "electron volt", "eV")
862
+ eV = unit.Unit({ev_base_unit: 1.0})
863
+
864
+ # Update positions and get state
865
+ self.context.setPositions(atoms.get_positions() * unit.angstrom)
866
+ state = self.context.getState(getEnergy=True, getForces=True)
867
+
868
+ # Extract energy and forces in eV units
869
+ energy = state.getPotentialEnergy().value_in_unit(eV / unit.item)
870
+ forces = state.getForces(asNumpy=True).value_in_unit(eV / unit.angstrom / unit.item)
871
+
872
+ self.results = {"energy": energy, "forces": forces}
873
+
874
+ def finite_difference_hessian(
875
+ self,
876
+ atoms: Atoms,
877
+ *,
878
+ delta: float = 0.01,
879
+ info_path: Optional[str] = None,
880
+ dtype: np.dtype = np.float64,
881
+ return_partial_hessian: bool = False,
882
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
883
+ """
884
+ Compute Hessian via finite differences using hessian_calc utility.
885
+
886
+ Parameters
887
+ ----------
888
+ atoms : Atoms
889
+ Structure to differentiate.
890
+ delta : float, default 0.01
891
+ Displacement size in Angstrom.
892
+ info_path : str | None
893
+ Progress log file path.
894
+ dtype : numpy dtype, default float64
895
+ Data type for the Hessian matrix.
896
+ return_partial_hessian : bool, default False
897
+ If True, return only the active sub-Hessian and active atom indices.
898
+
899
+ Returns
900
+ -------
901
+ H_full : ndarray
902
+ Full (3N, 3N) Hessian matrix in eV/Ų.
903
+ active_atoms : ndarray | None
904
+ Active atom indices (only if return_partial_hessian=True).
905
+ """
906
+ from .hessian_calc import hessian_calc
907
+
908
+ H_full = hessian_calc(atoms, self, delta=delta, info_path=info_path, dtype=dtype)
909
+
910
+ if return_partial_hessian:
911
+ fixed = _fixed_indices_from_constraints(atoms)
912
+ active_atoms = np.asarray([i for i in range(len(atoms)) if i not in fixed])
913
+ # Extract active sub-Hessian to match hessian_ff convention
914
+ idx3 = np.concatenate([3 * active_atoms + d for d in range(3)])
915
+ idx3.sort()
916
+ H_sub = H_full[np.ix_(idx3, idx3)]
917
+ return H_sub, active_atoms
918
+
919
+ return H_full, None
920
+
921
+
922
+ # ======================================================================
923
+ # ML/MM Core (Multi-Backend)
924
+ # ======================================================================
925
+
926
+ @dataclass(frozen=True)
927
+ class _MLHighOut:
928
+ E: float
929
+ F: np.ndarray
930
+ H: Optional[torch.Tensor]
931
+ timing: Dict[str, float | str]
932
+
933
+
934
+ @dataclass(frozen=True)
935
+ class _MMLowOut:
936
+ E_real: float
937
+ F_real: np.ndarray
938
+ E_model: float
939
+ F_model: np.ndarray
940
+ H_real: Optional[np.ndarray]
941
+ H_model: Optional[np.ndarray]
942
+ active_atoms_from_fd: Optional[np.ndarray]
943
+ timing: Dict[str, float | str]
944
+
945
+
946
+ class MLMMCore:
947
+ """ONIOM-like ML/MM engine supporting multiple MLIP backends.
948
+
949
+ Supported ML backends: UMA (default), ORB, MACE, AIMNet2.
950
+ Supported MM backends: hessian_ff (analytical), OpenMM (FD).
951
+ Optional xTB point-charge embedding correction for environmental effects.
952
+ """
953
+
954
+ def __init__(
955
+ self,
956
+ *,
957
+ input_pdb: str = None,
958
+ real_parm7: str = None,
959
+ model_pdb: str = None,
960
+ model_charge: Optional[int] = 0,
961
+ model_mult: int = 1,
962
+ link_mlmm: List[Tuple[str, str]] | None = None,
963
+ # ML backend selection
964
+ backend: str = "uma",
965
+ uma_model: str = "uma-s-1p1",
966
+ uma_task_name: str = "omol",
967
+ orb_model: str = "orb_v3_conservative_omol",
968
+ orb_precision: str = "float32",
969
+ mace_model: str = "MACE-OMOL-0",
970
+ mace_dtype: str = "float64",
971
+ aimnet2_model: str = "aimnet2",
972
+ # MM settings
973
+ mm_fd: bool = True,
974
+ mm_fd_dir: Optional[str] = None,
975
+ mm_fd_delta: float = 1e-3,
976
+ symmetrize_hessian: bool = True,
977
+ print_timing: bool = True,
978
+ print_vram: bool = True,
979
+ H_double: bool = False,
980
+ ml_device: str = "auto",
981
+ ml_cuda_idx: int = 0,
982
+ mm_backend: str = "hessian_ff",
983
+ mm_device: str = "cpu",
984
+ mm_cuda_idx: int = 0,
985
+ mm_threads: int = 16,
986
+ freeze_atoms: List[int] | None = None,
987
+ ml_hessian_mode: str = "FiniteDifference",
988
+ hessian_calc_mode: Optional[str] = None,
989
+ return_partial_hessian: bool = True,
990
+ hess_cutoff: Optional[float] = None,
991
+ movable_cutoff: Optional[float] = None,
992
+ use_bfactor_layers: bool = False,
993
+ hess_mm_atoms: Optional[List[int]] = None,
994
+ movable_mm_atoms: Optional[List[int]] = None,
995
+ frozen_mm_atoms: Optional[List[int]] = None,
996
+ # Point-charge embedding correction
997
+ embedcharge: bool = False,
998
+ embedcharge_step: float = 1.0e-3,
999
+ embedcharge_cutoff: Optional[float] = None,
1000
+ xtb_cmd: str = "xtb",
1001
+ xtb_acc: float = 0.2,
1002
+ xtb_workdir: str = "tmp",
1003
+ xtb_keep_files: bool = False,
1004
+ xtb_ncores: int = 4,
1005
+ **kwargs,
1006
+ ):
1007
+ # --- v0.1.x backward compatibility aliases ---
1008
+ if "real_pdb" in kwargs:
1009
+ warnings.warn("'real_pdb' is deprecated; use 'input_pdb'.", DeprecationWarning, stacklevel=2)
1010
+ if input_pdb is None:
1011
+ input_pdb = kwargs.pop("real_pdb")
1012
+ else:
1013
+ kwargs.pop("real_pdb")
1014
+ for _old_name in ("real_rst7", "vib_run", "vib_dir"):
1015
+ if _old_name in kwargs:
1016
+ warnings.warn(f"'{_old_name}' is no longer used and will be ignored.", DeprecationWarning, stacklevel=2)
1017
+ kwargs.pop(_old_name)
1018
+ if kwargs:
1019
+ raise TypeError(f"MLMMCore.__init__() got unexpected keyword arguments: {', '.join(kwargs)}")
1020
+ if input_pdb is None:
1021
+ raise TypeError("MLMMCore.__init__() missing required keyword argument: 'input_pdb'")
1022
+
1023
+ self._tmpdir_obj = tempfile.TemporaryDirectory()
1024
+ self.tmpdir: str = self._tmpdir_obj.name
1025
+ for src, dst in [(input_pdb, "input.pdb"), (real_parm7, "real.parm7"), (model_pdb, "model.pdb")]:
1026
+ shutil.copy(src, os.path.join(self.tmpdir, dst))
1027
+
1028
+ self.input_pdb = os.path.join(self.tmpdir, "input.pdb")
1029
+ self.real_parm7 = os.path.join(self.tmpdir, "real.parm7")
1030
+ self.real_rst7 = os.path.join(self.tmpdir, "real.rst7")
1031
+ self.model_pdb = os.path.join(self.tmpdir, "model.pdb")
1032
+ self.model_parm7 = os.path.join(self.tmpdir, "model.parm7")
1033
+ self.model_rst7 = os.path.join(self.tmpdir, "model.rst7")
1034
+
1035
+ real_top = pmd.load_file(self.real_parm7)
1036
+ start_struct = pmd.load_file(self.input_pdb)
1037
+ real_n_atoms = int(len(real_top.atoms))
1038
+ start_n_atoms = int(len(start_struct.atoms))
1039
+ if start_n_atoms != real_n_atoms:
1040
+ raise ValueError(
1041
+ "Atom-count mismatch between input structure and real topology: "
1042
+ f"input_pdb='{input_pdb}' has {start_n_atoms} atoms, "
1043
+ f"real_parm7='{real_parm7}' expects {real_n_atoms} atoms. "
1044
+ "Provide a full-system input structure consistent with the parm7."
1045
+ )
1046
+ real_top.coordinates = start_struct.coordinates
1047
+ real_top.box = None
1048
+ real_top.save(self.real_parm7, overwrite=True)
1049
+ real_top.save(self.real_rst7, overwrite=True)
1050
+
1051
+ self.link_mlmm = link_mlmm
1052
+ self.ml_ID, self.mlmm_links = self._ml_prep()
1053
+ self.selection_indices = self._mk_model_parm7()
1054
+
1055
+ self.hess_cutoff = hess_cutoff
1056
+ self.movable_cutoff = movable_cutoff
1057
+ self.use_bfactor_layers = use_bfactor_layers
1058
+ self._original_input_pdb = input_pdb
1059
+ self._explicit_hess_mm_atoms = hess_mm_atoms
1060
+ self._explicit_movable_mm_atoms = movable_mm_atoms
1061
+ self._explicit_frozen_mm_atoms = frozen_mm_atoms
1062
+ self._compute_layer_indices(real_top.coordinates)
1063
+
1064
+ self.freeze_atoms = [] if freeze_atoms is None else list(freeze_atoms)
1065
+ if self.frozen_layer_indices:
1066
+ self.freeze_atoms = sorted(set(self.freeze_atoms) | set(self.frozen_layer_indices))
1067
+
1068
+ hess_set = set(self.hess_indices)
1069
+ all_atoms = set(range(len(real_top.atoms)))
1070
+ self.hess_freeze_atoms = sorted(all_atoms - hess_set)
1071
+
1072
+ self.return_partial_hessian = bool(return_partial_hessian)
1073
+
1074
+ self._n_real = len(real_top.atoms)
1075
+ self._idx_map_real_to_model = {idx: pos for pos, idx in enumerate(self.selection_indices)}
1076
+ self._update_active_dof_mappings()
1077
+
1078
+ self.H_double = bool(H_double)
1079
+ self.H_dtype = torch.float64 if self.H_double else torch.float32
1080
+ self.H_np_dtype = np.float64 if self.H_double else np.float32
1081
+
1082
+ self.mm_fd = mm_fd
1083
+ self.mm_fd_dir = mm_fd_dir
1084
+ self.mm_fd_delta = mm_fd_delta
1085
+ self.symmetrize_hessian = symmetrize_hessian
1086
+ self.print_timing = bool(print_timing)
1087
+ self.print_vram = bool(print_vram)
1088
+ if self.mm_fd_dir and not os.path.exists(self.mm_fd_dir):
1089
+ os.makedirs(self.mm_fd_dir, exist_ok=True)
1090
+
1091
+ if ml_device == "auto":
1092
+ ml_device = "cuda" if torch.cuda.is_available() else "cpu"
1093
+ self.device_str = ml_device
1094
+ self.ml_device = torch.device(f"cuda:{ml_cuda_idx}" if ml_device == "cuda" else "cpu")
1095
+
1096
+ self.model_charge = int(0 if model_charge is None else model_charge)
1097
+ self.model_mult = int(model_mult)
1098
+ self.backend_name = str(backend).strip().lower() if backend is not None else "uma"
1099
+
1100
+ # Create ML backend via factory
1101
+ self._ml_backend = _create_ml_backend(
1102
+ self.backend_name,
1103
+ uma_model=uma_model,
1104
+ uma_task_name=uma_task_name,
1105
+ orb_model=orb_model,
1106
+ mace_model=mace_model,
1107
+ mace_dtype=mace_dtype,
1108
+ aimnet2_model=aimnet2_model,
1109
+ model_charge=self.model_charge,
1110
+ model_mult=self.model_mult,
1111
+ ml_device=self.ml_device,
1112
+ )
1113
+
1114
+ # Point-charge embedding correction
1115
+ self.embedcharge = bool(embedcharge)
1116
+ self.embedcharge_cutoff = embedcharge_cutoff
1117
+ self._embed_correction: Optional[_EmbedChargeCorrection] = None
1118
+ if self.embedcharge:
1119
+ self._embed_correction = _EmbedChargeCorrection(
1120
+ xtb_cmd=xtb_cmd,
1121
+ xtb_acc=xtb_acc,
1122
+ xtb_workdir=xtb_workdir,
1123
+ xtb_keep_files=xtb_keep_files,
1124
+ xtb_ncores=xtb_ncores,
1125
+ hessian_step=embedcharge_step,
1126
+ )
1127
+
1128
+ # MM backend selection: hessian_ff or openmm
1129
+ self.mm_backend = str(mm_backend).strip().lower()
1130
+ if self.mm_backend == "openmm":
1131
+ self.calc_real_low = OpenMMCalculator(
1132
+ parm7=self.real_parm7, rst7=self.real_rst7,
1133
+ device=mm_device, cuda_idx=mm_cuda_idx, threads=mm_threads
1134
+ )
1135
+ self.calc_model_low = OpenMMCalculator(
1136
+ parm7=self.model_parm7, rst7=self.model_rst7,
1137
+ device=mm_device, cuda_idx=mm_cuda_idx, threads=mm_threads
1138
+ )
1139
+ elif self.mm_backend == "hessian_ff":
1140
+ self.calc_real_low = hessianffCalculator(
1141
+ parm7=self.real_parm7, rst7=None,
1142
+ device=mm_device, cuda_idx=mm_cuda_idx, threads=mm_threads
1143
+ )
1144
+ self.calc_model_low = hessianffCalculator(
1145
+ parm7=self.model_parm7, rst7=None,
1146
+ device=mm_device, cuda_idx=mm_cuda_idx, threads=mm_threads
1147
+ )
1148
+ else:
1149
+ raise ValueError(
1150
+ f"Unknown mm_backend '{mm_backend}'. Choose 'hessian_ff' or 'openmm'."
1151
+ )
1152
+
1153
+ mode_in = hessian_calc_mode if hessian_calc_mode is not None else ml_hessian_mode
1154
+ mode = (mode_in or "Analytical").strip().lower()
1155
+ self._ml_hessian_mode = "analytical" if mode.startswith("analyt") else "fd"
1156
+
1157
+ self._atoms_real_tpl = read(self.input_pdb)
1158
+ self._atoms_model_tpl = read(self.model_pdb)
1159
+ tmp = self._atoms_model_tpl.copy()
1160
+ for _ in self.mlmm_links:
1161
+ tmp += Atoms("H", positions=[[0.0, 0.0, 0.0]])
1162
+ self._atoms_model_LH_tpl = tmp
1163
+
1164
+ def cleanup(self):
1165
+ """Clean up temporary directory."""
1166
+ if hasattr(self, '_tmpdir_obj') and self._tmpdir_obj is not None:
1167
+ try:
1168
+ self._tmpdir_obj.cleanup()
1169
+ except Exception:
1170
+ logger.debug("Failed to clean up tmpdir", exc_info=True)
1171
+
1172
+ def __del__(self):
1173
+ self.cleanup()
1174
+
1175
+ @staticmethod
1176
+ def _pdb_atom_key(line: str) -> str:
1177
+ return f"{line[12:16].strip()} {line[17:20].strip()} {line[22:26].strip()}"
1178
+
1179
+ def _ml_prep(self) -> Tuple[List[str], List[Tuple[int, int]]]:
1180
+ ml_region = set()
1181
+ with open(self.model_pdb) as fh:
1182
+ for ln in fh:
1183
+ if ln.startswith(("ATOM", "HETATM")):
1184
+ ml_region.add(self._pdb_atom_key(ln))
1185
+
1186
+ leap_atoms: List[Dict] = []
1187
+ with open(self.input_pdb) as fh:
1188
+ for ln in fh:
1189
+ if not ln.startswith(("ATOM", "HETATM")):
1190
+ continue
1191
+ leap_atoms.append(
1192
+ {
1193
+ "idx": int(ln[6:11]),
1194
+ "id": self._pdb_atom_key(ln),
1195
+ "elem": ln[76:78].strip(),
1196
+ "coord": np.array([float(ln[30:38]), float(ln[38:46]), float(ln[46:54])]),
1197
+ }
1198
+ )
1199
+
1200
+ ml_ID = [str(a["idx"]) for a in leap_atoms if a["id"] in ml_region]
1201
+
1202
+ if self.link_mlmm:
1203
+ processed = [(" ".join(q.split()[:3]), " ".join(m.split()[:3])) for q, m in self.link_mlmm]
1204
+
1205
+ ml_indices: List[int] = []
1206
+ mm_indices: List[int] = []
1207
+ for a in leap_atoms:
1208
+ for qnm, mnm in processed:
1209
+ if a["id"] == qnm:
1210
+ ml_indices.append(a["idx"])
1211
+ elif a["id"] == mnm:
1212
+ mm_indices.append(a["idx"])
1213
+
1214
+ if len(set(ml_indices)) != len(ml_indices) or len(set(mm_indices)) != len(mm_indices):
1215
+ raise ValueError("Duplicated ML or MM indices in link specification.")
1216
+ mlmm_links = list(zip(ml_indices, mm_indices))
1217
+ else:
1218
+ threshold = 1.7
1219
+ ml_set = {a["idx"] for a in leap_atoms if a["id"] in ml_region}
1220
+ coords = {a["idx"]: a["coord"] for a in leap_atoms}
1221
+ elem = {a["idx"]: a["elem"] for a in leap_atoms}
1222
+
1223
+ ml_indices: List[int] = []
1224
+ mm_indices: List[int] = []
1225
+ for qidx in ml_set:
1226
+ for a in leap_atoms:
1227
+ midx = a["idx"]
1228
+ if midx in ml_set:
1229
+ continue
1230
+ if (
1231
+ np.linalg.norm(coords[midx] - coords[qidx]) < threshold
1232
+ and (
1233
+ (elem[midx] == "C" and elem[qidx] == "C")
1234
+ or (elem[midx] == "N" and elem[qidx] == "C")
1235
+ or (elem[midx] == "C" and elem[qidx] == "N")
1236
+ )
1237
+ ):
1238
+ ml_indices.append(qidx)
1239
+ mm_indices.append(midx)
1240
+
1241
+ if len(set(ml_indices)) != len(ml_indices) or len(set(mm_indices)) != len(mm_indices):
1242
+ raise ValueError(
1243
+ "Automatic link detection produced duplicate pairs. Specify 'link_mlmm' manually."
1244
+ )
1245
+ mlmm_links = list(zip(ml_indices, mm_indices))
1246
+
1247
+ return ml_ID, mlmm_links
1248
+
1249
+ def _mk_model_parm7(self) -> List[int]:
1250
+ real = pmd.load_file(self.real_parm7, self.real_rst7)
1251
+ real.box = None
1252
+ ml_atoms = [real.atoms[int(i) - 1] for i in self.ml_ID]
1253
+ selection = [a.idx for a in ml_atoms]
1254
+
1255
+ if len(selection) == len(real.atoms):
1256
+ shutil.copy(self.real_parm7, self.model_parm7)
1257
+ shutil.copy(self.real_rst7, self.model_rst7)
1258
+ return selection
1259
+
1260
+ model = real[selection]
1261
+ model.box = None
1262
+ model.save(self.model_parm7, overwrite=True)
1263
+ _normalize_prmtop_lj_tables(self.model_parm7)
1264
+ model.save(self.model_rst7, overwrite=True)
1265
+ return selection
1266
+
1267
+ def _compute_layer_indices(self, coords: np.ndarray) -> None:
1268
+ self.ml_indices = sorted(self.selection_indices)
1269
+
1270
+ n_atoms = int(coords.shape[0])
1271
+ all_indices = set(range(n_atoms))
1272
+ mm_indices = all_indices - set(self.ml_indices)
1273
+
1274
+ has_explicit = (
1275
+ self._explicit_hess_mm_atoms is not None
1276
+ or self._explicit_movable_mm_atoms is not None
1277
+ or self._explicit_frozen_mm_atoms is not None
1278
+ )
1279
+ if has_explicit:
1280
+ explicit_hess = set(self._explicit_hess_mm_atoms or [])
1281
+ explicit_movable = set(self._explicit_movable_mm_atoms or [])
1282
+ explicit_frozen = set(self._explicit_frozen_mm_atoms or [])
1283
+
1284
+ for idx_set, name in [
1285
+ (explicit_hess, "hess_mm_atoms"),
1286
+ (explicit_movable, "movable_mm_atoms"),
1287
+ (explicit_frozen, "frozen_mm_atoms"),
1288
+ ]:
1289
+ for idx in idx_set:
1290
+ if idx < 0 or idx >= n_atoms:
1291
+ raise ValueError(f"Invalid atom index {idx} in {name}: must be 0 <= idx < {n_atoms}")
1292
+ if idx in self.ml_indices:
1293
+ raise ValueError(f"Atom index {idx} in {name} is also in ML region (model_pdb)")
1294
+
1295
+ self.hess_mm_indices = sorted(explicit_hess & mm_indices)
1296
+ self.movable_mm_indices = sorted(explicit_movable & mm_indices)
1297
+ self.frozen_layer_indices = sorted(explicit_frozen & mm_indices)
1298
+
1299
+ assigned_mm = explicit_hess | explicit_movable | explicit_frozen
1300
+ unassigned_mm = mm_indices - assigned_mm
1301
+ self.movable_mm_indices = sorted(set(self.movable_mm_indices) | unassigned_mm)
1302
+
1303
+ self.hess_indices = sorted(self.ml_indices + self.hess_mm_indices)
1304
+ self.movable_indices = sorted(self.ml_indices + self.hess_mm_indices + self.movable_mm_indices)
1305
+ return
1306
+
1307
+ if self.use_bfactor_layers:
1308
+ from .utils import read_bfactors_from_pdb, parse_layer_indices_from_bfactors, has_valid_layer_bfactors
1309
+ from pathlib import Path
1310
+
1311
+ bfactors = read_bfactors_from_pdb(Path(self._original_input_pdb))
1312
+ if has_valid_layer_bfactors(bfactors):
1313
+ layer_info = parse_layer_indices_from_bfactors(bfactors)
1314
+
1315
+ movable_from_layer = set(layer_info["movable_mm_indices"]) & mm_indices
1316
+ frozen_from_layer = set(layer_info["frozen_indices"]) & mm_indices
1317
+ hess_from_layer = set(layer_info["hess_mm_indices"]) & mm_indices
1318
+
1319
+ # Unassigned MM atoms default to movable.
1320
+ assigned_mm = movable_from_layer | frozen_from_layer | hess_from_layer
1321
+ unassigned_mm = mm_indices - assigned_mm
1322
+ movable_pool = set(movable_from_layer) | set(unassigned_mm)
1323
+
1324
+ # Hessian-target MM selection:
1325
+ # 1) If hess_cutoff is set, use distance-to-ML over movable MM pool.
1326
+ # 2) Otherwise, keep any Layer-2 assignments (if present).
1327
+ hess_mm: set[int]
1328
+ if self.hess_cutoff is not None:
1329
+ ml_coords = coords[self.ml_indices]
1330
+
1331
+ def min_dist_to_ml(atom_idx: int) -> float:
1332
+ atom_coord = coords[atom_idx]
1333
+ dists = np.linalg.norm(ml_coords - atom_coord, axis=1)
1334
+ return float(np.min(dists))
1335
+
1336
+ hess_cut = float(self.hess_cutoff)
1337
+ hess_mm = {idx for idx in movable_pool if min_dist_to_ml(idx) <= hess_cut}
1338
+ else:
1339
+ hess_mm = set(hess_from_layer)
1340
+
1341
+ movable_mm = movable_pool - hess_mm
1342
+
1343
+ self.hess_mm_indices = sorted(hess_mm)
1344
+ self.movable_mm_indices = sorted(movable_mm)
1345
+ self.frozen_layer_indices = sorted(frozen_from_layer)
1346
+
1347
+ self.hess_indices = sorted(self.ml_indices + self.hess_mm_indices)
1348
+ self.movable_indices = sorted(self.ml_indices + self.hess_mm_indices + self.movable_mm_indices)
1349
+ return
1350
+
1351
+ if self.hess_cutoff is None and self.movable_cutoff is None:
1352
+ self.hess_mm_indices = sorted(mm_indices)
1353
+ self.movable_mm_indices = []
1354
+ self.frozen_layer_indices = []
1355
+ self.hess_indices = sorted(self.ml_indices + self.hess_mm_indices)
1356
+ self.movable_indices = sorted(self.ml_indices + self.hess_mm_indices)
1357
+ return
1358
+
1359
+ ml_coords = coords[self.ml_indices]
1360
+
1361
+ def min_dist_to_ml(atom_idx: int) -> float:
1362
+ atom_coord = coords[atom_idx]
1363
+ dists = np.linalg.norm(ml_coords - atom_coord, axis=1)
1364
+ return float(np.min(dists))
1365
+
1366
+ hess_mm: List[int] = []
1367
+ movable_mm: List[int] = []
1368
+ frozen_mm: List[int] = []
1369
+
1370
+ hess_cut = self.hess_cutoff if self.hess_cutoff is not None else float("inf")
1371
+ mov_cut = self.movable_cutoff if self.movable_cutoff is not None else float("inf")
1372
+
1373
+ for idx in mm_indices:
1374
+ d = min_dist_to_ml(idx)
1375
+ if d <= hess_cut:
1376
+ hess_mm.append(idx)
1377
+ elif d <= mov_cut:
1378
+ movable_mm.append(idx)
1379
+ else:
1380
+ frozen_mm.append(idx)
1381
+
1382
+ self.hess_mm_indices = sorted(hess_mm)
1383
+ self.movable_mm_indices = sorted(movable_mm)
1384
+ self.frozen_layer_indices = sorted(frozen_mm)
1385
+ self.hess_indices = sorted(self.ml_indices + self.hess_mm_indices)
1386
+ self.movable_indices = sorted(self.ml_indices + self.hess_mm_indices + self.movable_mm_indices)
1387
+
1388
+ def _update_active_dof_mappings(self) -> None:
1389
+ freeze_set = set(self.freeze_atoms)
1390
+ self.active_atoms_real = [i for i in range(self._n_real) if i not in freeze_set]
1391
+ self.n_active_real = len(self.active_atoms_real)
1392
+ self.full_to_active_real = {a: i for i, a in enumerate(self.active_atoms_real)}
1393
+ self.active_to_full_real = {i: a for i, a in enumerate(self.active_atoms_real)}
1394
+
1395
+ hess_freeze_set = set(self.hess_freeze_atoms)
1396
+ self.hess_active_atoms = [i for i in range(self._n_real) if i not in hess_freeze_set]
1397
+ self.n_hess_active = len(self.hess_active_atoms)
1398
+ self.full_to_hess_active = {a: i for i, a in enumerate(self.hess_active_atoms)}
1399
+ self.hess_active_to_full = {i: a for i, a in enumerate(self.hess_active_atoms)}
1400
+
1401
+ self.ml_hess_active_indices = [
1402
+ self.full_to_hess_active[i] for i in self.selection_indices if i in self.full_to_hess_active
1403
+ ]
1404
+
1405
+ self.freeze_model = [
1406
+ self._idx_map_real_to_model[i] for i in self.freeze_atoms if i in self._idx_map_real_to_model
1407
+ ]
1408
+
1409
+ def _build_within_partial_hessian(self) -> Dict[str, np.ndarray | int | str]:
1410
+ """Build metadata for a partial (Hessian-target-only) Hessian."""
1411
+ n_real = int(self._n_real)
1412
+ active_atoms = np.asarray(self.hess_active_atoms, dtype=int)
1413
+ active_n_atoms = int(active_atoms.size)
1414
+
1415
+ active_dofs = np.empty(active_n_atoms * 3, dtype=int)
1416
+ for i, a in enumerate(active_atoms):
1417
+ base = 3 * int(a)
1418
+ active_dofs[3 * i:3 * i + 3] = (base, base + 1, base + 2)
1419
+
1420
+ full_to_active = -np.ones(n_real, dtype=int)
1421
+ if active_n_atoms:
1422
+ full_to_active[active_atoms] = np.arange(active_n_atoms, dtype=int)
1423
+
1424
+ return {
1425
+ "kind": "hess-target-only",
1426
+ "active_atoms": active_atoms,
1427
+ "active_dofs": active_dofs,
1428
+ "active_to_full": active_atoms.copy(),
1429
+ "full_to_active": full_to_active,
1430
+ "full_n_atoms": n_real,
1431
+ "full_n_dof": int(3 * n_real),
1432
+ "active_n_atoms": active_n_atoms,
1433
+ "active_n_dof": int(3 * active_n_atoms),
1434
+ }
1435
+
1436
+ def _prep_3_layer_atoms(self, real_coord: np.ndarray):
1437
+ atoms_real = self._atoms_real_tpl.copy()
1438
+ atoms_real.set_positions(real_coord)
1439
+
1440
+ atoms_model = self._atoms_model_tpl.copy()
1441
+ atoms_model_LH = self._atoms_model_LH_tpl.copy()
1442
+
1443
+ for i, ridx in enumerate(self.ml_ID):
1444
+ pos = atoms_real[int(ridx) - 1].position
1445
+ atoms_model[i].position = pos
1446
+ atoms_model_LH[i].position = pos
1447
+
1448
+ added_link_atoms = []
1449
+ base_model_len = len(self._atoms_model_tpl)
1450
+ for k, (ml_idx, mm_idx) in enumerate(self.mlmm_links):
1451
+ ml_i = ml_idx - 1
1452
+ mm_i = mm_idx - 1
1453
+ ml_elem = atoms_real[ml_i].symbol
1454
+ if ml_elem == "C":
1455
+ dist = 1.09
1456
+ elif ml_elem == "N":
1457
+ dist = 1.01
1458
+ else:
1459
+ raise ValueError(
1460
+ f"Unsupported link parent element: {ml_elem}. Only C and N are supported."
1461
+ )
1462
+ vec = atoms_real[mm_i].position - atoms_real[ml_i].position
1463
+ R = np.linalg.norm(vec)
1464
+ if R < 1e-6:
1465
+ continue
1466
+ u = vec / R
1467
+ H_pos = atoms_real[ml_i].position + u * dist
1468
+ link_idx_in_model_LH = base_model_len + k
1469
+ atoms_model_LH[link_idx_in_model_LH].position = H_pos
1470
+ added_link_atoms.append((link_idx_in_model_LH, ml_i, mm_i, dist))
1471
+
1472
+ freeze_model: List[int] = []
1473
+ if self.freeze_atoms:
1474
+ atoms_real.set_constraint(FixAtoms(indices=self.freeze_atoms))
1475
+ real_to_model = self._idx_map_real_to_model
1476
+ freeze_model = [real_to_model[i] for i in self.freeze_atoms if i in real_to_model]
1477
+ if freeze_model:
1478
+ atoms_model.set_constraint(FixAtoms(indices=freeze_model))
1479
+ atoms_model_LH.set_constraint(FixAtoms(indices=freeze_model))
1480
+
1481
+ return atoms_real, atoms_model, atoms_model_LH, added_link_atoms, freeze_model
1482
+
1483
+ @staticmethod
1484
+ def _jacobian_blocks_numpy(r_ml: np.ndarray, r_mm: np.ndarray, dist: float) -> Optional[np.ndarray]:
1485
+ """Returns J shape (6, 3): rows=[Q_xyz, M_xyz], cols=L_xyz."""
1486
+ vec = r_mm - r_ml
1487
+ R = np.linalg.norm(vec)
1488
+ if R < 1e-12:
1489
+ return None
1490
+ u = vec / R
1491
+ I = np.eye(3)
1492
+ du_dQ = (I - np.outer(u, u)) / R
1493
+ dR_dQ = I - dist * du_dQ
1494
+ dR_dM = dist * du_dQ
1495
+ return np.hstack([dR_dQ, dR_dM]).T
1496
+
1497
+ @staticmethod
1498
+ def _jacobian_blocks_torch(
1499
+ r_ml: torch.Tensor,
1500
+ r_mm: torch.Tensor,
1501
+ dist: float,
1502
+ *,
1503
+ dtype: torch.dtype,
1504
+ device: torch.device,
1505
+ ) -> Optional[torch.Tensor]:
1506
+ """Returns K shape (3, 6): rows=L_xyz, cols=[Q_xyz, M_xyz]."""
1507
+ vec = r_mm - r_ml
1508
+ Rlen = torch.norm(vec)
1509
+ if float(Rlen) < 1e-12:
1510
+ return None
1511
+ u = vec / Rlen
1512
+ I3 = torch.eye(3, dtype=dtype, device=device)
1513
+ du_dQ = (I3 - torch.outer(u, u)) / Rlen
1514
+ dR_dQ = I3 - dist * du_dQ
1515
+ dR_dM = dist * du_dQ
1516
+ return torch.hstack([dR_dQ, dR_dM])
1517
+
1518
+ def _get_mm_charges(self, atom_indices: Sequence[int]) -> np.ndarray:
1519
+ """Retrieve MM partial charges for the given atom indices.
1520
+
1521
+ Works with both hessian_ff (AmberSystem) and OpenMM backends.
1522
+ """
1523
+ calc = self.calc_real_low
1524
+ # hessian_ff: AmberSystem with .charge tensor
1525
+ if isinstance(calc, hessianffCalculator) and hasattr(calc, "system"):
1526
+ return np.array(
1527
+ [calc.system.charge[i].item() for i in atom_indices],
1528
+ dtype=np.float64,
1529
+ )
1530
+ # OpenMM: extract charges from NonbondedForce
1531
+ if isinstance(calc, OpenMMCalculator) and HAS_OPENMM:
1532
+ sys_omm = calc.system
1533
+ for fi in range(sys_omm.getNumForces()):
1534
+ force = sys_omm.getForce(fi)
1535
+ if force.__class__.__name__ == "NonbondedForce":
1536
+ charges = np.array(
1537
+ [force.getParticleParameters(i)[0].value_in_unit(
1538
+ unit.elementary_charge)
1539
+ for i in atom_indices],
1540
+ dtype=np.float64,
1541
+ )
1542
+ return charges
1543
+ # Fallback: zero charges
1544
+ warnings.warn(
1545
+ "Could not extract MM charges from the calculator; returning zeros. "
1546
+ "Embedcharge correction will have no effect.",
1547
+ RuntimeWarning,
1548
+ )
1549
+ return np.zeros(len(atom_indices), dtype=np.float64)
1550
+
1551
+ def _eval_ml_high(self, atoms_model_LH: Atoms, freeze_model: Sequence[int], *, return_hessian: bool) -> _MLHighOut:
1552
+ local_timing: Dict[str, float | str] = {}
1553
+ E_model_high, F_model_high, opaque = self._ml_backend.eval(atoms_model_LH, need_grad=True)
1554
+ local_timing["ml_backend"] = self.backend_name
1555
+
1556
+ H_high = None
1557
+ if return_hessian:
1558
+ n_mlLH = len(atoms_model_LH)
1559
+ if self._ml_hessian_mode == "analytical" and self._ml_backend.supports_analytical_hessian:
1560
+ t0 = time.perf_counter()
1561
+ H_high = self._ml_backend.hessian_analytical(opaque, n_mlLH, dtype=self.H_dtype)
1562
+ local_timing["ml_hessian_mode"] = "Analytical"
1563
+ local_timing["ml_hessian_s"] = time.perf_counter() - t0
1564
+ else:
1565
+ t0 = time.perf_counter()
1566
+ H_high = self._ml_backend.hessian_fd(
1567
+ atoms_model_LH, freeze_model,
1568
+ eps_ang=1.0e-3, dtype=self.H_dtype, device=self.ml_device,
1569
+ )
1570
+ local_timing["ml_hessian_mode"] = "FiniteDifference"
1571
+ local_timing["ml_hessian_s"] = time.perf_counter() - t0
1572
+
1573
+ return _MLHighOut(E=E_model_high, F=F_model_high, H=H_high, timing=local_timing)
1574
+
1575
+ def _eval_mm_low(self, atoms_real: Atoms, atoms_model: Atoms, *, return_hessian: bool) -> _MMLowOut:
1576
+ local_timing: Dict[str, float | str] = {}
1577
+
1578
+ atoms_real.calc = self.calc_real_low
1579
+ atoms_model.calc = self.calc_model_low
1580
+
1581
+ E_real_low = atoms_real.get_potential_energy()
1582
+ F_real_low = np.double(atoms_real.get_forces())
1583
+
1584
+ E_model_low = atoms_model.get_potential_energy()
1585
+ F_model_low = np.double(atoms_model.get_forces())
1586
+
1587
+ H_real_np = None
1588
+ H_model_np = None
1589
+ active_atoms_from_fd = None
1590
+
1591
+ if return_hessian and self.mm_fd is True:
1592
+ info_real = os.path.join(self.mm_fd_dir, "real.log") if self.mm_fd_dir else None
1593
+ info_model = os.path.join(self.mm_fd_dir, "model.log") if self.mm_fd_dir else None
1594
+
1595
+ atoms_real_for_hess = atoms_real.copy()
1596
+ # Clear any inherited constraints before applying hess-specific ones
1597
+ atoms_real_for_hess.set_constraint()
1598
+ atoms_real_for_hess.calc = self.calc_real_low
1599
+ if self.hess_freeze_atoms:
1600
+ atoms_real_for_hess.set_constraint(FixAtoms(indices=self.hess_freeze_atoms))
1601
+
1602
+ t0 = time.perf_counter()
1603
+ H_real_np, active_atoms_from_fd = self.calc_real_low.finite_difference_hessian(
1604
+ atoms_real_for_hess,
1605
+ delta=self.mm_fd_delta,
1606
+ info_path=info_real,
1607
+ dtype=self.H_np_dtype,
1608
+ return_partial_hessian=True,
1609
+ )
1610
+ local_timing["mm_fd_real_s"] = time.perf_counter() - t0
1611
+
1612
+ t0 = time.perf_counter()
1613
+ H_model_np, _ = self.calc_model_low.finite_difference_hessian(
1614
+ atoms_model,
1615
+ delta=self.mm_fd_delta,
1616
+ info_path=info_model,
1617
+ dtype=self.H_np_dtype,
1618
+ return_partial_hessian=False,
1619
+ )
1620
+ local_timing["mm_fd_model_s"] = time.perf_counter() - t0
1621
+ local_timing["mm_fd_total_s"] = float(local_timing["mm_fd_real_s"]) + float(local_timing["mm_fd_model_s"])
1622
+
1623
+ return _MMLowOut(
1624
+ E_real=E_real_low,
1625
+ F_real=F_real_low,
1626
+ E_model=E_model_low,
1627
+ F_model=F_model_low,
1628
+ H_real=H_real_np,
1629
+ H_model=H_model_np,
1630
+ active_atoms_from_fd=active_atoms_from_fd,
1631
+ timing=local_timing,
1632
+ )
1633
+
1634
+ def compute(
1635
+ self,
1636
+ coord_ang: np.ndarray,
1637
+ *,
1638
+ return_forces: bool = False,
1639
+ return_hessian: bool = False,
1640
+ ) -> Dict:
1641
+ timing: Dict[str, float | str] = {}
1642
+ hess_total_start: Optional[float] = time.perf_counter() if return_hessian else None
1643
+ hess_vram_base_alloc: Optional[float] = None
1644
+ hess_vram_base_reserved: Optional[float] = None
1645
+ hess_vram_total: Optional[float] = None
1646
+ if return_hessian and self.print_vram and self.ml_device.type == "cuda":
1647
+ torch.cuda.synchronize(device=self.ml_device)
1648
+ hess_vram_base_alloc = float(torch.cuda.memory_allocated(device=self.ml_device))
1649
+ hess_vram_base_reserved = float(torch.cuda.memory_reserved(device=self.ml_device))
1650
+ hess_vram_total = float(torch.cuda.get_device_properties(self.ml_device).total_memory)
1651
+ torch.cuda.reset_peak_memory_stats(device=self.ml_device)
1652
+
1653
+ atoms_real, atoms_model, atoms_model_LH, added_link_atoms, freeze_model = self._prep_3_layer_atoms(coord_ang)
1654
+ atoms_real.set_pbc(False)
1655
+ atoms_model.set_pbc(False)
1656
+ atoms_model_LH.set_pbc(False)
1657
+
1658
+ use_parallel = (self.ml_device.type == "cuda") and (getattr(self.calc_real_low, "device", None) == "cpu")
1659
+ if use_parallel:
1660
+ with ThreadPoolExecutor(max_workers=2) as executor:
1661
+ fut_ml = executor.submit(self._eval_ml_high, atoms_model_LH, freeze_model, return_hessian=return_hessian)
1662
+ fut_mm = executor.submit(self._eval_mm_low, atoms_real, atoms_model, return_hessian=return_hessian)
1663
+ ml_out = fut_ml.result()
1664
+ mm_out = fut_mm.result()
1665
+ else:
1666
+ ml_out = self._eval_ml_high(atoms_model_LH, freeze_model, return_hessian=return_hessian)
1667
+ mm_out = self._eval_mm_low(atoms_real, atoms_model, return_hessian=return_hessian)
1668
+
1669
+ timing.update(ml_out.timing)
1670
+ timing.update(mm_out.timing)
1671
+
1672
+ total_E = mm_out.E_real + ml_out.E - mm_out.E_model
1673
+ results: Dict = {"energy": total_E}
1674
+
1675
+ if return_forces or return_hessian:
1676
+ F_combined = np.copy(mm_out.F_real)
1677
+ for i, ridx in enumerate(self.selection_indices):
1678
+ F_combined[ridx] += ml_out.F[i] - mm_out.F_model[i]
1679
+
1680
+ real_to_model = self._idx_map_real_to_model
1681
+ for link_idx, ml_idx, mm_idx, dist in added_link_atoms:
1682
+ ml_model_idx = real_to_model[ml_idx]
1683
+ r_ml = atoms_model_LH[ml_model_idx].position
1684
+ r_mm = atoms_real[mm_idx].position
1685
+ grad_link = ml_out.F[link_idx]
1686
+ J = self._jacobian_blocks_numpy(r_ml, r_mm, dist)
1687
+ if J is None:
1688
+ continue
1689
+ redistributed = J @ grad_link
1690
+ F_combined[ml_idx] += redistributed[:3]
1691
+ F_combined[mm_idx] += redistributed[3:]
1692
+ results["forces"] = F_combined
1693
+
1694
+ # Point-charge embedding correction (optional)
1695
+ embed_dH = None
1696
+ if self.embedcharge and self._embed_correction is not None:
1697
+ t0_embed = time.perf_counter()
1698
+ # ML atom symbols and coordinates
1699
+ ml_symbols = [atoms_model_LH[i].symbol for i in range(len(self._atoms_model_tpl))]
1700
+ ml_coords = np.array([atoms_model_LH[i].position for i in range(len(self._atoms_model_tpl))])
1701
+ # MM atom coordinates and charges from the real topology
1702
+ ml_set = set(self.selection_indices)
1703
+ mm_atom_indices = [i for i in range(len(atoms_real)) if i not in ml_set]
1704
+ if mm_atom_indices and self.embedcharge_cutoff is not None:
1705
+ from scipy.spatial.distance import cdist
1706
+ ml_coords = atoms_real.get_positions()[sorted(ml_set)]
1707
+ mm_coords_all = atoms_real.get_positions()[mm_atom_indices]
1708
+ dists = cdist(mm_coords_all, ml_coords).min(axis=1)
1709
+ n_before = len(mm_atom_indices)
1710
+ mask = dists <= self.embedcharge_cutoff
1711
+ mm_atom_indices = [mm_atom_indices[j] for j in range(n_before) if mask[j]]
1712
+ if self.print_timing and not getattr(self, '_embedcharge_logged', False):
1713
+ print(f"[embedcharge] {len(mm_atom_indices)}/{n_before} MM atoms within {self.embedcharge_cutoff:.1f} Å cutoff.")
1714
+ self._embedcharge_logged = True
1715
+ if mm_atom_indices:
1716
+ mm_coords = atoms_real.get_positions()[mm_atom_indices]
1717
+ # Get MM partial charges from the topology
1718
+ mm_charges = self._get_mm_charges(mm_atom_indices)
1719
+
1720
+ dE_embed, dF_embed, dH_embed = self._embed_correction.compute_correction(
1721
+ symbols=ml_symbols,
1722
+ coords_ml_ang=ml_coords,
1723
+ mm_coords_ang=mm_coords,
1724
+ mm_charges=mm_charges,
1725
+ charge=self.model_charge,
1726
+ multiplicity=self.model_mult,
1727
+ need_forces=return_forces or return_hessian,
1728
+ need_hessian=return_hessian,
1729
+ )
1730
+
1731
+ # Add energy correction
1732
+ results["energy"] += dE_embed
1733
+
1734
+ # Add force corrections on ML atoms
1735
+ if dF_embed is not None and (return_forces or return_hessian):
1736
+ for i, ridx in enumerate(self.selection_indices):
1737
+ if i < len(dF_embed):
1738
+ results["forces"][ridx] += dF_embed[i]
1739
+
1740
+ # Store Hessian correction for later assembly
1741
+ if dH_embed is not None:
1742
+ embed_dH = dH_embed
1743
+
1744
+ timing["embedcharge_s"] = time.perf_counter() - t0_embed
1745
+
1746
+ if return_hessian:
1747
+ n_real = len(atoms_real)
1748
+ n_ml = len(self.selection_indices)
1749
+ n_hess_active = self.n_hess_active
1750
+
1751
+ if self.mm_fd is True:
1752
+ if mm_out.H_real is None or mm_out.H_model is None:
1753
+ raise RuntimeError("MM Hessians were not computed as expected.")
1754
+
1755
+ if mm_out.active_atoms_from_fd is not None:
1756
+ expected = set(self.hess_active_atoms)
1757
+ got = set(mm_out.active_atoms_from_fd.tolist())
1758
+ if expected != got:
1759
+ raise RuntimeError(
1760
+ f"Hessian active atoms mismatch: expected {len(expected)} atoms, got {len(got)}"
1761
+ )
1762
+
1763
+ H = torch.from_numpy(mm_out.H_real).to(self.ml_device, self.H_dtype)
1764
+ H = H.view(n_hess_active, 3, n_hess_active, 3)
1765
+
1766
+ H_model = torch.from_numpy(mm_out.H_model).to(self.ml_device, self.H_dtype)
1767
+ H_model = H_model.view(n_ml, 3, n_ml, 3)
1768
+ else:
1769
+ H = torch.zeros((n_hess_active, 3, n_hess_active, 3), dtype=self.H_dtype, device=self.ml_device)
1770
+ H_model = torch.zeros((n_ml, 3, n_ml, 3), dtype=self.H_dtype, device=self.ml_device)
1771
+
1772
+ H_high = ml_out.H
1773
+ ml_pairs = [
1774
+ (i, self.full_to_hess_active[gi_real])
1775
+ for i, gi_real in enumerate(self.selection_indices)
1776
+ if gi_real in self.full_to_hess_active
1777
+ ]
1778
+ if ml_pairs:
1779
+ ml_sel_idx = torch.as_tensor([p[0] for p in ml_pairs], dtype=torch.long, device=self.ml_device)
1780
+ ml_active_idx = torch.as_tensor([p[1] for p in ml_pairs], dtype=torch.long, device=self.ml_device)
1781
+ else:
1782
+ ml_sel_idx = torch.empty((0,), dtype=torch.long, device=self.ml_device)
1783
+ ml_active_idx = torch.empty((0,), dtype=torch.long, device=self.ml_device)
1784
+
1785
+ if H_high is not None and ml_sel_idx.numel() > 0:
1786
+ t_asm = time.perf_counter()
1787
+ H_high_mm = H_high.index_select(0, ml_sel_idx).index_select(2, ml_sel_idx)
1788
+ H_model_mm = H_model.index_select(0, ml_sel_idx).index_select(2, ml_sel_idx)
1789
+ delta_mm = H_high_mm - H_model_mm
1790
+ H[ml_active_idx[:, None], :, ml_active_idx[None, :], :] += delta_mm.permute(0, 2, 1, 3)
1791
+ timing["hess_asm_mlml_s"] = time.perf_counter() - t_asm
1792
+ del H_model
1793
+
1794
+ real_to_model = self._idx_map_real_to_model
1795
+ link_data: List[Tuple[int, int, int, int, int, float, torch.Tensor]] = []
1796
+ for link_idx, ml_idx, mm_idx, dist in added_link_atoms:
1797
+ ml_model_idx = real_to_model[ml_idx]
1798
+ r_ml_t = torch.tensor(atoms_model_LH[ml_model_idx].position, dtype=self.H_dtype, device=self.ml_device)
1799
+ r_mm_t = torch.tensor(atoms_real[mm_idx].position, dtype=self.H_dtype, device=self.ml_device)
1800
+ K = self._jacobian_blocks_torch(r_ml_t, r_mm_t, dist, dtype=self.H_dtype, device=self.ml_device)
1801
+ if K is None:
1802
+ continue
1803
+ ml_active = self.full_to_hess_active.get(ml_idx)
1804
+ mm_active = self.full_to_hess_active.get(mm_idx)
1805
+ if ml_active is None or mm_active is None:
1806
+ continue
1807
+ link_data.append((link_idx, ml_idx, mm_idx, ml_active, mm_active, dist, K))
1808
+
1809
+ F_high_t = torch.as_tensor(ml_out.F, dtype=self.H_dtype, device=self.ml_device)
1810
+ has_link_force = bool((F_high_t.abs() > 1e-12).any().item())
1811
+ if link_data and (H_high is not None or has_link_force):
1812
+ t_asm = time.perf_counter()
1813
+ I3 = torch.eye(3, dtype=self.H_dtype, device=self.ml_device)
1814
+ for link_idx, ml_idx, mm_idx, ml_active, mm_active, dist, K in link_data:
1815
+
1816
+ if H_high is not None:
1817
+ H_l = H_high[link_idx, :, link_idx, :]
1818
+ H_self = K.T @ H_l @ K
1819
+ H[ml_active, :, ml_active, :].add_(H_self[0:3, 0:3])
1820
+ H[ml_active, :, mm_active, :].add_(H_self[0:3, 3:6])
1821
+ H[mm_active, :, ml_active, :].add_(H_self[3:6, 0:3])
1822
+ H[mm_active, :, mm_active, :].add_(H_self[3:6, 3:6])
1823
+
1824
+ f_L = -F_high_t[link_idx]
1825
+
1826
+ r_ml_t = torch.as_tensor(atoms_model_LH[real_to_model[ml_idx]].position,
1827
+ dtype=self.H_dtype, device=self.ml_device)
1828
+ r_mm_t = torch.as_tensor(atoms_real[mm_idx].position,
1829
+ dtype=self.H_dtype, device=self.ml_device)
1830
+ v = r_mm_t - r_ml_t
1831
+ R_sq = torch.dot(v, v)
1832
+ inv_R = torch.rsqrt(torch.clamp(R_sq, min=1.0e-24))
1833
+ inv_R2 = inv_R * inv_R
1834
+ u = v * inv_R
1835
+
1836
+ alpha = torch.dot(u, f_L)
1837
+ uuT = torch.outer(u, u)
1838
+ ufT = torch.outer(u, f_L)
1839
+ fTu = torch.outer(f_L, u)
1840
+ B = (alpha * (3.0 * uuT - I3) - (ufT + fTu)) * inv_R2
1841
+
1842
+ H_corr6 = torch.zeros((6, 6), dtype=self.H_dtype, device=self.ml_device)
1843
+ H_corr6[0:3, 0:3] = B
1844
+ H_corr6[3:6, 3:6] = B
1845
+ H_corr6[0:3, 3:6] = -B
1846
+ H_corr6[3:6, 0:3] = -B
1847
+ H_corr6.mul_(dist)
1848
+
1849
+ H[ml_active, :, ml_active, :].add_(H_corr6[0:3, 0:3])
1850
+ H[ml_active, :, mm_active, :].add_(H_corr6[0:3, 3:6])
1851
+ H[mm_active, :, ml_active, :].add_(H_corr6[3:6, 0:3])
1852
+ H[mm_active, :, mm_active, :].add_(H_corr6[3:6, 3:6])
1853
+ timing["hess_asm_link_self_s"] = time.perf_counter() - t_asm
1854
+
1855
+ if H_high is not None and link_data and ml_sel_idx.numel() > 0:
1856
+ t_asm = time.perf_counter()
1857
+ for link_idx, _ml_idx, _mm_idx, ml_active, mm_active, _dist, K in link_data:
1858
+ H_coup = H_high[link_idx].index_select(1, ml_sel_idx).permute(1, 0, 2).contiguous() # (K,3,3)
1859
+ H_row = torch.einsum("ac,bcd->bad", K.T, H_coup) # (K,6,3)
1860
+ H_col = torch.einsum("bca,cd->bad", H_coup, K) # (K,3,6)
1861
+
1862
+ # Mixed scalar/tensor indexing in PyTorch returns (3, K, 3) for
1863
+ # H[scalar, :, tensor, :], so align H_row blocks explicitly.
1864
+ H[ml_active, :, ml_active_idx, :].add_(H_row[:, 0:3, :].permute(1, 0, 2))
1865
+ H[mm_active, :, ml_active_idx, :].add_(H_row[:, 3:6, :].permute(1, 0, 2))
1866
+ H[ml_active_idx, :, ml_active, :].add_(H_col[:, :, 0:3])
1867
+ H[ml_active_idx, :, mm_active, :].add_(H_col[:, :, 3:6])
1868
+ timing["hess_asm_link_ml_s"] = time.perf_counter() - t_asm
1869
+
1870
+ if H_high is not None and link_data:
1871
+ t_asm = time.perf_counter()
1872
+ n_links = len(link_data)
1873
+ for a in range(n_links):
1874
+ link_idx_a, _ml_a, _mm_a, ml_a_active, mm_a_active, _dist_a, K_a = link_data[a]
1875
+
1876
+ for b in range(a + 1, n_links):
1877
+ link_idx_b, _ml_b, _mm_b, ml_b_active, mm_b_active, _dist_b, K_b = link_data[b]
1878
+
1879
+ H_ab = H_high[link_idx_a, :, link_idx_b, :]
1880
+ HAB = K_a.T @ H_ab @ K_b
1881
+
1882
+ H[ml_a_active, :, ml_b_active, :].add_(HAB[0:3, 0:3])
1883
+ H[ml_a_active, :, mm_b_active, :].add_(HAB[0:3, 3:6])
1884
+ H[mm_a_active, :, ml_b_active, :].add_(HAB[3:6, 0:3])
1885
+ H[mm_a_active, :, mm_b_active, :].add_(HAB[3:6, 3:6])
1886
+
1887
+ HBA = HAB.T
1888
+ H[ml_b_active, :, ml_a_active, :].add_(HBA[0:3, 0:3])
1889
+ H[ml_b_active, :, mm_a_active, :].add_(HBA[0:3, 3:6])
1890
+ H[mm_b_active, :, ml_a_active, :].add_(HBA[3:6, 0:3])
1891
+ H[mm_b_active, :, mm_a_active, :].add_(HBA[3:6, 3:6])
1892
+ timing["hess_asm_link_link_s"] = time.perf_counter() - t_asm
1893
+
1894
+ # Add point-charge embedding Hessian correction
1895
+ if embed_dH is not None:
1896
+ t_asm = time.perf_counter()
1897
+ n_model_atoms = len(self.selection_indices)
1898
+ dH_t = torch.from_numpy(embed_dH).to(self.ml_device, self.H_dtype)
1899
+ dH_t = dH_t.view(n_model_atoms, 3, n_model_atoms, 3)
1900
+ if ml_sel_idx.numel() > 0:
1901
+ dH_sub = dH_t.index_select(0, ml_sel_idx).index_select(2, ml_sel_idx)
1902
+ H[ml_active_idx[:, None], :, ml_active_idx[None, :], :] += dH_sub.permute(0, 2, 1, 3)
1903
+ timing["hess_asm_embed_s"] = time.perf_counter() - t_asm
1904
+
1905
+ if self.symmetrize_hessian:
1906
+ t_asm = time.perf_counter()
1907
+ H_flat = H.view(3 * n_hess_active, 3 * n_hess_active)
1908
+ H_flat = (H_flat + H_flat.t()).mul_(0.5)
1909
+ H = H_flat.view(n_hess_active, 3, n_hess_active, 3)
1910
+ timing["hess_asm_sym_s"] = time.perf_counter() - t_asm
1911
+
1912
+ if self.return_partial_hessian:
1913
+ results["hessian"] = H.detach()
1914
+ results["within_partial_hessian"] = self._build_within_partial_hessian()
1915
+ else:
1916
+ t_asm = time.perf_counter()
1917
+ H_full = torch.zeros((n_real, 3, n_real, 3), dtype=self.H_dtype, device=self.ml_device)
1918
+ active_idx = torch.as_tensor(self.hess_active_atoms, dtype=torch.long, device=self.ml_device)
1919
+ if active_idx.numel() > 0:
1920
+ H_full[active_idx[:, None], :, active_idx[None, :], :] = H.permute(0, 2, 1, 3).contiguous()
1921
+ results["hessian"] = H_full.detach()
1922
+ timing["hess_asm_full_expand_s"] = time.perf_counter() - t_asm
1923
+ del H_full
1924
+
1925
+ if hess_total_start is not None:
1926
+ timing["hessian_total_s"] = time.perf_counter() - hess_total_start
1927
+ results["timing"] = timing
1928
+ if self.print_timing:
1929
+ ml_mode = timing.get("ml_hessian_mode")
1930
+ ml_time = timing.get("ml_hessian_s")
1931
+ if ml_mode is not None and ml_time is not None:
1932
+ click.echo(f"[HessianTiming] ML Hessian ({ml_mode}): {ml_time:.2f} s")
1933
+ if "mm_fd_total_s" in timing:
1934
+ click.echo(
1935
+ f"[HessianTiming] MM Hessian: REAL {timing['mm_fd_real_s']:.2f} s | "
1936
+ f"MODEL {timing['mm_fd_model_s']:.2f} s | "
1937
+ f"total {timing['mm_fd_total_s']:.2f} s"
1938
+ )
1939
+ asm_parts = []
1940
+ for key, label in (
1941
+ ("hess_asm_mlml_s", "ML-ML"),
1942
+ ("hess_asm_link_self_s", "link-self"),
1943
+ ("hess_asm_link_ml_s", "link-ML"),
1944
+ ("hess_asm_link_link_s", "link-link"),
1945
+ ("hess_asm_sym_s", "sym"),
1946
+ ("hess_asm_full_expand_s", "full-expand"),
1947
+ ):
1948
+ if key in timing:
1949
+ asm_parts.append(f"{label} {float(timing[key]):.2f} s")
1950
+ if asm_parts:
1951
+ click.echo(f"[HessianTiming] Assembly: {' | '.join(asm_parts)}")
1952
+ click.echo(f"[HessianTiming] Hessian total: {timing['hessian_total_s']:.2f} s")
1953
+ if self.print_vram and self.ml_device.type == "cuda":
1954
+ torch.cuda.synchronize(device=self.ml_device)
1955
+ base_alloc = float(hess_vram_base_alloc or 0.0)
1956
+ base_reserved = float(hess_vram_base_reserved or 0.0)
1957
+ peak_alloc = max(
1958
+ float(torch.cuda.max_memory_allocated(device=self.ml_device)) - base_alloc,
1959
+ 0.0,
1960
+ ) / 1e9
1961
+ peak_reserved_abs = float(torch.cuda.max_memory_reserved(device=self.ml_device))
1962
+ peak_reserved = max(
1963
+ peak_reserved_abs - base_reserved,
1964
+ 0.0,
1965
+ ) / 1e9
1966
+ total_vram = float(hess_vram_total or torch.cuda.get_device_properties(self.ml_device).total_memory) / 1e9
1967
+ remaining_vram = max((total_vram * 1e9) - peak_reserved_abs, 0.0) / 1e9
1968
+ click.echo(
1969
+ f"[HessianVRAM] total={total_vram:.3f} GB | "
1970
+ f"peak_allocated={peak_alloc:.3f} GB | "
1971
+ f"peak_reserved={peak_reserved:.3f} GB | "
1972
+ f"remaining={remaining_vram:.3f} GB"
1973
+ )
1974
+
1975
+ del H, H_high
1976
+ if self.ml_device.type == "cuda":
1977
+ torch.cuda.empty_cache()
1978
+
1979
+ return results
1980
+
1981
+
1982
+ # ======================================================================
1983
+ # ASE Calculator wrapper for ML/MM (ONIOM)
1984
+ # ======================================================================
1985
+
1986
+ class MLMMASECalculator(Calculator):
1987
+ """ASE Calculator wrapping MLMMCore for use with DMF and other ASE-based methods.
1988
+
1989
+ The underlying MLMMCore takes full-system coordinates (Angstrom) and
1990
+ returns energy in eV and forces in eV/Angstrom, which matches ASE conventions.
1991
+ """
1992
+
1993
+ implemented_properties = ["energy", "forces"]
1994
+
1995
+ def __init__(self, core: "MLMMCore", **kwargs):
1996
+ super().__init__(**kwargs)
1997
+ self.core = core
1998
+
1999
+ def calculate(self, atoms=None, properties=("energy",), system_changes=all_changes):
2000
+ super().calculate(atoms, properties, system_changes)
2001
+ coord_ang = atoms.get_positions().astype(float)
2002
+ want_forces = "forces" in properties
2003
+ res = self.core.compute(coord_ang, return_forces=want_forces, return_hessian=False)
2004
+ self.results = {
2005
+ "energy": float(res["energy"]),
2006
+ }
2007
+ if want_forces:
2008
+ self.results["forces"] = res["forces"].reshape(-1, 3)
2009
+
2010
+
2011
+ # ======================================================================
2012
+ # PySisyphus Calculator (ML/MM)
2013
+ # ======================================================================
2014
+
2015
+ from pysisyphus.calculators.Calculator import Calculator as PySiCalc
2016
+
2017
+
2018
+ class mlmm(PySiCalc):
2019
+ implemented_properties = ["energy", "forces", "hessian"]
2020
+
2021
+ def __init__(
2022
+ self,
2023
+ input_pdb: Optional[str] = None,
2024
+ real_parm7: Optional[str] = None,
2025
+ model_pdb: Optional[str] = None,
2026
+ *,
2027
+ model_charge: int = 0,
2028
+ model_mult: int = 1,
2029
+ link_mlmm: List[Tuple[str, str]] | None = None,
2030
+ # ML backend selection
2031
+ backend: str = "uma",
2032
+ uma_model: str = "uma-s-1p1",
2033
+ uma_task_name: str = "omol",
2034
+ orb_model: str = "orb_v3_conservative_omol",
2035
+ orb_precision: str = "float32",
2036
+ mace_model: str = "MACE-OMOL-0",
2037
+ mace_dtype: str = "float64",
2038
+ aimnet2_model: str = "aimnet2",
2039
+ # MM settings
2040
+ mm_fd: bool = True,
2041
+ mm_fd_dir: Optional[str] = None,
2042
+ mm_fd_delta: float = 1e-3,
2043
+ symmetrize_hessian: bool = True,
2044
+ out_hess_torch: bool = True,
2045
+ H_double: bool = False,
2046
+ ml_hessian_mode: str = "FiniteDifference",
2047
+ hessian_calc_mode: Optional[str] = None,
2048
+ ml_device: str = "auto",
2049
+ ml_cuda_idx: int = 0,
2050
+ mm_device: str = "cpu",
2051
+ mm_cuda_idx: int = 0,
2052
+ mm_threads: int = 16,
2053
+ mm_backend: str = "hessian_ff",
2054
+ freeze_atoms: List[int] | None = None,
2055
+ return_partial_hessian: bool = True,
2056
+ print_timing: bool = True,
2057
+ print_vram: bool = True,
2058
+ hess_cutoff: Optional[float] = None,
2059
+ movable_cutoff: Optional[float] = None,
2060
+ use_bfactor_layers: bool = False,
2061
+ hess_mm_atoms: Optional[List[int]] = None,
2062
+ movable_mm_atoms: Optional[List[int]] = None,
2063
+ frozen_mm_atoms: Optional[List[int]] = None,
2064
+ # Point-charge embedding correction
2065
+ embedcharge: bool = False,
2066
+ embedcharge_step: float = 1.0e-3,
2067
+ embedcharge_cutoff: Optional[float] = None,
2068
+ xtb_cmd: str = "xtb",
2069
+ xtb_acc: float = 0.2,
2070
+ xtb_workdir: str = "tmp",
2071
+ xtb_keep_files: bool = False,
2072
+ xtb_ncores: int = 4,
2073
+ **kwargs,
2074
+ ):
2075
+ # --- v0.1.x backward compatibility aliases ---
2076
+ if "real_pdb" in kwargs:
2077
+ warnings.warn("'real_pdb' is deprecated; use 'input_pdb'.", DeprecationWarning, stacklevel=2)
2078
+ if input_pdb is None:
2079
+ input_pdb = kwargs.pop("real_pdb")
2080
+ else:
2081
+ kwargs.pop("real_pdb")
2082
+ for _old_name in ("real_rst7", "vib_run", "vib_dir"):
2083
+ if _old_name in kwargs:
2084
+ warnings.warn(f"'{_old_name}' is no longer used and will be ignored.", DeprecationWarning, stacklevel=2)
2085
+ kwargs.pop(_old_name)
2086
+
2087
+ self._freeze_atoms = [] if freeze_atoms is None else list(freeze_atoms)
2088
+ super().__init__(charge=model_charge, mult=model_mult, **kwargs)
2089
+
2090
+ self.core = MLMMCore(
2091
+ input_pdb=input_pdb,
2092
+ real_parm7=real_parm7,
2093
+ model_pdb=model_pdb,
2094
+ model_charge=model_charge,
2095
+ model_mult=model_mult,
2096
+ link_mlmm=link_mlmm,
2097
+ backend=backend,
2098
+ uma_model=uma_model,
2099
+ uma_task_name=uma_task_name,
2100
+ orb_model=orb_model,
2101
+ mace_model=mace_model,
2102
+ mace_dtype=mace_dtype,
2103
+ aimnet2_model=aimnet2_model,
2104
+ mm_fd=mm_fd,
2105
+ mm_fd_dir=mm_fd_dir,
2106
+ mm_fd_delta=mm_fd_delta,
2107
+ symmetrize_hessian=symmetrize_hessian,
2108
+ H_double=H_double,
2109
+ ml_device=ml_device,
2110
+ ml_cuda_idx=ml_cuda_idx,
2111
+ mm_device=mm_device,
2112
+ mm_cuda_idx=mm_cuda_idx,
2113
+ mm_threads=mm_threads,
2114
+ mm_backend=mm_backend,
2115
+ freeze_atoms=self._freeze_atoms,
2116
+ ml_hessian_mode=ml_hessian_mode,
2117
+ hessian_calc_mode=hessian_calc_mode,
2118
+ return_partial_hessian=return_partial_hessian,
2119
+ print_timing=print_timing,
2120
+ print_vram=print_vram,
2121
+ hess_cutoff=hess_cutoff,
2122
+ movable_cutoff=movable_cutoff,
2123
+ use_bfactor_layers=use_bfactor_layers,
2124
+ hess_mm_atoms=hess_mm_atoms,
2125
+ movable_mm_atoms=movable_mm_atoms,
2126
+ frozen_mm_atoms=frozen_mm_atoms,
2127
+ embedcharge=embedcharge,
2128
+ embedcharge_step=embedcharge_step,
2129
+ embedcharge_cutoff=embedcharge_cutoff,
2130
+ xtb_cmd=xtb_cmd,
2131
+ xtb_acc=xtb_acc,
2132
+ xtb_workdir=xtb_workdir,
2133
+ xtb_keep_files=xtb_keep_files,
2134
+ xtb_ncores=xtb_ncores,
2135
+ )
2136
+
2137
+ self.out_hess_torch = bool(out_hess_torch)
2138
+ self.hess_torch_double = bool(H_double)
2139
+ self._hess_scale = EV2AU / ANG2BOHR / ANG2BOHR
2140
+
2141
+ @property
2142
+ def freeze_atoms(self) -> List[int] | None:
2143
+ return self.core.freeze_atoms
2144
+
2145
+ @freeze_atoms.setter
2146
+ def freeze_atoms(self, indices: List[int] | None):
2147
+ self._freeze_atoms = [] if indices is None else list(indices)
2148
+ self.core.freeze_atoms = self._freeze_atoms
2149
+ self.core._update_active_dof_mappings()
2150
+
2151
+ def _run_core(self, coords, *, want_forces: bool, want_hessian: bool):
2152
+ coord_ang = np.asarray(coords).reshape(-1, 3) * BOHR2ANG
2153
+ res = self.core.compute(coord_ang, return_forces=want_forces or want_hessian, return_hessian=want_hessian)
2154
+ out = {"energy": res["energy"] * EV2AU}
2155
+ if want_forces or want_hessian:
2156
+ out["forces"] = (res["forces"] * (EV2AU / ANG2BOHR)).flatten()
2157
+ if want_hessian:
2158
+ H = res.pop("hessian")
2159
+ H = H.view(H.size(0) * 3, H.size(2) * 3)
2160
+ H.mul_(self._hess_scale)
2161
+ if self.out_hess_torch:
2162
+ target_dtype = torch.float64 if self.hess_torch_double else torch.float32
2163
+ out["hessian"] = H.to(target_dtype).detach().requires_grad_(False)
2164
+ else:
2165
+ out["hessian"] = H.detach().cpu().numpy()
2166
+ if "within_partial_hessian" in res:
2167
+ out["within_partial_hessian"] = res["within_partial_hessian"]
2168
+ return out
2169
+
2170
+ def get_energy(self, elem, coords):
2171
+ return self._run_core(coords, want_forces=False, want_hessian=False)
2172
+
2173
+ def get_forces(self, elem, coords):
2174
+ return self._run_core(coords, want_forces=True, want_hessian=False)
2175
+
2176
+ def get_hessian(self, elem, coords):
2177
+ return self._run_core(coords, want_forces=True, want_hessian=True)
2178
+
2179
+
2180
+ # ======================================================================
2181
+ # PySisyphus Calculator (MM-only)
2182
+ # ======================================================================
2183
+
2184
+
2185
+ class mlmm_mm_only(PySiCalc):
2186
+ """PySisyphus calculator that returns MM-only energy and forces (F_real_mm).
2187
+
2188
+ Used for microiteration: relaxes the MM region without ML computation.
2189
+ Shares the MLMMCore from an existing ``mlmm`` calculator to avoid
2190
+ re-initializing topology and force field objects.
2191
+ """
2192
+
2193
+ implemented_properties = ["energy", "forces"]
2194
+
2195
+ def __init__(self, core: "MLMMCore", *, freeze_atoms: list[int] | None = None, **kwargs):
2196
+ super().__init__(charge=core.model_charge, mult=core.model_mult, **kwargs)
2197
+ self.core = core
2198
+ self._freeze_atoms = list(freeze_atoms) if freeze_atoms else []
2199
+
2200
+ def _run_core(self, coords, *, want_forces: bool):
2201
+ coord_ang = np.asarray(coords).reshape(-1, 3) * BOHR2ANG
2202
+ atoms_real = self.core._atoms_real_tpl.copy()
2203
+ atoms_real.set_positions(coord_ang)
2204
+ atoms_real.set_pbc(False)
2205
+ atoms_real.calc = self.core.calc_real_low
2206
+ E_real = float(atoms_real.get_potential_energy())
2207
+ out = {"energy": E_real * EV2AU}
2208
+ if want_forces:
2209
+ F_real = np.double(atoms_real.get_forces())
2210
+ # Zero forces on frozen atoms
2211
+ for i in self._freeze_atoms:
2212
+ if 0 <= i < F_real.shape[0]:
2213
+ F_real[i, :] = 0.0
2214
+ out["forces"] = (F_real * (EV2AU / ANG2BOHR)).flatten()
2215
+ return out
2216
+
2217
+ def get_energy(self, elem, coords):
2218
+ return self._run_core(coords, want_forces=False)
2219
+
2220
+ def get_forces(self, elem, coords):
2221
+ return self._run_core(coords, want_forces=True)
2222
+
2223
+ def get_hessian(self, elem, coords):
2224
+ raise NotImplementedError("MM-only calculator does not support Hessian computation.")
2225
+
2226
+
2227
+ # ======================================================================
2228
+ # v0.1.x compatibility: mlmm_ase() factory
2229
+ # ======================================================================
2230
+
2231
+
2232
+ def mlmm_ase(**kwargs):
2233
+ """v0.1.x compatibility wrapper.
2234
+
2235
+ Accepts all MLMMCore parameters as keyword arguments and returns
2236
+ an MLMMASECalculator. Equivalent to::
2237
+
2238
+ MLMMASECalculator(MLMMCore(**kwargs))
2239
+ """
2240
+ warnings.warn(
2241
+ "mlmm_ase() is deprecated; use MLMMASECalculator(MLMMCore(...)) instead.",
2242
+ DeprecationWarning,
2243
+ stacklevel=2,
2244
+ )
2245
+ core = MLMMCore(**kwargs)
2246
+ return MLMMASECalculator(core)
2247
+
2248
+
2249
+ # ======================================================================
2250
+ # CLI registration
2251
+ # ======================================================================
2252
+
2253
+ from pysisyphus import run as _run
2254
+
2255
+
2256
+ def run_pysis_mlmm():
2257
+ _run.CALC_DICT["mlmm"] = mlmm
2258
+ _run.run()
2259
+
2260
+
2261
+ if __name__ == "__main__":
2262
+ run_pysis_mlmm()