mlmm-toolkit 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (372) hide show
  1. hessian_ff/__init__.py +50 -0
  2. hessian_ff/analytical_hessian.py +609 -0
  3. hessian_ff/constants.py +46 -0
  4. hessian_ff/forcefield.py +339 -0
  5. hessian_ff/loaders.py +608 -0
  6. hessian_ff/native/Makefile +8 -0
  7. hessian_ff/native/__init__.py +28 -0
  8. hessian_ff/native/analytical_hessian.py +88 -0
  9. hessian_ff/native/analytical_hessian_ext.cpp +258 -0
  10. hessian_ff/native/bonded.py +82 -0
  11. hessian_ff/native/bonded_ext.cpp +640 -0
  12. hessian_ff/native/loader.py +349 -0
  13. hessian_ff/native/nonbonded.py +118 -0
  14. hessian_ff/native/nonbonded_ext.cpp +1150 -0
  15. hessian_ff/prmtop_parmed.py +23 -0
  16. hessian_ff/system.py +107 -0
  17. hessian_ff/terms/__init__.py +14 -0
  18. hessian_ff/terms/angle.py +73 -0
  19. hessian_ff/terms/bond.py +44 -0
  20. hessian_ff/terms/cmap.py +406 -0
  21. hessian_ff/terms/dihedral.py +141 -0
  22. hessian_ff/terms/nonbonded.py +209 -0
  23. hessian_ff/tests/__init__.py +0 -0
  24. hessian_ff/tests/conftest.py +75 -0
  25. hessian_ff/tests/data/small/complex.parm7 +1346 -0
  26. hessian_ff/tests/data/small/complex.pdb +125 -0
  27. hessian_ff/tests/data/small/complex.rst7 +63 -0
  28. hessian_ff/tests/test_coords_input.py +44 -0
  29. hessian_ff/tests/test_energy_force.py +49 -0
  30. hessian_ff/tests/test_hessian.py +137 -0
  31. hessian_ff/tests/test_smoke.py +18 -0
  32. hessian_ff/tests/test_validation.py +40 -0
  33. hessian_ff/workflows.py +889 -0
  34. mlmm/__init__.py +36 -0
  35. mlmm/__main__.py +7 -0
  36. mlmm/_version.py +34 -0
  37. mlmm/add_elem_info.py +374 -0
  38. mlmm/advanced_help.py +91 -0
  39. mlmm/align_freeze_atoms.py +601 -0
  40. mlmm/all.py +3535 -0
  41. mlmm/bond_changes.py +231 -0
  42. mlmm/bool_compat.py +223 -0
  43. mlmm/cli.py +574 -0
  44. mlmm/cli_utils.py +166 -0
  45. mlmm/default_group.py +337 -0
  46. mlmm/defaults.py +467 -0
  47. mlmm/define_layer.py +526 -0
  48. mlmm/dft.py +1041 -0
  49. mlmm/energy_diagram.py +253 -0
  50. mlmm/extract.py +2213 -0
  51. mlmm/fix_altloc.py +464 -0
  52. mlmm/freq.py +1406 -0
  53. mlmm/harmonic_constraints.py +140 -0
  54. mlmm/hessian_cache.py +44 -0
  55. mlmm/hessian_calc.py +174 -0
  56. mlmm/irc.py +638 -0
  57. mlmm/mlmm_calc.py +2262 -0
  58. mlmm/mm_parm.py +945 -0
  59. mlmm/oniom_export.py +1983 -0
  60. mlmm/oniom_import.py +457 -0
  61. mlmm/opt.py +1742 -0
  62. mlmm/path_opt.py +1353 -0
  63. mlmm/path_search.py +2299 -0
  64. mlmm/preflight.py +88 -0
  65. mlmm/py.typed +1 -0
  66. mlmm/pysis_runner.py +45 -0
  67. mlmm/scan.py +1047 -0
  68. mlmm/scan2d.py +1226 -0
  69. mlmm/scan3d.py +1265 -0
  70. mlmm/scan_common.py +184 -0
  71. mlmm/summary_log.py +736 -0
  72. mlmm/trj2fig.py +448 -0
  73. mlmm/tsopt.py +2871 -0
  74. mlmm/utils.py +2309 -0
  75. mlmm/xtb_embedcharge_correction.py +475 -0
  76. mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
  77. mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
  78. mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
  79. mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
  80. mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
  81. mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
  82. pysisyphus/Geometry.py +1667 -0
  83. pysisyphus/LICENSE +674 -0
  84. pysisyphus/TableFormatter.py +63 -0
  85. pysisyphus/TablePrinter.py +74 -0
  86. pysisyphus/__init__.py +12 -0
  87. pysisyphus/calculators/AFIR.py +452 -0
  88. pysisyphus/calculators/AnaPot.py +20 -0
  89. pysisyphus/calculators/AnaPot2.py +48 -0
  90. pysisyphus/calculators/AnaPot3.py +12 -0
  91. pysisyphus/calculators/AnaPot4.py +20 -0
  92. pysisyphus/calculators/AnaPotBase.py +337 -0
  93. pysisyphus/calculators/AnaPotCBM.py +25 -0
  94. pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
  95. pysisyphus/calculators/CFOUR.py +250 -0
  96. pysisyphus/calculators/Calculator.py +844 -0
  97. pysisyphus/calculators/CerjanMiller.py +24 -0
  98. pysisyphus/calculators/Composite.py +123 -0
  99. pysisyphus/calculators/ConicalIntersection.py +171 -0
  100. pysisyphus/calculators/DFTBp.py +430 -0
  101. pysisyphus/calculators/DFTD3.py +66 -0
  102. pysisyphus/calculators/DFTD4.py +84 -0
  103. pysisyphus/calculators/Dalton.py +61 -0
  104. pysisyphus/calculators/Dimer.py +681 -0
  105. pysisyphus/calculators/Dummy.py +20 -0
  106. pysisyphus/calculators/EGO.py +76 -0
  107. pysisyphus/calculators/EnergyMin.py +224 -0
  108. pysisyphus/calculators/ExternalPotential.py +264 -0
  109. pysisyphus/calculators/FakeASE.py +35 -0
  110. pysisyphus/calculators/FourWellAnaPot.py +28 -0
  111. pysisyphus/calculators/FreeEndNEBPot.py +39 -0
  112. pysisyphus/calculators/Gaussian09.py +18 -0
  113. pysisyphus/calculators/Gaussian16.py +726 -0
  114. pysisyphus/calculators/HardSphere.py +159 -0
  115. pysisyphus/calculators/IDPPCalculator.py +49 -0
  116. pysisyphus/calculators/IPIClient.py +133 -0
  117. pysisyphus/calculators/IPIServer.py +234 -0
  118. pysisyphus/calculators/LEPSBase.py +24 -0
  119. pysisyphus/calculators/LEPSExpr.py +139 -0
  120. pysisyphus/calculators/LennardJones.py +80 -0
  121. pysisyphus/calculators/MOPAC.py +219 -0
  122. pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
  123. pysisyphus/calculators/MultiCalc.py +85 -0
  124. pysisyphus/calculators/NFK.py +45 -0
  125. pysisyphus/calculators/OBabel.py +87 -0
  126. pysisyphus/calculators/ONIOMv2.py +1129 -0
  127. pysisyphus/calculators/ORCA.py +893 -0
  128. pysisyphus/calculators/ORCA5.py +6 -0
  129. pysisyphus/calculators/OpenMM.py +88 -0
  130. pysisyphus/calculators/OpenMolcas.py +281 -0
  131. pysisyphus/calculators/OverlapCalculator.py +908 -0
  132. pysisyphus/calculators/Psi4.py +218 -0
  133. pysisyphus/calculators/PyPsi4.py +37 -0
  134. pysisyphus/calculators/PySCF.py +341 -0
  135. pysisyphus/calculators/PyXTB.py +73 -0
  136. pysisyphus/calculators/QCEngine.py +106 -0
  137. pysisyphus/calculators/Rastrigin.py +22 -0
  138. pysisyphus/calculators/Remote.py +76 -0
  139. pysisyphus/calculators/Rosenbrock.py +15 -0
  140. pysisyphus/calculators/SocketCalc.py +97 -0
  141. pysisyphus/calculators/TIP3P.py +111 -0
  142. pysisyphus/calculators/TransTorque.py +161 -0
  143. pysisyphus/calculators/Turbomole.py +965 -0
  144. pysisyphus/calculators/VRIPot.py +37 -0
  145. pysisyphus/calculators/WFOWrapper.py +333 -0
  146. pysisyphus/calculators/WFOWrapper2.py +341 -0
  147. pysisyphus/calculators/XTB.py +418 -0
  148. pysisyphus/calculators/__init__.py +81 -0
  149. pysisyphus/calculators/cosmo_data.py +139 -0
  150. pysisyphus/calculators/parser.py +150 -0
  151. pysisyphus/color.py +19 -0
  152. pysisyphus/config.py +133 -0
  153. pysisyphus/constants.py +65 -0
  154. pysisyphus/cos/AdaptiveNEB.py +230 -0
  155. pysisyphus/cos/ChainOfStates.py +725 -0
  156. pysisyphus/cos/FreeEndNEB.py +25 -0
  157. pysisyphus/cos/FreezingString.py +103 -0
  158. pysisyphus/cos/GrowingChainOfStates.py +71 -0
  159. pysisyphus/cos/GrowingNT.py +309 -0
  160. pysisyphus/cos/GrowingString.py +508 -0
  161. pysisyphus/cos/NEB.py +189 -0
  162. pysisyphus/cos/SimpleZTS.py +64 -0
  163. pysisyphus/cos/__init__.py +22 -0
  164. pysisyphus/cos/stiffness.py +199 -0
  165. pysisyphus/drivers/__init__.py +17 -0
  166. pysisyphus/drivers/afir.py +855 -0
  167. pysisyphus/drivers/barriers.py +271 -0
  168. pysisyphus/drivers/birkholz.py +138 -0
  169. pysisyphus/drivers/cluster.py +318 -0
  170. pysisyphus/drivers/diabatization.py +133 -0
  171. pysisyphus/drivers/merge.py +368 -0
  172. pysisyphus/drivers/merge_mol2.py +322 -0
  173. pysisyphus/drivers/opt.py +375 -0
  174. pysisyphus/drivers/perf.py +91 -0
  175. pysisyphus/drivers/pka.py +52 -0
  176. pysisyphus/drivers/precon_pos_rot.py +669 -0
  177. pysisyphus/drivers/rates.py +480 -0
  178. pysisyphus/drivers/replace.py +219 -0
  179. pysisyphus/drivers/scan.py +212 -0
  180. pysisyphus/drivers/spectrum.py +166 -0
  181. pysisyphus/drivers/thermo.py +31 -0
  182. pysisyphus/dynamics/Gaussian.py +103 -0
  183. pysisyphus/dynamics/__init__.py +20 -0
  184. pysisyphus/dynamics/colvars.py +136 -0
  185. pysisyphus/dynamics/driver.py +297 -0
  186. pysisyphus/dynamics/helpers.py +256 -0
  187. pysisyphus/dynamics/lincs.py +105 -0
  188. pysisyphus/dynamics/mdp.py +364 -0
  189. pysisyphus/dynamics/rattle.py +121 -0
  190. pysisyphus/dynamics/thermostats.py +128 -0
  191. pysisyphus/dynamics/wigner.py +266 -0
  192. pysisyphus/elem_data.py +3473 -0
  193. pysisyphus/exceptions.py +2 -0
  194. pysisyphus/filtertrj.py +69 -0
  195. pysisyphus/helpers.py +623 -0
  196. pysisyphus/helpers_pure.py +649 -0
  197. pysisyphus/init_logging.py +50 -0
  198. pysisyphus/intcoords/Bend.py +69 -0
  199. pysisyphus/intcoords/Bend2.py +25 -0
  200. pysisyphus/intcoords/BondedFragment.py +32 -0
  201. pysisyphus/intcoords/Cartesian.py +41 -0
  202. pysisyphus/intcoords/CartesianCoords.py +140 -0
  203. pysisyphus/intcoords/Coords.py +56 -0
  204. pysisyphus/intcoords/DLC.py +197 -0
  205. pysisyphus/intcoords/DistanceFunction.py +34 -0
  206. pysisyphus/intcoords/DummyImproper.py +70 -0
  207. pysisyphus/intcoords/DummyTorsion.py +72 -0
  208. pysisyphus/intcoords/LinearBend.py +105 -0
  209. pysisyphus/intcoords/LinearDisplacement.py +80 -0
  210. pysisyphus/intcoords/OutOfPlane.py +59 -0
  211. pysisyphus/intcoords/PrimTypes.py +286 -0
  212. pysisyphus/intcoords/Primitive.py +137 -0
  213. pysisyphus/intcoords/RedundantCoords.py +659 -0
  214. pysisyphus/intcoords/RobustTorsion.py +59 -0
  215. pysisyphus/intcoords/Rotation.py +147 -0
  216. pysisyphus/intcoords/Stretch.py +31 -0
  217. pysisyphus/intcoords/Torsion.py +101 -0
  218. pysisyphus/intcoords/Torsion2.py +25 -0
  219. pysisyphus/intcoords/Translation.py +45 -0
  220. pysisyphus/intcoords/__init__.py +61 -0
  221. pysisyphus/intcoords/augment_bonds.py +126 -0
  222. pysisyphus/intcoords/derivatives.py +10512 -0
  223. pysisyphus/intcoords/eval.py +80 -0
  224. pysisyphus/intcoords/exceptions.py +37 -0
  225. pysisyphus/intcoords/findiffs.py +48 -0
  226. pysisyphus/intcoords/generate_derivatives.py +414 -0
  227. pysisyphus/intcoords/helpers.py +235 -0
  228. pysisyphus/intcoords/logging_conf.py +10 -0
  229. pysisyphus/intcoords/mp_derivatives.py +10836 -0
  230. pysisyphus/intcoords/setup.py +962 -0
  231. pysisyphus/intcoords/setup_fast.py +176 -0
  232. pysisyphus/intcoords/update.py +272 -0
  233. pysisyphus/intcoords/valid.py +89 -0
  234. pysisyphus/interpolate/Geodesic.py +93 -0
  235. pysisyphus/interpolate/IDPP.py +55 -0
  236. pysisyphus/interpolate/Interpolator.py +116 -0
  237. pysisyphus/interpolate/LST.py +70 -0
  238. pysisyphus/interpolate/Redund.py +152 -0
  239. pysisyphus/interpolate/__init__.py +9 -0
  240. pysisyphus/interpolate/helpers.py +34 -0
  241. pysisyphus/io/__init__.py +22 -0
  242. pysisyphus/io/aomix.py +178 -0
  243. pysisyphus/io/cjson.py +24 -0
  244. pysisyphus/io/crd.py +101 -0
  245. pysisyphus/io/cube.py +220 -0
  246. pysisyphus/io/fchk.py +184 -0
  247. pysisyphus/io/hdf5.py +49 -0
  248. pysisyphus/io/hessian.py +72 -0
  249. pysisyphus/io/mol2.py +146 -0
  250. pysisyphus/io/molden.py +293 -0
  251. pysisyphus/io/orca.py +189 -0
  252. pysisyphus/io/pdb.py +269 -0
  253. pysisyphus/io/psf.py +79 -0
  254. pysisyphus/io/pubchem.py +31 -0
  255. pysisyphus/io/qcschema.py +34 -0
  256. pysisyphus/io/sdf.py +29 -0
  257. pysisyphus/io/xyz.py +61 -0
  258. pysisyphus/io/zmat.py +175 -0
  259. pysisyphus/irc/DWI.py +108 -0
  260. pysisyphus/irc/DampedVelocityVerlet.py +134 -0
  261. pysisyphus/irc/Euler.py +22 -0
  262. pysisyphus/irc/EulerPC.py +345 -0
  263. pysisyphus/irc/GonzalezSchlegel.py +187 -0
  264. pysisyphus/irc/IMKMod.py +164 -0
  265. pysisyphus/irc/IRC.py +878 -0
  266. pysisyphus/irc/IRCDummy.py +10 -0
  267. pysisyphus/irc/Instanton.py +307 -0
  268. pysisyphus/irc/LQA.py +53 -0
  269. pysisyphus/irc/ModeKill.py +136 -0
  270. pysisyphus/irc/ParamPlot.py +53 -0
  271. pysisyphus/irc/RK4.py +36 -0
  272. pysisyphus/irc/__init__.py +31 -0
  273. pysisyphus/irc/initial_displ.py +219 -0
  274. pysisyphus/linalg.py +411 -0
  275. pysisyphus/line_searches/Backtracking.py +88 -0
  276. pysisyphus/line_searches/HagerZhang.py +184 -0
  277. pysisyphus/line_searches/LineSearch.py +232 -0
  278. pysisyphus/line_searches/StrongWolfe.py +108 -0
  279. pysisyphus/line_searches/__init__.py +9 -0
  280. pysisyphus/line_searches/interpol.py +15 -0
  281. pysisyphus/modefollow/NormalMode.py +40 -0
  282. pysisyphus/modefollow/__init__.py +10 -0
  283. pysisyphus/modefollow/davidson.py +199 -0
  284. pysisyphus/modefollow/lanczos.py +95 -0
  285. pysisyphus/optimizers/BFGS.py +99 -0
  286. pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
  287. pysisyphus/optimizers/ConjugateGradient.py +98 -0
  288. pysisyphus/optimizers/CubicNewton.py +75 -0
  289. pysisyphus/optimizers/FIRE.py +113 -0
  290. pysisyphus/optimizers/HessianOptimizer.py +1176 -0
  291. pysisyphus/optimizers/LBFGS.py +228 -0
  292. pysisyphus/optimizers/LayerOpt.py +411 -0
  293. pysisyphus/optimizers/MicroOptimizer.py +169 -0
  294. pysisyphus/optimizers/NCOptimizer.py +90 -0
  295. pysisyphus/optimizers/Optimizer.py +1084 -0
  296. pysisyphus/optimizers/PreconLBFGS.py +260 -0
  297. pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
  298. pysisyphus/optimizers/QuickMin.py +74 -0
  299. pysisyphus/optimizers/RFOptimizer.py +181 -0
  300. pysisyphus/optimizers/RSA.py +99 -0
  301. pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
  302. pysisyphus/optimizers/SteepestDescent.py +23 -0
  303. pysisyphus/optimizers/StringOptimizer.py +173 -0
  304. pysisyphus/optimizers/__init__.py +41 -0
  305. pysisyphus/optimizers/closures.py +301 -0
  306. pysisyphus/optimizers/cls_map.py +58 -0
  307. pysisyphus/optimizers/exceptions.py +6 -0
  308. pysisyphus/optimizers/gdiis.py +280 -0
  309. pysisyphus/optimizers/guess_hessians.py +311 -0
  310. pysisyphus/optimizers/hessian_updates.py +355 -0
  311. pysisyphus/optimizers/poly_fit.py +285 -0
  312. pysisyphus/optimizers/precon.py +153 -0
  313. pysisyphus/optimizers/restrict_step.py +24 -0
  314. pysisyphus/pack.py +172 -0
  315. pysisyphus/peakdetect.py +948 -0
  316. pysisyphus/plot.py +1031 -0
  317. pysisyphus/run.py +2106 -0
  318. pysisyphus/socket_helper.py +74 -0
  319. pysisyphus/stocastic/FragmentKick.py +132 -0
  320. pysisyphus/stocastic/Kick.py +81 -0
  321. pysisyphus/stocastic/Pipeline.py +303 -0
  322. pysisyphus/stocastic/__init__.py +21 -0
  323. pysisyphus/stocastic/align.py +127 -0
  324. pysisyphus/testing.py +96 -0
  325. pysisyphus/thermo.py +156 -0
  326. pysisyphus/trj.py +824 -0
  327. pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
  328. pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
  329. pysisyphus/tsoptimizers/TRIM.py +59 -0
  330. pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
  331. pysisyphus/tsoptimizers/__init__.py +23 -0
  332. pysisyphus/wavefunction/Basis.py +239 -0
  333. pysisyphus/wavefunction/DIIS.py +76 -0
  334. pysisyphus/wavefunction/__init__.py +25 -0
  335. pysisyphus/wavefunction/build_ext.py +42 -0
  336. pysisyphus/wavefunction/cart2sph.py +190 -0
  337. pysisyphus/wavefunction/diabatization.py +304 -0
  338. pysisyphus/wavefunction/excited_states.py +435 -0
  339. pysisyphus/wavefunction/gen_ints.py +1811 -0
  340. pysisyphus/wavefunction/helpers.py +104 -0
  341. pysisyphus/wavefunction/ints/__init__.py +0 -0
  342. pysisyphus/wavefunction/ints/boys.py +193 -0
  343. pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
  344. pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
  345. pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
  346. pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
  347. pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
  348. pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
  349. pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
  350. pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
  351. pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
  352. pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
  353. pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
  354. pysisyphus/wavefunction/localization.py +458 -0
  355. pysisyphus/wavefunction/multipole.py +159 -0
  356. pysisyphus/wavefunction/normalization.py +36 -0
  357. pysisyphus/wavefunction/pop_analysis.py +134 -0
  358. pysisyphus/wavefunction/shells.py +1171 -0
  359. pysisyphus/wavefunction/wavefunction.py +504 -0
  360. pysisyphus/wrapper/__init__.py +11 -0
  361. pysisyphus/wrapper/exceptions.py +2 -0
  362. pysisyphus/wrapper/jmol.py +120 -0
  363. pysisyphus/wrapper/mwfn.py +169 -0
  364. pysisyphus/wrapper/packmol.py +71 -0
  365. pysisyphus/xyzloader.py +168 -0
  366. pysisyphus/yaml_mods.py +45 -0
  367. thermoanalysis/LICENSE +674 -0
  368. thermoanalysis/QCData.py +244 -0
  369. thermoanalysis/__init__.py +0 -0
  370. thermoanalysis/config.py +3 -0
  371. thermoanalysis/constants.py +20 -0
  372. thermoanalysis/thermo.py +1011 -0
mlmm/dft.py ADDED
@@ -0,0 +1,1041 @@
1
+ # mlmm/dft.py
2
+
3
+ """ML/MM-aware single-point DFT for the ML region with energy recombination.
4
+
5
+ Example:
6
+ mlmm dft -i enzyme.pdb --parm real.parm7 --model-pdb ml_region.pdb -q 0
7
+
8
+ For detailed documentation, see: docs/dft.md
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from copy import deepcopy
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
17
+
18
+ import logging
19
+ import shutil
20
+ import sys
21
+ import tempfile
22
+ import textwrap
23
+ import time
24
+ import traceback
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ import click
29
+ import numpy as np
30
+ import yaml
31
+
32
+ from ase import Atoms
33
+ from ase.io import read
34
+
35
+ from pysisyphus.constants import AU2EV, AU2KCALPERMOL
36
+
37
+ from .mlmm_calc import hessianffCalculator
38
+ from .opt import (
39
+ GEOM_KW as OPT_GEOM_KW,
40
+ CALC_KW as OPT_CALC_KW,
41
+ _parse_freeze_atoms as _parse_freeze_atoms_opt,
42
+ _normalize_geom_freeze as _normalize_geom_freeze_opt,
43
+ )
44
+ from .utils import (
45
+ apply_layer_freeze_constraints,
46
+ deep_update,
47
+ load_yaml_dict,
48
+ apply_yaml_overrides,
49
+ pretty_block,
50
+ format_freeze_atoms_for_echo,
51
+ format_elapsed,
52
+ merge_freeze_atom_indices,
53
+ prepare_input_structure,
54
+ resolve_charge_spin_or_raise,
55
+ parse_indices_string,
56
+ build_model_pdb_from_bfactors,
57
+ build_model_pdb_from_indices,
58
+ set_convert_file_enabled,
59
+ )
60
+ from .cli_utils import resolve_yaml_sources, load_merged_yaml_cfg, make_is_param_explicit
61
+ from .defaults import DFT_KW as _DFT_KW_DEFAULT
62
+
63
+ from functools import reduce
64
+
65
+ EV2AU = 1.0 / AU2EV
66
+
67
+ # Module-level alias (deepcopy at use-site for mutable safety)
68
+ DFT_KW = _DFT_KW_DEFAULT
69
+
70
+
71
+ # -----------------------------------------------
72
+ # Helper classes & utilities
73
+ # -----------------------------------------------
74
+ @dataclass
75
+ class MLRegionWorkspace:
76
+ tmpdir: tempfile.TemporaryDirectory
77
+ input_pdb: Path
78
+ real_parm7: Path
79
+ real_rst7: Path
80
+ model_pdb: Path
81
+ model_parm7: Path
82
+ model_rst7: Path
83
+ selection_indices: List[int]
84
+ link_pairs: List[Tuple[int, int]] # 1-based REAL indices (ml_idx, mm_idx)
85
+ atoms_real: Atoms
86
+ atoms_model: Atoms
87
+ atoms_model_lh: Atoms
88
+
89
+ def cleanup(self) -> None:
90
+ self.tmpdir.cleanup()
91
+
92
+
93
+ def _parse_func_basis(s: str) -> Tuple[str, str]:
94
+ if not s or "/" not in s:
95
+ raise click.BadParameter("Expected 'FUNC/BASIS' (e.g., 'wb97m-v/def2-tzvpd').")
96
+ func, basis = s.split("/", 1)
97
+ func = func.strip()
98
+ basis = basis.strip()
99
+ if not func or not basis:
100
+ raise click.BadParameter("Functional or basis is empty. Example: --func-basis 'wb97m-v/6-31g**'")
101
+ return func, basis
102
+
103
+
104
+ def _atoms_to_xyz_string(atoms: Atoms, comment: str) -> str:
105
+ lines = [str(len(atoms)), comment]
106
+ for sym, (x, y, z) in zip(atoms.get_chemical_symbols(), atoms.get_positions()):
107
+ lines.append(f"{sym:<2s} {x:15.8f} {y:15.8f} {z:15.8f}")
108
+ return "\n".join(lines) + "\n"
109
+
110
+
111
+ def _atoms_to_pyscf_atoms(atoms: Atoms) -> List[Tuple[str, Tuple[float, float, float]]]:
112
+ entries: List[Tuple[str, Tuple[float, float, float]]] = []
113
+ for sym, coord in zip(atoms.get_chemical_symbols(), atoms.get_positions()):
114
+ entries.append((sym, (float(coord[0]), float(coord[1]), float(coord[2]))))
115
+ return entries
116
+
117
+
118
+ def _load_model_region_ids(model_pdb: Path) -> set[str]:
119
+ ids: set[str] = set()
120
+ with model_pdb.open() as fh:
121
+ for line in fh:
122
+ if line.startswith(("ATOM", "HETATM")):
123
+ ids.add(f"{line[12:16].strip()} {line[17:20].strip()} {line[22:26].strip()}")
124
+ if not ids:
125
+ raise ValueError("No atoms found in model_pdb to define the ML region.")
126
+ return ids
127
+
128
+
129
+ def _load_input_atoms(input_pdb: Path) -> List[Dict[str, Any]]:
130
+ atoms: List[Dict[str, Any]] = []
131
+ with input_pdb.open() as fh:
132
+ for line in fh:
133
+ if not line.startswith(("ATOM", "HETATM")):
134
+ continue
135
+ elem = line[76:78].strip()
136
+ if not elem:
137
+ elem = line[12:16].strip()[0]
138
+ atoms.append(
139
+ {
140
+ "idx": int(line[6:11]),
141
+ "id": f"{line[12:16].strip()} {line[17:20].strip()} {line[22:26].strip()}",
142
+ "elem": elem,
143
+ "coord": np.array(
144
+ [float(line[30:38]), float(line[38:46]), float(line[46:54])],
145
+ dtype=float,
146
+ ),
147
+ }
148
+ )
149
+ if not atoms:
150
+ raise ValueError("No ATOM/HETATM records found in the input PDB.")
151
+ return atoms
152
+
153
+
154
+ def _detect_link_pairs(
155
+ leap_atoms: Sequence[Dict[str, Any]],
156
+ ml_region_ids: set[str],
157
+ manual_links: Optional[Sequence[Sequence[str]]],
158
+ ) -> List[Tuple[int, int]]:
159
+ if manual_links:
160
+ processed = [(" ".join(q.split()[:3]), " ".join(m.split()[:3])) for q, m in manual_links]
161
+ ml_indices: List[int] = []
162
+ mm_indices: List[int] = []
163
+ for atom in leap_atoms:
164
+ for qnm, mnm in processed:
165
+ if atom["id"] == qnm:
166
+ ml_indices.append(atom["idx"])
167
+ elif atom["id"] == mnm:
168
+ mm_indices.append(atom["idx"])
169
+ if len(set(ml_indices)) != len(ml_indices) or len(set(mm_indices)) != len(mm_indices):
170
+ raise ValueError("Duplicated ML or MM indices detected in link_mlmm specification.")
171
+ return list(zip(ml_indices, mm_indices))
172
+
173
+ threshold = 1.7
174
+ ml_set = {atom["idx"] for atom in leap_atoms if atom["id"] in ml_region_ids}
175
+ coords = {atom["idx"]: atom["coord"] for atom in leap_atoms}
176
+ elems = {atom["idx"]: atom["elem"] for atom in leap_atoms}
177
+ ml_indices: List[int] = []
178
+ mm_indices: List[int] = []
179
+ for qidx in ml_set:
180
+ for atom in leap_atoms:
181
+ midx = atom["idx"]
182
+ if midx in ml_set:
183
+ continue
184
+ if np.linalg.norm(coords[midx] - coords[qidx]) < threshold and (
185
+ (elems[midx] == "C" and elems[qidx] == "C")
186
+ or (elems[midx] == "N" and elems[qidx] == "C")
187
+ or (elems[midx] == "C" and elems[qidx] == "N")
188
+ ):
189
+ ml_indices.append(qidx)
190
+ mm_indices.append(midx)
191
+ if len(set(ml_indices)) != len(ml_indices) or len(set(mm_indices)) != len(mm_indices):
192
+ raise ValueError("Automatic link detection produced duplicate ML/MM indices; specify link_mlmm explicitly.")
193
+ return list(zip(ml_indices, mm_indices))
194
+
195
+
196
+ def _append_link_hydrogens(atoms_model: Atoms, atoms_real: Atoms, link_pairs: Sequence[Tuple[int, int]]) -> Atoms:
197
+ atoms_with_link = atoms_model.copy()
198
+ for ml_idx1, mm_idx1 in link_pairs:
199
+ ml_idx = ml_idx1 - 1
200
+ mm_idx = mm_idx1 - 1
201
+ ml_elem = atoms_real[ml_idx].symbol.strip().upper()
202
+ if ml_elem == "C":
203
+ dist = 1.09
204
+ elif ml_elem == "N":
205
+ dist = 1.01
206
+ else:
207
+ raise ValueError(
208
+ f"Unsupported link-atom parent element '{ml_elem}' (only C or N are allowed)."
209
+ )
210
+ vec = atoms_real[mm_idx].position - atoms_real[ml_idx].position
211
+ R = np.linalg.norm(vec)
212
+ if R < 1e-8:
213
+ continue
214
+ pos = atoms_real[ml_idx].position + (vec / R) * dist
215
+ atoms_with_link += Atoms("H", positions=[pos])
216
+ return atoms_with_link
217
+
218
+
219
+ def _prepare_ml_region_workspace(
220
+ *,
221
+ input_pdb: Path,
222
+ real_parm7: Path,
223
+ model_pdb: Path,
224
+ link_mlmm: Optional[Sequence[Sequence[str]]],
225
+ ) -> MLRegionWorkspace:
226
+ tmpdir = tempfile.TemporaryDirectory()
227
+ tmp = Path(tmpdir.name)
228
+
229
+ input_copy = tmp / "input.pdb"
230
+ real_copy = tmp / "real.parm7"
231
+ model_copy = tmp / "model.pdb"
232
+ shutil.copyfile(input_pdb, input_copy)
233
+ shutil.copyfile(real_parm7, real_copy)
234
+ shutil.copyfile(model_pdb, model_copy)
235
+
236
+ real_top = None
237
+ try:
238
+ import parmed as pmd
239
+
240
+ real_top = pmd.load_file(str(real_copy))
241
+ start_struct = pmd.load_file(str(input_copy))
242
+ real_top.coordinates = start_struct.coordinates
243
+ real_top.box = None
244
+ real_top.save(str(real_copy), overwrite=True)
245
+ real_rst7 = tmp / "real.rst7"
246
+ real_top.save(str(real_rst7), overwrite=True)
247
+ except Exception as exc: # pragma: no cover - requires parmed
248
+ tmp.cleanup()
249
+ raise RuntimeError(f"Failed to prepare sanitized Amber inputs: {exc}") from exc
250
+
251
+ ml_region_ids = _load_model_region_ids(model_copy)
252
+ leap_atoms = _load_input_atoms(input_copy)
253
+ ml_ids = [atom["idx"] for atom in leap_atoms if atom["id"] in ml_region_ids]
254
+ if not ml_ids:
255
+ tmp.cleanup()
256
+ raise ValueError("No overlap between model_pdb atoms and the input PDB was found.")
257
+
258
+ link_pairs = _detect_link_pairs(leap_atoms, ml_region_ids, link_mlmm)
259
+ selection_indices = [idx - 1 for idx in ml_ids]
260
+
261
+ model_parm7 = tmp / "model.parm7"
262
+ model_rst7 = tmp / "model.rst7"
263
+ selection = selection_indices
264
+ if len(selection) == len(real_top.atoms):
265
+ shutil.copyfile(real_copy, model_parm7)
266
+ shutil.copyfile(real_rst7, model_rst7)
267
+ else:
268
+ model = real_top[selection]
269
+ model.box = None
270
+ model.save(str(model_parm7), overwrite=True)
271
+ model.save(str(model_rst7), overwrite=True)
272
+
273
+ atoms_real = read(str(input_copy))
274
+ atoms_model = read(str(model_copy))
275
+ if len(atoms_model) != len(selection_indices):
276
+ tmp.cleanup()
277
+ raise ValueError(
278
+ "model_pdb atom count does not match the detected ML-region selection from the input PDB."
279
+ )
280
+ for i, ridx in enumerate(selection_indices):
281
+ atoms_model[i].position = atoms_real[ridx].position
282
+
283
+ atoms_model_lh = _append_link_hydrogens(atoms_model, atoms_real, link_pairs)
284
+
285
+ return MLRegionWorkspace(
286
+ tmpdir=tmpdir,
287
+ input_pdb=input_copy,
288
+ real_parm7=real_copy,
289
+ real_rst7=real_rst7,
290
+ model_pdb=model_copy,
291
+ model_parm7=model_parm7,
292
+ model_rst7=model_rst7,
293
+ selection_indices=selection_indices,
294
+ link_pairs=link_pairs,
295
+ atoms_real=atoms_real,
296
+ atoms_model=atoms_model,
297
+ atoms_model_lh=atoms_model_lh,
298
+ )
299
+
300
+
301
+ def _hartree_to_kcalmol(Eh: float) -> float:
302
+ return float(Eh * AU2KCALPERMOL)
303
+
304
+
305
+ class FlowList(list):
306
+ pass
307
+
308
+
309
+ def _flow_seq_representer(dumper, data):
310
+ return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=True)
311
+
312
+
313
+ yaml.SafeDumper.add_representer(FlowList, _flow_seq_representer)
314
+
315
+
316
+ def _format_row_for_echo(row: List[Any]) -> str:
317
+ def _fmt(x: Any) -> str:
318
+ if x is None:
319
+ return "null"
320
+ if isinstance(x, float):
321
+ return f"{x:.10g}"
322
+ return str(x)
323
+
324
+ return "[" + ", ".join(_fmt(v) for v in row) + "]"
325
+
326
+
327
+ # ---- PySCF helpers copied from the legacy DFT script ----
328
+ def fast_iao_mullikan_spin_pop(mol, dm, iaos, verbose=None):
329
+ import numpy
330
+ from pyscf.lib import logger as pyscf_logger
331
+ from pyscf.lo.iao import reference_mol
332
+ from pyscf.scf import uhf as scf_uhf
333
+
334
+ if verbose is None:
335
+ verbose = pyscf_logger.DEBUG
336
+
337
+ pmol = reference_mol(mol)
338
+ if getattr(mol, 'pbc_intor', None):
339
+ ovlpS = mol.pbc_intor('int1e_ovlp')
340
+ else:
341
+ ovlpS = mol.intor_symmetric('int1e_ovlp')
342
+
343
+ cs = numpy.dot(iaos.T.conj(), ovlpS)
344
+ s_iao = numpy.dot(cs, iaos)
345
+ iao_inv = numpy.linalg.solve(s_iao, cs)
346
+
347
+ if isinstance(dm, numpy.ndarray) and dm.ndim == 2:
348
+ spin_pop_ao = numpy.zeros(s_iao.shape[0], dtype=float)
349
+ Ms = numpy.zeros(pmol.natm, dtype=float)
350
+ return spin_pop_ao, Ms
351
+
352
+ dm_a = reduce(numpy.dot, (iao_inv, dm[0], iao_inv.conj().T))
353
+ dm_b = reduce(numpy.dot, (iao_inv, dm[1], iao_inv.conj().T))
354
+ return scf_uhf.mulliken_spin_pop(pmol, [dm_a, dm_b], s_iao, verbose)
355
+
356
+
357
+ def _compute_atomic_charges(mol, mf) -> Dict[str, Optional[List[float]]]:
358
+ from pyscf.scf import hf as scf_hf
359
+ from pyscf.lo import iao as lo_iao
360
+
361
+ dm = mf.make_rdm1()
362
+ S = mf.get_ovlp()
363
+ dm_tot = dm[0] + dm[1] if (isinstance(dm, np.ndarray) and dm.ndim == 3) else dm
364
+
365
+ try:
366
+ _, mull_chg = scf_hf.mulliken_pop(mol, dm_tot, s=S, verbose=0)
367
+ mull_q = np.asarray(mull_chg, dtype=float).tolist()
368
+ except Exception as e:
369
+ click.echo(f"[Mulliken] WARNING: Failed to compute Mulliken charges: {e}", err=True)
370
+ mull_q = None
371
+
372
+ try:
373
+ _, low_chg = scf_hf.mulliken_pop_meta_lowdin_ao(mol, dm_tot, verbose=0, s=S)
374
+ low_q = np.asarray(low_chg, dtype=float).tolist()
375
+ except Exception as e:
376
+ click.echo(f"[Löwdin] WARNING: Failed to compute meta-Löwdin charges: {e}", err=True)
377
+ low_q = None
378
+
379
+ iao_q: Optional[List[float]] = None
380
+ try:
381
+ mo = mf.mo_coeff
382
+ mo_occ = mf.mo_occ
383
+ if isinstance(mo, np.ndarray) and mo.ndim == 2:
384
+ occ_idx = np.asarray(mo_occ) > 0
385
+ orbocc = mo[:, occ_idx]
386
+ else:
387
+ occ_idx = np.asarray(mo_occ[0]) > 0
388
+ orbocc = mo[0][:, occ_idx]
389
+ iaos = lo_iao.iao(mol, orbocc, minao="minao")
390
+ _, iao_chg = lo_iao.fast_iao_mullikan_pop(mol, dm, iaos, verbose=0)
391
+ iao_q = np.asarray(iao_chg, dtype=float).tolist()
392
+ except Exception as e:
393
+ click.echo(f"[IAO] WARNING: Failed to compute IAO charges: {e}", err=True)
394
+ iao_q = None
395
+
396
+ return {
397
+ "mulliken": mull_q,
398
+ "lowdin": low_q,
399
+ "iao": iao_q,
400
+ }
401
+
402
+
403
+ def _compute_atomic_spin_densities(mol, mf) -> Dict[str, Optional[List[float]]]:
404
+ from pyscf.scf import uhf as scf_uhf
405
+ from pyscf.lo import iao as lo_iao
406
+
407
+ dm = mf.make_rdm1()
408
+ S = mf.get_ovlp()
409
+ nat = mol.natm
410
+
411
+ if not (isinstance(dm, np.ndarray) and dm.ndim == 3):
412
+ zeros = [0.0] * nat
413
+ return {"mulliken": zeros, "lowdin": zeros, "iao": zeros}
414
+
415
+ try:
416
+ _, Ms_mull = scf_uhf.mulliken_spin_pop(mol, dm, s=S, verbose=0)
417
+ mull = np.asarray(Ms_mull, dtype=float).tolist()
418
+ except Exception as e:
419
+ click.echo(f"[Spin Mulliken] WARNING: Failed to compute Mulliken spin densities: {e}", err=True)
420
+ mull = None
421
+
422
+ try:
423
+ _, Ms_low = scf_uhf.mulliken_spin_pop_meta_lowdin_ao(mol, dm, verbose=0, s=S)
424
+ low = np.asarray(Ms_low, dtype=float).tolist()
425
+ except Exception as e:
426
+ click.echo(f"[Spin Löwdin] WARNING: Failed to compute meta-Löwdin spin densities: {e}", err=True)
427
+ low = None
428
+
429
+ iao_ms: Optional[List[float]] = None
430
+ try:
431
+ mo = mf.mo_coeff
432
+ mo_occ = mf.mo_occ
433
+ if isinstance(mo, np.ndarray) and mo.ndim == 2:
434
+ occ_idx = np.asarray(mo_occ) > 0
435
+ orbocc = mo[:, occ_idx]
436
+ else:
437
+ occ_idx = np.asarray(mo_occ[0]) > 0
438
+ orbocc = mo[0][:, occ_idx]
439
+ iaos = lo_iao.iao(mol, orbocc, minao="minao")
440
+ _, Ms_iao = fast_iao_mullikan_spin_pop(mol, dm, iaos, verbose=0)
441
+ iao_ms = np.asarray(Ms_iao, dtype=float).tolist()
442
+ except Exception as e:
443
+ click.echo(f"[Spin IAO] WARNING: Failed to compute IAO spin densities: {e}", err=True)
444
+ iao_ms = None
445
+
446
+ return {"mulliken": mull, "lowdin": low, "iao": iao_ms}
447
+ # -----------------------------------------------
448
+ # CLI
449
+ # -----------------------------------------------
450
+
451
+
452
+ @click.command(
453
+ help="Single-point ML-region DFT with ML(dft)/MM energy recombination.",
454
+ context_settings={"help_option_names": ["-h", "--help"]},
455
+ )
456
+ @click.option(
457
+ "-i",
458
+ "--input",
459
+ "input_path",
460
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
461
+ required=True,
462
+ help="Full enzyme structure (PDB or XYZ). If XYZ, use --ref-pdb for topology.",
463
+ )
464
+ @click.option(
465
+ "--ref-pdb",
466
+ "ref_pdb",
467
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
468
+ default=None,
469
+ show_default=False,
470
+ help="Reference PDB topology when input is XYZ. XYZ coordinates are used (higher precision) "
471
+ "while PDB provides atom ordering and residue information.",
472
+ )
473
+ @click.option(
474
+ "--parm",
475
+ "real_parm7",
476
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
477
+ required=True,
478
+ help="Amber parm7 topology for the full system.",
479
+ )
480
+ @click.option(
481
+ "--model-pdb",
482
+ "model_pdb",
483
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
484
+ required=False,
485
+ help="PDB defining the ML region (atom IDs must match the enzyme PDB). "
486
+ "Optional when --detect-layer is enabled.",
487
+ )
488
+ @click.option(
489
+ "--model-indices",
490
+ "model_indices_str",
491
+ type=str,
492
+ default=None,
493
+ show_default=False,
494
+ help="Comma-separated atom indices for the ML region (ranges allowed like 1-5). "
495
+ "Used when --model-pdb is omitted.",
496
+ )
497
+ @click.option(
498
+ "--model-indices-one-based/--model-indices-zero-based",
499
+ "model_indices_one_based",
500
+ default=True,
501
+ show_default=True,
502
+ help="Interpret --model-indices as 1-based (default) or 0-based.",
503
+ )
504
+ @click.option(
505
+ "--detect-layer/--no-detect-layer",
506
+ "detect_layer",
507
+ default=True,
508
+ show_default=True,
509
+ help="Detect ML/MM layers from input PDB B-factors (B=0/10/20). "
510
+ "If disabled, you must provide --model-pdb or --model-indices.",
511
+ )
512
+ @click.option("-q", "--charge", type=int, required=True, help="Charge of the ML region.")
513
+ @click.option(
514
+ "-m",
515
+ "--multiplicity",
516
+ "spin",
517
+ type=int,
518
+ default=1,
519
+ show_default=True,
520
+ help="Spin multiplicity (2S+1) for the ML region.",
521
+ )
522
+ @click.option(
523
+ "--freeze-atoms",
524
+ "freeze_atoms_text",
525
+ type=str,
526
+ default=None,
527
+ help="Comma-separated 1-based indices to freeze (e.g., '1,3,5').",
528
+ )
529
+ @click.option(
530
+ "--func-basis",
531
+ "func_basis",
532
+ type=str,
533
+ default="wb97m-v/def2-tzvpd",
534
+ show_default=True,
535
+ help='Exchange-correlation functional and basis set as "FUNC/BASIS".',
536
+ )
537
+ @click.option("--max-cycle", type=int, default=DFT_KW["max_cycle"], show_default=True, help="Maximum SCF iterations.")
538
+ @click.option("--conv-tol", type=float, default=DFT_KW["conv_tol"], show_default=True, help="SCF convergence tolerance (Hartree).")
539
+ @click.option("--grid-level", type=int, default=DFT_KW["grid_level"], show_default=True, help="DFT integration grid level (0=coarse, 3=default, 9=ultrafine).")
540
+ @click.option(
541
+ "-o", "--out-dir",
542
+ type=click.Path(path_type=Path, dir_okay=True, file_okay=False),
543
+ default=Path(DFT_KW["out_dir"]),
544
+ show_default=True,
545
+ help="Output directory.",
546
+ )
547
+ @click.option(
548
+ "--config",
549
+ "config_yaml",
550
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
551
+ default=None,
552
+ help="Base YAML configuration file applied before explicit CLI options.",
553
+ )
554
+ @click.option(
555
+ "--show-config/--no-show-config",
556
+ "show_config",
557
+ default=False,
558
+ show_default=True,
559
+ help="Print resolved configuration and continue execution.",
560
+ )
561
+ @click.option(
562
+ "--dry-run/--no-dry-run",
563
+ "dry_run",
564
+ default=False,
565
+ show_default=True,
566
+ help="Validate options and print the execution plan without running DFT.",
567
+ )
568
+ @click.option(
569
+ "--convert-files/--no-convert-files",
570
+ "convert_files",
571
+ default=True,
572
+ show_default=True,
573
+ help="Toggle XYZ/TRJ to PDB companions when a PDB template is available.",
574
+ )
575
+ @click.option(
576
+ "-b", "--backend",
577
+ type=click.Choice(["uma", "orb", "mace", "aimnet2"], case_sensitive=False),
578
+ default=None,
579
+ show_default=False,
580
+ help="ML backend for the ONIOM high-level region (default: uma).",
581
+ )
582
+ @click.option(
583
+ "--embedcharge/--no-embedcharge",
584
+ "embedcharge",
585
+ default=False,
586
+ show_default=True,
587
+ help="Enable electrostatic embedding: MM point charges are added to the PySCF QM Hamiltonian via pyscf.qmmm.mm_charge().",
588
+ )
589
+ @click.option(
590
+ "--embedcharge-cutoff",
591
+ "embedcharge_cutoff",
592
+ type=float,
593
+ default=None,
594
+ show_default=False,
595
+ help="Distance cutoff (Å) from ML region for MM point charges in xTB embedding. "
596
+ "Default: 12.0 Å when --embedcharge is enabled.",
597
+ )
598
+ @click.pass_context
599
+ def cli(
600
+ ctx: click.Context,
601
+ input_path: Path,
602
+ ref_pdb: Optional[Path],
603
+ real_parm7: Path,
604
+ model_pdb: Optional[Path],
605
+ model_indices_str: Optional[str],
606
+ model_indices_one_based: bool,
607
+ detect_layer: bool,
608
+ charge: int,
609
+ spin: int,
610
+ freeze_atoms_text: Optional[str],
611
+ func_basis: str,
612
+ max_cycle: int,
613
+ conv_tol: float,
614
+ grid_level: int,
615
+ out_dir: Path,
616
+ config_yaml: Optional[Path],
617
+ show_config: bool,
618
+ dry_run: bool,
619
+ convert_files: bool,
620
+ backend: Optional[str],
621
+ embedcharge: bool,
622
+ embedcharge_cutoff: Optional[float],
623
+ ) -> None:
624
+ set_convert_file_enabled(convert_files)
625
+
626
+ # Resolve XYZ + --ref-pdb → use ref-pdb as topology source
627
+ if input_path.suffix.lower() != ".pdb":
628
+ if ref_pdb is None:
629
+ raise click.BadParameter(
630
+ "Input is not a PDB file. Provide --ref-pdb for topology when using XYZ input."
631
+ )
632
+ if ref_pdb.suffix.lower() != ".pdb":
633
+ raise click.BadParameter("--ref-pdb must be a .pdb file.")
634
+ source_pdb = ref_pdb
635
+ else:
636
+ source_pdb = input_path
637
+
638
+ _is_param_explicit = make_is_param_explicit(ctx)
639
+
640
+ config_yaml, override_yaml, used_legacy_yaml = resolve_yaml_sources(
641
+ config_yaml=config_yaml,
642
+ override_yaml=None,
643
+ args_yaml_legacy=None,
644
+ )
645
+ merged_yaml_cfg, _, _ = load_merged_yaml_cfg(
646
+ config_yaml=config_yaml,
647
+ override_yaml=None,
648
+ )
649
+
650
+ prepared_input = None
651
+ workspace: Optional[MLRegionWorkspace] = None
652
+
653
+ model_indices: Optional[List[int]] = None
654
+ if model_indices_str:
655
+ try:
656
+ model_indices = parse_indices_string(model_indices_str, one_based=model_indices_one_based)
657
+ except click.BadParameter as e:
658
+ raise click.ClickException(str(e))
659
+
660
+ try:
661
+ geom_kw = deepcopy(OPT_GEOM_KW)
662
+ calc_kw = deepcopy(OPT_CALC_KW)
663
+ dft_kw = dict(DFT_KW)
664
+
665
+ apply_yaml_overrides(
666
+ merged_yaml_cfg,
667
+ [
668
+ (geom_kw, (("geom",),)),
669
+ (calc_kw, (("calc",), ("mlmm",))),
670
+ (dft_kw, (("dft",),)),
671
+ ],
672
+ )
673
+
674
+ # CLI explicit overrides (after config YAML)
675
+ if backend is not None:
676
+ calc_kw["backend"] = str(backend).lower()
677
+ if _is_param_explicit("embedcharge"):
678
+ calc_kw["embedcharge"] = bool(embedcharge)
679
+ if _is_param_explicit("embedcharge_cutoff"):
680
+ calc_kw["embedcharge_cutoff"] = embedcharge_cutoff
681
+
682
+ if _is_param_explicit("conv_tol"):
683
+ dft_kw["conv_tol"] = float(conv_tol)
684
+ if _is_param_explicit("max_cycle"):
685
+ dft_kw["max_cycle"] = int(max_cycle)
686
+ if _is_param_explicit("grid_level"):
687
+ dft_kw["grid_level"] = int(grid_level)
688
+ if _is_param_explicit("out_dir"):
689
+ dft_kw["out_dir"] = str(out_dir)
690
+
691
+ func_basis_value = str(dft_kw.get("func_basis", func_basis))
692
+ if _is_param_explicit("func_basis"):
693
+ func_basis_value = func_basis
694
+ xc, basis = _parse_func_basis(func_basis_value)
695
+
696
+ geom_kw["coord_type"] = "cart"
697
+ geom_kw["freeze_atoms"] = _normalize_geom_freeze_opt(geom_kw.get("freeze_atoms"))
698
+ freeze_atoms_cli = _parse_freeze_atoms_opt(freeze_atoms_text)
699
+ calc_kw["freeze_atoms"] = merge_freeze_atom_indices(geom_kw, freeze_atoms_cli)
700
+
701
+ detect_layer_enabled = bool(calc_kw.get("use_bfactor_layers", True))
702
+ if _is_param_explicit("detect_layer"):
703
+ detect_layer_enabled = bool(detect_layer)
704
+
705
+ layer_source_pdb = source_pdb.resolve()
706
+ model_pdb_cfg = calc_kw.get("model_pdb")
707
+ if model_pdb is not None:
708
+ model_pdb_cfg = model_pdb
709
+ model_pdb_path: Optional[Path] = None
710
+ layer_info: Optional[Dict[str, List[int]]] = None
711
+
712
+ model_multiplicity = int(calc_kw.get("model_mult", spin))
713
+ if _is_param_explicit("spin"):
714
+ model_multiplicity = int(spin)
715
+ calc_kw["model_mult"] = model_multiplicity
716
+ calc_kw["model_charge"] = int(charge)
717
+
718
+ dft_block = {
719
+ "charge": int(charge),
720
+ "multiplicity": int(calc_kw["model_mult"]),
721
+ "xc": xc,
722
+ "basis": basis,
723
+ "conv_tol": dft_kw["conv_tol"],
724
+ "max_cycle": dft_kw["max_cycle"],
725
+ "grid_level": dft_kw["grid_level"],
726
+ "out_dir": str(Path(dft_kw["out_dir"]).resolve()),
727
+ }
728
+
729
+ click.echo(pretty_block("geom", format_freeze_atoms_for_echo(geom_kw, key="freeze_atoms")))
730
+ click.echo(pretty_block("calc", {k: calc_kw[k] for k in sorted(calc_kw.keys()) if k not in {"freeze_atoms"}}))
731
+ click.echo(pretty_block("dft", dft_block))
732
+
733
+ if show_config:
734
+ click.echo(
735
+ pretty_block(
736
+ "yaml_layers",
737
+ {
738
+ "config": None if config_yaml is None else str(config_yaml),
739
+ "override_yaml": None if override_yaml is None else str(override_yaml),
740
+ "merged_keys": sorted(merged_yaml_cfg.keys()),
741
+ },
742
+ )
743
+ )
744
+
745
+ if dry_run:
746
+ if (not detect_layer_enabled) and (model_pdb_cfg is None) and (not model_indices):
747
+ raise click.ClickException(
748
+ "Provide --model-pdb or --model-indices when --no-detect-layer."
749
+ )
750
+ click.echo(
751
+ pretty_block(
752
+ "dry_run_plan",
753
+ {
754
+ "will_prepare_input": True,
755
+ "detect_layer": bool(detect_layer_enabled),
756
+ "model_region_source": (
757
+ "bfactor"
758
+ if detect_layer_enabled
759
+ else ("model_pdb" if model_pdb_cfg is not None else "model_indices")
760
+ ),
761
+ "model_indices_count": 0 if not model_indices else len(model_indices),
762
+ "output_dir": str(Path(dft_kw["out_dir"]).resolve()),
763
+ "backend": calc_kw.get("backend", "uma"),
764
+ "embedcharge": bool(calc_kw.get("embedcharge", False)),
765
+ },
766
+ )
767
+ )
768
+ click.echo("[dry-run] Validation complete. DFT execution was skipped.")
769
+ return
770
+
771
+ prepared_input = prepare_input_structure(input_path)
772
+ out_dir_path = Path(dft_kw["out_dir"]).resolve()
773
+ out_dir_path.mkdir(parents=True, exist_ok=True)
774
+
775
+ if detect_layer_enabled:
776
+ try:
777
+ model_pdb_path, layer_info = build_model_pdb_from_bfactors(layer_source_pdb, out_dir_path)
778
+ calc_kw["use_bfactor_layers"] = True
779
+ click.echo(
780
+ f"[layer] Detected B-factor layers: ML={len(layer_info.get('ml_indices', []))}, "
781
+ f"MovableMM={len(layer_info.get('movable_mm_indices', []))}, "
782
+ f"FrozenMM={len(layer_info.get('frozen_indices', []))}"
783
+ )
784
+ except Exception as e:
785
+ if model_pdb_cfg is None and not model_indices:
786
+ raise click.ClickException(str(e))
787
+ click.echo(f"[layer] WARNING: {e} Falling back to explicit ML region.", err=True)
788
+ detect_layer_enabled = False
789
+
790
+ if not detect_layer_enabled:
791
+ if model_pdb_cfg is None and not model_indices:
792
+ raise click.ClickException("Provide --model-pdb or --model-indices when --no-detect-layer.")
793
+ if model_pdb_cfg is not None:
794
+ model_pdb_path = Path(model_pdb_cfg)
795
+ else:
796
+ try:
797
+ model_pdb_path = build_model_pdb_from_indices(layer_source_pdb, out_dir_path, model_indices or [])
798
+ except Exception as e:
799
+ raise click.ClickException(str(e))
800
+ calc_kw["use_bfactor_layers"] = False
801
+
802
+ if model_pdb_path is None:
803
+ raise click.ClickException("Failed to resolve model PDB for the ML region.")
804
+
805
+ calc_kw["input_pdb"] = str(source_pdb.resolve())
806
+ calc_kw["real_parm7"] = str(real_parm7.resolve())
807
+ calc_kw["model_pdb"] = str(model_pdb_path.resolve())
808
+ apply_layer_freeze_constraints(
809
+ geom_kw,
810
+ calc_kw,
811
+ layer_info if detect_layer_enabled else None,
812
+ echo_fn=click.echo,
813
+ )
814
+ time_start = time.perf_counter()
815
+
816
+ workspace = _prepare_ml_region_workspace(
817
+ input_pdb=Path(calc_kw["input_pdb"]),
818
+ real_parm7=Path(calc_kw["real_parm7"]),
819
+ model_pdb=Path(calc_kw["model_pdb"]),
820
+ link_mlmm=calc_kw.get("link_mlmm"),
821
+ )
822
+ model_charge = int(calc_kw["model_charge"])
823
+ model_mult = int(calc_kw["model_mult"])
824
+ model_spin2s = model_mult - 1
825
+
826
+ xyz_path = out_dir_path / "ml_region_with_linkH.xyz"
827
+ xyz_path.write_text(_atoms_to_xyz_string(workspace.atoms_model_lh, "ML region + link-H"))
828
+ click.echo(f"[write] Wrote '{xyz_path}'.")
829
+
830
+ try:
831
+ from pyscf import gto
832
+ except Exception as exc:
833
+ raise click.ClickException(f"PySCF import failed: {exc}") from exc
834
+
835
+ mol = gto.Mole()
836
+ mol.verbose = int(dft_kw.get("verbose", 4))
837
+ mol.build(
838
+ atom=_atoms_to_pyscf_atoms(workspace.atoms_model_lh),
839
+ unit="Angstrom",
840
+ charge=model_charge,
841
+ spin=model_spin2s,
842
+ basis=basis,
843
+ )
844
+
845
+ using_gpu = False
846
+ engine_label = "pyscf(cpu)"
847
+ try:
848
+ import gpu4pyscf
849
+
850
+ gpu4pyscf.activate()
851
+ from gpu4pyscf import dft as gdf
852
+
853
+ mf = gdf.RKS(mol) if model_spin2s == 0 else gdf.UKS(mol)
854
+ using_gpu = True
855
+ engine_label = "gpu4pyscf"
856
+ except Exception:
857
+ from pyscf import dft as pdft
858
+
859
+ mf = pdft.RKS(mol) if model_spin2s == 0 else pdft.UKS(mol)
860
+
861
+ mf.xc = xc
862
+ mf.max_cycle = int(dft_kw["max_cycle"])
863
+ mf.conv_tol = float(dft_kw["conv_tol"])
864
+ try:
865
+ mf.grids.level = int(dft_kw["grid_level"])
866
+ except Exception as exc:
867
+ click.echo(f"[grids] WARNING: Could not set grids.level={dft_kw['grid_level']}: {exc}", err=True)
868
+ try:
869
+ mf.chkfile = None
870
+ except Exception:
871
+ logger.debug("Failed to disable chkfile", exc_info=True)
872
+ if xc.lower().endswith("-v") or "vv10" in xc.lower():
873
+ mf.nlc = "vv10"
874
+
875
+ # --- Electrostatic embedding (--embedcharge) ---
876
+ n_mm_charges = 0
877
+ if calc_kw.get("embedcharge", False):
878
+ import parmed as pmd
879
+ from pyscf import qmmm as pyscf_qmmm
880
+
881
+ real_top = pmd.load_file(str(workspace.real_parm7))
882
+ ml_set = set(workspace.selection_indices)
883
+ mm_indices = [i for i in range(len(workspace.atoms_real)) if i not in ml_set]
884
+
885
+ if mm_indices:
886
+ mm_coords = workspace.atoms_real.get_positions()[mm_indices]
887
+ mm_charges = np.array([real_top.atoms[i].charge for i in mm_indices])
888
+ mf = pyscf_qmmm.mm_charge(mf, mm_coords, mm_charges, unit="Angstrom")
889
+ n_mm_charges = len(mm_indices)
890
+ click.echo(f"[embedcharge] {n_mm_charges} MM point charges embedded into QM Hamiltonian.")
891
+ else:
892
+ click.echo("[embedcharge] No MM atoms found; skipping embedding.")
893
+
894
+ click.echo("\n=== ML-region DFT single-point started ===\n")
895
+ tic_scf = time.time()
896
+ e_tot = mf.kernel()
897
+ toc_scf = time.time()
898
+ click.echo("\n=== ML-region DFT single-point finished ===\n")
899
+
900
+ converged = bool(getattr(mf, "converged", False))
901
+ if e_tot is None:
902
+ e_tot = float(getattr(mf, "e_tot", np.nan))
903
+ e_h = float(e_tot)
904
+ e_kcal = _hartree_to_kcalmol(e_h)
905
+
906
+ charges = _compute_atomic_charges(mol, mf)
907
+ spins = _compute_atomic_spin_densities(mol, mf)
908
+
909
+ def _round(xs: Optional[List[float]]) -> Optional[List[float]]:
910
+ if xs is None:
911
+ return None
912
+ return [0.0 if (x == x) and abs(x) < 1e-10 else float(x) for x in xs]
913
+
914
+ charges = {k: _round(v) for k, v in charges.items()}
915
+ spins = {k: _round(v) for k, v in spins.items()}
916
+
917
+ charges_table: List[List[Any]] = []
918
+ spins_table: List[List[Any]] = []
919
+ for i in range(mol.natm):
920
+ elem = mol.atom_symbol(i)
921
+ charges_table.append([
922
+ i,
923
+ elem,
924
+ None if charges["mulliken"] is None else charges["mulliken"][i],
925
+ None if charges["lowdin"] is None else charges["lowdin"][i],
926
+ None if charges["iao"] is None else charges["iao"][i],
927
+ ])
928
+ spins_table.append([
929
+ i,
930
+ elem,
931
+ None if spins["mulliken"] is None else spins["mulliken"][i],
932
+ None if spins["lowdin"] is None else spins["lowdin"][i],
933
+ None if spins["iao"] is None else spins["iao"][i],
934
+ ])
935
+
936
+ click.echo("\ncharges [index, element, mulliken, lowdin, iao]:")
937
+ for row in charges_table:
938
+ click.echo(f"- {_format_row_for_echo(row)}")
939
+
940
+ click.echo("\nspin_densities [index, element, mulliken, lowdin, iao]:")
941
+ for row in spins_table:
942
+ click.echo(f"- {_format_row_for_echo(row)}")
943
+
944
+ mm_device = calc_kw.get("mm_device", "cpu")
945
+ mm_cuda_idx = int(calc_kw.get("mm_cuda_idx", 0))
946
+ mm_threads = int(calc_kw.get("mm_threads", 16))
947
+
948
+ atoms_real = workspace.atoms_real.copy()
949
+ atoms_model = workspace.atoms_model.copy()
950
+
951
+ calc_real = hessianffCalculator(
952
+ parm7=str(workspace.real_parm7),
953
+ rst7=str(workspace.real_rst7),
954
+ device=mm_device,
955
+ cuda_idx=mm_cuda_idx,
956
+ threads=mm_threads,
957
+ )
958
+ atoms_real.calc = calc_real
959
+ e_real_low = atoms_real.get_potential_energy()
960
+
961
+ calc_model = hessianffCalculator(
962
+ parm7=str(workspace.model_parm7),
963
+ rst7=str(workspace.model_rst7),
964
+ device=mm_device,
965
+ cuda_idx=mm_cuda_idx,
966
+ threads=mm_threads,
967
+ )
968
+ atoms_model.calc = calc_model
969
+ e_model_low = atoms_model.get_potential_energy()
970
+
971
+ e_real_low_au = e_real_low * EV2AU
972
+ e_model_low_au = e_model_low * EV2AU
973
+ e_total_au = e_real_low_au + e_h - e_model_low_au
974
+ e_total_kcal = _hartree_to_kcalmol(e_total_au)
975
+
976
+ result_yaml = {
977
+ "input": {
978
+ "charge": model_charge,
979
+ "multiplicity": model_mult,
980
+ "xc": xc,
981
+ "basis": basis,
982
+ "conv_tol": dft_kw["conv_tol"],
983
+ "max_cycle": dft_kw["max_cycle"],
984
+ "grid_level": dft_kw["grid_level"],
985
+ "out_dir": str(out_dir_path),
986
+ "embedcharge": bool(calc_kw.get("embedcharge", False)),
987
+ "n_mm_charges": n_mm_charges,
988
+ },
989
+ "energy": {
990
+ "hartree": e_h,
991
+ "kcal_per_mol": e_kcal,
992
+ "converged": converged,
993
+ "scf_time_sec": round(toc_scf - tic_scf, 3),
994
+ "engine": engine_label,
995
+ "used_gpu": bool(using_gpu),
996
+ },
997
+ "mlmm_energy": {
998
+ "E_real_low_eV": e_real_low,
999
+ "E_model_low_eV": e_model_low,
1000
+ "E_real_low_hartree": e_real_low_au,
1001
+ "E_model_low_hartree": e_model_low_au,
1002
+ "E_total_ml_dft_mm_hartree": e_total_au,
1003
+ "E_total_ml_dft_mm_kcal_per_mol": e_total_kcal,
1004
+ },
1005
+ "charges [index, element, mulliken, lowdin, iao]": [FlowList(r) for r in charges_table],
1006
+ "spin_densities [index, element, mulliken, lowdin, iao]": [FlowList(r) for r in spins_table],
1007
+ }
1008
+
1009
+ result_file = out_dir_path / "result.yaml"
1010
+ result_file.write_text(yaml.safe_dump(result_yaml, sort_keys=False, allow_unicode=True))
1011
+ click.echo(f"[write] Wrote '{result_file}'.")
1012
+ # summary.md and key_* outputs are disabled.
1013
+ click.echo(f"\nE_DFT (Hartree): {e_h:.12f}")
1014
+ click.echo(f"E_DFT (kcal/mol): {e_kcal:.6f}")
1015
+ click.echo(f"E_total ML(dft)/MM (Hartree): {e_total_au:.12f}")
1016
+ click.echo(f"E_total ML(dft)/MM (kcal/mol): {e_total_kcal:.6f}")
1017
+
1018
+ if not converged:
1019
+ click.echo("WARNING: SCF did not converge.", err=True)
1020
+ sys.exit(3)
1021
+
1022
+ click.echo(format_elapsed("[time] Elapsed Time for DFT", time_start))
1023
+
1024
+ except KeyboardInterrupt:
1025
+ click.echo("\nInterrupted by user.", err=True)
1026
+ sys.exit(130)
1027
+ except click.ClickException:
1028
+ raise
1029
+ except Exception as exc:
1030
+ tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
1031
+ click.echo("Unhandled error during ML/MM DFT:\n" + textwrap.indent(tb, " "), err=True)
1032
+ sys.exit(1)
1033
+ finally:
1034
+ if prepared_input is not None:
1035
+ prepared_input.cleanup()
1036
+ if workspace is not None:
1037
+ workspace.cleanup()
1038
+
1039
+
1040
+ if __name__ == "__main__":
1041
+ cli()