mlmm-toolkit 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (372) hide show
  1. hessian_ff/__init__.py +50 -0
  2. hessian_ff/analytical_hessian.py +609 -0
  3. hessian_ff/constants.py +46 -0
  4. hessian_ff/forcefield.py +339 -0
  5. hessian_ff/loaders.py +608 -0
  6. hessian_ff/native/Makefile +8 -0
  7. hessian_ff/native/__init__.py +28 -0
  8. hessian_ff/native/analytical_hessian.py +88 -0
  9. hessian_ff/native/analytical_hessian_ext.cpp +258 -0
  10. hessian_ff/native/bonded.py +82 -0
  11. hessian_ff/native/bonded_ext.cpp +640 -0
  12. hessian_ff/native/loader.py +349 -0
  13. hessian_ff/native/nonbonded.py +118 -0
  14. hessian_ff/native/nonbonded_ext.cpp +1150 -0
  15. hessian_ff/prmtop_parmed.py +23 -0
  16. hessian_ff/system.py +107 -0
  17. hessian_ff/terms/__init__.py +14 -0
  18. hessian_ff/terms/angle.py +73 -0
  19. hessian_ff/terms/bond.py +44 -0
  20. hessian_ff/terms/cmap.py +406 -0
  21. hessian_ff/terms/dihedral.py +141 -0
  22. hessian_ff/terms/nonbonded.py +209 -0
  23. hessian_ff/tests/__init__.py +0 -0
  24. hessian_ff/tests/conftest.py +75 -0
  25. hessian_ff/tests/data/small/complex.parm7 +1346 -0
  26. hessian_ff/tests/data/small/complex.pdb +125 -0
  27. hessian_ff/tests/data/small/complex.rst7 +63 -0
  28. hessian_ff/tests/test_coords_input.py +44 -0
  29. hessian_ff/tests/test_energy_force.py +49 -0
  30. hessian_ff/tests/test_hessian.py +137 -0
  31. hessian_ff/tests/test_smoke.py +18 -0
  32. hessian_ff/tests/test_validation.py +40 -0
  33. hessian_ff/workflows.py +889 -0
  34. mlmm/__init__.py +36 -0
  35. mlmm/__main__.py +7 -0
  36. mlmm/_version.py +34 -0
  37. mlmm/add_elem_info.py +374 -0
  38. mlmm/advanced_help.py +91 -0
  39. mlmm/align_freeze_atoms.py +601 -0
  40. mlmm/all.py +3535 -0
  41. mlmm/bond_changes.py +231 -0
  42. mlmm/bool_compat.py +223 -0
  43. mlmm/cli.py +574 -0
  44. mlmm/cli_utils.py +166 -0
  45. mlmm/default_group.py +337 -0
  46. mlmm/defaults.py +467 -0
  47. mlmm/define_layer.py +526 -0
  48. mlmm/dft.py +1041 -0
  49. mlmm/energy_diagram.py +253 -0
  50. mlmm/extract.py +2213 -0
  51. mlmm/fix_altloc.py +464 -0
  52. mlmm/freq.py +1406 -0
  53. mlmm/harmonic_constraints.py +140 -0
  54. mlmm/hessian_cache.py +44 -0
  55. mlmm/hessian_calc.py +174 -0
  56. mlmm/irc.py +638 -0
  57. mlmm/mlmm_calc.py +2262 -0
  58. mlmm/mm_parm.py +945 -0
  59. mlmm/oniom_export.py +1983 -0
  60. mlmm/oniom_import.py +457 -0
  61. mlmm/opt.py +1742 -0
  62. mlmm/path_opt.py +1353 -0
  63. mlmm/path_search.py +2299 -0
  64. mlmm/preflight.py +88 -0
  65. mlmm/py.typed +1 -0
  66. mlmm/pysis_runner.py +45 -0
  67. mlmm/scan.py +1047 -0
  68. mlmm/scan2d.py +1226 -0
  69. mlmm/scan3d.py +1265 -0
  70. mlmm/scan_common.py +184 -0
  71. mlmm/summary_log.py +736 -0
  72. mlmm/trj2fig.py +448 -0
  73. mlmm/tsopt.py +2871 -0
  74. mlmm/utils.py +2309 -0
  75. mlmm/xtb_embedcharge_correction.py +475 -0
  76. mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
  77. mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
  78. mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
  79. mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
  80. mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
  81. mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
  82. pysisyphus/Geometry.py +1667 -0
  83. pysisyphus/LICENSE +674 -0
  84. pysisyphus/TableFormatter.py +63 -0
  85. pysisyphus/TablePrinter.py +74 -0
  86. pysisyphus/__init__.py +12 -0
  87. pysisyphus/calculators/AFIR.py +452 -0
  88. pysisyphus/calculators/AnaPot.py +20 -0
  89. pysisyphus/calculators/AnaPot2.py +48 -0
  90. pysisyphus/calculators/AnaPot3.py +12 -0
  91. pysisyphus/calculators/AnaPot4.py +20 -0
  92. pysisyphus/calculators/AnaPotBase.py +337 -0
  93. pysisyphus/calculators/AnaPotCBM.py +25 -0
  94. pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
  95. pysisyphus/calculators/CFOUR.py +250 -0
  96. pysisyphus/calculators/Calculator.py +844 -0
  97. pysisyphus/calculators/CerjanMiller.py +24 -0
  98. pysisyphus/calculators/Composite.py +123 -0
  99. pysisyphus/calculators/ConicalIntersection.py +171 -0
  100. pysisyphus/calculators/DFTBp.py +430 -0
  101. pysisyphus/calculators/DFTD3.py +66 -0
  102. pysisyphus/calculators/DFTD4.py +84 -0
  103. pysisyphus/calculators/Dalton.py +61 -0
  104. pysisyphus/calculators/Dimer.py +681 -0
  105. pysisyphus/calculators/Dummy.py +20 -0
  106. pysisyphus/calculators/EGO.py +76 -0
  107. pysisyphus/calculators/EnergyMin.py +224 -0
  108. pysisyphus/calculators/ExternalPotential.py +264 -0
  109. pysisyphus/calculators/FakeASE.py +35 -0
  110. pysisyphus/calculators/FourWellAnaPot.py +28 -0
  111. pysisyphus/calculators/FreeEndNEBPot.py +39 -0
  112. pysisyphus/calculators/Gaussian09.py +18 -0
  113. pysisyphus/calculators/Gaussian16.py +726 -0
  114. pysisyphus/calculators/HardSphere.py +159 -0
  115. pysisyphus/calculators/IDPPCalculator.py +49 -0
  116. pysisyphus/calculators/IPIClient.py +133 -0
  117. pysisyphus/calculators/IPIServer.py +234 -0
  118. pysisyphus/calculators/LEPSBase.py +24 -0
  119. pysisyphus/calculators/LEPSExpr.py +139 -0
  120. pysisyphus/calculators/LennardJones.py +80 -0
  121. pysisyphus/calculators/MOPAC.py +219 -0
  122. pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
  123. pysisyphus/calculators/MultiCalc.py +85 -0
  124. pysisyphus/calculators/NFK.py +45 -0
  125. pysisyphus/calculators/OBabel.py +87 -0
  126. pysisyphus/calculators/ONIOMv2.py +1129 -0
  127. pysisyphus/calculators/ORCA.py +893 -0
  128. pysisyphus/calculators/ORCA5.py +6 -0
  129. pysisyphus/calculators/OpenMM.py +88 -0
  130. pysisyphus/calculators/OpenMolcas.py +281 -0
  131. pysisyphus/calculators/OverlapCalculator.py +908 -0
  132. pysisyphus/calculators/Psi4.py +218 -0
  133. pysisyphus/calculators/PyPsi4.py +37 -0
  134. pysisyphus/calculators/PySCF.py +341 -0
  135. pysisyphus/calculators/PyXTB.py +73 -0
  136. pysisyphus/calculators/QCEngine.py +106 -0
  137. pysisyphus/calculators/Rastrigin.py +22 -0
  138. pysisyphus/calculators/Remote.py +76 -0
  139. pysisyphus/calculators/Rosenbrock.py +15 -0
  140. pysisyphus/calculators/SocketCalc.py +97 -0
  141. pysisyphus/calculators/TIP3P.py +111 -0
  142. pysisyphus/calculators/TransTorque.py +161 -0
  143. pysisyphus/calculators/Turbomole.py +965 -0
  144. pysisyphus/calculators/VRIPot.py +37 -0
  145. pysisyphus/calculators/WFOWrapper.py +333 -0
  146. pysisyphus/calculators/WFOWrapper2.py +341 -0
  147. pysisyphus/calculators/XTB.py +418 -0
  148. pysisyphus/calculators/__init__.py +81 -0
  149. pysisyphus/calculators/cosmo_data.py +139 -0
  150. pysisyphus/calculators/parser.py +150 -0
  151. pysisyphus/color.py +19 -0
  152. pysisyphus/config.py +133 -0
  153. pysisyphus/constants.py +65 -0
  154. pysisyphus/cos/AdaptiveNEB.py +230 -0
  155. pysisyphus/cos/ChainOfStates.py +725 -0
  156. pysisyphus/cos/FreeEndNEB.py +25 -0
  157. pysisyphus/cos/FreezingString.py +103 -0
  158. pysisyphus/cos/GrowingChainOfStates.py +71 -0
  159. pysisyphus/cos/GrowingNT.py +309 -0
  160. pysisyphus/cos/GrowingString.py +508 -0
  161. pysisyphus/cos/NEB.py +189 -0
  162. pysisyphus/cos/SimpleZTS.py +64 -0
  163. pysisyphus/cos/__init__.py +22 -0
  164. pysisyphus/cos/stiffness.py +199 -0
  165. pysisyphus/drivers/__init__.py +17 -0
  166. pysisyphus/drivers/afir.py +855 -0
  167. pysisyphus/drivers/barriers.py +271 -0
  168. pysisyphus/drivers/birkholz.py +138 -0
  169. pysisyphus/drivers/cluster.py +318 -0
  170. pysisyphus/drivers/diabatization.py +133 -0
  171. pysisyphus/drivers/merge.py +368 -0
  172. pysisyphus/drivers/merge_mol2.py +322 -0
  173. pysisyphus/drivers/opt.py +375 -0
  174. pysisyphus/drivers/perf.py +91 -0
  175. pysisyphus/drivers/pka.py +52 -0
  176. pysisyphus/drivers/precon_pos_rot.py +669 -0
  177. pysisyphus/drivers/rates.py +480 -0
  178. pysisyphus/drivers/replace.py +219 -0
  179. pysisyphus/drivers/scan.py +212 -0
  180. pysisyphus/drivers/spectrum.py +166 -0
  181. pysisyphus/drivers/thermo.py +31 -0
  182. pysisyphus/dynamics/Gaussian.py +103 -0
  183. pysisyphus/dynamics/__init__.py +20 -0
  184. pysisyphus/dynamics/colvars.py +136 -0
  185. pysisyphus/dynamics/driver.py +297 -0
  186. pysisyphus/dynamics/helpers.py +256 -0
  187. pysisyphus/dynamics/lincs.py +105 -0
  188. pysisyphus/dynamics/mdp.py +364 -0
  189. pysisyphus/dynamics/rattle.py +121 -0
  190. pysisyphus/dynamics/thermostats.py +128 -0
  191. pysisyphus/dynamics/wigner.py +266 -0
  192. pysisyphus/elem_data.py +3473 -0
  193. pysisyphus/exceptions.py +2 -0
  194. pysisyphus/filtertrj.py +69 -0
  195. pysisyphus/helpers.py +623 -0
  196. pysisyphus/helpers_pure.py +649 -0
  197. pysisyphus/init_logging.py +50 -0
  198. pysisyphus/intcoords/Bend.py +69 -0
  199. pysisyphus/intcoords/Bend2.py +25 -0
  200. pysisyphus/intcoords/BondedFragment.py +32 -0
  201. pysisyphus/intcoords/Cartesian.py +41 -0
  202. pysisyphus/intcoords/CartesianCoords.py +140 -0
  203. pysisyphus/intcoords/Coords.py +56 -0
  204. pysisyphus/intcoords/DLC.py +197 -0
  205. pysisyphus/intcoords/DistanceFunction.py +34 -0
  206. pysisyphus/intcoords/DummyImproper.py +70 -0
  207. pysisyphus/intcoords/DummyTorsion.py +72 -0
  208. pysisyphus/intcoords/LinearBend.py +105 -0
  209. pysisyphus/intcoords/LinearDisplacement.py +80 -0
  210. pysisyphus/intcoords/OutOfPlane.py +59 -0
  211. pysisyphus/intcoords/PrimTypes.py +286 -0
  212. pysisyphus/intcoords/Primitive.py +137 -0
  213. pysisyphus/intcoords/RedundantCoords.py +659 -0
  214. pysisyphus/intcoords/RobustTorsion.py +59 -0
  215. pysisyphus/intcoords/Rotation.py +147 -0
  216. pysisyphus/intcoords/Stretch.py +31 -0
  217. pysisyphus/intcoords/Torsion.py +101 -0
  218. pysisyphus/intcoords/Torsion2.py +25 -0
  219. pysisyphus/intcoords/Translation.py +45 -0
  220. pysisyphus/intcoords/__init__.py +61 -0
  221. pysisyphus/intcoords/augment_bonds.py +126 -0
  222. pysisyphus/intcoords/derivatives.py +10512 -0
  223. pysisyphus/intcoords/eval.py +80 -0
  224. pysisyphus/intcoords/exceptions.py +37 -0
  225. pysisyphus/intcoords/findiffs.py +48 -0
  226. pysisyphus/intcoords/generate_derivatives.py +414 -0
  227. pysisyphus/intcoords/helpers.py +235 -0
  228. pysisyphus/intcoords/logging_conf.py +10 -0
  229. pysisyphus/intcoords/mp_derivatives.py +10836 -0
  230. pysisyphus/intcoords/setup.py +962 -0
  231. pysisyphus/intcoords/setup_fast.py +176 -0
  232. pysisyphus/intcoords/update.py +272 -0
  233. pysisyphus/intcoords/valid.py +89 -0
  234. pysisyphus/interpolate/Geodesic.py +93 -0
  235. pysisyphus/interpolate/IDPP.py +55 -0
  236. pysisyphus/interpolate/Interpolator.py +116 -0
  237. pysisyphus/interpolate/LST.py +70 -0
  238. pysisyphus/interpolate/Redund.py +152 -0
  239. pysisyphus/interpolate/__init__.py +9 -0
  240. pysisyphus/interpolate/helpers.py +34 -0
  241. pysisyphus/io/__init__.py +22 -0
  242. pysisyphus/io/aomix.py +178 -0
  243. pysisyphus/io/cjson.py +24 -0
  244. pysisyphus/io/crd.py +101 -0
  245. pysisyphus/io/cube.py +220 -0
  246. pysisyphus/io/fchk.py +184 -0
  247. pysisyphus/io/hdf5.py +49 -0
  248. pysisyphus/io/hessian.py +72 -0
  249. pysisyphus/io/mol2.py +146 -0
  250. pysisyphus/io/molden.py +293 -0
  251. pysisyphus/io/orca.py +189 -0
  252. pysisyphus/io/pdb.py +269 -0
  253. pysisyphus/io/psf.py +79 -0
  254. pysisyphus/io/pubchem.py +31 -0
  255. pysisyphus/io/qcschema.py +34 -0
  256. pysisyphus/io/sdf.py +29 -0
  257. pysisyphus/io/xyz.py +61 -0
  258. pysisyphus/io/zmat.py +175 -0
  259. pysisyphus/irc/DWI.py +108 -0
  260. pysisyphus/irc/DampedVelocityVerlet.py +134 -0
  261. pysisyphus/irc/Euler.py +22 -0
  262. pysisyphus/irc/EulerPC.py +345 -0
  263. pysisyphus/irc/GonzalezSchlegel.py +187 -0
  264. pysisyphus/irc/IMKMod.py +164 -0
  265. pysisyphus/irc/IRC.py +878 -0
  266. pysisyphus/irc/IRCDummy.py +10 -0
  267. pysisyphus/irc/Instanton.py +307 -0
  268. pysisyphus/irc/LQA.py +53 -0
  269. pysisyphus/irc/ModeKill.py +136 -0
  270. pysisyphus/irc/ParamPlot.py +53 -0
  271. pysisyphus/irc/RK4.py +36 -0
  272. pysisyphus/irc/__init__.py +31 -0
  273. pysisyphus/irc/initial_displ.py +219 -0
  274. pysisyphus/linalg.py +411 -0
  275. pysisyphus/line_searches/Backtracking.py +88 -0
  276. pysisyphus/line_searches/HagerZhang.py +184 -0
  277. pysisyphus/line_searches/LineSearch.py +232 -0
  278. pysisyphus/line_searches/StrongWolfe.py +108 -0
  279. pysisyphus/line_searches/__init__.py +9 -0
  280. pysisyphus/line_searches/interpol.py +15 -0
  281. pysisyphus/modefollow/NormalMode.py +40 -0
  282. pysisyphus/modefollow/__init__.py +10 -0
  283. pysisyphus/modefollow/davidson.py +199 -0
  284. pysisyphus/modefollow/lanczos.py +95 -0
  285. pysisyphus/optimizers/BFGS.py +99 -0
  286. pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
  287. pysisyphus/optimizers/ConjugateGradient.py +98 -0
  288. pysisyphus/optimizers/CubicNewton.py +75 -0
  289. pysisyphus/optimizers/FIRE.py +113 -0
  290. pysisyphus/optimizers/HessianOptimizer.py +1176 -0
  291. pysisyphus/optimizers/LBFGS.py +228 -0
  292. pysisyphus/optimizers/LayerOpt.py +411 -0
  293. pysisyphus/optimizers/MicroOptimizer.py +169 -0
  294. pysisyphus/optimizers/NCOptimizer.py +90 -0
  295. pysisyphus/optimizers/Optimizer.py +1084 -0
  296. pysisyphus/optimizers/PreconLBFGS.py +260 -0
  297. pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
  298. pysisyphus/optimizers/QuickMin.py +74 -0
  299. pysisyphus/optimizers/RFOptimizer.py +181 -0
  300. pysisyphus/optimizers/RSA.py +99 -0
  301. pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
  302. pysisyphus/optimizers/SteepestDescent.py +23 -0
  303. pysisyphus/optimizers/StringOptimizer.py +173 -0
  304. pysisyphus/optimizers/__init__.py +41 -0
  305. pysisyphus/optimizers/closures.py +301 -0
  306. pysisyphus/optimizers/cls_map.py +58 -0
  307. pysisyphus/optimizers/exceptions.py +6 -0
  308. pysisyphus/optimizers/gdiis.py +280 -0
  309. pysisyphus/optimizers/guess_hessians.py +311 -0
  310. pysisyphus/optimizers/hessian_updates.py +355 -0
  311. pysisyphus/optimizers/poly_fit.py +285 -0
  312. pysisyphus/optimizers/precon.py +153 -0
  313. pysisyphus/optimizers/restrict_step.py +24 -0
  314. pysisyphus/pack.py +172 -0
  315. pysisyphus/peakdetect.py +948 -0
  316. pysisyphus/plot.py +1031 -0
  317. pysisyphus/run.py +2106 -0
  318. pysisyphus/socket_helper.py +74 -0
  319. pysisyphus/stocastic/FragmentKick.py +132 -0
  320. pysisyphus/stocastic/Kick.py +81 -0
  321. pysisyphus/stocastic/Pipeline.py +303 -0
  322. pysisyphus/stocastic/__init__.py +21 -0
  323. pysisyphus/stocastic/align.py +127 -0
  324. pysisyphus/testing.py +96 -0
  325. pysisyphus/thermo.py +156 -0
  326. pysisyphus/trj.py +824 -0
  327. pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
  328. pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
  329. pysisyphus/tsoptimizers/TRIM.py +59 -0
  330. pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
  331. pysisyphus/tsoptimizers/__init__.py +23 -0
  332. pysisyphus/wavefunction/Basis.py +239 -0
  333. pysisyphus/wavefunction/DIIS.py +76 -0
  334. pysisyphus/wavefunction/__init__.py +25 -0
  335. pysisyphus/wavefunction/build_ext.py +42 -0
  336. pysisyphus/wavefunction/cart2sph.py +190 -0
  337. pysisyphus/wavefunction/diabatization.py +304 -0
  338. pysisyphus/wavefunction/excited_states.py +435 -0
  339. pysisyphus/wavefunction/gen_ints.py +1811 -0
  340. pysisyphus/wavefunction/helpers.py +104 -0
  341. pysisyphus/wavefunction/ints/__init__.py +0 -0
  342. pysisyphus/wavefunction/ints/boys.py +193 -0
  343. pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
  344. pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
  345. pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
  346. pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
  347. pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
  348. pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
  349. pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
  350. pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
  351. pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
  352. pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
  353. pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
  354. pysisyphus/wavefunction/localization.py +458 -0
  355. pysisyphus/wavefunction/multipole.py +159 -0
  356. pysisyphus/wavefunction/normalization.py +36 -0
  357. pysisyphus/wavefunction/pop_analysis.py +134 -0
  358. pysisyphus/wavefunction/shells.py +1171 -0
  359. pysisyphus/wavefunction/wavefunction.py +504 -0
  360. pysisyphus/wrapper/__init__.py +11 -0
  361. pysisyphus/wrapper/exceptions.py +2 -0
  362. pysisyphus/wrapper/jmol.py +120 -0
  363. pysisyphus/wrapper/mwfn.py +169 -0
  364. pysisyphus/wrapper/packmol.py +71 -0
  365. pysisyphus/xyzloader.py +168 -0
  366. pysisyphus/yaml_mods.py +45 -0
  367. thermoanalysis/LICENSE +674 -0
  368. thermoanalysis/QCData.py +244 -0
  369. thermoanalysis/__init__.py +0 -0
  370. thermoanalysis/config.py +3 -0
  371. thermoanalysis/constants.py +20 -0
  372. thermoanalysis/thermo.py +1011 -0
mlmm/opt.py ADDED
@@ -0,0 +1,1742 @@
1
+ # mlmm/opt.py
2
+
3
+ """
4
+ ML/MM geometry optimization (LBFGS or RFO) with UMA + hessian_ff calculator.
5
+
6
+ Example:
7
+ mlmm opt -i pocket.pdb --parm real.parm7 --model-pdb ml_region.pdb -q 0
8
+
9
+ For detailed documentation, see: docs/opt.md
10
+ """
11
+
12
+ from pathlib import Path
13
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Set
14
+
15
+ import ast
16
+ import contextlib
17
+ import gc
18
+ import io
19
+ import logging
20
+
21
+ import sys
22
+ import textwrap
23
+ import traceback
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ import click
28
+ import numpy as np
29
+ import torch
30
+ import time
31
+
32
+ from pysisyphus.helpers import geom_loader
33
+ from pysisyphus.optimizers.LBFGS import LBFGS
34
+ from pysisyphus.optimizers.RFOptimizer import RFOptimizer
35
+ from pysisyphus.optimizers.exceptions import OptimizationError, ZeroStepLength
36
+ from pysisyphus.constants import ANG2BOHR, BOHR2ANG, AU2EV
37
+ from pysisyphus.TablePrinter import TablePrinter
38
+
39
+ from .mlmm_calc import mlmm, mlmm_mm_only
40
+ from .defaults import (
41
+ GEOM_KW_DEFAULT,
42
+ HESSIAN_DIMER_KW,
43
+ MLMM_CALC_KW,
44
+ OPT_BASE_KW,
45
+ LBFGS_KW,
46
+ RFO_KW,
47
+ OPT_MODE_ALIASES,
48
+ MICROITER_KW,
49
+ BFACTOR_ML,
50
+ BFACTOR_MOVABLE_MM,
51
+ BFACTOR_FROZEN,
52
+ )
53
+ from .utils import (
54
+ append_xyz_trajectory as _append_xyz_trajectory,
55
+ convert_xyz_to_pdb,
56
+ set_convert_file_enabled,
57
+ is_convert_file_enabled,
58
+ convert_xyz_like_outputs,
59
+ deep_update,
60
+ load_yaml_dict,
61
+ apply_yaml_overrides,
62
+ pretty_block,
63
+ strip_inherited_keys,
64
+ filter_calc_for_echo,
65
+ format_freeze_atoms_for_echo,
66
+ format_elapsed,
67
+ merge_freeze_atom_indices,
68
+ prepare_input_structure,
69
+ apply_ref_pdb_override,
70
+ resolve_charge_spin_or_raise,
71
+ parse_indices_string,
72
+ build_model_pdb_from_bfactors,
73
+ build_model_pdb_from_indices,
74
+ update_pdb_bfactors_from_layers,
75
+ normalize_choice,
76
+ yaml_section_has_key,
77
+ is_scan_spec_file,
78
+ parse_dist_freeze_list,
79
+ parse_dist_freeze_spec,
80
+ load_pdb_atom_metadata,
81
+ )
82
+ from .cli_utils import resolve_yaml_sources, load_merged_yaml_cfg, make_is_param_explicit
83
+
84
+ EV2AU = 1.0 / AU2EV # eV → Hartree
85
+ H_EVAA_2_AU = EV2AU / (ANG2BOHR * ANG2BOHR) # (eV/Å^2) → (Hartree/Bohr^2)
86
+
87
+ # Flatten-loop constants (sourced from defaults.py)
88
+ OPT_FLATTEN_NEG_FREQ_THRESH_CM = HESSIAN_DIMER_KW["neg_freq_thresh_cm"]
89
+ OPT_FLATTEN_AMP_ANG = HESSIAN_DIMER_KW["flatten_amp_ang"]
90
+ OPT_FLATTEN_MAX_ITER = HESSIAN_DIMER_KW["flatten_max_iter"]
91
+
92
+
93
+ # -----------------------------------------------
94
+ # Default settings (imported from defaults.py, aliased for compatibility)
95
+ # -----------------------------------------------
96
+
97
+ GEOM_KW: Dict[str, Any] = dict(GEOM_KW_DEFAULT)
98
+ CALC_KW: Dict[str, Any] = dict(MLMM_CALC_KW)
99
+
100
+ # Note: OPT_BASE_KW, LBFGS_KW, RFO_KW are imported from defaults.py
101
+
102
+
103
+ class HarmonicBiasCalculator:
104
+ """Wrap a base UMA calculator with harmonic distance restraints."""
105
+
106
+ def __init__(self, base_calc, k: float = 10.0, pairs: Optional[List[Tuple[int, int, float]]] = None):
107
+ self.base = base_calc
108
+ self.k_evAA = float(k)
109
+ self.k_au_bohr2 = self.k_evAA * H_EVAA_2_AU
110
+ self._pairs: List[Tuple[int, int, float]] = list(pairs or [])
111
+
112
+ def set_pairs(self, pairs: List[Tuple[int, int, float]]) -> None:
113
+ self._pairs = [(int(i), int(j), float(t)) for (i, j, t) in pairs]
114
+
115
+ def _bias_energy_forces_bohr(self, coords_bohr: np.ndarray) -> Tuple[float, np.ndarray]:
116
+ coords = np.array(coords_bohr, dtype=float).reshape(-1, 3)
117
+ n = coords.shape[0]
118
+ E_bias = 0.0
119
+ F_bias = np.zeros((n, 3), dtype=float)
120
+ k = self.k_au_bohr2
121
+ for (i, j, target_ang) in self._pairs:
122
+ if not (0 <= i < n and 0 <= j < n):
123
+ continue
124
+ rij_vec = coords[i] - coords[j]
125
+ rij = float(np.linalg.norm(rij_vec))
126
+ if rij < 1e-14:
127
+ continue
128
+ target_bohr = float(target_ang) * ANG2BOHR
129
+ diff_bohr = rij - target_bohr
130
+ E_bias += 0.5 * k * diff_bohr * diff_bohr
131
+ u = rij_vec / max(rij, 1e-14)
132
+ Fi = -k * diff_bohr * u
133
+ F_bias[i] += Fi
134
+ F_bias[j] -= Fi
135
+ return E_bias, F_bias.reshape(-1)
136
+
137
+ def get_forces(self, elem, coords):
138
+ coords_bohr = np.asarray(coords, dtype=float).reshape(-1, 3)
139
+ base = self.base.get_forces(elem, coords_bohr)
140
+ E0 = float(base["energy"])
141
+ F0 = np.asarray(base["forces"], dtype=float).reshape(-1)
142
+ Ebias, Fbias = self._bias_energy_forces_bohr(coords_bohr)
143
+ return {"energy": E0 + Ebias, "forces": F0 + Fbias}
144
+
145
+ def get_energy(self, elem, coords):
146
+ coords_bohr = np.asarray(coords, dtype=float).reshape(-1, 3)
147
+ E0 = float(self.base.get_energy(elem, coords_bohr)["energy"])
148
+ Ebias, _ = self._bias_energy_forces_bohr(coords_bohr)
149
+ return {"energy": E0 + Ebias}
150
+
151
+ def get_energy_and_forces(self, elem, coords):
152
+ res = self.get_forces(elem, coords)
153
+ return res["energy"], res["forces"]
154
+
155
+ def get_energy_and_gradient(self, elem, coords):
156
+ res = self.get_forces(elem, coords)
157
+ return res["energy"], -np.asarray(res["forces"], dtype=float).reshape(-1)
158
+
159
+ def __getattr__(self, name: str):
160
+ return getattr(self.base, name)
161
+
162
+
163
+ def _parse_freeze_atoms(arg: Optional[str]) -> List[int]:
164
+ """Parse comma-separated 1-based indices (e.g., "1,3,5") into 0-based ints."""
165
+ if arg is None:
166
+ return []
167
+
168
+ items = [chunk.strip() for chunk in str(arg).split(",")]
169
+ indices: List[int] = []
170
+ for idx, chunk in enumerate(items, start=1):
171
+ if not chunk:
172
+ continue
173
+ try:
174
+ value = int(chunk)
175
+ except ValueError as exc:
176
+ raise click.BadParameter(
177
+ f"Invalid integer in --freeze-atoms entry #{idx}: '{chunk}'"
178
+ ) from exc
179
+ if value <= 0:
180
+ raise click.BadParameter(
181
+ f"--freeze-atoms expects 1-based positive indices; got {value}"
182
+ )
183
+ indices.append(value - 1)
184
+ return sorted(set(indices))
185
+
186
+
187
+ def _normalize_geom_freeze(value: Any) -> List[int]:
188
+ """Normalize YAML-provided geom.freeze_atoms to a sorted 0-based list."""
189
+ if value is None:
190
+ return []
191
+ if isinstance(value, str):
192
+ tokens = [tok.strip() for tok in value.split(",") if tok.strip()]
193
+ try:
194
+ return sorted({int(tok) for tok in tokens})
195
+ except ValueError as exc:
196
+ raise click.BadParameter(
197
+ "geom.freeze_atoms must contain integers (string form)."
198
+ ) from exc
199
+ try:
200
+ return sorted({int(idx) for idx in value})
201
+ except TypeError as exc:
202
+ raise click.BadParameter("geom.freeze_atoms must be iterable of integers.") from exc
203
+
204
+
205
+ def _parse_dist_freeze_args(
206
+ raw_args: Sequence[str],
207
+ one_based: bool,
208
+ atom_meta: Optional[Sequence[Dict[str, Any]]],
209
+ ) -> List[Tuple[int, int, Optional[float]]]:
210
+ """Parse all ``--dist-freeze`` arguments (inline literal or spec file).
211
+
212
+ Accepts the same format as ``--scan-lists``: inline Python literal
213
+ (e.g. ``'[(1,5,1.4)]'``) or a YAML/JSON spec file path. String atom
214
+ specs (e.g. ``'A:SER123:OG'``) are supported when *atom_meta* is
215
+ available. Target distance is optional — omit to freeze at the current
216
+ distance.
217
+ """
218
+ all_pairs: List[Tuple[int, int, Optional[float]]] = []
219
+ for raw in raw_args:
220
+ if is_scan_spec_file(raw):
221
+ all_pairs.extend(parse_dist_freeze_spec(
222
+ Path(raw),
223
+ one_based_default=one_based,
224
+ atom_meta=atom_meta,
225
+ ))
226
+ else:
227
+ all_pairs.extend(parse_dist_freeze_list(
228
+ raw,
229
+ one_based=one_based,
230
+ atom_meta=atom_meta,
231
+ ))
232
+ return all_pairs
233
+
234
+
235
+ def _resolve_dist_freeze_targets(
236
+ geometry,
237
+ tuples: List[Tuple[int, int, Optional[float]]],
238
+ ) -> List[Tuple[int, int, float]]:
239
+ coords_bohr = np.array(geometry.coords3d, dtype=float).reshape(-1, 3)
240
+ coords_ang = coords_bohr * BOHR2ANG
241
+ n = coords_ang.shape[0]
242
+ resolved: List[Tuple[int, int, float]] = []
243
+ for (i, j, target) in tuples:
244
+ if not (0 <= i < n and 0 <= j < n):
245
+ raise click.BadParameter(
246
+ f"--dist-freeze indices {(i, j)} are out of bounds for the loaded geometry (N={n})."
247
+ )
248
+ if target is None:
249
+ vec = coords_ang[i] - coords_ang[j]
250
+ dist = float(np.linalg.norm(vec))
251
+ else:
252
+ dist = float(target)
253
+ resolved.append((i, j, dist))
254
+ return resolved
255
+
256
+
257
+ # -------------------------------
258
+ # PDB helpers for B-factor patch
259
+ # -------------------------------
260
+
261
+ def _pdb_keys_from_line(line: str) -> Tuple[Tuple, Tuple]:
262
+ """
263
+ Extract robust keys from a PDB ATOM/HETATM record.
264
+
265
+ Returns:
266
+ key_full: (chain, resseq, icode, resname, atomname, altloc)
267
+ key_simple: (chain, resseq, icode, atomname)
268
+ """
269
+ atom_name = line[12:16].strip()
270
+ altloc = line[16:17].strip()
271
+ resname = line[17:20].strip()
272
+ chain = line[21:22].strip()
273
+ resseq_str = line[22:26].strip()
274
+ try:
275
+ resseq = int(resseq_str)
276
+ except ValueError:
277
+ resseq = -10**9 # unlikely sentinel when missing
278
+ icode = line[26:27].strip()
279
+ key_full = (chain, resseq, icode, resname, atom_name, altloc)
280
+ key_simple = (chain, resseq, icode, atom_name)
281
+ return key_full, key_simple
282
+
283
+
284
+ def _collect_ml_atom_keys(model_pdb: Path) -> Tuple[Set[Tuple], Set[Tuple]]:
285
+ """Collect ML-region atom keys from model_pdb."""
286
+ keys_full: Set[Tuple] = set()
287
+ keys_simple: Set[Tuple] = set()
288
+ try:
289
+ with model_pdb.open("r") as fh:
290
+ for line in fh:
291
+ if line.startswith("ATOM") or line.startswith("HETATM"):
292
+ kf, ks = _pdb_keys_from_line(line)
293
+ keys_full.add(kf)
294
+ keys_simple.add(ks)
295
+ except Exception:
296
+ logger.debug("Failed to collect ML atom keys from %s", model_pdb, exc_info=True)
297
+ return keys_full, keys_simple
298
+
299
+
300
+ def _format_with_bfactor(line: str, b: float) -> str:
301
+ """Return PDB line with B-factor field (cols 61-66) set to b (6.2f)."""
302
+ if len(line) < 66:
303
+ line = line.rstrip("\n")
304
+ line = line + " " * max(0, 66 - len(line))
305
+ line = line + "\n"
306
+ bf_str = f"{b:6.2f}"
307
+ # Preserve occupancy (cols 55-60), overwrite tempFactor (61-66).
308
+ new_line = line[:60] + bf_str + line[66:]
309
+ return new_line
310
+
311
+
312
+ def _annotate_b_factors_inplace(
313
+ pdb_path: Path,
314
+ model_pdb: Path,
315
+ freeze_indices_0based: Sequence[int],
316
+ beta_ml: float = 100.0,
317
+ beta_frz: float = 50.0,
318
+ beta_both: float = 150.0,
319
+ ) -> None:
320
+ """
321
+ Overwrite B-factors in-place:
322
+ - ML-region atoms: 100.00
323
+ - frozen atoms: 50.00
324
+ - ML ∩ frozen: 150.00
325
+ Indexing for 'frozen' is 0-based and resets at each MODEL.
326
+ """
327
+ ml_full, ml_simple = _collect_ml_atom_keys(model_pdb)
328
+ frozen_set = set(int(i) for i in (freeze_indices_0based or []))
329
+
330
+ try:
331
+ lines = pdb_path.read_text().splitlines(keepends=True)
332
+ except Exception:
333
+ logger.debug("Failed to read PDB file for B-factor annotation: %s", pdb_path, exc_info=True)
334
+ return
335
+
336
+ out_lines: List[str] = []
337
+ atom_idx = 0 # resets per MODEL
338
+
339
+ for line in lines:
340
+ rec = line[:6]
341
+ if rec.startswith("MODEL"):
342
+ # reset atom counter for each model
343
+ atom_idx = 0
344
+ out_lines.append(line)
345
+ continue
346
+ if rec.startswith("ATOM ") or rec.startswith("HETATM"):
347
+ kf, ks = _pdb_keys_from_line(line)
348
+ is_ml = (kf in ml_full) or (ks in ml_simple)
349
+ is_frz = (atom_idx in frozen_set)
350
+ if is_ml and is_frz:
351
+ out_lines.append(_format_with_bfactor(line, beta_both))
352
+ elif is_ml:
353
+ out_lines.append(_format_with_bfactor(line, beta_ml))
354
+ elif is_frz:
355
+ out_lines.append(_format_with_bfactor(line, beta_frz))
356
+ else:
357
+ out_lines.append(line)
358
+ atom_idx += 1
359
+ else:
360
+ out_lines.append(line)
361
+
362
+ try:
363
+ pdb_path.write_text("".join(out_lines))
364
+ except Exception:
365
+ logger.debug("Failed to write B-factor annotated PDB: %s", pdb_path, exc_info=True)
366
+
367
+
368
+ def _maybe_convert_outputs_to_pdb(
369
+ input_path: Path,
370
+ out_dir: Path,
371
+ dump: bool,
372
+ get_trj_fn,
373
+ final_xyz_path: Path,
374
+ model_pdb: Path,
375
+ freeze_indices_0based: Sequence[int],
376
+ ml_indices: Optional[List[int]] = None,
377
+ hess_mm_indices: Optional[List[int]] = None,
378
+ movable_mm_indices: Optional[List[int]] = None,
379
+ frozen_layer_indices: Optional[List[int]] = None,
380
+ ) -> None:
381
+ """
382
+ If the input is a PDB, convert outputs (final_geometry.xyz and, if dump, optimization_all_trj.xyz /
383
+ optimization_trj.xyz) to PDB,
384
+ and annotate B-factors for the 3-layer ML/MM system.
385
+
386
+ B-factor encoding (3-layer system):
387
+ ML atoms: 0.0
388
+ Movable MM atoms: 10.0
389
+ Frozen MM atoms: 20.0
390
+
391
+ If layer indices are not provided, falls back to legacy encoding:
392
+ ML atoms: 100.0
393
+ Frozen atoms: 50.0
394
+ ML ∩ frozen: 150.0
395
+ """
396
+ if not is_convert_file_enabled():
397
+ return
398
+ if input_path.suffix.lower() != ".pdb":
399
+ return
400
+
401
+ # Determine if we should use the layer-based B-factor encoding
402
+ use_layer_bfactors = ml_indices is not None
403
+
404
+ ref_pdb = input_path.resolve()
405
+ # final_geometry.xyz → final_geometry.pdb
406
+ final_pdb = out_dir / "final_geometry.pdb"
407
+ try:
408
+ convert_xyz_to_pdb(final_xyz_path, ref_pdb, final_pdb)
409
+ click.echo(f"[convert] Wrote '{final_pdb}'.")
410
+
411
+ if use_layer_bfactors:
412
+ update_pdb_bfactors_from_layers(
413
+ final_pdb,
414
+ ml_indices=ml_indices or [],
415
+ hess_mm_indices=hess_mm_indices,
416
+ movable_mm_indices=movable_mm_indices,
417
+ frozen_indices=frozen_layer_indices,
418
+ )
419
+ click.echo(
420
+ f"[annot] B-factors set in '{final_pdb}' "
421
+ f"(ML={BFACTOR_ML:.0f}, MovableMM={BFACTOR_MOVABLE_MM:.0f}, "
422
+ f"FrozenMM={BFACTOR_FROZEN:.0f})."
423
+ )
424
+ else:
425
+ # Fall back to legacy encoding
426
+ _annotate_b_factors_inplace(
427
+ final_pdb,
428
+ model_pdb=model_pdb,
429
+ freeze_indices_0based=freeze_indices_0based,
430
+ )
431
+ click.echo(f"[annot] B-factors set in '{final_pdb}' (ML=100, frozen=50, both=150).")
432
+ except Exception as e:
433
+ click.echo(f"[convert] WARNING: Failed to convert final geometry to PDB: {e}", err=True)
434
+
435
+ # optimization_all_trj.xyz / optimization_trj.xyz → PDB (if dump)
436
+ if dump:
437
+ try:
438
+ wrote_any = False
439
+ all_trj_path = get_trj_fn("optimization_all_trj.xyz")
440
+ if all_trj_path.exists():
441
+ all_opt_pdb = out_dir / "optimization_all.pdb"
442
+ convert_xyz_to_pdb(all_trj_path, ref_pdb, all_opt_pdb)
443
+ click.echo(f"[convert] Wrote '{all_opt_pdb}'.")
444
+ wrote_any = True
445
+
446
+ if use_layer_bfactors:
447
+ update_pdb_bfactors_from_layers(
448
+ all_opt_pdb,
449
+ ml_indices=ml_indices or [],
450
+ hess_mm_indices=hess_mm_indices,
451
+ movable_mm_indices=movable_mm_indices,
452
+ frozen_indices=frozen_layer_indices,
453
+ )
454
+ click.echo(
455
+ f"[annot] B-factors set in '{all_opt_pdb}' "
456
+ f"(ML={BFACTOR_ML:.0f}, MovableMM={BFACTOR_MOVABLE_MM:.0f}, "
457
+ f"FrozenMM={BFACTOR_FROZEN:.0f})."
458
+ )
459
+ else:
460
+ _annotate_b_factors_inplace(
461
+ all_opt_pdb,
462
+ model_pdb=model_pdb,
463
+ freeze_indices_0based=freeze_indices_0based,
464
+ )
465
+ click.echo(f"[annot] B-factors set in '{all_opt_pdb}' (ML=100, frozen=50, both=150).")
466
+
467
+ trj_path = get_trj_fn("optimization_trj.xyz")
468
+ if trj_path.exists():
469
+ opt_pdb = out_dir / "optimization.pdb"
470
+ convert_xyz_to_pdb(trj_path, ref_pdb, opt_pdb)
471
+ click.echo(f"[convert] Wrote '{opt_pdb}'.")
472
+ wrote_any = True
473
+
474
+ if use_layer_bfactors:
475
+ update_pdb_bfactors_from_layers(
476
+ opt_pdb,
477
+ ml_indices=ml_indices or [],
478
+ hess_mm_indices=hess_mm_indices,
479
+ movable_mm_indices=movable_mm_indices,
480
+ frozen_indices=frozen_layer_indices,
481
+ )
482
+ click.echo(
483
+ f"[annot] B-factors set in '{opt_pdb}' "
484
+ f"(ML={BFACTOR_ML:.0f}, MovableMM={BFACTOR_MOVABLE_MM:.0f}, "
485
+ f"FrozenMM={BFACTOR_FROZEN:.0f})."
486
+ )
487
+ else:
488
+ _annotate_b_factors_inplace(
489
+ opt_pdb,
490
+ model_pdb=model_pdb,
491
+ freeze_indices_0based=freeze_indices_0based,
492
+ )
493
+ click.echo(f"[annot] B-factors set in '{opt_pdb}' (ML=100, frozen=50, both=150).")
494
+
495
+ if not wrote_any:
496
+ click.echo(
497
+ "[convert] WARNING: neither 'optimization_all_trj.xyz' nor 'optimization_trj.xyz' was found; "
498
+ "skipping trajectory PDB conversion.",
499
+ err=True,
500
+ )
501
+ except Exception as e:
502
+ click.echo(f"[convert] WARNING: Failed to convert optimization trajectory to PDB: {e}", err=True)
503
+
504
+
505
+ # -----------------------------------------------
506
+ # Flatten helpers
507
+ # -----------------------------------------------
508
+
509
+
510
+ def _calc_energy(geom, calc_kwargs: dict, calc=None) -> float:
511
+ """Compute energy (Hartree) from ML/MM calculator."""
512
+ owns_calc = calc is None
513
+ if owns_calc:
514
+ kw = dict(calc_kwargs or {})
515
+ kw["out_hess_torch"] = False
516
+ calc = mlmm(**kw)
517
+ result = calc.get_energy(geom.atoms, geom.coords)
518
+ energy = float(result.get("energy", 0.0))
519
+ del result
520
+ if owns_calc:
521
+ del calc
522
+ if torch.cuda.is_available():
523
+ torch.cuda.empty_cache()
524
+ return energy
525
+
526
+
527
+ def _flatten_all_imag_modes_for_geom(
528
+ geom,
529
+ masses_amu: np.ndarray,
530
+ calc_kwargs: dict,
531
+ freqs_cm: np.ndarray,
532
+ modes: torch.Tensor,
533
+ neg_freq_thresh_cm: float,
534
+ flatten_amp_ang: float,
535
+ ) -> bool:
536
+ """
537
+ Flatten all imaginary modes for a geometry in a single pass.
538
+ """
539
+ neg_idx_all = np.where(freqs_cm < -abs(neg_freq_thresh_cm))[0]
540
+ if len(neg_idx_all) == 0:
541
+ return False
542
+
543
+ order = np.argsort(freqs_cm[neg_idx_all]) # most negative first
544
+ targets = [int(x) for x in neg_idx_all[order]]
545
+ mass_scale = np.sqrt(12.011 / masses_amu)[:, None]
546
+ amp_bohr = float(flatten_amp_ang) / BOHR2ANG
547
+ E_ref = _calc_energy(geom, calc_kwargs)
548
+
549
+ m3 = np.repeat(masses_amu, 3).reshape(-1, 3)
550
+ for idx in targets:
551
+ v_mw = modes[idx].detach().cpu().numpy().reshape(-1, 3)
552
+ v_cart = v_mw / np.sqrt(m3)
553
+ v_cart /= np.linalg.norm(v_cart)
554
+
555
+ disp = amp_bohr * mass_scale * v_cart
556
+ ref = geom.cart_coords.reshape(-1, 3)
557
+
558
+ plus = ref + disp
559
+ minus = ref - disp
560
+
561
+ geom.coords = plus.reshape(-1)
562
+ E_plus = _calc_energy(geom, calc_kwargs)
563
+
564
+ geom.coords = minus.reshape(-1)
565
+ E_minus = _calc_energy(geom, calc_kwargs)
566
+
567
+ use_plus = E_plus <= E_minus
568
+ geom.coords = (plus if use_plus else minus).reshape(-1)
569
+ E_keep = E_plus if use_plus else E_minus
570
+ delta_e = E_keep - E_ref
571
+ click.echo(
572
+ f"[Flatten] mode={idx} freq={freqs_cm[idx]:+.2f} cm^-1 "
573
+ f"E_disp={E_keep:.8f} Ha \u0394E={delta_e:+.8f} Ha"
574
+ )
575
+
576
+ if torch.cuda.is_available():
577
+ torch.cuda.empty_cache()
578
+ return True
579
+
580
+
581
+ # -----------------------------------------------
582
+ # Microiteration optimizer
583
+ # -----------------------------------------------
584
+
585
+
586
+ def _run_microiter_opt(
587
+ geometry,
588
+ calc_cfg: Dict[str, Any],
589
+ rfo_cfg: Dict[str, Any],
590
+ lbfgs_cfg: Dict[str, Any],
591
+ opt_cfg: Dict[str, Any],
592
+ microiter_cfg: Dict[str, Any],
593
+ out_dir_path: Path,
594
+ *,
595
+ dump: bool = False,
596
+ ) -> None:
597
+ """Run macro/micro alternating optimization (Gaussian 16-style microiteration).
598
+
599
+ Macro step: 1 RFO step moving only ML region (full ONIOM force).
600
+ Micro step: LBFGS relaxing MM region with MM-only forces until convergence.
601
+ """
602
+ from .freq import _collect_layer_atom_sets
603
+
604
+ # Resolve layer atom sets
605
+ layer_sets = _collect_layer_atom_sets(calc_cfg)
606
+ ml_indices = sorted(layer_sets["ml"])
607
+ movable_mm = sorted(layer_sets["movable_mm"] | layer_sets["hess_mm"])
608
+ frozen_mm = sorted(layer_sets["frozen_mm"])
609
+
610
+ if not ml_indices:
611
+ click.echo("[microiter] WARNING: No ML atoms found. Falling back to standard optimization.")
612
+ return None
613
+
614
+ n_atoms = len(geometry.atoms)
615
+ all_indices = list(range(n_atoms))
616
+ mm_indices = sorted(set(all_indices) - set(ml_indices))
617
+
618
+ # Freeze lists: for macro step, freeze all MM; for micro step, freeze ML
619
+ macro_freeze = sorted(set(mm_indices) | set(frozen_mm))
620
+ micro_freeze = sorted(set(ml_indices) | set(frozen_mm))
621
+
622
+ max_cycles = int(opt_cfg.get("max_cycles", 10000))
623
+ thresh = opt_cfg.get("thresh", "gau")
624
+ micro_thresh = microiter_cfg.get("micro_thresh") or thresh
625
+ micro_max_cycles = int(microiter_cfg.get("micro_max_cycles", 10000))
626
+
627
+ click.echo(
628
+ f"[microiter] ML atoms: {len(ml_indices)}, "
629
+ f"Movable MM atoms: {len(movable_mm)}, "
630
+ f"Frozen MM atoms: {len(frozen_mm)}"
631
+ )
632
+ click.echo(f"[microiter] Macro thresh: {thresh}, Micro thresh: {micro_thresh}")
633
+
634
+ # Create ONIOM calculator (shared core for MM-only calc)
635
+ base_calc = mlmm(**calc_cfg)
636
+ mm_calc = mlmm_mm_only(base_calc.core, freeze_atoms=micro_freeze)
637
+
638
+ # Seed initial Hessian for RFO (with macro freeze)
639
+ # Try IRC endpoint cache first; fall back to full Hessian calculation.
640
+ from .hessian_cache import load as _hess_load
641
+ from .freq import (
642
+ _calc_full_hessian_torch as _freq_calc_full_hessian_torch,
643
+ _torch_device as _freq_torch_device,
644
+ )
645
+ hess_device = _freq_torch_device(calc_cfg.get("ml_device", "auto"))
646
+
647
+ # Always create macro calculator (needed for optimization loop below)
648
+ macro_calc_cfg = dict(calc_cfg)
649
+ macro_calc_cfg["freeze_atoms"] = macro_freeze
650
+ macro_calc_cfg["hess_mm_atoms"] = [] # macro step は ML-only Hessian
651
+ macro_calc = mlmm(**macro_calc_cfg)
652
+
653
+ cached = _hess_load("irc_endpoint")
654
+ _cache_used = False
655
+ if cached is not None:
656
+ active_dofs = cached.get("active_dofs")
657
+ h_raw = cached["hessian"]
658
+ if isinstance(h_raw, torch.Tensor):
659
+ h_init = h_raw.clone()
660
+ else:
661
+ h_init = torch.as_tensor(h_raw, dtype=torch.float64)
662
+
663
+ # Macro step freezes MM atoms → only ML DOFs are free.
664
+ # The cached IRC Hessian covers ML+MovableMM DOFs and is generally
665
+ # larger. Extract the ML-only sub-block when active_dofs are known;
666
+ # otherwise fall back to a fresh Hessian calculation.
667
+ n_free = geometry.cart_coords.size - 3 * len(macro_freeze)
668
+ if h_init.shape[0] == n_free:
669
+ # Size already matches (e.g. non-microiter or same freeze set)
670
+ geometry.set_calculator(macro_calc)
671
+ if active_dofs is not None:
672
+ geometry.within_partial_hessian = {
673
+ "active_n_dof": len(active_dofs),
674
+ "full_n_dof": geometry.cart_coords.size,
675
+ "active_dofs": active_dofs,
676
+ "active_atoms": sorted(set(d // 3 for d in active_dofs)),
677
+ }
678
+ geometry.cart_hessian = h_init
679
+ click.echo(f"[microiter] Reusing IRC endpoint Hessian for RFO macro step (shape={h_init.shape[0]}x{h_init.shape[1]}).")
680
+ _cache_used = True
681
+ elif active_dofs is not None:
682
+ # Extract ML-only sub-block from the larger cached Hessian.
683
+ macro_free_atoms = sorted(set(range(geometry.cart_coords.size // 3)) - set(macro_freeze))
684
+ macro_free_dofs = []
685
+ for a in macro_free_atoms:
686
+ macro_free_dofs.extend([3 * a, 3 * a + 1, 3 * a + 2])
687
+ # Map macro free DOFs to indices within the cached active_dofs
688
+ cached_dof_set = set(active_dofs)
689
+ sub_indices = []
690
+ for d in macro_free_dofs:
691
+ if d in cached_dof_set:
692
+ sub_indices.append(active_dofs.index(d))
693
+ if len(sub_indices) == n_free:
694
+ idx = torch.tensor(sub_indices, dtype=torch.long)
695
+ h_sub = h_init[idx][:, idx]
696
+ macro_active_dofs = macro_free_dofs
697
+ geometry.set_calculator(macro_calc)
698
+ geometry.within_partial_hessian = {
699
+ "active_n_dof": len(macro_active_dofs),
700
+ "full_n_dof": geometry.cart_coords.size,
701
+ "active_dofs": macro_active_dofs,
702
+ "active_atoms": macro_free_atoms,
703
+ }
704
+ geometry.cart_hessian = h_sub
705
+ click.echo(
706
+ f"[microiter] Reusing IRC endpoint Hessian (sub-block) for RFO macro step "
707
+ f"(cached {h_init.shape[0]}x{h_init.shape[1]} → extracted {h_sub.shape[0]}x{h_sub.shape[1]})."
708
+ )
709
+ _cache_used = True
710
+ del h_sub
711
+ else:
712
+ click.echo(
713
+ f"[microiter] IRC endpoint Hessian sub-block extraction failed "
714
+ f"(expected {n_free}, got {len(sub_indices)}). Falling back to fresh Hessian."
715
+ )
716
+ else:
717
+ click.echo(
718
+ f"[microiter] IRC endpoint Hessian size mismatch "
719
+ f"(cached={h_init.shape[0]}, needed={n_free}). Falling back to fresh Hessian."
720
+ )
721
+ del h_init
722
+ if not _cache_used:
723
+ click.echo("[microiter] Seeding initial Hessian for RFO macro step.")
724
+ geometry.set_calculator(macro_calc)
725
+
726
+ h_init, _ = _freq_calc_full_hessian_torch(
727
+ geometry, macro_calc_cfg, hess_device, refresh_geom_meta=True,
728
+ )
729
+ geometry.cart_hessian = h_init
730
+ click.echo(f"[microiter] Initial Hessian seeded (shape={h_init.shape[0]}x{h_init.shape[1]}).")
731
+ del h_init
732
+
733
+ optim_all_path = out_dir_path / "optimization_all_trj.xyz"
734
+ macro_trj_path = out_dir_path / "optimization_trj.xyz"
735
+ total_macro_steps = 0
736
+
737
+ # Create persistent RFOptimizer once (LayerOpt pattern).
738
+ # This preserves the BFGS Hessian update chain across macro iterations.
739
+ # NOTE: geometry already has macro_calc set (line above); do NOT call
740
+ # set_calculator() again as it clears the pre-computed cart_hessian.
741
+ geometry.freeze_atoms = macro_freeze
742
+
743
+ rfo_args = dict(rfo_cfg)
744
+ rfo_args["max_cycles"] = max_cycles
745
+ rfo_args["out_dir"] = str(out_dir_path)
746
+ rfo_args["dump"] = False # trajectory dumping handled externally
747
+ rfo_args["thresh"] = thresh
748
+
749
+ macro_optimizer = RFOptimizer(geometry, **rfo_args)
750
+ macro_optimizer.prepare_opt() # initialise Hessian from geometry.cart_hessian
751
+
752
+ # Microiteration progress table (pysisyphus-style with micro_steps column)
753
+ micro_header = "cycle Δ(energy) max(|force|) rms(force) max(|step|) rms(step) micro_steps s/cycle".split()
754
+ micro_col_fmts = "int float float float float float int float_short".split()
755
+ micro_table = TablePrinter(micro_header, micro_col_fmts, width=12)
756
+ micro_table.print_header()
757
+
758
+ for macro_iter in range(max_cycles):
759
+ # ---- Macro step: 1 RFO step with ONIOM forces, MM frozen ----
760
+ geometry.freeze_atoms = macro_freeze
761
+ geometry.set_calculator(macro_calc)
762
+
763
+ # Manually feed state to the persistent optimizer (cf. LayerOpt lines 358-364)
764
+ macro_optimizer.coords.append(geometry.coords.copy())
765
+ macro_optimizer.cart_coords.append(geometry.cart_coords.copy())
766
+ macro_optimizer.cur_cycle = macro_iter
767
+
768
+ t_start = time.time()
769
+ step = macro_optimizer.optimize() # housekeeping() triggers BFGS update
770
+ macro_optimizer.steps.append(step)
771
+
772
+ # Convergence check
773
+ macro_converged, conv_info = macro_optimizer.check_convergence()
774
+ total_macro_steps += 1
775
+
776
+ if dump:
777
+ with open(macro_trj_path, "a") as f:
778
+ f.write(geometry.as_xyz() + "\n")
779
+ _append_xyz_trajectory(optim_all_path, macro_trj_path)
780
+
781
+ if macro_converged:
782
+ # Print final converged row (no micro steps)
783
+ energy_diff = macro_optimizer.energies[-1] - macro_optimizer.energies[-2] if len(macro_optimizer.energies) >= 2 else float("nan")
784
+ marks = [False, *conv_info.get_convergence()[:-1], False, False]
785
+ cycle_time = time.time() - t_start
786
+ micro_table.print_row(
787
+ (macro_iter, energy_diff, macro_optimizer.max_forces[-1], macro_optimizer.rms_forces[-1],
788
+ macro_optimizer.max_steps[-1], macro_optimizer.rms_steps[-1], 0, cycle_time),
789
+ marks=marks,
790
+ )
791
+ click.echo(f"[microiter] Macro convergence reached at iteration {macro_iter + 1}.")
792
+ break
793
+
794
+ # Apply step to geometry
795
+ new_coords = geometry.coords.copy() + step
796
+ geometry.coords = new_coords
797
+ # Record actual step (may differ due to coordinate back-transformation)
798
+ macro_optimizer.steps[-1] = geometry.coords - macro_optimizer.coords[-1]
799
+
800
+ # ---- Micro step: LBFGS with MM-only forces, ML frozen ----
801
+ geometry.freeze_atoms = micro_freeze
802
+ geometry.set_calculator(mm_calc)
803
+
804
+ micro_lbfgs_args = dict(lbfgs_cfg)
805
+ micro_lbfgs_args["max_cycles"] = micro_max_cycles
806
+ micro_lbfgs_args["thresh"] = micro_thresh
807
+ micro_lbfgs_args["out_dir"] = str(out_dir_path)
808
+ micro_lbfgs_args["dump"] = dump
809
+
810
+ micro_opt = LBFGS(geometry, **micro_lbfgs_args)
811
+ with contextlib.redirect_stdout(io.StringIO()):
812
+ micro_opt.run()
813
+ micro_steps = max(int(micro_opt.cur_cycle) + 1, 1)
814
+
815
+ if dump:
816
+ _append_xyz_trajectory(optim_all_path, out_dir_path / "optimization_trj.xyz")
817
+
818
+ del micro_opt
819
+ if torch.cuda.is_available():
820
+ torch.cuda.empty_cache()
821
+
822
+ # Print progress row with micro_steps
823
+ cycle_time = time.time() - t_start
824
+ energy_diff = macro_optimizer.energies[-1] - macro_optimizer.energies[-2] if len(macro_optimizer.energies) >= 2 else float("nan")
825
+ marks = [False, *conv_info.get_convergence()[:-1], False, False]
826
+ if (macro_iter > 1) and (macro_iter % 10 == 0):
827
+ micro_table.print_sep()
828
+ micro_table.print_row(
829
+ (macro_iter, energy_diff, macro_optimizer.max_forces[-1], macro_optimizer.rms_forces[-1],
830
+ macro_optimizer.max_steps[-1], macro_optimizer.rms_steps[-1], micro_steps, cycle_time),
831
+ marks=marks,
832
+ )
833
+
834
+ else:
835
+ click.echo(f"[microiter] Reached max macro iterations ({max_cycles}).")
836
+
837
+ del macro_optimizer
838
+ if torch.cuda.is_available():
839
+ torch.cuda.empty_cache()
840
+
841
+ click.echo(f"[microiter] Total macro steps: {total_macro_steps}")
842
+ # Restore full calculator
843
+ geometry.freeze_atoms = list(set(frozen_mm))
844
+ geometry.set_calculator(base_calc)
845
+
846
+ return geometry
847
+
848
+
849
+ # -----------------------------------------------
850
+ # CLI
851
+ # -----------------------------------------------
852
+
853
+ @click.command(
854
+ help="ML/MM geometry optimization with LBFGS (light) or RFO (heavy).",
855
+ context_settings={"help_option_names": ["-h", "--help"]},
856
+ )
857
+ @click.option(
858
+ "-i", "--input",
859
+ "input_path",
860
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
861
+ required=True,
862
+ help="Input structure file (PDB, XYZ). XYZ provides higher coordinate precision. "
863
+ "If XYZ, use --ref-pdb to specify PDB topology for atom ordering and output conversion.",
864
+ )
865
+ @click.option(
866
+ "--ref-pdb",
867
+ "ref_pdb",
868
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
869
+ default=None,
870
+ show_default=False,
871
+ help="Reference PDB topology when input is XYZ. XYZ coordinates are used (higher precision) "
872
+ "while PDB provides atom ordering and residue information for output conversion.",
873
+ )
874
+ @click.option(
875
+ "--parm",
876
+ "real_parm7",
877
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
878
+ required=True,
879
+ help="Amber parm7 topology covering the whole enzyme complex.",
880
+ )
881
+ @click.option(
882
+ "--model-pdb",
883
+ "model_pdb",
884
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
885
+ required=False,
886
+ help="PDB defining atoms that belong to the ML (high-level) region. "
887
+ "Optional when --detect-layer is enabled.",
888
+ )
889
+ @click.option(
890
+ "--model-indices",
891
+ "model_indices_str",
892
+ type=str,
893
+ default=None,
894
+ show_default=False,
895
+ help="Comma-separated atom indices for the ML region (ranges allowed like 1-5). "
896
+ "Used when --model-pdb is omitted.",
897
+ )
898
+ @click.option(
899
+ "--model-indices-one-based/--model-indices-zero-based",
900
+ "model_indices_one_based",
901
+ default=True,
902
+ show_default=True,
903
+ help="Interpret --model-indices as 1-based (default) or 0-based.",
904
+ )
905
+ @click.option(
906
+ "--detect-layer/--no-detect-layer",
907
+ "detect_layer",
908
+ default=True,
909
+ show_default=True,
910
+ help="Detect ML/MM layers from input PDB B-factors (B=0/10/20). "
911
+ "If disabled, you must provide --model-pdb or --model-indices.",
912
+ )
913
+ @click.option("-q", "--charge", type=int, required=False,
914
+ help="ML region charge. Required unless --ligand-charge is provided.")
915
+ @click.option("-l", "--ligand-charge", type=str, default=None, show_default=False,
916
+ help="Total charge or per-resname mapping (e.g., GPP:-3,SAM:1) used to derive "
917
+ "charge when -q is omitted (requires PDB input or --ref-pdb).")
918
+ @click.option(
919
+ "-m",
920
+ "--multiplicity",
921
+ "spin",
922
+ type=int,
923
+ default=None,
924
+ show_default=False,
925
+ help="Spin multiplicity (2S+1) for the ML region. Defaults to 1 when omitted.",
926
+ )
927
+ @click.option(
928
+ "--freeze-atoms",
929
+ "freeze_atoms_text",
930
+ type=str,
931
+ default=None,
932
+ show_default=False,
933
+ help="Comma-separated 1-based atom indices to freeze (e.g., '1,3,5').",
934
+ )
935
+ @click.option(
936
+ "--radius-partial-hessian",
937
+ "--hess-cutoff",
938
+ "radius_partial_hessian",
939
+ type=float,
940
+ default=None,
941
+ show_default=False,
942
+ help="Distance cutoff (Å) from ML region for MM atoms to include in Hessian calculation. "
943
+ "Applied to movable MM atoms and can be combined with --detect-layer. "
944
+ "`--hess-cutoff` is a compatibility alias.",
945
+ )
946
+ @click.option(
947
+ "--radius-freeze",
948
+ "--movable-cutoff",
949
+ "radius_freeze",
950
+ type=float,
951
+ default=None,
952
+ show_default=False,
953
+ help="Distance cutoff (Å) from ML region for movable MM atoms. "
954
+ "MM atoms beyond this are frozen. "
955
+ "Providing --radius-freeze disables --detect-layer and uses distance-based layer assignment. "
956
+ "`--movable-cutoff` is a compatibility alias.",
957
+ )
958
+ @click.option(
959
+ "--dist-freeze",
960
+ "dist_freeze_raw",
961
+ type=str,
962
+ multiple=True,
963
+ default=(),
964
+ show_default=False,
965
+ help="Distance restraints: inline Python literal (e.g. '[(1,5,1.4)]') or a YAML/JSON spec file path. "
966
+ "Same format as --scan-lists: (i,j,target_A) triples. "
967
+ "Target may be omitted to freeze at the current distance: (i,j).",
968
+ )
969
+ @click.option(
970
+ "--one-based/--zero-based",
971
+ "one_based",
972
+ default=True,
973
+ show_default=True,
974
+ help="Interpret --dist-freeze / --scan-lists indices as 1-based (default) or 0-based.",
975
+ )
976
+ @click.option(
977
+ "--bias-k",
978
+ type=float,
979
+ default=300.0,
980
+ show_default=True,
981
+ help="Harmonic restraint strength k [eV/Å^2] for --dist-freeze.",
982
+ )
983
+ @click.option("--max-cycles", type=int, default=10000, show_default=True, help="Maximum number of optimization cycles.")
984
+ @click.option(
985
+ "--dump/--no-dump",
986
+ default=False,
987
+ show_default=True,
988
+ help="Write optimization trajectories ('optimization_trj.xyz' and 'optimization_all_trj.xyz').",
989
+ )
990
+ @click.option("-o", "--out-dir", type=str, default="./result_opt/", show_default=True, help="Output directory.")
991
+ @click.option(
992
+ "--thresh",
993
+ type=click.Choice(["gau_loose", "gau", "gau_tight", "gau_vtight", "baker", "never"], case_sensitive=False),
994
+ default=None,
995
+ help="Convergence preset.",
996
+ )
997
+ @click.option(
998
+ "--opt-mode",
999
+ type=click.Choice(["grad", "hess", "light", "heavy", "lbfgs", "rfo"], case_sensitive=False),
1000
+ default="grad",
1001
+ show_default=True,
1002
+ help="Optimization mode: grad (lbfgs) or hess (rfo). Aliases light/heavy and lbfgs/rfo are accepted.",
1003
+ )
1004
+ @click.option(
1005
+ "--microiter/--no-microiter",
1006
+ "microiter",
1007
+ default=True,
1008
+ show_default=True,
1009
+ help="Enable microiteration: alternate ML 1-step (RFO) and MM relaxation (LBFGS with MM-only forces). "
1010
+ "Only effective in --opt-mode hess (RFO). Ignored in grad mode.",
1011
+ )
1012
+ @click.option(
1013
+ "--flatten/--no-flatten",
1014
+ "flatten",
1015
+ default=False,
1016
+ show_default=True,
1017
+ help="Enable/disable imaginary-mode flatten loop after optimization.",
1018
+ )
1019
+ @click.option(
1020
+ "--config",
1021
+ "config_yaml",
1022
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
1023
+ default=None,
1024
+ help="Base YAML configuration file applied before explicit CLI options.",
1025
+ )
1026
+ @click.option(
1027
+ "--show-config/--no-show-config",
1028
+ "show_config",
1029
+ default=False,
1030
+ show_default=True,
1031
+ help="Print resolved configuration and continue execution.",
1032
+ )
1033
+ @click.option(
1034
+ "--dry-run/--no-dry-run",
1035
+ "dry_run",
1036
+ default=False,
1037
+ show_default=True,
1038
+ help="Validate options and print the execution plan without running optimization.",
1039
+ )
1040
+ @click.option(
1041
+ "--convert-files/--no-convert-files",
1042
+ "convert_files",
1043
+ default=True,
1044
+ show_default=True,
1045
+ help="Convert XYZ/TRJ outputs into PDB companions based on the input format.",
1046
+ )
1047
+ @click.option(
1048
+ "-b", "--backend",
1049
+ type=click.Choice(["uma", "orb", "mace", "aimnet2"], case_sensitive=False),
1050
+ default=None,
1051
+ show_default=False,
1052
+ help="ML backend for the ONIOM high-level region (default: uma).",
1053
+ )
1054
+ @click.option(
1055
+ "--embedcharge/--no-embedcharge",
1056
+ "embedcharge",
1057
+ default=False,
1058
+ show_default=True,
1059
+ help="Enable xTB point-charge embedding correction for MM→ML environmental effects.",
1060
+ )
1061
+ @click.option(
1062
+ "--embedcharge-cutoff",
1063
+ "embedcharge_cutoff",
1064
+ type=float,
1065
+ default=None,
1066
+ show_default=False,
1067
+ help="Distance cutoff (Å) from ML region for MM point charges in xTB embedding. "
1068
+ "Default: 12.0 Å when --embedcharge is enabled.",
1069
+ )
1070
+ @click.pass_context
1071
+ def cli(
1072
+ ctx: click.Context,
1073
+ input_path: Path,
1074
+ ref_pdb: Optional[Path],
1075
+ real_parm7: Path,
1076
+ model_pdb: Optional[Path],
1077
+ model_indices_str: Optional[str],
1078
+ model_indices_one_based: bool,
1079
+ detect_layer: bool,
1080
+ charge: Optional[int],
1081
+ ligand_charge: Optional[str],
1082
+ spin: Optional[int],
1083
+ freeze_atoms_text: Optional[str],
1084
+ radius_partial_hessian: Optional[float],
1085
+ radius_freeze: Optional[float],
1086
+ dist_freeze_raw: Sequence[str],
1087
+ one_based: bool,
1088
+ bias_k: float,
1089
+ max_cycles: int,
1090
+ dump: bool,
1091
+ out_dir: str,
1092
+ thresh: Optional[str],
1093
+ opt_mode: str,
1094
+ microiter: bool,
1095
+ flatten: bool,
1096
+ config_yaml: Optional[Path],
1097
+ show_config: bool,
1098
+ dry_run: bool,
1099
+ convert_files: bool,
1100
+ backend: Optional[str],
1101
+ embedcharge: bool,
1102
+ embedcharge_cutoff: Optional[float],
1103
+ ) -> None:
1104
+ set_convert_file_enabled(convert_files)
1105
+ time_start = time.perf_counter()
1106
+ prepared_input = None
1107
+
1108
+ _is_param_explicit = make_is_param_explicit(ctx)
1109
+
1110
+ config_yaml, override_yaml, used_legacy_yaml = resolve_yaml_sources(
1111
+ config_yaml=config_yaml,
1112
+ override_yaml=None,
1113
+ args_yaml_legacy=None,
1114
+ )
1115
+ merged_yaml_cfg, _, _ = load_merged_yaml_cfg(
1116
+ config_yaml=config_yaml,
1117
+ override_yaml=None,
1118
+ )
1119
+
1120
+ # Handle input: PDB directly, or XYZ with --ref-pdb for topology
1121
+ suffix = input_path.suffix.lower()
1122
+ if suffix == ".pdb":
1123
+ # PDB input: use directly
1124
+ prepared_input = prepare_input_structure(input_path)
1125
+ elif suffix == ".xyz":
1126
+ # XYZ input: require --ref-pdb for topology
1127
+ if ref_pdb is None:
1128
+ click.echo("ERROR: XYZ/TRJ input requires --ref-pdb to specify PDB topology.", err=True)
1129
+ sys.exit(1)
1130
+ prepared_input = prepare_input_structure(input_path)
1131
+ apply_ref_pdb_override(prepared_input, ref_pdb)
1132
+ click.echo(f"[input] Using XYZ coordinates from {input_path.name}, PDB topology from {ref_pdb.name}")
1133
+ else:
1134
+ click.echo(f"ERROR: Unsupported input format: {suffix}. Use .pdb or .xyz (with --ref-pdb).", err=True)
1135
+ sys.exit(1)
1136
+
1137
+ geom_input_path = prepared_input.geom_path
1138
+ charge, spin = resolve_charge_spin_or_raise(
1139
+ prepared_input, charge, spin,
1140
+ ligand_charge=ligand_charge, prefix="[opt]",
1141
+ )
1142
+
1143
+ try:
1144
+ freeze_atoms_cli = _parse_freeze_atoms(freeze_atoms_text)
1145
+ except click.BadParameter as e:
1146
+ click.echo(f"ERROR: {e}", err=True)
1147
+ prepared_input.cleanup()
1148
+ sys.exit(1)
1149
+
1150
+ model_indices: Optional[List[int]] = None
1151
+ if model_indices_str:
1152
+ try:
1153
+ model_indices = parse_indices_string(model_indices_str, one_based=model_indices_one_based)
1154
+ except click.BadParameter as e:
1155
+ click.echo(f"ERROR: {e}", err=True)
1156
+ prepared_input.cleanup()
1157
+ sys.exit(1)
1158
+
1159
+ pdb_atom_meta: List[Dict[str, Any]] = []
1160
+ if prepared_input.source_path.suffix.lower() == ".pdb":
1161
+ pdb_atom_meta = load_pdb_atom_metadata(prepared_input.source_path)
1162
+
1163
+ try:
1164
+ dist_freeze = _parse_dist_freeze_args(
1165
+ dist_freeze_raw, one_based=bool(one_based), atom_meta=pdb_atom_meta,
1166
+ )
1167
+ except click.BadParameter as e:
1168
+ click.echo(f"ERROR: {e}", err=True)
1169
+ prepared_input.cleanup()
1170
+ sys.exit(1)
1171
+
1172
+ # Resolve optimizer mode
1173
+ mode_resolved = normalize_choice(
1174
+ opt_mode,
1175
+ param="--opt-mode",
1176
+ alias_groups=OPT_MODE_ALIASES,
1177
+ allowed_hint="grad|hess|lbfgs|rfo",
1178
+ )
1179
+ use_rfo = (mode_resolved == "rfo")
1180
+
1181
+ try:
1182
+ config_layer_cfg = load_yaml_dict(config_yaml)
1183
+ override_layer_cfg = load_yaml_dict(override_yaml)
1184
+ geom_cfg = dict(GEOM_KW)
1185
+ calc_cfg = dict(CALC_KW)
1186
+ opt_cfg = dict(OPT_BASE_KW)
1187
+ lbfgs_cfg = dict(LBFGS_KW)
1188
+ rfo_cfg = dict(RFO_KW)
1189
+
1190
+ apply_yaml_overrides(
1191
+ config_layer_cfg,
1192
+ [
1193
+ (geom_cfg, (("geom",),)),
1194
+ (calc_cfg, (("calc",), ("mlmm",))),
1195
+ (opt_cfg, (("opt",),)),
1196
+ (lbfgs_cfg, (("lbfgs",), ("opt", "lbfgs"))),
1197
+ (rfo_cfg, (("rfo",), ("opt", "rfo"))),
1198
+ ],
1199
+ )
1200
+
1201
+ if _is_param_explicit("max_cycles"):
1202
+ opt_cfg["max_cycles"] = int(max_cycles)
1203
+ if _is_param_explicit("dump"):
1204
+ opt_cfg["dump"] = bool(dump)
1205
+ if _is_param_explicit("out_dir"):
1206
+ opt_cfg["out_dir"] = out_dir
1207
+ if _is_param_explicit("thresh") and thresh is not None:
1208
+ opt_cfg["thresh"] = str(thresh)
1209
+
1210
+ if _is_param_explicit("detect_layer"):
1211
+ calc_cfg["use_bfactor_layers"] = bool(detect_layer)
1212
+ if _is_param_explicit("radius_partial_hessian") and radius_partial_hessian is not None:
1213
+ calc_cfg["hess_cutoff"] = float(radius_partial_hessian)
1214
+ if _is_param_explicit("radius_freeze") and radius_freeze is not None:
1215
+ calc_cfg["movable_cutoff"] = float(radius_freeze)
1216
+ calc_cfg["use_bfactor_layers"] = False
1217
+
1218
+ model_charge_value = calc_cfg.get("model_charge", charge)
1219
+ if model_charge_value is None:
1220
+ model_charge_value = charge
1221
+ calc_cfg["model_charge"] = int(model_charge_value)
1222
+ if _is_param_explicit("charge"):
1223
+ calc_cfg["model_charge"] = int(charge)
1224
+ model_mult_value = calc_cfg.get("model_mult", spin)
1225
+ if model_mult_value is None:
1226
+ model_mult_value = spin
1227
+ calc_cfg["model_mult"] = int(model_mult_value)
1228
+ if _is_param_explicit("spin"):
1229
+ calc_cfg["model_mult"] = int(spin)
1230
+ if model_pdb is not None:
1231
+ calc_cfg["model_pdb"] = str(model_pdb)
1232
+ calc_cfg["input_pdb"] = str(prepared_input.source_path)
1233
+ calc_cfg["real_parm7"] = str(real_parm7)
1234
+ if backend is not None:
1235
+ calc_cfg["backend"] = str(backend).lower()
1236
+ if _is_param_explicit("embedcharge"):
1237
+ calc_cfg["embedcharge"] = bool(embedcharge)
1238
+ if _is_param_explicit("embedcharge_cutoff"):
1239
+ calc_cfg["embedcharge_cutoff"] = embedcharge_cutoff
1240
+
1241
+ apply_yaml_overrides(
1242
+ override_layer_cfg,
1243
+ [
1244
+ (geom_cfg, (("geom",),)),
1245
+ (calc_cfg, (("calc",), ("mlmm",))),
1246
+ (opt_cfg, (("opt",),)),
1247
+ (lbfgs_cfg, (("lbfgs",), ("opt", "lbfgs"))),
1248
+ (rfo_cfg, (("rfo",), ("opt", "rfo"))),
1249
+ ],
1250
+ )
1251
+ calc_paths = (("calc",), ("mlmm",))
1252
+ partial_explicit = (
1253
+ yaml_section_has_key(config_layer_cfg, calc_paths, "return_partial_hessian")
1254
+ or yaml_section_has_key(override_layer_cfg, calc_paths, "return_partial_hessian")
1255
+ )
1256
+ if not partial_explicit:
1257
+ calc_cfg["return_partial_hessian"] = True
1258
+
1259
+ try:
1260
+ geom_freeze = _normalize_geom_freeze(geom_cfg.get("freeze_atoms"))
1261
+ except click.BadParameter as e:
1262
+ click.echo(f"ERROR: {e}", err=True)
1263
+ prepared_input.cleanup()
1264
+ sys.exit(1)
1265
+ geom_cfg["freeze_atoms"] = geom_freeze
1266
+ if freeze_atoms_cli:
1267
+ merge_freeze_atom_indices(geom_cfg, freeze_atoms_cli)
1268
+ freeze_atoms_final = list(geom_cfg.get("freeze_atoms") or [])
1269
+ calc_cfg["freeze_atoms"] = freeze_atoms_final
1270
+
1271
+ out_dir_path = Path(opt_cfg["out_dir"]).resolve()
1272
+
1273
+ # radius_freeze implies full distance-based layer assignment.
1274
+ # radius_partial_hessian alone can be combined with --detect-layer.
1275
+ detect_layer_enabled = bool(calc_cfg.get("use_bfactor_layers", True))
1276
+ model_pdb_cfg = calc_cfg.get("model_pdb")
1277
+ if radius_freeze is not None:
1278
+ if detect_layer_enabled:
1279
+ click.echo("[layer] --radius-freeze provided; disabling --detect-layer.", err=True)
1280
+ detect_layer_enabled = False
1281
+ calc_cfg["use_bfactor_layers"] = False
1282
+
1283
+ layer_source_pdb = prepared_input.source_path
1284
+ if detect_layer_enabled and layer_source_pdb.suffix.lower() != ".pdb":
1285
+ click.echo("ERROR: --detect-layer requires a PDB input (or --ref-pdb).", err=True)
1286
+ prepared_input.cleanup()
1287
+ sys.exit(1)
1288
+
1289
+ if show_config:
1290
+ click.echo(
1291
+ pretty_block(
1292
+ "yaml_layers",
1293
+ {
1294
+ "config": None if config_yaml is None else str(config_yaml),
1295
+ "override_yaml": None if override_yaml is None else str(override_yaml),
1296
+ "merged_keys": sorted(merged_yaml_cfg.keys()),
1297
+ },
1298
+ )
1299
+ )
1300
+
1301
+ if dry_run:
1302
+ model_region_source = "bfactor"
1303
+ if not detect_layer_enabled:
1304
+ if model_pdb_cfg is not None:
1305
+ model_region_source = "model_pdb"
1306
+ elif model_indices:
1307
+ model_region_source = "model_indices"
1308
+ else:
1309
+ click.echo("ERROR: Provide --model-pdb or --model-indices when --no-detect-layer.", err=True)
1310
+ prepared_input.cleanup()
1311
+ sys.exit(1)
1312
+ if (
1313
+ not detect_layer_enabled
1314
+ and model_pdb_cfg is None
1315
+ and model_indices
1316
+ and layer_source_pdb.suffix.lower() != ".pdb"
1317
+ ):
1318
+ click.echo("ERROR: --model-indices requires a PDB input (or --ref-pdb).", err=True)
1319
+ prepared_input.cleanup()
1320
+ sys.exit(1)
1321
+ click.echo(
1322
+ pretty_block(
1323
+ "dry_run_plan",
1324
+ {
1325
+ "input_geometry": str(geom_input_path),
1326
+ "output_dir": str(out_dir_path),
1327
+ "optimizer_mode": "rfo" if use_rfo else "lbfgs",
1328
+ "detect_layer": bool(detect_layer_enabled),
1329
+ "model_region_source": model_region_source,
1330
+ "model_indices_count": 0 if not model_indices else len(model_indices),
1331
+ "will_run_optimization": True,
1332
+ "will_convert_outputs": True,
1333
+ "backend": calc_cfg.get("backend", "uma"),
1334
+ "embedcharge": bool(calc_cfg.get("embedcharge", False)),
1335
+ },
1336
+ )
1337
+ )
1338
+ click.echo("[dry-run] Validation complete. Optimization execution was skipped.")
1339
+ return
1340
+
1341
+ model_pdb_path: Optional[Path] = None
1342
+ layer_info: Optional[Dict[str, List[int]]] = None
1343
+
1344
+ if detect_layer_enabled:
1345
+ try:
1346
+ model_pdb_path, layer_info = build_model_pdb_from_bfactors(layer_source_pdb, out_dir_path)
1347
+ calc_cfg["use_bfactor_layers"] = True
1348
+ click.echo(
1349
+ f"[layer] Detected B-factor layers: ML={len(layer_info.get('ml_indices', []))}, "
1350
+ f"MovableMM={len(layer_info.get('movable_mm_indices', []))}, "
1351
+ f"FrozenMM={len(layer_info.get('frozen_indices', []))}"
1352
+ )
1353
+ except Exception as e:
1354
+ if model_pdb_cfg is None and not model_indices:
1355
+ click.echo(f"ERROR: {e}", err=True)
1356
+ prepared_input.cleanup()
1357
+ sys.exit(1)
1358
+ click.echo(f"[layer] WARNING: {e} Falling back to explicit ML region.", err=True)
1359
+ detect_layer_enabled = False
1360
+
1361
+ if not detect_layer_enabled:
1362
+ if model_pdb_cfg is None and not model_indices:
1363
+ click.echo("ERROR: Provide --model-pdb or --model-indices when --no-detect-layer.", err=True)
1364
+ prepared_input.cleanup()
1365
+ sys.exit(1)
1366
+ if model_pdb_cfg is not None:
1367
+ model_pdb_path = Path(model_pdb_cfg)
1368
+ else:
1369
+ if layer_source_pdb.suffix.lower() != ".pdb":
1370
+ click.echo("ERROR: --model-indices requires a PDB input (or --ref-pdb).", err=True)
1371
+ prepared_input.cleanup()
1372
+ sys.exit(1)
1373
+ try:
1374
+ model_pdb_path = build_model_pdb_from_indices(layer_source_pdb, out_dir_path, model_indices or [])
1375
+ except Exception as e:
1376
+ click.echo(f"ERROR: {e}", err=True)
1377
+ prepared_input.cleanup()
1378
+ sys.exit(1)
1379
+ calc_cfg["use_bfactor_layers"] = False
1380
+
1381
+ if model_pdb_path is None:
1382
+ click.echo("ERROR: Failed to resolve model PDB for the ML region.", err=True)
1383
+ prepared_input.cleanup()
1384
+ sys.exit(1)
1385
+
1386
+ calc_cfg["model_pdb"] = str(model_pdb_path)
1387
+
1388
+ # When layer detection is enabled, also freeze frozen-layer atoms at the
1389
+ # optimizer geometry level (not only inside the calculator).
1390
+ # Otherwise LBFGS may still move those coordinates through coupled
1391
+ # inverse-Hessian updates, even if raw forces are zeroed there.
1392
+ if layer_info is not None:
1393
+ frozen_from_layer = [int(i) for i in layer_info.get("frozen_indices", [])]
1394
+ if frozen_from_layer:
1395
+ before = set(freeze_atoms_final)
1396
+ merged = sorted(before | set(frozen_from_layer))
1397
+ added = len(set(merged) - before)
1398
+ freeze_atoms_final = merged
1399
+ geom_cfg["freeze_atoms"] = freeze_atoms_final
1400
+ calc_cfg["freeze_atoms"] = freeze_atoms_final
1401
+ click.echo(
1402
+ f"[layer] Applied optimizer freeze constraints: "
1403
+ f"total={len(freeze_atoms_final)} (added_from_layer={added})"
1404
+ )
1405
+
1406
+ # Distance-based overrides for Hessian-target and movable MM selection.
1407
+ hess_cutoff_final = calc_cfg.get("hess_cutoff")
1408
+ movable_cutoff_final = calc_cfg.get("movable_cutoff")
1409
+ if hess_cutoff_final is not None or movable_cutoff_final is not None:
1410
+ click.echo(
1411
+ f"[layer] Applied distance cutoffs: "
1412
+ f"hess={hess_cutoff_final} Å, freeze={movable_cutoff_final} Å"
1413
+ )
1414
+ from .freq import _align_three_layer_hessian_targets as _freq_align_three_layer_hessian_targets
1415
+ _freq_align_three_layer_hessian_targets(calc_cfg, echo_fn=click.echo)
1416
+
1417
+ for key in ("input_pdb", "real_parm7", "model_pdb", "mm_fd_dir"):
1418
+ val = calc_cfg.get(key)
1419
+ if val:
1420
+ calc_cfg[key] = str(Path(val).expanduser().resolve())
1421
+
1422
+ mode_str = "RFO (hess)" if use_rfo else "LBFGS (grad)"
1423
+ click.echo(f"\n[mode] Optimizer: {mode_str}\n")
1424
+ click.echo(pretty_block("geom", format_freeze_atoms_for_echo(geom_cfg, key="freeze_atoms")))
1425
+ echo_calc = format_freeze_atoms_for_echo(filter_calc_for_echo(calc_cfg), key="freeze_atoms")
1426
+ click.echo(pretty_block("calc", echo_calc))
1427
+ # Show only non-default opt settings
1428
+ echo_opt = strip_inherited_keys({**opt_cfg, "out_dir": str(out_dir_path)}, OPT_BASE_KW, mode="same")
1429
+ click.echo(pretty_block("opt", echo_opt))
1430
+ # Show only optimizer-specific settings, not inherited from opt_cfg
1431
+ if use_rfo:
1432
+ echo_rfo = strip_inherited_keys(rfo_cfg, opt_cfg)
1433
+ click.echo(pretty_block("rfo", echo_rfo))
1434
+ else:
1435
+ echo_lbfgs = strip_inherited_keys(lbfgs_cfg, opt_cfg)
1436
+ click.echo(pretty_block("lbfgs", echo_lbfgs))
1437
+ if dist_freeze:
1438
+ display_pairs = []
1439
+ for (i, j, target) in dist_freeze:
1440
+ label = (f"{target:.4f}" if target is not None else "<current>")
1441
+ display_pairs.append((int(i) + 1, int(j) + 1, label))
1442
+ click.echo(
1443
+ pretty_block(
1444
+ "dist_freeze (input)",
1445
+ {
1446
+ "k (eV/Å^2)": float(bias_k),
1447
+ "pairs_1based": display_pairs,
1448
+ },
1449
+ )
1450
+ )
1451
+
1452
+ out_dir_path.mkdir(parents=True, exist_ok=True)
1453
+ coord_type = geom_cfg.get("coord_type", "cart")
1454
+ coord_kwargs = dict(geom_cfg)
1455
+ coord_kwargs.pop("coord_type", None)
1456
+ geometry = geom_loader(
1457
+ geom_input_path,
1458
+ coord_type=coord_type,
1459
+ **coord_kwargs,
1460
+ )
1461
+
1462
+ base_calc = mlmm(**calc_cfg)
1463
+ geometry.set_calculator(base_calc)
1464
+
1465
+ resolved_dist_freeze: List[Tuple[int, int, float]] = []
1466
+ if dist_freeze:
1467
+ try:
1468
+ resolved_dist_freeze = _resolve_dist_freeze_targets(geometry, dist_freeze)
1469
+ except click.BadParameter as e:
1470
+ click.echo(f"ERROR: {e}", err=True)
1471
+ sys.exit(1)
1472
+ click.echo(
1473
+ pretty_block(
1474
+ "dist_freeze (active)",
1475
+ {
1476
+ "k (eV/Å^2)": float(bias_k),
1477
+ "pairs_1based": [
1478
+ (int(i) + 1, int(j) + 1, float(f"{t:.4f}"))
1479
+ for (i, j, t) in resolved_dist_freeze
1480
+ ],
1481
+ },
1482
+ )
1483
+ )
1484
+ bias_calc = HarmonicBiasCalculator(base_calc, k=float(bias_k))
1485
+ bias_calc.set_pairs(resolved_dist_freeze)
1486
+ geometry.set_calculator(bias_calc)
1487
+
1488
+ # Pass only opt-level values that differ from OPT_BASE defaults, so
1489
+ # optimizer-specific YAML (e.g. rfo.print_every / lbfgs.print_every)
1490
+ # is not overwritten by inherited defaults such as opt.print_every=100.
1491
+ common_kwargs = strip_inherited_keys(dict(opt_cfg), OPT_BASE_KW, mode="same")
1492
+ common_kwargs["out_dir"] = str(out_dir_path)
1493
+
1494
+ def _build_optimizer(run_kind: str):
1495
+ if run_kind == "lbfgs":
1496
+ lbfgs_args = {**lbfgs_cfg, **common_kwargs}
1497
+ return LBFGS(geometry, **lbfgs_args)
1498
+ if run_kind == "rfo":
1499
+ rfo_args = {**rfo_cfg, **common_kwargs}
1500
+ return RFOptimizer(geometry, **rfo_args)
1501
+ raise click.BadParameter(f"Unknown optimizer kind '{run_kind}'.")
1502
+
1503
+ def _seed_rfo_hessian():
1504
+ """Seed initial Hessian via shared freq backend for RFO."""
1505
+ from .hessian_cache import load as _hess_load
1506
+ cached = _hess_load("irc_endpoint")
1507
+ if cached is not None:
1508
+ click.echo("[opt] Reusing IRC endpoint Hessian for RFO seeding.")
1509
+ active_dofs = cached.get("active_dofs")
1510
+ h_raw = cached["hessian"]
1511
+ if isinstance(h_raw, torch.Tensor):
1512
+ h_init = h_raw.clone()
1513
+ else:
1514
+ h_init = torch.as_tensor(h_raw, dtype=torch.float64)
1515
+ if active_dofs is not None:
1516
+ geometry.within_partial_hessian = {
1517
+ "active_n_dof": len(active_dofs),
1518
+ "full_n_dof": geometry.cart_coords.size,
1519
+ "active_dofs": active_dofs,
1520
+ "active_atoms": sorted(set(d // 3 for d in active_dofs)),
1521
+ }
1522
+ geometry.cart_hessian = h_init
1523
+ click.echo(f"[opt] Initial Hessian seeded (shape={h_init.shape[0]}x{h_init.shape[1]}).")
1524
+ del h_init
1525
+ return
1526
+ click.echo("[opt] Seeding initial Hessian via shared freq backend.")
1527
+ from .freq import (
1528
+ _calc_full_hessian_torch as _freq_calc_full_hessian_torch,
1529
+ _torch_device as _freq_torch_device,
1530
+ )
1531
+ hess_device = _freq_torch_device(calc_cfg.get("ml_device", "auto"))
1532
+ h_init, _ = _freq_calc_full_hessian_torch(
1533
+ geometry,
1534
+ calc_cfg,
1535
+ hess_device,
1536
+ refresh_geom_meta=True,
1537
+ )
1538
+ geometry.cart_hessian = h_init
1539
+ click.echo(f"[opt] Initial Hessian seeded (shape={h_init.shape[0]}x{h_init.shape[1]}).")
1540
+ del h_init
1541
+
1542
+ # Resolve microiteration config from YAML
1543
+ microiter_cfg = dict(MICROITER_KW)
1544
+ apply_yaml_overrides(
1545
+ config_layer_cfg,
1546
+ [(microiter_cfg, (("microiter",),))],
1547
+ )
1548
+ apply_yaml_overrides(
1549
+ override_layer_cfg,
1550
+ [(microiter_cfg, (("microiter",),))],
1551
+ )
1552
+
1553
+ use_microiter = bool(microiter) and use_rfo and not dist_freeze
1554
+ if bool(microiter) and not use_rfo:
1555
+ click.echo("[microiter] --microiter is only effective with --opt-mode hess (RFO). Ignoring.")
1556
+ if bool(microiter) and use_rfo and dist_freeze:
1557
+ click.echo("[microiter] --microiter is not compatible with --dist-freeze. Falling back to standard RFO.")
1558
+
1559
+ if use_microiter:
1560
+ click.echo("\n====== Optimization (RFO + Microiteration) started ======\n")
1561
+ _run_microiter_opt(
1562
+ geometry,
1563
+ calc_cfg,
1564
+ rfo_cfg,
1565
+ lbfgs_cfg,
1566
+ opt_cfg,
1567
+ microiter_cfg,
1568
+ out_dir_path,
1569
+ dump=bool(opt_cfg["dump"]),
1570
+ )
1571
+ click.echo("\n====== Optimization (RFO + Microiteration) finished ======\n")
1572
+
1573
+ # Write final geometry
1574
+ from ase import Atoms as _Atoms
1575
+ from ase.io import write as _write
1576
+ final_xyz_path = out_dir_path / "final_geometry.xyz"
1577
+ final_coords_ang = geometry.coords.reshape(-1, 3) * BOHR2ANG
1578
+ atoms_final = _Atoms(geometry.atoms, positions=final_coords_ang, pbc=False)
1579
+ _write(final_xyz_path, atoms_final)
1580
+
1581
+ else:
1582
+ main_kind = "rfo" if use_rfo else "lbfgs"
1583
+ if use_rfo:
1584
+ _seed_rfo_hessian()
1585
+
1586
+ main_label = "RFO" if use_rfo else "LBFGS"
1587
+ optimizer = _build_optimizer(main_kind)
1588
+ click.echo(f"\n====== Optimization ({main_label}) started ======\n")
1589
+ optimizer.run()
1590
+ click.echo(f"\n====== Optimization ({main_label}) finished ======\n")
1591
+
1592
+ # Get final geometry path
1593
+ final_xyz_path = optimizer.final_fn if isinstance(optimizer.final_fn, Path) else Path(optimizer.final_fn)
1594
+
1595
+ if bool(opt_cfg["dump"]):
1596
+ optim_all_path = out_dir_path / "optimization_all_trj.xyz"
1597
+ if not optim_all_path.exists():
1598
+ trj_path = optimizer.get_path_for_fn("optimization_trj.xyz")
1599
+ _append_xyz_trajectory(optim_all_path, trj_path, reset=True)
1600
+
1601
+ # --------------------------
1602
+ # Flatten loop (all imaginary modes)
1603
+ # --------------------------
1604
+ if flatten:
1605
+ from .freq import (
1606
+ _torch_device,
1607
+ _calc_full_hessian_torch,
1608
+ _frequencies_cm_and_modes,
1609
+ _safe_masses_amu,
1610
+ )
1611
+
1612
+ click.echo("\n====== Optimization (Flatten loop) started ======\n")
1613
+
1614
+ geometry.set_calculator(None)
1615
+ uma_kwargs_for_flatten = dict(calc_cfg)
1616
+ uma_kwargs_for_flatten["out_hess_torch"] = True
1617
+ device = _torch_device(calc_cfg.get("ml_device", "auto"))
1618
+ freeze_idx = list(geom_cfg.get("freeze_atoms", [])) if len(geom_cfg.get("freeze_atoms", [])) > 0 else None
1619
+ masses_amu = _safe_masses_amu(geometry.atomic_numbers)
1620
+
1621
+ def _attach_opt_calc() -> None:
1622
+ geometry.set_calculator(
1623
+ bias_calc if resolved_dist_freeze else base_calc
1624
+ )
1625
+
1626
+ def _calc_freqs_and_modes() -> Tuple[np.ndarray, torch.Tensor]:
1627
+ H, _e = _calc_full_hessian_torch(geometry, uma_kwargs_for_flatten, device)
1628
+ freqs_local, modes_local = _frequencies_cm_and_modes(
1629
+ H,
1630
+ geometry.atomic_numbers,
1631
+ geometry.cart_coords.reshape(-1, 3),
1632
+ device,
1633
+ freeze_idx=freeze_idx,
1634
+ )
1635
+ del H
1636
+ return freqs_local, modes_local
1637
+
1638
+ freqs_cm, modes = _calc_freqs_and_modes()
1639
+ neg_mask = freqs_cm < -abs(OPT_FLATTEN_NEG_FREQ_THRESH_CM)
1640
+ n_imag = int(np.sum(neg_mask))
1641
+ ims = [float(x) for x in freqs_cm if x < -abs(OPT_FLATTEN_NEG_FREQ_THRESH_CM)]
1642
+ click.echo(f"[Imaginary modes] n={n_imag} ({ims})")
1643
+
1644
+ flatten_kind = mode_resolved # reuse same optimizer type
1645
+ for it in range(OPT_FLATTEN_MAX_ITER):
1646
+ if n_imag == 0:
1647
+ break
1648
+ click.echo(f"[flatten] iteration {it + 1}/{OPT_FLATTEN_MAX_ITER}")
1649
+ did_flatten = _flatten_all_imag_modes_for_geom(
1650
+ geometry,
1651
+ masses_amu,
1652
+ uma_kwargs_for_flatten,
1653
+ freqs_cm,
1654
+ modes,
1655
+ OPT_FLATTEN_NEG_FREQ_THRESH_CM,
1656
+ OPT_FLATTEN_AMP_ANG,
1657
+ )
1658
+ if not did_flatten:
1659
+ click.echo("[flatten] No eligible imaginary modes to flatten; stopping.")
1660
+ break
1661
+
1662
+ _attach_opt_calc()
1663
+ opt_restart = _build_optimizer(flatten_kind)
1664
+ restart_label = "LBFGS" if flatten_kind == "lbfgs" else "RFO"
1665
+ click.echo(f"\n====== Optimization ({restart_label}) restarted ======\n")
1666
+ opt_restart.run()
1667
+ click.echo(f"\n====== Optimization ({restart_label}) finished ======\n")
1668
+
1669
+ geometry.set_calculator(None)
1670
+ freqs_cm, modes = _calc_freqs_and_modes()
1671
+ neg_mask = freqs_cm < -abs(OPT_FLATTEN_NEG_FREQ_THRESH_CM)
1672
+ n_imag = int(np.sum(neg_mask))
1673
+ ims = [float(x) for x in freqs_cm if x < -abs(OPT_FLATTEN_NEG_FREQ_THRESH_CM)]
1674
+ click.echo(f"[Imaginary modes] n={n_imag} ({ims})")
1675
+
1676
+ if n_imag > 0:
1677
+ click.echo(
1678
+ f"[flatten] WARNING: Remaining imaginary modes after {OPT_FLATTEN_MAX_ITER} iterations: {n_imag}",
1679
+ err=True,
1680
+ )
1681
+ if torch.cuda.is_available():
1682
+ torch.cuda.empty_cache()
1683
+ click.echo("\n====== Optimization (Flatten loop) finished ======\n")
1684
+
1685
+ # Update final geometry after flatten
1686
+ final_xyz_path = out_dir_path / "final_geometry.xyz"
1687
+ from ase import Atoms as _Atoms
1688
+ from ase.io import write as _write
1689
+ final_coords_ang = geometry.coords.reshape(-1, 3) * BOHR2ANG
1690
+ atoms_final = _Atoms(geometry.atoms, positions=final_coords_ang, pbc=False)
1691
+ _write(final_xyz_path, atoms_final)
1692
+
1693
+ # Extract layer indices from calculator for layer-based B-factor encoding
1694
+ calc_core = base_calc.core if hasattr(base_calc, 'core') else base_calc
1695
+ ml_indices = getattr(calc_core, 'ml_indices', None)
1696
+ hess_mm_indices = getattr(calc_core, 'hess_mm_indices', None)
1697
+ movable_mm_indices = getattr(calc_core, 'movable_mm_indices', None)
1698
+ frozen_layer_indices = getattr(calc_core, 'frozen_layer_indices', None)
1699
+
1700
+ _maybe_convert_outputs_to_pdb(
1701
+ input_path=prepared_input.source_path, # Use PDB topology for conversion
1702
+ out_dir=out_dir_path,
1703
+ dump=bool(opt_cfg["dump"]),
1704
+ get_trj_fn=(lambda fn: out_dir_path / fn) if use_microiter else optimizer.get_path_for_fn,
1705
+ final_xyz_path=final_xyz_path,
1706
+ model_pdb=Path(calc_cfg["model_pdb"]),
1707
+ freeze_indices_0based=freeze_atoms_final,
1708
+ ml_indices=ml_indices,
1709
+ hess_mm_indices=hess_mm_indices,
1710
+ movable_mm_indices=movable_mm_indices,
1711
+ frozen_layer_indices=frozen_layer_indices,
1712
+ )
1713
+
1714
+ # summary.md and key_* outputs are disabled.
1715
+ click.echo(format_elapsed("[time] Elapsed Time for Opt", time_start))
1716
+
1717
+ except ZeroStepLength:
1718
+ click.echo("ERROR: Step length fell below the minimum allowed (ZeroStepLength).", err=True)
1719
+ sys.exit(2)
1720
+ except OptimizationError as e:
1721
+ click.echo(f"ERROR: Optimization failed - {e}", err=True)
1722
+ sys.exit(3)
1723
+ except KeyboardInterrupt:
1724
+ click.echo("\nInterrupted by user.", err=True)
1725
+ sys.exit(130)
1726
+ except Exception as e:
1727
+ tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))
1728
+ click.echo("Unhandled exception during optimization:\n" + textwrap.indent(tb, " "), err=True)
1729
+ sys.exit(1)
1730
+ finally:
1731
+ if prepared_input is not None:
1732
+ prepared_input.cleanup()
1733
+ # Release GPU memory so subsequent pipeline stages don't OOM
1734
+ base_calc = bias_calc = geometry = optimizer = mm_calc = macro_calc = macro_optimizer = None
1735
+ gc.collect() # break cyclic refs inside torch.nn.Module
1736
+ if torch.cuda.is_available():
1737
+ torch.cuda.empty_cache()
1738
+
1739
+
1740
+ # Allow `python -m mlmm.opt` direct execution
1741
+ if __name__ == "__main__":
1742
+ cli()