mlmm-toolkit 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (372) hide show
  1. hessian_ff/__init__.py +50 -0
  2. hessian_ff/analytical_hessian.py +609 -0
  3. hessian_ff/constants.py +46 -0
  4. hessian_ff/forcefield.py +339 -0
  5. hessian_ff/loaders.py +608 -0
  6. hessian_ff/native/Makefile +8 -0
  7. hessian_ff/native/__init__.py +28 -0
  8. hessian_ff/native/analytical_hessian.py +88 -0
  9. hessian_ff/native/analytical_hessian_ext.cpp +258 -0
  10. hessian_ff/native/bonded.py +82 -0
  11. hessian_ff/native/bonded_ext.cpp +640 -0
  12. hessian_ff/native/loader.py +349 -0
  13. hessian_ff/native/nonbonded.py +118 -0
  14. hessian_ff/native/nonbonded_ext.cpp +1150 -0
  15. hessian_ff/prmtop_parmed.py +23 -0
  16. hessian_ff/system.py +107 -0
  17. hessian_ff/terms/__init__.py +14 -0
  18. hessian_ff/terms/angle.py +73 -0
  19. hessian_ff/terms/bond.py +44 -0
  20. hessian_ff/terms/cmap.py +406 -0
  21. hessian_ff/terms/dihedral.py +141 -0
  22. hessian_ff/terms/nonbonded.py +209 -0
  23. hessian_ff/tests/__init__.py +0 -0
  24. hessian_ff/tests/conftest.py +75 -0
  25. hessian_ff/tests/data/small/complex.parm7 +1346 -0
  26. hessian_ff/tests/data/small/complex.pdb +125 -0
  27. hessian_ff/tests/data/small/complex.rst7 +63 -0
  28. hessian_ff/tests/test_coords_input.py +44 -0
  29. hessian_ff/tests/test_energy_force.py +49 -0
  30. hessian_ff/tests/test_hessian.py +137 -0
  31. hessian_ff/tests/test_smoke.py +18 -0
  32. hessian_ff/tests/test_validation.py +40 -0
  33. hessian_ff/workflows.py +889 -0
  34. mlmm/__init__.py +36 -0
  35. mlmm/__main__.py +7 -0
  36. mlmm/_version.py +34 -0
  37. mlmm/add_elem_info.py +374 -0
  38. mlmm/advanced_help.py +91 -0
  39. mlmm/align_freeze_atoms.py +601 -0
  40. mlmm/all.py +3535 -0
  41. mlmm/bond_changes.py +231 -0
  42. mlmm/bool_compat.py +223 -0
  43. mlmm/cli.py +574 -0
  44. mlmm/cli_utils.py +166 -0
  45. mlmm/default_group.py +337 -0
  46. mlmm/defaults.py +467 -0
  47. mlmm/define_layer.py +526 -0
  48. mlmm/dft.py +1041 -0
  49. mlmm/energy_diagram.py +253 -0
  50. mlmm/extract.py +2213 -0
  51. mlmm/fix_altloc.py +464 -0
  52. mlmm/freq.py +1406 -0
  53. mlmm/harmonic_constraints.py +140 -0
  54. mlmm/hessian_cache.py +44 -0
  55. mlmm/hessian_calc.py +174 -0
  56. mlmm/irc.py +638 -0
  57. mlmm/mlmm_calc.py +2262 -0
  58. mlmm/mm_parm.py +945 -0
  59. mlmm/oniom_export.py +1983 -0
  60. mlmm/oniom_import.py +457 -0
  61. mlmm/opt.py +1742 -0
  62. mlmm/path_opt.py +1353 -0
  63. mlmm/path_search.py +2299 -0
  64. mlmm/preflight.py +88 -0
  65. mlmm/py.typed +1 -0
  66. mlmm/pysis_runner.py +45 -0
  67. mlmm/scan.py +1047 -0
  68. mlmm/scan2d.py +1226 -0
  69. mlmm/scan3d.py +1265 -0
  70. mlmm/scan_common.py +184 -0
  71. mlmm/summary_log.py +736 -0
  72. mlmm/trj2fig.py +448 -0
  73. mlmm/tsopt.py +2871 -0
  74. mlmm/utils.py +2309 -0
  75. mlmm/xtb_embedcharge_correction.py +475 -0
  76. mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
  77. mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
  78. mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
  79. mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
  80. mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
  81. mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
  82. pysisyphus/Geometry.py +1667 -0
  83. pysisyphus/LICENSE +674 -0
  84. pysisyphus/TableFormatter.py +63 -0
  85. pysisyphus/TablePrinter.py +74 -0
  86. pysisyphus/__init__.py +12 -0
  87. pysisyphus/calculators/AFIR.py +452 -0
  88. pysisyphus/calculators/AnaPot.py +20 -0
  89. pysisyphus/calculators/AnaPot2.py +48 -0
  90. pysisyphus/calculators/AnaPot3.py +12 -0
  91. pysisyphus/calculators/AnaPot4.py +20 -0
  92. pysisyphus/calculators/AnaPotBase.py +337 -0
  93. pysisyphus/calculators/AnaPotCBM.py +25 -0
  94. pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
  95. pysisyphus/calculators/CFOUR.py +250 -0
  96. pysisyphus/calculators/Calculator.py +844 -0
  97. pysisyphus/calculators/CerjanMiller.py +24 -0
  98. pysisyphus/calculators/Composite.py +123 -0
  99. pysisyphus/calculators/ConicalIntersection.py +171 -0
  100. pysisyphus/calculators/DFTBp.py +430 -0
  101. pysisyphus/calculators/DFTD3.py +66 -0
  102. pysisyphus/calculators/DFTD4.py +84 -0
  103. pysisyphus/calculators/Dalton.py +61 -0
  104. pysisyphus/calculators/Dimer.py +681 -0
  105. pysisyphus/calculators/Dummy.py +20 -0
  106. pysisyphus/calculators/EGO.py +76 -0
  107. pysisyphus/calculators/EnergyMin.py +224 -0
  108. pysisyphus/calculators/ExternalPotential.py +264 -0
  109. pysisyphus/calculators/FakeASE.py +35 -0
  110. pysisyphus/calculators/FourWellAnaPot.py +28 -0
  111. pysisyphus/calculators/FreeEndNEBPot.py +39 -0
  112. pysisyphus/calculators/Gaussian09.py +18 -0
  113. pysisyphus/calculators/Gaussian16.py +726 -0
  114. pysisyphus/calculators/HardSphere.py +159 -0
  115. pysisyphus/calculators/IDPPCalculator.py +49 -0
  116. pysisyphus/calculators/IPIClient.py +133 -0
  117. pysisyphus/calculators/IPIServer.py +234 -0
  118. pysisyphus/calculators/LEPSBase.py +24 -0
  119. pysisyphus/calculators/LEPSExpr.py +139 -0
  120. pysisyphus/calculators/LennardJones.py +80 -0
  121. pysisyphus/calculators/MOPAC.py +219 -0
  122. pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
  123. pysisyphus/calculators/MultiCalc.py +85 -0
  124. pysisyphus/calculators/NFK.py +45 -0
  125. pysisyphus/calculators/OBabel.py +87 -0
  126. pysisyphus/calculators/ONIOMv2.py +1129 -0
  127. pysisyphus/calculators/ORCA.py +893 -0
  128. pysisyphus/calculators/ORCA5.py +6 -0
  129. pysisyphus/calculators/OpenMM.py +88 -0
  130. pysisyphus/calculators/OpenMolcas.py +281 -0
  131. pysisyphus/calculators/OverlapCalculator.py +908 -0
  132. pysisyphus/calculators/Psi4.py +218 -0
  133. pysisyphus/calculators/PyPsi4.py +37 -0
  134. pysisyphus/calculators/PySCF.py +341 -0
  135. pysisyphus/calculators/PyXTB.py +73 -0
  136. pysisyphus/calculators/QCEngine.py +106 -0
  137. pysisyphus/calculators/Rastrigin.py +22 -0
  138. pysisyphus/calculators/Remote.py +76 -0
  139. pysisyphus/calculators/Rosenbrock.py +15 -0
  140. pysisyphus/calculators/SocketCalc.py +97 -0
  141. pysisyphus/calculators/TIP3P.py +111 -0
  142. pysisyphus/calculators/TransTorque.py +161 -0
  143. pysisyphus/calculators/Turbomole.py +965 -0
  144. pysisyphus/calculators/VRIPot.py +37 -0
  145. pysisyphus/calculators/WFOWrapper.py +333 -0
  146. pysisyphus/calculators/WFOWrapper2.py +341 -0
  147. pysisyphus/calculators/XTB.py +418 -0
  148. pysisyphus/calculators/__init__.py +81 -0
  149. pysisyphus/calculators/cosmo_data.py +139 -0
  150. pysisyphus/calculators/parser.py +150 -0
  151. pysisyphus/color.py +19 -0
  152. pysisyphus/config.py +133 -0
  153. pysisyphus/constants.py +65 -0
  154. pysisyphus/cos/AdaptiveNEB.py +230 -0
  155. pysisyphus/cos/ChainOfStates.py +725 -0
  156. pysisyphus/cos/FreeEndNEB.py +25 -0
  157. pysisyphus/cos/FreezingString.py +103 -0
  158. pysisyphus/cos/GrowingChainOfStates.py +71 -0
  159. pysisyphus/cos/GrowingNT.py +309 -0
  160. pysisyphus/cos/GrowingString.py +508 -0
  161. pysisyphus/cos/NEB.py +189 -0
  162. pysisyphus/cos/SimpleZTS.py +64 -0
  163. pysisyphus/cos/__init__.py +22 -0
  164. pysisyphus/cos/stiffness.py +199 -0
  165. pysisyphus/drivers/__init__.py +17 -0
  166. pysisyphus/drivers/afir.py +855 -0
  167. pysisyphus/drivers/barriers.py +271 -0
  168. pysisyphus/drivers/birkholz.py +138 -0
  169. pysisyphus/drivers/cluster.py +318 -0
  170. pysisyphus/drivers/diabatization.py +133 -0
  171. pysisyphus/drivers/merge.py +368 -0
  172. pysisyphus/drivers/merge_mol2.py +322 -0
  173. pysisyphus/drivers/opt.py +375 -0
  174. pysisyphus/drivers/perf.py +91 -0
  175. pysisyphus/drivers/pka.py +52 -0
  176. pysisyphus/drivers/precon_pos_rot.py +669 -0
  177. pysisyphus/drivers/rates.py +480 -0
  178. pysisyphus/drivers/replace.py +219 -0
  179. pysisyphus/drivers/scan.py +212 -0
  180. pysisyphus/drivers/spectrum.py +166 -0
  181. pysisyphus/drivers/thermo.py +31 -0
  182. pysisyphus/dynamics/Gaussian.py +103 -0
  183. pysisyphus/dynamics/__init__.py +20 -0
  184. pysisyphus/dynamics/colvars.py +136 -0
  185. pysisyphus/dynamics/driver.py +297 -0
  186. pysisyphus/dynamics/helpers.py +256 -0
  187. pysisyphus/dynamics/lincs.py +105 -0
  188. pysisyphus/dynamics/mdp.py +364 -0
  189. pysisyphus/dynamics/rattle.py +121 -0
  190. pysisyphus/dynamics/thermostats.py +128 -0
  191. pysisyphus/dynamics/wigner.py +266 -0
  192. pysisyphus/elem_data.py +3473 -0
  193. pysisyphus/exceptions.py +2 -0
  194. pysisyphus/filtertrj.py +69 -0
  195. pysisyphus/helpers.py +623 -0
  196. pysisyphus/helpers_pure.py +649 -0
  197. pysisyphus/init_logging.py +50 -0
  198. pysisyphus/intcoords/Bend.py +69 -0
  199. pysisyphus/intcoords/Bend2.py +25 -0
  200. pysisyphus/intcoords/BondedFragment.py +32 -0
  201. pysisyphus/intcoords/Cartesian.py +41 -0
  202. pysisyphus/intcoords/CartesianCoords.py +140 -0
  203. pysisyphus/intcoords/Coords.py +56 -0
  204. pysisyphus/intcoords/DLC.py +197 -0
  205. pysisyphus/intcoords/DistanceFunction.py +34 -0
  206. pysisyphus/intcoords/DummyImproper.py +70 -0
  207. pysisyphus/intcoords/DummyTorsion.py +72 -0
  208. pysisyphus/intcoords/LinearBend.py +105 -0
  209. pysisyphus/intcoords/LinearDisplacement.py +80 -0
  210. pysisyphus/intcoords/OutOfPlane.py +59 -0
  211. pysisyphus/intcoords/PrimTypes.py +286 -0
  212. pysisyphus/intcoords/Primitive.py +137 -0
  213. pysisyphus/intcoords/RedundantCoords.py +659 -0
  214. pysisyphus/intcoords/RobustTorsion.py +59 -0
  215. pysisyphus/intcoords/Rotation.py +147 -0
  216. pysisyphus/intcoords/Stretch.py +31 -0
  217. pysisyphus/intcoords/Torsion.py +101 -0
  218. pysisyphus/intcoords/Torsion2.py +25 -0
  219. pysisyphus/intcoords/Translation.py +45 -0
  220. pysisyphus/intcoords/__init__.py +61 -0
  221. pysisyphus/intcoords/augment_bonds.py +126 -0
  222. pysisyphus/intcoords/derivatives.py +10512 -0
  223. pysisyphus/intcoords/eval.py +80 -0
  224. pysisyphus/intcoords/exceptions.py +37 -0
  225. pysisyphus/intcoords/findiffs.py +48 -0
  226. pysisyphus/intcoords/generate_derivatives.py +414 -0
  227. pysisyphus/intcoords/helpers.py +235 -0
  228. pysisyphus/intcoords/logging_conf.py +10 -0
  229. pysisyphus/intcoords/mp_derivatives.py +10836 -0
  230. pysisyphus/intcoords/setup.py +962 -0
  231. pysisyphus/intcoords/setup_fast.py +176 -0
  232. pysisyphus/intcoords/update.py +272 -0
  233. pysisyphus/intcoords/valid.py +89 -0
  234. pysisyphus/interpolate/Geodesic.py +93 -0
  235. pysisyphus/interpolate/IDPP.py +55 -0
  236. pysisyphus/interpolate/Interpolator.py +116 -0
  237. pysisyphus/interpolate/LST.py +70 -0
  238. pysisyphus/interpolate/Redund.py +152 -0
  239. pysisyphus/interpolate/__init__.py +9 -0
  240. pysisyphus/interpolate/helpers.py +34 -0
  241. pysisyphus/io/__init__.py +22 -0
  242. pysisyphus/io/aomix.py +178 -0
  243. pysisyphus/io/cjson.py +24 -0
  244. pysisyphus/io/crd.py +101 -0
  245. pysisyphus/io/cube.py +220 -0
  246. pysisyphus/io/fchk.py +184 -0
  247. pysisyphus/io/hdf5.py +49 -0
  248. pysisyphus/io/hessian.py +72 -0
  249. pysisyphus/io/mol2.py +146 -0
  250. pysisyphus/io/molden.py +293 -0
  251. pysisyphus/io/orca.py +189 -0
  252. pysisyphus/io/pdb.py +269 -0
  253. pysisyphus/io/psf.py +79 -0
  254. pysisyphus/io/pubchem.py +31 -0
  255. pysisyphus/io/qcschema.py +34 -0
  256. pysisyphus/io/sdf.py +29 -0
  257. pysisyphus/io/xyz.py +61 -0
  258. pysisyphus/io/zmat.py +175 -0
  259. pysisyphus/irc/DWI.py +108 -0
  260. pysisyphus/irc/DampedVelocityVerlet.py +134 -0
  261. pysisyphus/irc/Euler.py +22 -0
  262. pysisyphus/irc/EulerPC.py +345 -0
  263. pysisyphus/irc/GonzalezSchlegel.py +187 -0
  264. pysisyphus/irc/IMKMod.py +164 -0
  265. pysisyphus/irc/IRC.py +878 -0
  266. pysisyphus/irc/IRCDummy.py +10 -0
  267. pysisyphus/irc/Instanton.py +307 -0
  268. pysisyphus/irc/LQA.py +53 -0
  269. pysisyphus/irc/ModeKill.py +136 -0
  270. pysisyphus/irc/ParamPlot.py +53 -0
  271. pysisyphus/irc/RK4.py +36 -0
  272. pysisyphus/irc/__init__.py +31 -0
  273. pysisyphus/irc/initial_displ.py +219 -0
  274. pysisyphus/linalg.py +411 -0
  275. pysisyphus/line_searches/Backtracking.py +88 -0
  276. pysisyphus/line_searches/HagerZhang.py +184 -0
  277. pysisyphus/line_searches/LineSearch.py +232 -0
  278. pysisyphus/line_searches/StrongWolfe.py +108 -0
  279. pysisyphus/line_searches/__init__.py +9 -0
  280. pysisyphus/line_searches/interpol.py +15 -0
  281. pysisyphus/modefollow/NormalMode.py +40 -0
  282. pysisyphus/modefollow/__init__.py +10 -0
  283. pysisyphus/modefollow/davidson.py +199 -0
  284. pysisyphus/modefollow/lanczos.py +95 -0
  285. pysisyphus/optimizers/BFGS.py +99 -0
  286. pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
  287. pysisyphus/optimizers/ConjugateGradient.py +98 -0
  288. pysisyphus/optimizers/CubicNewton.py +75 -0
  289. pysisyphus/optimizers/FIRE.py +113 -0
  290. pysisyphus/optimizers/HessianOptimizer.py +1176 -0
  291. pysisyphus/optimizers/LBFGS.py +228 -0
  292. pysisyphus/optimizers/LayerOpt.py +411 -0
  293. pysisyphus/optimizers/MicroOptimizer.py +169 -0
  294. pysisyphus/optimizers/NCOptimizer.py +90 -0
  295. pysisyphus/optimizers/Optimizer.py +1084 -0
  296. pysisyphus/optimizers/PreconLBFGS.py +260 -0
  297. pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
  298. pysisyphus/optimizers/QuickMin.py +74 -0
  299. pysisyphus/optimizers/RFOptimizer.py +181 -0
  300. pysisyphus/optimizers/RSA.py +99 -0
  301. pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
  302. pysisyphus/optimizers/SteepestDescent.py +23 -0
  303. pysisyphus/optimizers/StringOptimizer.py +173 -0
  304. pysisyphus/optimizers/__init__.py +41 -0
  305. pysisyphus/optimizers/closures.py +301 -0
  306. pysisyphus/optimizers/cls_map.py +58 -0
  307. pysisyphus/optimizers/exceptions.py +6 -0
  308. pysisyphus/optimizers/gdiis.py +280 -0
  309. pysisyphus/optimizers/guess_hessians.py +311 -0
  310. pysisyphus/optimizers/hessian_updates.py +355 -0
  311. pysisyphus/optimizers/poly_fit.py +285 -0
  312. pysisyphus/optimizers/precon.py +153 -0
  313. pysisyphus/optimizers/restrict_step.py +24 -0
  314. pysisyphus/pack.py +172 -0
  315. pysisyphus/peakdetect.py +948 -0
  316. pysisyphus/plot.py +1031 -0
  317. pysisyphus/run.py +2106 -0
  318. pysisyphus/socket_helper.py +74 -0
  319. pysisyphus/stocastic/FragmentKick.py +132 -0
  320. pysisyphus/stocastic/Kick.py +81 -0
  321. pysisyphus/stocastic/Pipeline.py +303 -0
  322. pysisyphus/stocastic/__init__.py +21 -0
  323. pysisyphus/stocastic/align.py +127 -0
  324. pysisyphus/testing.py +96 -0
  325. pysisyphus/thermo.py +156 -0
  326. pysisyphus/trj.py +824 -0
  327. pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
  328. pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
  329. pysisyphus/tsoptimizers/TRIM.py +59 -0
  330. pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
  331. pysisyphus/tsoptimizers/__init__.py +23 -0
  332. pysisyphus/wavefunction/Basis.py +239 -0
  333. pysisyphus/wavefunction/DIIS.py +76 -0
  334. pysisyphus/wavefunction/__init__.py +25 -0
  335. pysisyphus/wavefunction/build_ext.py +42 -0
  336. pysisyphus/wavefunction/cart2sph.py +190 -0
  337. pysisyphus/wavefunction/diabatization.py +304 -0
  338. pysisyphus/wavefunction/excited_states.py +435 -0
  339. pysisyphus/wavefunction/gen_ints.py +1811 -0
  340. pysisyphus/wavefunction/helpers.py +104 -0
  341. pysisyphus/wavefunction/ints/__init__.py +0 -0
  342. pysisyphus/wavefunction/ints/boys.py +193 -0
  343. pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
  344. pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
  345. pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
  346. pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
  347. pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
  348. pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
  349. pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
  350. pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
  351. pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
  352. pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
  353. pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
  354. pysisyphus/wavefunction/localization.py +458 -0
  355. pysisyphus/wavefunction/multipole.py +159 -0
  356. pysisyphus/wavefunction/normalization.py +36 -0
  357. pysisyphus/wavefunction/pop_analysis.py +134 -0
  358. pysisyphus/wavefunction/shells.py +1171 -0
  359. pysisyphus/wavefunction/wavefunction.py +504 -0
  360. pysisyphus/wrapper/__init__.py +11 -0
  361. pysisyphus/wrapper/exceptions.py +2 -0
  362. pysisyphus/wrapper/jmol.py +120 -0
  363. pysisyphus/wrapper/mwfn.py +169 -0
  364. pysisyphus/wrapper/packmol.py +71 -0
  365. pysisyphus/xyzloader.py +168 -0
  366. pysisyphus/yaml_mods.py +45 -0
  367. thermoanalysis/LICENSE +674 -0
  368. thermoanalysis/QCData.py +244 -0
  369. thermoanalysis/__init__.py +0 -0
  370. thermoanalysis/config.py +3 -0
  371. thermoanalysis/constants.py +20 -0
  372. thermoanalysis/thermo.py +1011 -0
mlmm/tsopt.py ADDED
@@ -0,0 +1,2871 @@
1
+ # mlmm/tsopt.py
2
+
3
+ """Partial Hessian guided Dimer / RS-I-RFO transition-state search with ML/MM.
4
+
5
+ Example:
6
+ mlmm tsopt -i ts_guess.pdb --parm real.parm7 --model-pdb ml_region.pdb \
7
+ -q 0 -m 1 --max-cycles 8000
8
+
9
+ For detailed documentation, see: docs/tsopt.md
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import contextlib
15
+ import gc
16
+ import io
17
+ import logging
18
+ import sys
19
+ import textwrap
20
+ from copy import deepcopy
21
+
22
+ logger = logging.getLogger(__name__)
23
+ from pathlib import Path
24
+ from typing import Dict, Any, Optional, Tuple, List
25
+
26
+ import click
27
+ import numpy as np
28
+ import torch
29
+ from ase import Atoms
30
+ from ase.io import write
31
+ from ase.data import atomic_masses
32
+ import ase.units as units
33
+ import time
34
+
35
+ # ---------------- pysisyphus / mlmm imports ----------------
36
+ from pysisyphus.helpers import geom_loader
37
+ from pysisyphus.optimizers.LBFGS import LBFGS
38
+ from pysisyphus.optimizers.exceptions import OptimizationError, ZeroStepLength
39
+ from pysisyphus.constants import BOHR2ANG, ANG2BOHR, AMU2AU, AU2EV
40
+ from pysisyphus.calculators.Dimer import Dimer # Dimer calculator (orientation-projected forces)
41
+
42
+ # RS-I-RFO optimizer for heavy mode
43
+ from pysisyphus.tsoptimizers.RSIRFOptimizer import RSIRFOptimizer
44
+ from pysisyphus.TablePrinter import TablePrinter
45
+
46
+ # local helpers from mlmm
47
+ from .mlmm_calc import mlmm, mlmm_mm_only
48
+ from .defaults import OUT_DIR_TSOPT
49
+ from .defaults import (
50
+ GEOM_KW_DEFAULT,
51
+ MLMM_CALC_KW,
52
+ OPT_BASE_KW,
53
+ LBFGS_KW,
54
+ DIMER_KW,
55
+ HESSIAN_DIMER_KW,
56
+ RSIRFO_KW,
57
+ MICROITER_KW,
58
+ TSOPT_MODE_ALIASES,
59
+ BFACTOR_ML,
60
+ BFACTOR_MOVABLE_MM,
61
+ BFACTOR_FROZEN,
62
+ )
63
+ from .opt import (
64
+ _parse_freeze_atoms as _parse_freeze_atoms_opt,
65
+ _normalize_geom_freeze as _normalize_geom_freeze_opt,
66
+ )
67
+ from .utils import (
68
+ append_xyz_trajectory as _append_xyz_trajectory,
69
+ apply_layer_freeze_constraints,
70
+ convert_xyz_to_pdb,
71
+ set_convert_file_enabled,
72
+ is_convert_file_enabled,
73
+ convert_xyz_like_outputs,
74
+ deep_update,
75
+ load_yaml_dict,
76
+ apply_yaml_overrides,
77
+ pretty_block,
78
+ strip_inherited_keys,
79
+ filter_calc_for_echo,
80
+ format_freeze_atoms_for_echo,
81
+ format_elapsed,
82
+ merge_freeze_atom_indices,
83
+ prepare_input_structure,
84
+ apply_ref_pdb_override,
85
+ resolve_charge_spin_or_raise,
86
+ parse_indices_string,
87
+ build_model_pdb_from_bfactors,
88
+ build_model_pdb_from_indices,
89
+ update_pdb_bfactors_from_layers,
90
+ normalize_choice,
91
+ yaml_section_has_key,
92
+ )
93
+ from .cli_utils import resolve_yaml_sources, load_merged_yaml_cfg, make_is_param_explicit
94
+ from .freq import (
95
+ _calc_full_hessian_torch as _freq_calc_full_hessian_torch,
96
+ _torch_device,
97
+ _build_tr_basis,
98
+ _tr_orthonormal_basis,
99
+ _mass_weighted_hessian,
100
+ _align_three_layer_hessian_targets,
101
+ _resolve_active_atom_indices,
102
+ )
103
+
104
+
105
+ # ===================================================================
106
+ # Mass-weighted projection & vib analysis
107
+ # ===================================================================
108
+
109
+
110
+ def _calc_full_hessian_torch(geom, calc_kwargs: Dict[str, Any], device: torch.device) -> torch.Tensor:
111
+ """
112
+ Shared Hessian backend from freq.py; keeps tsopt metadata refresh behavior.
113
+ """
114
+ H, _ = _freq_calc_full_hessian_torch(
115
+ geom,
116
+ calc_kwargs,
117
+ device,
118
+ refresh_geom_meta=True,
119
+ )
120
+ return H
121
+
122
+
123
+ def _calc_energy(geom, calc_kwargs: Dict[str, Any], calc=None) -> float:
124
+ owns_calc = calc is None
125
+ if owns_calc:
126
+ kw = dict(calc_kwargs or {})
127
+ kw["out_hess_torch"] = False
128
+ calc = mlmm(**kw)
129
+ result = calc.get_energy(geom.atoms, geom.coords)
130
+ energy = float(result.get("energy", 0.0))
131
+ del result
132
+ if owns_calc:
133
+ del calc
134
+ _clear_cuda_cache()
135
+ return energy
136
+
137
+
138
+ def _omega2_to_freqs_cm(omega2: torch.Tensor) -> np.ndarray:
139
+ """Convert eigenvalues (omega^2) to vibrational frequencies in cm^-1."""
140
+ s_new = (units._hbar * 1e10 / np.sqrt(units._e * units._amu) * np.sqrt(AU2EV) / BOHR2ANG)
141
+ hnu = s_new * torch.sqrt(torch.abs(omega2))
142
+ hnu = torch.where(omega2 < 0, -hnu, hnu)
143
+ return (hnu / units.invcm).detach().cpu().numpy()
144
+
145
+
146
+ def _clear_cuda_cache(tensor: Optional[torch.Tensor] = None) -> None:
147
+ """Clear CUDA cache if available and tensor (if provided) is on CUDA."""
148
+ if torch.cuda.is_available():
149
+ if tensor is None or tensor.is_cuda:
150
+ torch.cuda.empty_cache()
151
+
152
+
153
+
154
+
155
+ def _mw_projected_hessian_inplace(H_t: torch.Tensor,
156
+ coords_bohr_t: torch.Tensor,
157
+ masses_au_t: torch.Tensor,
158
+ freeze_idx: Optional[List[int]] = None) -> torch.Tensor:
159
+ """
160
+ Mass-weight H in-place, optionally restrict to active DOF subspace (PHVA) and
161
+ project out TR motions (in that subspace), also in-place.
162
+ Returns the (possibly reduced) Hessian to be diagonalized.
163
+ """
164
+ dtype, device = H_t.dtype, H_t.device
165
+ with torch.no_grad():
166
+ N = coords_bohr_t.shape[0]
167
+ if freeze_idx:
168
+ frozen = set(int(i) for i in freeze_idx if 0 <= int(i) < N)
169
+ active_idx = [i for i in range(N) if i not in frozen]
170
+ if len(active_idx) == 0:
171
+ raise RuntimeError("All atoms are frozen; no active DOF left for TR projection.")
172
+ # mass-weight first
173
+ H_t = _mass_weighted_hessian(H_t, masses_au_t)
174
+ # take active DOF submatrix
175
+ mask_dof = torch.ones(3 * N, dtype=torch.bool, device=device)
176
+ for i in frozen:
177
+ mask_dof[3 * i:3 * i + 3] = False
178
+ H_t = H_t[mask_dof][:, mask_dof]
179
+ # TR basis and projection in active subspace (in-place)
180
+ coords_act = coords_bohr_t[active_idx, :]
181
+ masses_act = masses_au_t[active_idx]
182
+ Q, _ = _tr_orthonormal_basis(coords_act, masses_act) # (3N_act, r)
183
+ Qt = Q.T
184
+ QtH = Qt @ H_t
185
+ H_t.addmm_(Q, QtH, beta=1.0, alpha=-1.0)
186
+ H_t.addmm_((QtH.T), Qt, beta=1.0, alpha=-1.0)
187
+ QtHQ = QtH @ Q
188
+ H_t.addmm_(Q @ QtHQ, Qt, beta=1.0, alpha=1.0)
189
+ del Q, Qt, QtH, QtHQ, mask_dof, coords_act, masses_act, active_idx, frozen
190
+ else:
191
+ # Full DOF: mass-weight + TR projection in-place
192
+ H_t = _mass_weighted_hessian(H_t, masses_au_t)
193
+ Q, _ = _tr_orthonormal_basis(coords_bohr_t, masses_au_t) # (3N, r)
194
+ Qt = Q.T
195
+ QtH = Qt @ H_t
196
+ H_t.addmm_(Q, QtH, beta=1.0, alpha=-1.0)
197
+ H_t.addmm_(QtH.T, Qt, beta=1.0, alpha=-1.0)
198
+ QtHQ = QtH @ Q
199
+ H_t.addmm_(Q @ QtHQ, Qt, beta=1.0, alpha=1.0)
200
+ del Q, Qt, QtH, QtHQ
201
+ _clear_cuda_cache()
202
+ return H_t
203
+
204
+
205
+ def _mode_direction_by_root(H_t: torch.Tensor,
206
+ coords_bohr_t: torch.Tensor,
207
+ masses_au_t: torch.Tensor,
208
+ root: int = 0,
209
+ freeze_idx: Optional[List[int]] = None) -> np.ndarray:
210
+ """
211
+ Get the eigenvector (Cartesian space) corresponding to the `root`-th most negative
212
+ eigenvalue (root=0: most negative) of the mass-weighted, TR-projected Hessian.
213
+ PHVA (active-subspace) is applied if freeze_idx is provided: frozen DOFs are zero.
214
+ root==0 prefers torch.lobpcg; fallback to eigh (UPLO='U').
215
+ """
216
+ with torch.no_grad():
217
+ # In-place: mass weight + (active-subspace) TR projection
218
+ Hmw_proj = _mw_projected_hessian_inplace(H_t, coords_bohr_t, masses_au_t, freeze_idx=freeze_idx)
219
+
220
+ # Explicit symmetrization before eigendecomposition
221
+ _t = Hmw_proj.T.clone()
222
+ Hmw_proj.add_(_t).mul_(0.5)
223
+ del _t
224
+
225
+ # Solve eigenproblem in the (possibly reduced) space
226
+ if int(root) == 0:
227
+ try:
228
+ w, v_mw_sub = torch.lobpcg(Hmw_proj, k=1, largest=False)
229
+ u_mw_sub = v_mw_sub[:, 0]
230
+ except Exception:
231
+ evals_f, evecs_f = torch.linalg.eigh(Hmw_proj, UPLO="U")
232
+ u_mw_sub = evecs_f[:, torch.argmin(evals_f)]
233
+ del evals_f, evecs_f
234
+ else:
235
+ evals, evecs_mw = torch.linalg.eigh(Hmw_proj, UPLO="U") # ascending
236
+ neg = (evals < 0)
237
+ neg_inds = torch.nonzero(neg, as_tuple=False).view(-1)
238
+ if neg_inds.numel() == 0:
239
+ pick = int(torch.argmin(evals).item())
240
+ else:
241
+ k = max(0, min(int(root), neg_inds.numel() - 1))
242
+ pick = int(neg_inds[k].item())
243
+ u_mw_sub = evecs_mw[:, pick]
244
+ del evals, evecs_mw
245
+
246
+ # Embed back to full 3N (frozen DOF as zeros) if we solved in subspace
247
+ N = coords_bohr_t.shape[0]
248
+ if freeze_idx:
249
+ frozen = set(int(i) for i in freeze_idx if 0 <= int(i) < N)
250
+ mask_dof = torch.ones(3 * N, dtype=torch.bool, device=Hmw_proj.device)
251
+ for i in frozen:
252
+ mask_dof[3 * i:3 * i + 3] = False
253
+ u_mw_full = torch.zeros(3 * N, dtype=Hmw_proj.dtype, device=Hmw_proj.device)
254
+ u_mw_full[mask_dof] = u_mw_sub
255
+ u_mw = u_mw_full
256
+ del mask_dof, frozen
257
+ else:
258
+ u_mw = u_mw_sub
259
+
260
+ # Convert mass-weighted → Cartesian & normalize
261
+ masses_amu_t = (masses_au_t / AMU2AU).to(dtype=Hmw_proj.dtype, device=Hmw_proj.device)
262
+ m3 = torch.repeat_interleave(masses_amu_t, 3)
263
+ inv_sqrt_m = torch.sqrt(1.0 / m3)
264
+ v = inv_sqrt_m * u_mw
265
+ v = v / torch.linalg.norm(v)
266
+ mode = v.reshape(-1, 3).detach().cpu().numpy()
267
+
268
+ del masses_amu_t, m3, inv_sqrt_m, v, u_mw, u_mw_sub
269
+ _clear_cuda_cache()
270
+ return mode
271
+
272
+
273
+ def _calc_gradient(geom, calc_kwargs: Dict[str, Any]) -> np.ndarray:
274
+ """
275
+ Return true Cartesian gradient (shape 3N,) in Hartree/Bohr.
276
+ """
277
+ kw = dict(calc_kwargs or {})
278
+ kw["out_hess_torch"] = False
279
+ calc = mlmm(**kw)
280
+ geom.set_calculator(calc)
281
+ g = np.array(geom.gradient, dtype=float).reshape(-1)
282
+ geom.set_calculator(None)
283
+ del calc
284
+ _clear_cuda_cache()
285
+ return g
286
+
287
+
288
+ def _frequencies_cm_and_modes(H_t: torch.Tensor,
289
+ atomic_numbers: List[int],
290
+ coords_bohr: np.ndarray,
291
+ device: torch.device,
292
+ tol: float = 1e-6,
293
+ freeze_idx: Optional[List[int]] = None) -> Tuple[np.ndarray, torch.Tensor]:
294
+ """
295
+ In-place PHVA/TR projection (active-subspace if freeze_idx) and diagonalization.
296
+ Returns:
297
+ freqs_cm : (nmode,) numpy (negatives are imaginary)
298
+ modes : (nmode, 3N) torch (mass-weighted eigenvectors embedded to full 3N)
299
+ """
300
+ with torch.no_grad():
301
+ Z = np.array(atomic_numbers, dtype=int)
302
+ N = int(len(Z))
303
+ masses_amu = np.array([atomic_masses[z] for z in Z]) # amu
304
+ masses_au_t = torch.as_tensor(masses_amu * AMU2AU, dtype=H_t.dtype, device=device)
305
+ coords_bohr_t = torch.as_tensor(coords_bohr.reshape(-1, 3), dtype=H_t.dtype, device=device)
306
+
307
+ # in-place mass-weight + (active-subspace) TR projection
308
+ Hmw = _mw_projected_hessian_inplace(H_t, coords_bohr_t, masses_au_t, freeze_idx=freeze_idx)
309
+
310
+ # Explicit symmetrization before eigendecomposition
311
+ _t = Hmw.T.clone()
312
+ Hmw.add_(_t).mul_(0.5)
313
+ del _t
314
+ omega2, Vsub = torch.linalg.eigh(Hmw, UPLO="U")
315
+
316
+ sel = torch.abs(omega2) > tol
317
+ omega2 = omega2[sel]
318
+ Vsub = Vsub[:, sel] # (3N_act or 3N, nsel)
319
+
320
+ # embed modes to full 3N
321
+ if freeze_idx:
322
+ frozen = set(int(i) for i in freeze_idx if 0 <= int(i) < N)
323
+ mask_dof = torch.ones(3 * N, dtype=torch.bool, device=Hmw.device)
324
+ for i in frozen:
325
+ mask_dof[3 * i:3 * i + 3] = False
326
+ modes = torch.zeros((Vsub.shape[1], 3 * N), dtype=Hmw.dtype, device=Hmw.device)
327
+ modes[:, mask_dof] = Vsub.T
328
+ del mask_dof, frozen
329
+ else:
330
+ modes = Vsub.T # (nsel, 3N)
331
+
332
+ # convert to cm^-1
333
+ freqs_cm = _omega2_to_freqs_cm(omega2)
334
+
335
+ del omega2, Vsub, sel, masses_amu, masses_au_t, coords_bohr_t, Hmw
336
+ _clear_cuda_cache(H_t)
337
+ return freqs_cm, modes
338
+
339
+
340
+ def _write_mode_trj_and_pdb(geom,
341
+ mode_vec_3N: np.ndarray,
342
+ out_trj: Path,
343
+ out_pdb: Path,
344
+ amplitude_ang: float = 0.25,
345
+ n_frames: int = 20,
346
+ comment: str = "imag mode",
347
+ ref_pdb: Optional[Path] = None) -> None:
348
+ """
349
+ Write a single imaginary mode animation both as _trj.xyz (XYZ-like) and .pdb.
350
+
351
+ If `ref_pdb` is provided and is a .pdb file, the .pdb is generated by
352
+ converting the _trj.xyz using the input PDB as the template.
353
+ """
354
+ ref_ang = geom.coords.reshape(-1, 3) * BOHR2ANG
355
+ mode = mode_vec_3N.reshape(-1, 3).copy()
356
+ mode /= np.linalg.norm(mode)
357
+
358
+ # _trj.xyz (XYZ-like concatenation) — always write
359
+ with out_trj.open("w", encoding="utf-8") as f:
360
+ for i in range(n_frames):
361
+ phase = np.sin(2.0 * np.pi * i / n_frames)
362
+ coords = ref_ang + phase * amplitude_ang * mode
363
+ f.write(f"{len(geom.atoms)}\n{comment} frame={i+1}/{n_frames}\n")
364
+ for sym, (x, y, z) in zip(geom.atoms, coords):
365
+ f.write(f"{sym:2s} {x: .8f} {y: .8f} {z: .8f}\n")
366
+
367
+ # .pdb — use ref_pdb template when available
368
+ if ref_pdb is not None and ref_pdb.suffix.lower() == ".pdb" and is_convert_file_enabled():
369
+ try:
370
+ convert_xyz_to_pdb(out_trj, ref_pdb, out_pdb)
371
+ return
372
+ except Exception:
373
+ pass # fall through to ASE fallback
374
+
375
+ # Fallback: MODEL/ENDMDL via ASE (no topology)
376
+ atoms0 = Atoms(geom.atoms, positions=ref_ang, pbc=False)
377
+ for i in range(n_frames):
378
+ phase = np.sin(2.0 * np.pi * i / n_frames)
379
+ ai = atoms0.copy()
380
+ ai.set_positions(ref_ang + phase * amplitude_ang * mode)
381
+ write(out_pdb, ai, append=(i != 0))
382
+
383
+
384
+ def _write_all_imag_modes(
385
+ geom,
386
+ freqs_cm: np.ndarray,
387
+ modes: torch.Tensor,
388
+ neg_freq_thresh_cm: float,
389
+ vib_dir: Path,
390
+ *,
391
+ ref_pdb: Optional[Path] = None,
392
+ filename_prefix: str = "final_imag_mode",
393
+ amplitude_ang: float = 0.8,
394
+ n_frames: int = 20,
395
+ ) -> int:
396
+ """
397
+ Write all imaginary modes (freq < -|threshold|) to vib_dir.
398
+
399
+ Returns:
400
+ Number of mode trajectories written.
401
+ """
402
+ neg_idx = np.where(freqs_cm < -abs(neg_freq_thresh_cm))[0]
403
+ if len(neg_idx) == 0:
404
+ return 0
405
+
406
+ masses_amu = np.array([atomic_masses[int(z)] for z in geom.atomic_numbers], dtype=float)
407
+ sqrt_m3 = np.sqrt(np.repeat(masses_amu, 3))
408
+ order = np.argsort(freqs_cm[neg_idx]) # most negative first
409
+ written = 0
410
+
411
+ for rank, rel_i in enumerate(order, start=1):
412
+ mode_idx = int(neg_idx[int(rel_i)])
413
+ freq = float(freqs_cm[mode_idx])
414
+ mode_mw = modes[mode_idx].detach().cpu().numpy().reshape(-1)
415
+ v_cart = mode_mw / sqrt_m3
416
+ norm = float(np.linalg.norm(v_cart))
417
+ if norm <= 0.0:
418
+ del mode_mw, v_cart
419
+ continue
420
+ v_cart = v_cart / norm
421
+
422
+ stem = f"{filename_prefix}_{rank:02d}_mode{mode_idx:04d}_{freq:+.2f}cm-1"
423
+ out_trj = vib_dir / f"{stem}_trj.xyz"
424
+ out_pdb = vib_dir / f"{stem}.pdb"
425
+ _write_mode_trj_and_pdb(
426
+ geom,
427
+ v_cart,
428
+ out_trj,
429
+ out_pdb,
430
+ amplitude_ang=amplitude_ang,
431
+ n_frames=n_frames,
432
+ comment=f"imag#{rank} mode={mode_idx} {freq:+.2f} cm^-1",
433
+ ref_pdb=ref_pdb,
434
+ )
435
+ del mode_mw, v_cart
436
+ written += 1
437
+
438
+ del masses_amu, sqrt_m3, order, neg_idx
439
+ _clear_cuda_cache()
440
+ return written
441
+
442
+
443
+ # ===================================================================
444
+ # Active-subspace helpers & Bofill update
445
+ # ===================================================================
446
+
447
+ def _active_indices(N: int, freeze_idx: Optional[List[int]]) -> List[int]:
448
+ if not freeze_idx:
449
+ return list(range(N))
450
+ fz = set(int(i) for i in freeze_idx if 0 <= int(i) < N)
451
+ return [i for i in range(N) if i not in fz]
452
+
453
+
454
+ def _active_mask_dof(N: int, freeze_idx: Optional[List[int]]) -> np.ndarray:
455
+ mask = np.ones(3 * N, dtype=bool)
456
+ if freeze_idx:
457
+ for i in freeze_idx:
458
+ if 0 <= int(i) < N:
459
+ mask[3 * int(i):3 * int(i) + 3] = False
460
+ return mask
461
+
462
+
463
+ def _mask_dof_from_active_idx(N: int, active_idx: List[int]) -> np.ndarray:
464
+ mask = np.zeros(3 * N, dtype=bool)
465
+ for i in active_idx:
466
+ j = int(i)
467
+ if 0 <= j < N:
468
+ mask[3 * j:3 * j + 3] = True
469
+ return mask
470
+
471
+
472
+ def _extract_active_block(H_full: torch.Tensor, mask_dof: np.ndarray) -> torch.Tensor:
473
+ """
474
+ Return the active-DOF block as a torch.Tensor sharing device/dtype.
475
+ """
476
+ device = H_full.device
477
+ m = torch.as_tensor(mask_dof, device=device, dtype=torch.bool)
478
+ return H_full[m][:, m].clone()
479
+
480
+
481
+ def _mw_tr_project_active_inplace(H_act: torch.Tensor,
482
+ coords_act_t: torch.Tensor,
483
+ masses_act_au_t: torch.Tensor) -> torch.Tensor:
484
+ """
485
+ Mass-weight & project TR in the *active* subspace (in-place).
486
+ """
487
+ with torch.no_grad():
488
+ # mass-weight
489
+ masses_amu_t = (masses_act_au_t / AMU2AU).to(dtype=H_act.dtype, device=H_act.device)
490
+ m3 = torch.repeat_interleave(masses_amu_t, 3)
491
+ inv_sqrt_m_col = torch.sqrt(1.0 / m3).view(1, -1)
492
+ inv_sqrt_m_row = inv_sqrt_m_col.view(-1, 1)
493
+ H_act.mul_(inv_sqrt_m_row)
494
+ H_act.mul_(inv_sqrt_m_col)
495
+ # TR basis & projection
496
+ Q, _ = _tr_orthonormal_basis(coords_act_t, masses_act_au_t) # (3N_act, r)
497
+ Qt = Q.T
498
+ QtH = Qt @ H_act
499
+ H_act.addmm_(Q, QtH, beta=1.0, alpha=-1.0)
500
+ H_act.addmm_(QtH.T, Qt, beta=1.0, alpha=-1.0)
501
+ QtHQ = QtH @ Q
502
+ H_act.addmm_(Q @ QtHQ, Qt, beta=1.0, alpha=1.0)
503
+ del masses_amu_t, m3, inv_sqrt_m_col, inv_sqrt_m_row, Q, Qt, QtH, QtHQ
504
+ return H_act
505
+
506
+
507
+ def _frequencies_from_Hact(H_act: torch.Tensor,
508
+ atomic_numbers: List[int],
509
+ coords_bohr: np.ndarray,
510
+ active_idx: List[int],
511
+ device: torch.device,
512
+ tol: float = 1e-6) -> np.ndarray:
513
+ """
514
+ Frequencies (cm^-1) computed from active-block Hessian with active-space TR projection.
515
+ """
516
+ with torch.no_grad():
517
+ coords_act = torch.as_tensor(coords_bohr.reshape(-1, 3)[active_idx, :], dtype=H_act.dtype, device=device)
518
+ masses_act_au = torch.as_tensor([atomic_masses[int(z)] * AMU2AU
519
+ for z in np.array(atomic_numbers, int)[active_idx]],
520
+ dtype=H_act.dtype, device=device)
521
+ Hmw = H_act.clone()
522
+ _mw_tr_project_active_inplace(Hmw, coords_act, masses_act_au)
523
+ # Explicit symmetrization before eigendecomposition
524
+ _t = Hmw.T.clone()
525
+ Hmw.add_(_t).mul_(0.5)
526
+ del _t
527
+ omega2 = torch.linalg.eigvalsh(Hmw, UPLO="U")
528
+ sel = torch.abs(omega2) > tol
529
+ omega2 = omega2[sel]
530
+ freqs_cm = _omega2_to_freqs_cm(omega2)
531
+ del coords_act, masses_act_au, Hmw, omega2, sel
532
+ _clear_cuda_cache(H_act)
533
+ return freqs_cm
534
+
535
+
536
+ def _modes_from_Hact_embedded(H_act: torch.Tensor,
537
+ atomic_numbers: List[int],
538
+ coords_bohr: np.ndarray,
539
+ active_idx: List[int],
540
+ device: torch.device,
541
+ tol: float = 1e-6) -> Tuple[np.ndarray, torch.Tensor]:
542
+ """
543
+ Diagonalize active-block Hessian with mass-weight/TR in active space and return:
544
+ freqs_cm : (nmode,)
545
+ modes : (nmode, 3N) mass-weighted eigenvectors embedded to full 3N (torch)
546
+ """
547
+ with torch.no_grad():
548
+ N = len(atomic_numbers)
549
+ coords_act = torch.as_tensor(coords_bohr.reshape(-1, 3)[active_idx, :], dtype=H_act.dtype, device=device)
550
+ masses_act_au = torch.as_tensor([atomic_masses[int(z)] * AMU2AU
551
+ for z in np.array(atomic_numbers, int)[active_idx]],
552
+ dtype=H_act.dtype, device=device)
553
+ Hmw = H_act.clone()
554
+ _mw_tr_project_active_inplace(Hmw, coords_act, masses_act_au)
555
+ # Explicit symmetrization before eigendecomposition
556
+ _t = Hmw.T.clone()
557
+ Hmw.add_(_t).mul_(0.5)
558
+ del _t
559
+ omega2, Vsub = torch.linalg.eigh(Hmw, UPLO="U")
560
+ sel = torch.abs(omega2) > tol
561
+ omega2 = omega2[sel]
562
+ Vsub = Vsub[:, sel] # (3N_act, nsel)
563
+
564
+ # Embed to full 3N (mass-weighted eigenvectors)
565
+ modes_full = torch.zeros((Vsub.shape[1], 3 * N), dtype=Hmw.dtype, device=device)
566
+ mask_dof = _active_mask_dof(N, list(set(range(N)) - set(active_idx))) # give frozen list
567
+ mask_t = torch.as_tensor(mask_dof, dtype=torch.bool, device=device)
568
+ modes_full[:, mask_t] = Vsub.T
569
+ # frequencies
570
+ freqs_cm = _omega2_to_freqs_cm(omega2)
571
+
572
+ del coords_act, masses_act_au, Hmw, omega2, Vsub, mask_t
573
+ _clear_cuda_cache(H_act)
574
+ return freqs_cm, modes_full
575
+
576
+
577
+ def _mode_direction_by_root_from_Hact(H_act: torch.Tensor,
578
+ coords_bohr: np.ndarray,
579
+ atomic_numbers: List[int],
580
+ masses_au_t: torch.Tensor,
581
+ active_idx: List[int],
582
+ device: torch.device,
583
+ root: int = 0) -> np.ndarray:
584
+ """
585
+ TS direction from the *active* Hessian block. Mass-weighting/TR are done in the
586
+ active space. Result is embedded back to full 3N in Cartesian space.
587
+ """
588
+ with torch.no_grad():
589
+ N = len(atomic_numbers)
590
+ coords_act = torch.as_tensor(coords_bohr.reshape(-1, 3)[active_idx, :], dtype=H_act.dtype, device=device)
591
+ masses_act_au = masses_au_t[active_idx].to(device=device, dtype=H_act.dtype)
592
+ # mass-weight + TR in active space
593
+ Hmw = H_act.clone()
594
+ _mw_tr_project_active_inplace(Hmw, coords_act, masses_act_au)
595
+ # Explicit symmetrization before eigendecomposition
596
+ _t = Hmw.T.clone()
597
+ Hmw.add_(_t).mul_(0.5)
598
+ del _t
599
+
600
+ # eigenvector for requested root
601
+ if int(root) == 0:
602
+ try:
603
+ w, V = torch.lobpcg(Hmw, k=1, largest=False)
604
+ u_mw = V[:, 0]
605
+ except Exception:
606
+ vals, vecs = torch.linalg.eigh(Hmw, UPLO="U")
607
+ u_mw = vecs[:, torch.argmin(vals)]
608
+ del vals, vecs
609
+ else:
610
+ vals, vecs = torch.linalg.eigh(Hmw, UPLO="U")
611
+ neg = (vals < 0)
612
+ neg_inds = torch.nonzero(neg, as_tuple=False).view(-1)
613
+ if neg_inds.numel() == 0:
614
+ pick = int(torch.argmin(vals).item())
615
+ else:
616
+ k = max(0, min(int(root), neg_inds.numel() - 1))
617
+ pick = int(neg_inds[k].item())
618
+ u_mw = vecs[:, pick]
619
+ del vals, vecs
620
+
621
+ # Mass un-weight to Cartesian in the active space, then embed to full 3N
622
+ masses_act_amu = (masses_act_au / AMU2AU).to(dtype=H_act.dtype, device=device)
623
+ m3 = torch.repeat_interleave(masses_act_amu, 3)
624
+ v_cart_act = u_mw / torch.sqrt(m3)
625
+ v_cart_act = v_cart_act / torch.linalg.norm(v_cart_act)
626
+
627
+ full = torch.zeros(3 * N, dtype=H_act.dtype, device=device)
628
+ mask_dof = _active_mask_dof(N, list(set(range(N)) - set(active_idx)))
629
+ mask_t = torch.as_tensor(mask_dof, dtype=torch.bool, device=device)
630
+ full[mask_t] = v_cart_act
631
+ mode = full.reshape(-1, 3).detach().cpu().numpy()
632
+
633
+ del coords_act, masses_act_au, masses_act_amu, m3, v_cart_act, full, mask_t, Hmw, u_mw
634
+ _clear_cuda_cache(H_act)
635
+ return mode
636
+
637
+
638
+ def _representative_atoms_for_mode(mode: torch.Tensor, flatten_k: int) -> np.ndarray:
639
+ """
640
+ Return indices of the top-k atoms with largest displacement norm in mode.
641
+ """
642
+ vec = mode.reshape(-1, 3)
643
+ norms = torch.linalg.norm(vec, dim=1)
644
+ k = min(int(flatten_k), vec.shape[0])
645
+ if k <= 0:
646
+ return np.zeros(0, dtype=int)
647
+ topk = torch.topk(norms, k=k, largest=True)
648
+ return topk.indices.detach().cpu().numpy()
649
+
650
+
651
+ def _select_flatten_targets_for_geom(
652
+ freqs_cm: np.ndarray,
653
+ modes: torch.Tensor,
654
+ coords_bohr: np.ndarray,
655
+ neg_freq_thresh_cm: float,
656
+ root: int,
657
+ flatten_sep_cutoff: float,
658
+ flatten_k: int,
659
+ ) -> List[int]:
660
+ """
661
+ Select a subset of imaginary modes to flatten for a geometry.
662
+ """
663
+ neg_idx_all = np.where(freqs_cm < -abs(neg_freq_thresh_cm))[0]
664
+ if len(neg_idx_all) <= 1:
665
+ return []
666
+
667
+ order = np.argsort(freqs_cm[neg_idx_all])
668
+ sorted_neg = neg_idx_all[order]
669
+ root_clamped = max(0, min(int(root), len(order) - 1))
670
+ primary_idx = sorted_neg[root_clamped]
671
+ candidates = [int(i) for i in sorted_neg if int(i) != int(primary_idx)]
672
+ if not candidates:
673
+ return []
674
+
675
+ coords_ang = torch.as_tensor(
676
+ coords_bohr.reshape(-1, 3) * BOHR2ANG,
677
+ dtype=modes.dtype,
678
+ device=modes.device,
679
+ )
680
+
681
+ targets: List[int] = []
682
+ reps_list: List[np.ndarray] = []
683
+
684
+ for idx in candidates:
685
+ rep = _representative_atoms_for_mode(modes[idx], flatten_k)
686
+ if rep.size == 0:
687
+ continue
688
+ rep_coords = coords_ang[rep]
689
+ if not reps_list:
690
+ targets.append(idx)
691
+ reps_list.append(rep)
692
+ continue
693
+
694
+ accept = True
695
+ for prev_rep in reps_list:
696
+ prev_coords = coords_ang[prev_rep]
697
+ dmat = torch.cdist(rep_coords, prev_coords)
698
+ min_dist = float(torch.min(dmat).item())
699
+ if min_dist < float(flatten_sep_cutoff):
700
+ accept = False
701
+ break
702
+ if accept:
703
+ targets.append(idx)
704
+ reps_list.append(rep)
705
+
706
+ return targets
707
+
708
+
709
+ def _flatten_once_with_modes_for_geom(
710
+ geom,
711
+ masses_amu: np.ndarray,
712
+ calc_kwargs: dict,
713
+ freqs_cm: np.ndarray,
714
+ modes: torch.Tensor,
715
+ neg_freq_thresh_cm: float,
716
+ flatten_amp_ang: float,
717
+ flatten_sep_cutoff: float,
718
+ flatten_k: int,
719
+ root: int,
720
+ ) -> bool:
721
+ """
722
+ Flatten extra imaginary modes for a geometry (single pass).
723
+ """
724
+ neg_idx_all = np.where(freqs_cm < -abs(neg_freq_thresh_cm))[0]
725
+ if len(neg_idx_all) <= 1:
726
+ return False
727
+
728
+ targets = _select_flatten_targets_for_geom(
729
+ freqs_cm,
730
+ modes,
731
+ geom.cart_coords,
732
+ neg_freq_thresh_cm,
733
+ root,
734
+ flatten_sep_cutoff,
735
+ flatten_k,
736
+ )
737
+ if not targets:
738
+ return False
739
+
740
+ mass_scale = np.sqrt(12.011 / masses_amu)[:, None]
741
+ amp_bohr = float(flatten_amp_ang) / BOHR2ANG
742
+
743
+ for idx in targets:
744
+ v_mw = modes[idx].detach().cpu().numpy().reshape(-1, 3)
745
+ m3 = np.repeat(masses_amu, 3).reshape(-1, 3)
746
+ v_cart = v_mw / np.sqrt(m3)
747
+ v_cart /= np.linalg.norm(v_cart)
748
+
749
+ disp = amp_bohr * mass_scale * v_cart
750
+ ref = geom.cart_coords.reshape(-1, 3)
751
+
752
+ plus = ref + disp
753
+ minus = ref - disp
754
+
755
+ geom.coords = plus.reshape(-1)
756
+ E_plus = _calc_energy(geom, calc_kwargs)
757
+
758
+ geom.coords = minus.reshape(-1)
759
+ E_minus = _calc_energy(geom, calc_kwargs)
760
+
761
+ # Move towards lower energy
762
+ if E_plus <= E_minus:
763
+ geom.coords = plus.reshape(-1)
764
+ else:
765
+ geom.coords = minus.reshape(-1)
766
+
767
+ return True
768
+
769
+
770
+ def _get_active_dof_indices(
771
+ calc_cfg: Dict[str, Any],
772
+ n_atoms: int,
773
+ active_dof_mode: str,
774
+ freeze_atoms_final: List[int],
775
+ ) -> Optional[List[int]]:
776
+ del freeze_atoms_final # Kept for backward-compatible signature.
777
+ active_indices, _ = _resolve_active_atom_indices(calc_cfg, n_atoms, active_dof_mode)
778
+ if active_indices is None:
779
+ return None
780
+ return sorted(active_indices)
781
+
782
+
783
+ def _bofill_update_active(H_act: torch.Tensor,
784
+ delta_act: np.ndarray,
785
+ g_new_act: np.ndarray,
786
+ g_old_act: np.ndarray,
787
+ eps: float = 1e-12) -> torch.Tensor:
788
+ """
789
+ Memory-efficient Bofill update on the *active* Cartesian Hessian block.
790
+ Apply symmetric rank-1/2 updates directly **in place** using only the **upper triangle**
791
+ index set (and mirror to the lower) to avoid allocating large NxN temporaries.
792
+ Explicit symmetrization is applied at eigendecomposition sites.
793
+ """
794
+ device = H_act.device
795
+ dtype = H_act.dtype
796
+
797
+ # as torch vectors
798
+ d = torch.as_tensor(delta_act, dtype=dtype, device=device).reshape(-1)
799
+ g0 = torch.as_tensor(g_old_act, dtype=dtype, device=device).reshape(-1)
800
+ g1 = torch.as_tensor(g_new_act, dtype=dtype, device=device).reshape(-1)
801
+ y = g1 - g0
802
+
803
+ # Use current symmetric H_act for matvec (no extra allocation)
804
+ Hd = H_act @ d
805
+ xi = y - Hd
806
+
807
+ d_dot_xi = torch.dot(d, xi)
808
+ d_norm2 = torch.dot(d, d)
809
+ xi_norm2 = torch.dot(xi, xi)
810
+
811
+ # guards
812
+ if torch.abs(d_dot_xi) > eps:
813
+ denom_ms = d_dot_xi
814
+ else:
815
+ sign = torch.sign(d_dot_xi)
816
+ denom_ms = (sign if sign != 0 else torch.tensor(1.0, device=device)) * eps
817
+ denom_psb_d4 = d_norm2 * d_norm2 if d_norm2 > eps else eps
818
+ denom_psb_d2 = d_norm2 if d_norm2 > eps else eps
819
+ denom_phi = d_norm2 * xi_norm2 if (d_norm2 > eps and xi_norm2 > eps) else (1.0)
820
+
821
+ phi = 1.0 - (d_dot_xi * d_dot_xi) / denom_phi
822
+ phi = torch.clamp(phi, 0.0, 1.0)
823
+
824
+ # coefficients for rank updates
825
+ alpha = (1.0 - phi) / denom_ms # for xi xi^T
826
+ beta = - phi * (d_dot_xi / denom_psb_d4) # for d d^T
827
+ gamma = phi / denom_psb_d2 # for d xi^T + xi d^T
828
+
829
+ n = H_act.shape[0]
830
+ iu0, iu1 = torch.triu_indices(n, n, device=device)
831
+ is_diag = (iu0 == iu1)
832
+ off = ~is_diag
833
+
834
+ # Diagonal contributions (i == j): alpha*xi_i^2 + beta*d_i^2 + 2*gamma*d_i*xi_i
835
+ if is_diag.any():
836
+ idx = iu0[is_diag]
837
+ H_act[idx, idx].add_(alpha * xi[idx] * xi[idx]
838
+ + beta * d[idx] * d[idx]
839
+ + 2.0 * gamma * d[idx] * xi[idx])
840
+
841
+ # Off-diagonal (i < j): symmetric update
842
+ if off.any():
843
+ i = iu0[off]; j = iu1[off]
844
+ inc = (alpha * xi[i] * xi[j]
845
+ + beta * d[i] * d[j]
846
+ + gamma * (d[i] * xi[j] + xi[i] * d[j]))
847
+ H_act[i, j].add_(inc)
848
+ H_act[j, i].add_(inc)
849
+
850
+ return H_act
851
+
852
+
853
+ # ===================================================================
854
+ # HessianDimer (extended)
855
+ # ===================================================================
856
+
857
+ class HessianDimer:
858
+ """
859
+ Dimer-based TS search with periodic Hessian updates.
860
+
861
+ Extensions in this implementation:
862
+ - `root` parameter: choose which imaginary mode to follow (0 = most negative).
863
+ - Pass-through kwargs: `dimer_kwargs` and `lbfgs_kwargs` to tune internals.
864
+ - Hard cap on total LBFGS steps across segments: `max_total_cycles`.
865
+ - PHVA (active DOF subspace) + TR projection for mode picking,
866
+ respecting ``freeze_atoms``, with in-place operations. When ``root == 0`` the
867
+ implementation prefers LOBPCG.
868
+ - The flatten loop uses a *Bofill*-updated active Hessian block, so the
869
+ expensive exact Hessian is evaluated only once before the flatten loop and
870
+ once at the end for the final frequency analysis.
871
+ - UMA calculator kwargs accept ``freeze_atoms`` and ``hessian_calc_mode`` and
872
+ default to ``return_partial_hessian=True`` (active-block Hessian when frozen).
873
+ """
874
+
875
+ def __init__(self,
876
+ fn: str,
877
+ out_dir: str = "./result_dimer",
878
+ thresh_loose: str = "gau_loose",
879
+ thresh: str = "baker",
880
+ update_interval_hessian: int = 500,
881
+ neg_freq_thresh_cm: float = 5.0,
882
+ flatten_amp_ang: float = 0.10,
883
+ flatten_max_iter: int = 20,
884
+ mem: int = 100000,
885
+ use_lobpcg: bool = True, # kept for backward compat (not used when root!=0)
886
+ calc_kwargs: Optional[dict] = None,
887
+ device: str = "auto",
888
+ dump: bool = False,
889
+ #
890
+ # New:
891
+ root: int = 0,
892
+ dimer_kwargs: Optional[Dict[str, Any]] = None,
893
+ lbfgs_kwargs: Optional[Dict[str, Any]] = None,
894
+ max_total_cycles: int = 10000,
895
+ #
896
+ # Pass geom kwargs so freeze-atoms and YAML geometry overrides apply on the light path (fix #1)
897
+ geom_kwargs: Optional[Dict[str, Any]] = None,
898
+ # New: Use partial Hessian for imaginary mode detection in flatten loop
899
+ partial_hessian_flatten: bool = True,
900
+ # Spatial separation for flatten mode selection (from pdb2reaction)
901
+ flatten_sep_cutoff: float = 0.0,
902
+ flatten_k: int = 10,
903
+ flatten_loop_bofill: bool = False,
904
+ ml_only_hessian_dimer: bool = False,
905
+ source_path: Optional[Path] = None,
906
+ ) -> None:
907
+
908
+ self.fn = fn
909
+ self.source_path = Path(source_path) if source_path is not None else None
910
+ self.out_dir = Path(out_dir); self.out_dir.mkdir(parents=True, exist_ok=True)
911
+ self.vib_dir = self.out_dir / "vib"; self.vib_dir.mkdir(parents=True, exist_ok=True)
912
+
913
+ self.thresh_loose = thresh_loose
914
+ self.thresh = thresh
915
+ self.update_interval_hessian = int(update_interval_hessian)
916
+ self.neg_freq_thresh_cm = float(neg_freq_thresh_cm)
917
+ self.flatten_amp_ang = float(flatten_amp_ang)
918
+ self.flatten_max_iter = int(flatten_max_iter)
919
+ self.mem = int(mem)
920
+ self.use_lobpcg = bool(use_lobpcg) # used only when root==0 shortcut
921
+ self.root = int(root)
922
+ self.dimer_kwargs = dict(dimer_kwargs or {})
923
+ self.lbfgs_kwargs = dict(lbfgs_kwargs or {})
924
+ self.max_total_cycles = int(max_total_cycles)
925
+ self.partial_hessian_flatten = bool(partial_hessian_flatten)
926
+ # Spatial separation for flatten mode selection
927
+ self.flatten_sep_cutoff = float(flatten_sep_cutoff)
928
+ self.flatten_k = int(flatten_k)
929
+ self.flatten_loop_bofill = bool(flatten_loop_bofill)
930
+ self.ml_only_hessian_dimer = bool(ml_only_hessian_dimer)
931
+
932
+ # Track total cycles globally across ALL loops/segments (fix #2)
933
+ self._cycles_spent = 0
934
+
935
+ # Hessian caching for 0-step convergence (avoid redundant recalculation)
936
+ self._raw_hessian_cache_cpu: Optional[torch.Tensor] = None
937
+ self._raw_hessian_coords_cpu: Optional[np.ndarray] = None
938
+ self._last_active_idx: Optional[List[int]] = None
939
+ self._last_active_mask_dof: Optional[np.ndarray] = None
940
+
941
+ # ML/MM calculator settings
942
+ self.calc_kwargs = dict(calc_kwargs or {})
943
+ self.calc_kwargs.setdefault("out_hess_torch", False)
944
+
945
+ # Geometry & masses (use provided geom kwargs so freeze_atoms etc. apply)
946
+ gkw = dict(geom_kwargs or {})
947
+ coord_type = gkw.pop("coord_type", "cart")
948
+ freeze_geom = list(gkw.get("freeze_atoms", [])) if "freeze_atoms" in gkw else []
949
+ freeze_calc_raw = self.calc_kwargs.get("freeze_atoms") or []
950
+ try:
951
+ freeze_calc = [int(i) for i in freeze_calc_raw]
952
+ except TypeError:
953
+ freeze_calc = [int(freeze_calc_raw)]
954
+ merged_freeze = sorted({int(i) for i in (freeze_geom + freeze_calc)})
955
+ if merged_freeze:
956
+ gkw["freeze_atoms"] = merged_freeze
957
+ elif "freeze_atoms" in gkw:
958
+ gkw["freeze_atoms"] = []
959
+ self.calc_kwargs["freeze_atoms"] = merged_freeze
960
+
961
+ self.calc_kwargs_partial = dict(self.calc_kwargs)
962
+ self.calc_kwargs_partial["mm_fd"] = False
963
+ self.calc_kwargs_partial["return_partial_hessian"] = False
964
+ self.calc_kwargs_partial["out_hess_torch"] = True
965
+ self.calc_kwargs_full = dict(self.calc_kwargs)
966
+ self.calc_kwargs_full.setdefault("mm_fd", True)
967
+ self.calc_kwargs_full["return_partial_hessian"] = False
968
+ self.calc_kwargs_full["out_hess_torch"] = True
969
+ # ML-only Hessian kwargs: skip MM Hessian entirely, use ML partial Hessian only
970
+ self.calc_kwargs_ml_only = dict(self.calc_kwargs)
971
+ self.calc_kwargs_ml_only["mm_fd"] = False
972
+ self.calc_kwargs_ml_only["return_partial_hessian"] = True
973
+ self.calc_kwargs_ml_only["out_hess_torch"] = True
974
+ self.calc_kwargs_ml_only["hess_cutoff"] = 0.0 # ML atoms only in Hessian
975
+ self.geom = geom_loader(fn, coord_type=coord_type, **gkw)
976
+ # If partial Hessian is requested (explicitly or via B-factor layers),
977
+ # avoid full 3N Hessian allocations in light TS dimer runs.
978
+ if self.calc_kwargs.get("return_partial_hessian"):
979
+ self.calc_kwargs_partial["return_partial_hessian"] = True
980
+ self.calc_kwargs_full["return_partial_hessian"] = True
981
+ elif self.partial_hessian_flatten and self.calc_kwargs.get("use_bfactor_layers"):
982
+ self.calc_kwargs_partial["return_partial_hessian"] = True
983
+ self.calc_kwargs_full["return_partial_hessian"] = True
984
+ self.masses_amu = np.array([atomic_masses[z] for z in self.geom.atomic_numbers])
985
+ self.masses_au_t = torch.as_tensor(self.masses_amu * AMU2AU, dtype=torch.float32)
986
+
987
+ # --- Preserve freeze list (for PHVA) ---
988
+ self.freeze_atoms: List[int] = list(gkw.get("freeze_atoms", [])) if "freeze_atoms" in gkw else []
989
+
990
+ # Device
991
+ self.device = _torch_device(device)
992
+ self.masses_au_t = self.masses_au_t.to(self.device)
993
+
994
+ # temp file for Dimer orientation (N_raw)
995
+ self.mode_path = self.out_dir / ".dimer_mode.dat"
996
+
997
+ self.dump = bool(dump)
998
+ self.optim_all_path = self.out_dir / "optimization_all_trj.xyz"
999
+
1000
+ # ----- One dimer segment for up to n_steps; returns (steps_done, converged) -----
1001
+ def _dimer_segment(self, threshold: str, n_steps: int) -> Tuple[int, bool]:
1002
+ # Dimer calculator using current mode as initial N
1003
+ calc_sp = mlmm(**self.calc_kwargs)
1004
+
1005
+ # Merge user dimer kwargs (but enforce N_raw & write_orientations)
1006
+ dimer_kwargs = dict(self.dimer_kwargs)
1007
+ dimer_kwargs.update({
1008
+ "calculator": calc_sp,
1009
+ "N_raw": str(self.mode_path),
1010
+ "write_orientations": False, # runner override to reduce IO
1011
+ "seed": 0, # runner override for determinism
1012
+ "mem": self.mem, # accepted by Calculator base through **kwargs
1013
+ })
1014
+ dimer = Dimer(**dimer_kwargs)
1015
+
1016
+ self.geom.set_calculator(dimer)
1017
+
1018
+ # LBFGS kwargs: enforce thresh/max_cycles/out_dir/dump; allow others
1019
+ lbfgs_kwargs = dict(self.lbfgs_kwargs)
1020
+ lbfgs_kwargs.update({
1021
+ "max_cycles": n_steps,
1022
+ "thresh": threshold,
1023
+ "out_dir": str(self.out_dir),
1024
+ "dump": self.dump,
1025
+ })
1026
+ opt = LBFGS(self.geom, **lbfgs_kwargs)
1027
+ opt.run()
1028
+ # pysisyphus uses 0-indexed cur_cycle; keep budget accounting strict by clamping
1029
+ # to the requested segment step count.
1030
+ steps = min(max(int(opt.cur_cycle) + 1, 1), int(n_steps))
1031
+ converged = opt.is_converged
1032
+ self.geom.set_calculator(None)
1033
+
1034
+ # Free dimer/optimizer GPU resources before next Hessian computation
1035
+ del calc_sp, dimer, opt
1036
+ if torch.cuda.is_available():
1037
+ torch.cuda.empty_cache()
1038
+
1039
+ # Append to concatenated trajectory if dump enabled
1040
+ if self.dump:
1041
+ _append_xyz_trajectory(self.optim_all_path, self.out_dir / "optimization_trj.xyz")
1042
+ return steps, converged
1043
+
1044
+ # ----- Hessian caching for 0-step convergence -----
1045
+ def _cache_raw_hessian_cpu(self, H: torch.Tensor) -> None:
1046
+ """Cache the raw Hessian on CPU for the current geometry."""
1047
+ self._raw_hessian_cache_cpu = H.detach().cpu().clone()
1048
+ self._raw_hessian_coords_cpu = self.geom.cart_coords.copy()
1049
+
1050
+ def _reuse_cached_hessian(self) -> Optional[torch.Tensor]:
1051
+ """If the cached geometry matches current, return cached Hessian on device."""
1052
+ if self._raw_hessian_cache_cpu is None or self._raw_hessian_coords_cpu is None:
1053
+ return None
1054
+ if not np.array_equal(self.geom.cart_coords, self._raw_hessian_coords_cpu):
1055
+ return None
1056
+ H_dev = self._raw_hessian_cache_cpu.to(self.device)
1057
+ if self.device.type == "cpu":
1058
+ H_dev = H_dev.clone()
1059
+ return H_dev
1060
+
1061
+ def _calc_full_hessian_cached(
1062
+ self, calc_kwargs: Dict[str, Any], allow_reuse: bool
1063
+ ) -> torch.Tensor:
1064
+ """Compute Hessian, caching on CPU. Reuse if allow_reuse and geometry unchanged."""
1065
+ if allow_reuse:
1066
+ cached = self._reuse_cached_hessian()
1067
+ if cached is not None:
1068
+ click.echo("[tsopt] Reusing cached raw Hessian (0-step convergence).")
1069
+ return cached
1070
+ H = _calc_full_hessian_torch(self.geom, calc_kwargs, self.device)
1071
+ self._cache_raw_hessian_cpu(H)
1072
+ return H
1073
+
1074
+ def _resolve_hessian_active_subspace(self, H_t: torch.Tensor, N: int) -> Tuple[List[int], np.ndarray]:
1075
+ """
1076
+ Resolve active atoms/DOFs for a Hessian tensor.
1077
+ For partial Hessians, prefer geometry metadata populated by the calculator.
1078
+ """
1079
+ h_dim = int(H_t.size(0))
1080
+ full_dim = 3 * int(N)
1081
+ freeze = self.freeze_atoms if len(self.freeze_atoms) > 0 else []
1082
+
1083
+ if h_dim == full_dim:
1084
+ active_idx = _active_indices(N, freeze)
1085
+ mask_dof = _active_mask_dof(N, freeze)
1086
+ self._last_active_idx = list(active_idx)
1087
+ self._last_active_mask_dof = mask_dof.copy()
1088
+ return active_idx, mask_dof
1089
+
1090
+ def _norm_atoms(vals: Optional[Any]) -> np.ndarray:
1091
+ if vals is None:
1092
+ return np.zeros(0, dtype=int)
1093
+ arr = np.asarray(vals, dtype=int).reshape(-1)
1094
+ return arr[(arr >= 0) & (arr < N)]
1095
+
1096
+ def _norm_dofs(vals: Optional[Any]) -> np.ndarray:
1097
+ if vals is None:
1098
+ return np.zeros(0, dtype=int)
1099
+ arr = np.asarray(vals, dtype=int).reshape(-1)
1100
+ return arr[(arr >= 0) & (arr < full_dim)]
1101
+
1102
+ def _stable_unique(vals: np.ndarray) -> np.ndarray:
1103
+ seen = set()
1104
+ out: List[int] = []
1105
+ for v in vals.tolist():
1106
+ iv = int(v)
1107
+ if iv not in seen:
1108
+ seen.add(iv)
1109
+ out.append(iv)
1110
+ return np.asarray(out, dtype=int)
1111
+
1112
+ candidates: List[Tuple[str, np.ndarray, np.ndarray]] = []
1113
+ try:
1114
+ candidates.append((
1115
+ "geom.hess_active_*",
1116
+ _norm_atoms(self.geom.hess_active_atom_indices),
1117
+ _norm_dofs(self.geom.hess_active_dof_indices),
1118
+ ))
1119
+ except Exception:
1120
+ logger.debug("Failed to read hess_active_* indices", exc_info=True)
1121
+
1122
+ within = getattr(self.geom, "within_partial_hessian", None)
1123
+ if isinstance(within, dict):
1124
+ candidates.append((
1125
+ "geom.within_partial_hessian",
1126
+ _norm_atoms(within.get("active_atoms")),
1127
+ _norm_dofs(within.get("active_dofs")),
1128
+ ))
1129
+
1130
+ candidates.append((
1131
+ "geom._hess_active_*_last",
1132
+ _norm_atoms(getattr(self.geom, "_hess_active_atoms_last", None)),
1133
+ _norm_dofs(getattr(self.geom, "_hess_active_dofs_last", None)),
1134
+ ))
1135
+
1136
+ if self._last_active_idx is not None or self._last_active_mask_dof is not None:
1137
+ cached_atoms = _norm_atoms(self._last_active_idx)
1138
+ cached_dofs = np.flatnonzero(self._last_active_mask_dof).astype(int) \
1139
+ if self._last_active_mask_dof is not None else np.zeros(0, dtype=int)
1140
+ candidates.append(("cached_active_subspace", cached_atoms, _norm_dofs(cached_dofs)))
1141
+
1142
+ fallback_atoms = _norm_atoms(_active_indices(N, freeze))
1143
+ fallback_dofs = np.flatnonzero(_active_mask_dof(N, freeze)).astype(int)
1144
+ candidates.append(("freeze_based", fallback_atoms, _norm_dofs(fallback_dofs)))
1145
+
1146
+ for _, atoms_arr, dofs_arr in candidates:
1147
+ if dofs_arr.size > 0:
1148
+ mask_dof = np.zeros(full_dim, dtype=bool)
1149
+ mask_dof[dofs_arr] = True
1150
+ elif atoms_arr.size > 0:
1151
+ mask_dof = _mask_dof_from_active_idx(N, atoms_arr.tolist())
1152
+ else:
1153
+ continue
1154
+
1155
+ if int(mask_dof.sum()) != h_dim:
1156
+ continue
1157
+
1158
+ if dofs_arr.size > 0:
1159
+ atoms_arr = _stable_unique((dofs_arr // 3).astype(int))
1160
+ elif atoms_arr.size == 0:
1161
+ atoms_arr = _stable_unique((np.flatnonzero(mask_dof) // 3).astype(int))
1162
+ active_idx = [int(i) for i in atoms_arr.tolist()]
1163
+ self._last_active_idx = list(active_idx)
1164
+ self._last_active_mask_dof = mask_dof.copy()
1165
+ return active_idx, mask_dof
1166
+
1167
+ raise RuntimeError(
1168
+ f"Failed to resolve active subspace for partial Hessian: "
1169
+ f"H_dim={h_dim}, full_dim={full_dim}, freeze_active_dof={int(fallback_dofs.size)}"
1170
+ )
1171
+
1172
+ # ----- Loop dimer segments, updating mode from Hessian every interval -----
1173
+ def _dimer_loop(self, threshold: str) -> Tuple[int, bool]:
1174
+ """
1175
+ Run multiple LBFGS segments separated by periodic Hessian-based mode updates.
1176
+ Consumes from a *global* cycle budget self.max_total_cycles.
1177
+
1178
+ Returns:
1179
+ (steps_in_this_call, zero_step_converged)
1180
+ where `zero_step_converged` is True iff the loop terminated by convergence
1181
+ without changing the geometry (i.e., 0-step convergence).
1182
+ """
1183
+ steps_in_this_call = 0
1184
+ zero_step_converged = False
1185
+ while True:
1186
+ remaining_global = max(0, self.max_total_cycles - self._cycles_spent)
1187
+ if remaining_global == 0:
1188
+ break
1189
+ steps_this = min(self.update_interval_hessian, remaining_global)
1190
+ coords_before = self.geom.cart_coords.copy()
1191
+ steps, ok = self._dimer_segment(threshold, steps_this)
1192
+ self._cycles_spent += steps
1193
+ steps_in_this_call += steps
1194
+ if ok:
1195
+ # Check if geometry unchanged (0-step convergence)
1196
+ if np.array_equal(self.geom.cart_coords, coords_before):
1197
+ zero_step_converged = True
1198
+ break
1199
+ # If budget exhausted after this segment, stop before doing a Hessian update
1200
+ if (self.max_total_cycles - self._cycles_spent) <= 0:
1201
+ break
1202
+ # Update mode from Hessian (respect freeze atoms via PHVA)
1203
+ # Ensure VRAM is fully released after dimer segment before heavy Hessian computation
1204
+ if torch.cuda.is_available():
1205
+ torch.cuda.empty_cache()
1206
+ # Choose ML-only or full active-DOF Hessian for mode direction
1207
+ hess_kw = self.calc_kwargs_ml_only if self.ml_only_hessian_dimer else self.calc_kwargs_partial
1208
+ H_t = _calc_full_hessian_torch(self.geom, hess_kw, self.device)
1209
+ N = len(self.geom.atomic_numbers)
1210
+ coords_bohr_t = torch.as_tensor(self.geom.coords.reshape(-1, 3),
1211
+ dtype=H_t.dtype, device=H_t.device)
1212
+ # full vs active-block Hessian
1213
+ if H_t.size(0) == 3 * N:
1214
+ mode_xyz = _mode_direction_by_root(
1215
+ H_t, coords_bohr_t, self.masses_au_t,
1216
+ root=self.root, freeze_idx=self.freeze_atoms if len(self.freeze_atoms) > 0 else None
1217
+ )
1218
+ else:
1219
+ # partial (active) Hessian returned by UMA
1220
+ active_idx, _ = self._resolve_hessian_active_subspace(H_t, N)
1221
+ mode_xyz = _mode_direction_by_root_from_Hact(
1222
+ H_t, self.geom.coords.reshape(-1, 3), self.geom.atomic_numbers,
1223
+ self.masses_au_t, active_idx, self.device, root=self.root
1224
+ )
1225
+ np.savetxt(self.mode_path, mode_xyz, fmt="%.12f")
1226
+ del H_t, coords_bohr_t, mode_xyz
1227
+ _clear_cuda_cache()
1228
+ return steps_in_this_call, zero_step_converged
1229
+
1230
+ # ----- Flatten (resolve multiple imaginary modes) -----
1231
+ def _flatten_once(self) -> bool:
1232
+ """
1233
+ Legacy: exact-Hessian-based flattening (kept for reference / fallback).
1234
+ """
1235
+ H_t = _calc_full_hessian_torch(self.geom, self.calc_kwargs_full, self.device)
1236
+ freqs_cm, modes = _frequencies_cm_and_modes(
1237
+ H_t, self.geom.atomic_numbers, self.geom.coords.reshape(-1, 3), self.device,
1238
+ freeze_idx=self.freeze_atoms if len(self.freeze_atoms) > 0 else None
1239
+ )
1240
+ del H_t
1241
+ neg_idx_all = np.where(freqs_cm < -abs(self.neg_freq_thresh_cm))[0]
1242
+ if len(neg_idx_all) <= 1:
1243
+ del modes
1244
+ return False
1245
+
1246
+ # Identify the "primary" imaginary by root among negative modes
1247
+ order = np.argsort(freqs_cm[neg_idx_all]) # ascending (more negative first)
1248
+ root_clamped = max(0, min(self.root, len(order) - 1))
1249
+ primary_idx = neg_idx_all[order[root_clamped]]
1250
+
1251
+ targets = [i for i in neg_idx_all if i != primary_idx]
1252
+ if not targets:
1253
+ del modes
1254
+ return False
1255
+
1256
+ # Reference structure and energy
1257
+ ref = self.geom.coords.reshape(-1, 3).copy()
1258
+ _ = _calc_energy(self.geom, self.calc_kwargs) # E_ref (unused, but keeps semantics)
1259
+
1260
+ # mass scaling so that carbon ~ amplitude
1261
+ mass_scale = np.sqrt(12.011 / self.masses_amu)[:, None]
1262
+ amp_bohr = self.flatten_amp_ang / BOHR2ANG
1263
+
1264
+ disp_total = np.zeros_like(ref)
1265
+ for idx in targets:
1266
+ v_mw = modes[idx].detach().cpu().numpy().reshape(-1, 3) # mass-weighted eigenvector embedded to 3N
1267
+ # Convert to Cartesian step direction already done downstream in writer,
1268
+ # but for flattening we only need a normalized direction in Cartesian:
1269
+ # use masses to unweight:
1270
+ m3 = np.repeat(self.masses_amu, 3).reshape(-1, 3)
1271
+ v_cart = v_mw / np.sqrt(m3)
1272
+ v_cart /= np.linalg.norm(v_cart)
1273
+ disp0 = amp_bohr * mass_scale * v_cart
1274
+
1275
+ self.geom.coords = (ref + disp0).reshape(-1)
1276
+ E_plus = _calc_energy(self.geom, self.calc_kwargs)
1277
+ self.geom.coords = (ref - disp0).reshape(-1)
1278
+ E_minus = _calc_energy(self.geom, self.calc_kwargs)
1279
+ self.geom.coords = ref.reshape(-1)
1280
+
1281
+ disp_total += (disp0 if E_plus <= E_minus else -disp0)
1282
+
1283
+ del modes
1284
+ _clear_cuda_cache()
1285
+
1286
+ self.geom.coords = (ref + disp_total).reshape(-1)
1287
+ return True
1288
+
1289
+ def _flatten_once_with_modes(self, freqs_cm: np.ndarray, modes: torch.Tensor) -> bool:
1290
+ """
1291
+ Flatten using precomputed (approximate) modes (mass-weighted, embedded).
1292
+
1293
+ Uses spatial separation (if flatten_sep_cutoff > 0) to select only modes
1294
+ whose representative atoms are well-separated from each other. This avoids
1295
+ applying conflicting displacements to nearby regions. Modes are applied
1296
+ sequentially, updating the reference position after each mode.
1297
+ """
1298
+ neg_idx_all = np.where(freqs_cm < -abs(self.neg_freq_thresh_cm))[0]
1299
+ if len(neg_idx_all) <= 1:
1300
+ return False
1301
+
1302
+ # Use spatial separation if cutoff > 0, otherwise select all non-primary modes
1303
+ if self.flatten_sep_cutoff > 0:
1304
+ targets = _select_flatten_targets_for_geom(
1305
+ freqs_cm,
1306
+ modes,
1307
+ self.geom.cart_coords,
1308
+ self.neg_freq_thresh_cm,
1309
+ self.root,
1310
+ self.flatten_sep_cutoff,
1311
+ self.flatten_k,
1312
+ )
1313
+ else:
1314
+ # Legacy behavior: select all imaginary modes except primary
1315
+ order = np.argsort(freqs_cm[neg_idx_all])
1316
+ root_clamped = max(0, min(self.root, len(order) - 1))
1317
+ primary_idx = neg_idx_all[order[root_clamped]]
1318
+ targets = [i for i in neg_idx_all if i != primary_idx]
1319
+
1320
+ if not targets:
1321
+ return False
1322
+
1323
+ # Mass scaling (carbon moves exactly flatten_amp_ang Å)
1324
+ mass_scale = np.sqrt(12.011 / self.masses_amu)[:, None]
1325
+ amp_bohr = self.flatten_amp_ang / BOHR2ANG
1326
+
1327
+ # Get reference energy
1328
+ E_ref = _calc_energy(self.geom, self.calc_kwargs)
1329
+
1330
+ # Apply modes sequentially (like pdb2reaction)
1331
+ for idx in targets:
1332
+ v_mw = modes[idx].detach().cpu().numpy().reshape(-1, 3)
1333
+ m3 = np.repeat(self.masses_amu, 3).reshape(-1, 3)
1334
+ v_cart = v_mw / np.sqrt(m3)
1335
+ v_cart /= np.linalg.norm(v_cart)
1336
+
1337
+ disp = amp_bohr * mass_scale * v_cart
1338
+ ref = self.geom.coords.reshape(-1, 3)
1339
+
1340
+ plus = ref + disp
1341
+ minus = ref - disp
1342
+
1343
+ self.geom.coords = plus.reshape(-1)
1344
+ E_plus = _calc_energy(self.geom, self.calc_kwargs)
1345
+
1346
+ self.geom.coords = minus.reshape(-1)
1347
+ E_minus = _calc_energy(self.geom, self.calc_kwargs)
1348
+
1349
+ # Keep lower-energy side and continue from there
1350
+ use_plus = E_plus <= E_minus
1351
+ self.geom.coords = (plus if use_plus else minus).reshape(-1)
1352
+ E_keep = E_plus if use_plus else E_minus
1353
+ delta_e = E_keep - E_ref
1354
+ click.echo(
1355
+ f"[Flatten] mode={idx} freq={freqs_cm[idx]:+.2f} cm^-1 "
1356
+ f"E_disp={E_keep:.8f} Ha ΔE={delta_e:+.8f} Ha"
1357
+ )
1358
+
1359
+ _clear_cuda_cache()
1360
+ return True
1361
+
1362
+ # ----- Run full procedure -----
1363
+ def run(self) -> None:
1364
+ if self.dump and self.optim_all_path.exists():
1365
+ self.optim_all_path.unlink()
1366
+
1367
+ N = len(self.geom.atomic_numbers)
1368
+ H_final_reuse_cpu: Optional[torch.Tensor] = None
1369
+ H_final_reuse_coords: Optional[np.ndarray] = None
1370
+
1371
+ # (1) Initial Hessian → pick direction by `root`
1372
+ hess_kw_init = self.calc_kwargs_ml_only if self.ml_only_hessian_dimer else self.calc_kwargs_partial
1373
+ if self.ml_only_hessian_dimer:
1374
+ click.echo("[tsopt] Using ML-only Hessian for dimer orientation.")
1375
+ H_t = _calc_full_hessian_torch(self.geom, hess_kw_init, self.device)
1376
+ coords_bohr_t = torch.as_tensor(self.geom.coords.reshape(-1, 3),
1377
+ dtype=H_t.dtype, device=H_t.device)
1378
+ active_idx, mask_dof = self._resolve_hessian_active_subspace(H_t, N)
1379
+ if H_t.size(0) != 3 * N:
1380
+ click.echo(
1381
+ f"[tsopt] H_act={int(H_t.size(0))} active_atoms={len(active_idx)} "
1382
+ f"active_dofs={int(mask_dof.sum())} within={self.geom.within_partial_hessian is not None}"
1383
+ )
1384
+
1385
+ if H_t.size(0) == 3 * N:
1386
+ # Skip heavy TR-projection residual check to conserve VRAM.
1387
+ click.echo("[tsopt] TR-projection residual check skipped to conserve VRAM.")
1388
+ mode_xyz = _mode_direction_by_root(
1389
+ H_t, coords_bohr_t, self.masses_au_t,
1390
+ root=self.root, freeze_idx=self.freeze_atoms if len(self.freeze_atoms) > 0 else None
1391
+ )
1392
+ else:
1393
+ click.echo("[tsopt] Using active-block Hessian from UMA (partial Hessian). Skip full-space TR check.")
1394
+ mode_xyz = _mode_direction_by_root_from_Hact(
1395
+ H_t, self.geom.coords.reshape(-1, 3), self.geom.atomic_numbers,
1396
+ self.masses_au_t, active_idx, self.device, root=self.root
1397
+ )
1398
+ np.savetxt(self.mode_path, mode_xyz, fmt="%.12f")
1399
+ del mode_xyz, coords_bohr_t, H_t
1400
+ _clear_cuda_cache()
1401
+
1402
+ # (2) Loose loop
1403
+ if self.root!=0:
1404
+ click.echo("[tsopt] root != 0. Use this 'root' in first dimer loop", err=True)
1405
+ click.echo(f"[tsopt] Dimer Loop with initial direction from mode {self.root}...")
1406
+ self.root=0
1407
+ self.thresh_loose = self.thresh
1408
+ else:
1409
+ click.echo("[tsopt] Loose Dimer Loop...")
1410
+
1411
+ _, zero_step_loose = self._dimer_loop(self.thresh_loose)
1412
+
1413
+ zero_step_normal = False
1414
+ if (self.max_total_cycles - self._cycles_spent) > 0:
1415
+ # (3) Update mode & normal loop (reuse Hessian if 0-step converged)
1416
+ H_t = self._calc_full_hessian_cached(self.calc_kwargs_partial, allow_reuse=zero_step_loose)
1417
+ coords_bohr_t = torch.as_tensor(self.geom.coords.reshape(-1, 3),
1418
+ dtype=H_t.dtype, device=H_t.device)
1419
+ if H_t.size(0) == 3 * N:
1420
+ click.echo("[tsopt] TR-projection residual check skipped to conserve VRAM.")
1421
+ mode_xyz = _mode_direction_by_root(
1422
+ H_t, coords_bohr_t, self.masses_au_t,
1423
+ root=self.root, freeze_idx=self.freeze_atoms if len(self.freeze_atoms) > 0 else None
1424
+ )
1425
+ else:
1426
+ click.echo("[tsopt] Using active-block Hessian from UMA (partial Hessian). Skip full-space TR check.")
1427
+ active_idx, mask_dof = self._resolve_hessian_active_subspace(H_t, N)
1428
+ mode_xyz = _mode_direction_by_root_from_Hact(
1429
+ H_t, self.geom.coords.reshape(-1, 3), self.geom.atomic_numbers,
1430
+ self.masses_au_t, active_idx, self.device, root=self.root
1431
+ )
1432
+ np.savetxt(self.mode_path, mode_xyz, fmt="%.12f")
1433
+ del mode_xyz, coords_bohr_t, H_t
1434
+ _clear_cuda_cache()
1435
+
1436
+ click.echo("[tsopt] Normal Dimer Loop...")
1437
+ _, zero_step_normal = self._dimer_loop(self.thresh)
1438
+ else:
1439
+ click.echo("[tsopt] Reached --max-cycles budget after loose loop; skipping normal dimer loop.")
1440
+
1441
+ if self.flatten_max_iter > 0 and (self.max_total_cycles - self._cycles_spent) > 0:
1442
+ # (4) Flatten Loop — *reduced* exact Hessian calls via Bofill updates (active DOF only)
1443
+ click.echo("[tsopt] Flatten Loop with Bofill-updated active Hessian...")
1444
+
1445
+ # (4.1) Evaluate one exact Hessian at the loop start and prepare the active block
1446
+ # (reuse Hessian if 0-step converged)
1447
+ H_any = self._calc_full_hessian_cached(self.calc_kwargs_full, allow_reuse=zero_step_normal)
1448
+ # Keep a CPU copy so we can skip the final Hessian recomputation
1449
+ # when the flatten loop leaves geometry unchanged.
1450
+ H_final_reuse_cpu = H_any.detach().cpu().clone()
1451
+ H_final_reuse_coords = self.geom.cart_coords.copy()
1452
+ if H_any.size(0) == 3 * N:
1453
+ # full → extract active
1454
+ H_act = _extract_active_block(H_any, mask_dof) # torch (3N_act,3N_act)
1455
+ else:
1456
+ # UMA already returned active-block Hessian
1457
+ active_idx, mask_dof = self._resolve_hessian_active_subspace(H_any, N)
1458
+ H_act = H_any
1459
+ del H_any
1460
+ _clear_cuda_cache()
1461
+
1462
+ # Gradient & coordinates snapshot for quasi-Newton updates
1463
+ x_prev = self.geom.coords.copy().reshape(-1) # (3N,)
1464
+ g_prev = _calc_gradient(self.geom, self.calc_kwargs).reshape(-1) # (3N,)
1465
+
1466
+ # Flatten iterations with *approximate* Hessian updates
1467
+ for it in range(self.flatten_max_iter):
1468
+ if (self.max_total_cycles - self._cycles_spent) <= 0:
1469
+ break
1470
+
1471
+ # (a) Estimate current imaginary modes using the *active* Hessian
1472
+ freqs_est = _frequencies_from_Hact(H_act, self.geom.atomic_numbers,
1473
+ self.geom.coords.reshape(-1, 3), active_idx, self.device)
1474
+ n_imag = int(np.sum(freqs_est < -abs(self.neg_freq_thresh_cm)))
1475
+ click.echo(f"[tsopt] n≈{n_imag} (approx imag: {[float(x) for x in freqs_est if x < -abs(self.neg_freq_thresh_cm)]})")
1476
+ if n_imag <= 1:
1477
+ break
1478
+
1479
+ # (b) Get approximate modes for flattening (embedded, mass-weighted)
1480
+ freqs_cm_approx, modes_embedded = _modes_from_Hact_embedded(
1481
+ H_act, self.geom.atomic_numbers, self.geom.coords.reshape(-1, 3), active_idx, self.device
1482
+ )
1483
+
1484
+ # (c) Do flatten step using the approximate modes
1485
+ x_before_flat = self.geom.coords.copy().reshape(-1)
1486
+ did_flatten = self._flatten_once_with_modes(freqs_cm_approx, modes_embedded)
1487
+ # Free GPU tensors from mode computation immediately after use
1488
+ del freqs_cm_approx, modes_embedded
1489
+ if torch.cuda.is_available():
1490
+ torch.cuda.empty_cache()
1491
+ if not did_flatten:
1492
+ break
1493
+ x_after_flat = self.geom.coords.copy().reshape(-1)
1494
+
1495
+ # (d) Bofill update using UMA gradients across the flatten displacement
1496
+ g_after_flat = _calc_gradient(self.geom, self.calc_kwargs).reshape(-1)
1497
+ delta_flat_full = x_after_flat - x_before_flat
1498
+ delta_flat_act = delta_flat_full[mask_dof]
1499
+ g_old_act = g_prev[mask_dof]
1500
+ g_new_act = g_after_flat[mask_dof]
1501
+ H_act = _bofill_update_active(H_act, delta_flat_act, g_new_act, g_old_act)
1502
+
1503
+ # (e) Refresh dimer direction from updated active Hessian
1504
+ mode_xyz = _mode_direction_by_root_from_Hact(
1505
+ H_act, self.geom.coords.reshape(-1, 3), self.geom.atomic_numbers,
1506
+ self.masses_au_t, active_idx, self.device, root=self.root
1507
+ )
1508
+ np.savetxt(self.mode_path, mode_xyz, fmt="%.12f")
1509
+ del mode_xyz
1510
+
1511
+ # (f) Re-optimize with Dimer (consumes global cycle budget)
1512
+ # Clear VRAM before dimer loop to ensure space for Hessian recomputation
1513
+ if torch.cuda.is_available():
1514
+ torch.cuda.empty_cache()
1515
+ _, zero_step_flat = self._dimer_loop(self.thresh)
1516
+
1517
+ # (g) Bofill update again across the optimization displacement
1518
+ x_after_opt = self.geom.coords.copy().reshape(-1)
1519
+ g_after_opt = _calc_gradient(self.geom, self.calc_kwargs).reshape(-1)
1520
+ delta_opt_full = x_after_opt - x_after_flat
1521
+ delta_opt_act = delta_opt_full[mask_dof]
1522
+ g_old_act2 = g_after_flat[mask_dof]
1523
+ g_new_act2 = g_after_opt[mask_dof]
1524
+ H_act = _bofill_update_active(H_act, delta_opt_act, g_new_act2, g_old_act2)
1525
+
1526
+ # (h) Prepare for next iteration
1527
+ x_prev = x_after_opt
1528
+ g_prev = g_after_opt
1529
+ elif self.flatten_max_iter > 0:
1530
+ click.echo("[tsopt] Reached --max-cycles budget; skipping flatten loop.")
1531
+
1532
+ # (5) Final outputs
1533
+ final_xyz = self.out_dir / "final_geometry.xyz"
1534
+ atoms_final = Atoms(self.geom.atoms, positions=(self.geom.coords.reshape(-1, 3) * BOHR2ANG), pbc=False)
1535
+ write(final_xyz, atoms_final)
1536
+
1537
+ # Final Hessian → imaginary mode animation
1538
+ reuse_final_hessian = (
1539
+ H_final_reuse_cpu is not None
1540
+ and H_final_reuse_coords is not None
1541
+ and np.array_equal(self.geom.cart_coords, H_final_reuse_coords)
1542
+ )
1543
+ if reuse_final_hessian:
1544
+ click.echo("[tsopt] Reusing flatten-start Hessian for final frequency analysis (geometry unchanged).")
1545
+ H_t = H_final_reuse_cpu.to(self.device)
1546
+ else:
1547
+ H_t = _calc_full_hessian_torch(self.geom, self.calc_kwargs_full, self.device)
1548
+ if H_t.size(0) == 3 * N:
1549
+ freqs_cm, modes = _frequencies_cm_and_modes(
1550
+ H_t, self.geom.atomic_numbers, self.geom.coords.reshape(-1, 3), self.device,
1551
+ freeze_idx=self.freeze_atoms if len(self.freeze_atoms) > 0 else None
1552
+ )
1553
+ else:
1554
+ active_idx_final, _ = self._resolve_hessian_active_subspace(H_t, N)
1555
+ freqs_cm, modes = _modes_from_Hact_embedded(
1556
+ H_t, self.geom.atomic_numbers, self.geom.coords.reshape(-1, 3),
1557
+ active_idx_final, self.device
1558
+ )
1559
+
1560
+ del H_t
1561
+ del H_final_reuse_cpu, H_final_reuse_coords
1562
+ _ref_pdb_light = (
1563
+ self.source_path
1564
+ if self.source_path is not None and self.source_path.suffix.lower() == ".pdb"
1565
+ else None
1566
+ )
1567
+ n_written = _write_all_imag_modes(
1568
+ self.geom,
1569
+ freqs_cm,
1570
+ modes,
1571
+ self.neg_freq_thresh_cm,
1572
+ self.vib_dir,
1573
+ ref_pdb=_ref_pdb_light,
1574
+ )
1575
+ if n_written == 0:
1576
+ click.echo(
1577
+ "[tsopt] No imaginary mode found at the end (nu_min = %.2f cm^-1)." % (float(freqs_cm.min()),),
1578
+ err=True,
1579
+ )
1580
+ else:
1581
+ click.echo(f"[tsopt] Wrote {n_written} final imaginary mode(s).")
1582
+ del modes, freqs_cm
1583
+
1584
+ _clear_cuda_cache()
1585
+ click.echo(f"[tsopt] Saved final geometry → {final_xyz}")
1586
+ click.echo(f"[tsopt] Mode files → {self.vib_dir}")
1587
+
1588
+
1589
+ # ===================================================================
1590
+ # Microiteration loop for RS-I-RFO heavy mode
1591
+ # ===================================================================
1592
+
1593
+
1594
+ def _run_microiter_tsopt(
1595
+ geometry,
1596
+ calc_cfg: Dict[str, Any],
1597
+ rsirfo_cfg: Dict[str, Any],
1598
+ lbfgs_cfg: Dict[str, Any],
1599
+ opt_cfg: Dict[str, Any],
1600
+ microiter_cfg: Dict[str, Any],
1601
+ out_dir_path: Path,
1602
+ *,
1603
+ dump: bool = False,
1604
+ thresh: Optional[str] = None,
1605
+ ) -> None:
1606
+ """Run macro/micro alternating TS optimization (Gaussian 16-style microiteration).
1607
+
1608
+ Macro step: 1 RS-I-RFO step moving only ML region (full ONIOM force).
1609
+ Micro step: LBFGS relaxing MM region with MM-only forces until convergence.
1610
+ """
1611
+ from .freq import _collect_layer_atom_sets
1612
+
1613
+ # Resolve layer atom sets
1614
+ layer_sets = _collect_layer_atom_sets(calc_cfg)
1615
+ ml_indices = sorted(layer_sets["ml"])
1616
+ movable_mm = sorted(layer_sets["movable_mm"] | layer_sets["hess_mm"])
1617
+ frozen_mm = sorted(layer_sets["frozen_mm"])
1618
+
1619
+ if not ml_indices:
1620
+ click.echo("[microiter] WARNING: No ML atoms found. Falling back to standard RS-I-RFO.")
1621
+ return None
1622
+
1623
+ n_atoms = len(geometry.atoms)
1624
+ all_indices = list(range(n_atoms))
1625
+ mm_indices = sorted(set(all_indices) - set(ml_indices))
1626
+
1627
+ # Freeze lists: for macro step, freeze all MM; for micro step, freeze ML
1628
+ macro_freeze = sorted(set(mm_indices) | set(frozen_mm))
1629
+ micro_freeze = sorted(set(ml_indices) | set(frozen_mm))
1630
+
1631
+ max_cycles = int(opt_cfg.get("max_cycles", 10000))
1632
+ macro_thresh = thresh if thresh is not None else rsirfo_cfg.get("thresh", "baker")
1633
+ micro_thresh = microiter_cfg.get("micro_thresh") or macro_thresh
1634
+ micro_max_cycles = int(microiter_cfg.get("micro_max_cycles", 10000))
1635
+
1636
+ click.echo(
1637
+ f"[microiter] ML atoms: {len(ml_indices)}, "
1638
+ f"Movable MM atoms: {len(movable_mm)}, "
1639
+ f"Frozen MM atoms: {len(frozen_mm)}"
1640
+ )
1641
+ click.echo(f"[microiter] Macro thresh: {macro_thresh}, Micro thresh: {micro_thresh}")
1642
+
1643
+ # Create ONIOM calculator (shared core for MM-only calc)
1644
+ macro_calc_cfg = dict(calc_cfg)
1645
+ macro_calc_cfg["freeze_atoms"] = macro_freeze
1646
+ macro_calc_cfg["hess_mm_atoms"] = [] # macro step は ML-only Hessian
1647
+ macro_calc = mlmm(**macro_calc_cfg)
1648
+ mm_calc = mlmm_mm_only(macro_calc.core, freeze_atoms=micro_freeze)
1649
+
1650
+ # Seed initial Hessian for RS-I-RFO (with macro freeze)
1651
+ # Try TS Hessian cache first; fall back to full Hessian calculation.
1652
+ from .hessian_cache import load as _hess_load_ts
1653
+ hess_device = _torch_device(calc_cfg.get("ml_device", "auto"))
1654
+
1655
+ cached_ts = _hess_load_ts("ts")
1656
+ if cached_ts is not None:
1657
+ click.echo("[microiter] Reusing cached TS Hessian for RS-I-RFO macro step.")
1658
+ active_dofs = cached_ts.get("active_dofs")
1659
+ h_raw = cached_ts["hessian"]
1660
+ if isinstance(h_raw, torch.Tensor):
1661
+ h_init = h_raw.clone()
1662
+ else:
1663
+ h_init = torch.as_tensor(h_raw, dtype=torch.float64)
1664
+ geometry.freeze_atoms = macro_freeze
1665
+ geometry.set_calculator(macro_calc)
1666
+ if active_dofs is not None:
1667
+ geometry.within_partial_hessian = {
1668
+ "active_n_dof": len(active_dofs),
1669
+ "full_n_dof": geometry.cart_coords.size,
1670
+ "active_dofs": active_dofs,
1671
+ "active_atoms": sorted(set(d // 3 for d in active_dofs)),
1672
+ }
1673
+ geometry.cart_hessian = h_init
1674
+ click.echo(f"[microiter] Initial Hessian seeded from cache (shape={h_init.shape[0]}x{h_init.shape[1]}).")
1675
+ del h_init
1676
+ else:
1677
+ click.echo("[microiter] Seeding initial Hessian for RS-I-RFO macro step.")
1678
+
1679
+ geometry.freeze_atoms = macro_freeze
1680
+ geometry.set_calculator(macro_calc)
1681
+
1682
+ h_init = _calc_full_hessian_torch(geometry, macro_calc_cfg, hess_device)
1683
+ geometry.cart_hessian = h_init
1684
+ click.echo(f"[microiter] Initial Hessian seeded (shape={h_init.shape[0]}x{h_init.shape[1]}).")
1685
+ del h_init
1686
+
1687
+ optim_all_path = out_dir_path / "optimization_all_trj.xyz"
1688
+ macro_trj_path = out_dir_path / "optimization_trj.xyz"
1689
+ total_macro_steps = 0
1690
+
1691
+ # Create persistent RSIRFOptimizer once (LayerOpt pattern).
1692
+ # This preserves the BFGS Hessian update chain across macro iterations.
1693
+ # NOTE: geometry already has macro_calc set (line above); do NOT call
1694
+ # set_calculator() again as it clears the pre-computed cart_hessian.
1695
+ geometry.freeze_atoms = macro_freeze
1696
+
1697
+ rsirfo_args = dict(rsirfo_cfg)
1698
+ rsirfo_args["max_cycles"] = max_cycles
1699
+ rsirfo_args["out_dir"] = str(out_dir_path)
1700
+ rsirfo_args["dump"] = False # trajectory dumping handled externally
1701
+ if macro_thresh is not None:
1702
+ rsirfo_args["thresh"] = str(macro_thresh)
1703
+ # RSIRFOptimizer does not accept RFOptimizer-specific DIIS knobs; strip them.
1704
+ for _diis_kw in ("gediis", "gdiis", "gdiis_thresh", "gediis_thresh", "gdiis_test_direction", "adapt_step_func"):
1705
+ rsirfo_args.pop(_diis_kw, None)
1706
+
1707
+ macro_optimizer = RSIRFOptimizer(geometry, **rsirfo_args)
1708
+ macro_optimizer.prepare_opt() # initialise Hessian from geometry.cart_hessian
1709
+
1710
+ # Microiteration progress table (pysisyphus-style with micro_steps column)
1711
+ micro_header = "cycle Δ(energy) max(|force|) rms(force) max(|step|) rms(step) micro_steps s/cycle".split()
1712
+ micro_col_fmts = "int float float float float float int float_short".split()
1713
+ micro_table = TablePrinter(micro_header, micro_col_fmts, width=12)
1714
+ micro_table.print_header()
1715
+
1716
+ for macro_iter in range(max_cycles):
1717
+ # ---- Macro step: 1 RS-I-RFO step with ONIOM forces, MM frozen ----
1718
+ geometry.freeze_atoms = macro_freeze
1719
+ geometry.set_calculator(macro_calc)
1720
+
1721
+ # Manually feed state to the persistent optimizer (cf. LayerOpt lines 358-364)
1722
+ macro_optimizer.coords.append(geometry.coords.copy())
1723
+ macro_optimizer.cart_coords.append(geometry.cart_coords.copy())
1724
+ macro_optimizer.cur_cycle = macro_iter
1725
+
1726
+ t_start = time.time()
1727
+ step = macro_optimizer.optimize() # housekeeping() triggers BFGS update
1728
+ macro_optimizer.steps.append(step)
1729
+
1730
+ # Convergence check
1731
+ macro_converged, conv_info = macro_optimizer.check_convergence()
1732
+ total_macro_steps += 1
1733
+
1734
+ if dump:
1735
+ with open(macro_trj_path, "a") as f:
1736
+ f.write(geometry.as_xyz() + "\n")
1737
+ _append_xyz_trajectory(optim_all_path, macro_trj_path)
1738
+
1739
+ if macro_converged:
1740
+ # Print final converged row (no micro steps)
1741
+ energy_diff = macro_optimizer.energies[-1] - macro_optimizer.energies[-2] if len(macro_optimizer.energies) >= 2 else float("nan")
1742
+ marks = [False, *conv_info.get_convergence()[:-1], False, False]
1743
+ cycle_time = time.time() - t_start
1744
+ micro_table.print_row(
1745
+ (macro_iter, energy_diff, macro_optimizer.max_forces[-1], macro_optimizer.rms_forces[-1],
1746
+ macro_optimizer.max_steps[-1], macro_optimizer.rms_steps[-1], 0, cycle_time),
1747
+ marks=marks,
1748
+ )
1749
+ click.echo(f"[microiter] Macro convergence reached at iteration {macro_iter + 1}.")
1750
+ break
1751
+
1752
+ # Apply step to geometry
1753
+ new_coords = geometry.coords.copy() + step
1754
+ geometry.coords = new_coords
1755
+ # Record actual step (may differ due to coordinate back-transformation)
1756
+ macro_optimizer.steps[-1] = geometry.coords - macro_optimizer.coords[-1]
1757
+
1758
+ # ---- Micro step: LBFGS with MM-only forces, ML frozen ----
1759
+ geometry.freeze_atoms = micro_freeze
1760
+ geometry.set_calculator(mm_calc)
1761
+
1762
+ micro_lbfgs_args = dict(lbfgs_cfg)
1763
+ micro_lbfgs_args["max_cycles"] = micro_max_cycles
1764
+ micro_lbfgs_args["thresh"] = micro_thresh
1765
+ micro_lbfgs_args["out_dir"] = str(out_dir_path)
1766
+ micro_lbfgs_args["dump"] = dump
1767
+
1768
+ micro_opt = LBFGS(geometry, **micro_lbfgs_args)
1769
+ with contextlib.redirect_stdout(io.StringIO()):
1770
+ micro_opt.run()
1771
+ micro_steps = max(int(micro_opt.cur_cycle) + 1, 1)
1772
+
1773
+ if dump:
1774
+ _append_xyz_trajectory(optim_all_path, out_dir_path / "optimization_trj.xyz")
1775
+
1776
+ del micro_opt
1777
+ _clear_cuda_cache()
1778
+
1779
+ # Print progress row with micro_steps
1780
+ cycle_time = time.time() - t_start
1781
+ energy_diff = macro_optimizer.energies[-1] - macro_optimizer.energies[-2] if len(macro_optimizer.energies) >= 2 else float("nan")
1782
+ marks = [False, *conv_info.get_convergence()[:-1], False, False]
1783
+ if (macro_iter > 1) and (macro_iter % 10 == 0):
1784
+ micro_table.print_sep()
1785
+ micro_table.print_row(
1786
+ (macro_iter, energy_diff, macro_optimizer.max_forces[-1], macro_optimizer.rms_forces[-1],
1787
+ macro_optimizer.max_steps[-1], macro_optimizer.rms_steps[-1], micro_steps, cycle_time),
1788
+ marks=marks,
1789
+ )
1790
+
1791
+ else:
1792
+ click.echo(f"[microiter] Reached max macro iterations ({max_cycles}).")
1793
+
1794
+ del macro_optimizer
1795
+ _clear_cuda_cache()
1796
+
1797
+ click.echo(f"[microiter] Total macro steps: {total_macro_steps}")
1798
+ # Restore full calculator with only frozen MM frozen
1799
+ geometry.freeze_atoms = list(set(frozen_mm))
1800
+ base_calc = mlmm(**calc_cfg)
1801
+ geometry.set_calculator(base_calc)
1802
+
1803
+ return geometry
1804
+
1805
+
1806
+ # ===================================================================
1807
+ # Defaults for CLI
1808
+ # ===================================================================
1809
+
1810
+ # Configuration defaults (imported from defaults.py)
1811
+ GEOM_KW: Dict[str, Any] = deepcopy(GEOM_KW_DEFAULT)
1812
+ CALC_KW: Dict[str, Any] = deepcopy(MLMM_CALC_KW)
1813
+
1814
+ # HessianDimer defaults - combine imported DIMER_KW and HESSIAN_DIMER_KW
1815
+ hessian_dimer_KW = {
1816
+ **HESSIAN_DIMER_KW,
1817
+ "dimer": {**DIMER_KW},
1818
+ "lbfgs": {**LBFGS_KW},
1819
+ }
1820
+
1821
+ # ===================================================================
1822
+ # CLI
1823
+ # ===================================================================
1824
+
1825
+ @click.command(
1826
+ help="TS optimization: grad (Dimer) or hess (RS-I-RFO) for the ML/MM calculator.",
1827
+ context_settings={"help_option_names": ["-h", "--help"]},
1828
+ )
1829
+ @click.option(
1830
+ "-i", "--input",
1831
+ "input_path",
1832
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
1833
+ required=True,
1834
+ help="Starting geometry (PDB or XYZ). XYZ provides higher coordinate precision. "
1835
+ "If XYZ, use --ref-pdb to specify PDB topology for atom ordering and output conversion.",
1836
+ )
1837
+ @click.option(
1838
+ "--ref-pdb",
1839
+ "ref_pdb",
1840
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
1841
+ default=None,
1842
+ show_default=False,
1843
+ help="Reference PDB topology when input is XYZ. XYZ coordinates are used (higher precision) "
1844
+ "while PDB provides atom ordering and residue information for output conversion.",
1845
+ )
1846
+ @click.option(
1847
+ "--parm",
1848
+ "real_parm7",
1849
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
1850
+ required=True,
1851
+ help="Amber parm7 topology for the whole enzyme (MM region).",
1852
+ )
1853
+ @click.option(
1854
+ "--model-pdb",
1855
+ "model_pdb",
1856
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
1857
+ required=False,
1858
+ help="PDB containing the ML-region atoms. Optional when --detect-layer is enabled.",
1859
+ )
1860
+ @click.option(
1861
+ "--model-indices",
1862
+ "model_indices_str",
1863
+ type=str,
1864
+ default=None,
1865
+ show_default=False,
1866
+ help="Comma-separated atom indices for the ML region (ranges allowed like 1-5). "
1867
+ "Used when --model-pdb is omitted.",
1868
+ )
1869
+ @click.option(
1870
+ "--model-indices-one-based/--model-indices-zero-based",
1871
+ "model_indices_one_based",
1872
+ default=True,
1873
+ show_default=True,
1874
+ help="Interpret --model-indices as 1-based (default) or 0-based.",
1875
+ )
1876
+ @click.option(
1877
+ "--detect-layer/--no-detect-layer",
1878
+ "detect_layer",
1879
+ default=True,
1880
+ show_default=True,
1881
+ help="Detect ML/MM layers from input PDB B-factors (B=0/10/20). "
1882
+ "If disabled, you must provide --model-pdb or --model-indices.",
1883
+ )
1884
+ @click.option(
1885
+ "-q",
1886
+ "--charge",
1887
+ type=int,
1888
+ required=False,
1889
+ help="Total charge of the ML region. Required unless --ligand-charge is provided.",
1890
+ )
1891
+ @click.option("-l", "--ligand-charge", type=str, default=None, show_default=False,
1892
+ help="Total charge or per-resname mapping (e.g., GPP:-3,SAM:1) used to derive "
1893
+ "charge when -q is omitted (requires PDB input or --ref-pdb).")
1894
+ @click.option(
1895
+ "-m",
1896
+ "--multiplicity",
1897
+ "spin",
1898
+ type=int,
1899
+ default=None,
1900
+ show_default=False,
1901
+ help="Spin multiplicity (2S+1) for the ML region.",
1902
+ )
1903
+ @click.option(
1904
+ "--freeze-atoms",
1905
+ "freeze_atoms_text",
1906
+ type=str,
1907
+ default=None,
1908
+ show_default=False,
1909
+ help="Comma-separated 1-based indices to freeze (e.g., '1,3,5').",
1910
+ )
1911
+ @click.option(
1912
+ "--radius-hessian",
1913
+ "--hess-cutoff",
1914
+ "hess_cutoff",
1915
+ type=float,
1916
+ default=0.0,
1917
+ show_default=True,
1918
+ help="Distance cutoff (Å) from ML region for MM atoms to include in Hessian calculation. "
1919
+ "Applied to movable MM atoms. Default 0.0 means ML-only partial Hessian.",
1920
+ )
1921
+ @click.option(
1922
+ "--movable-cutoff",
1923
+ "movable_cutoff",
1924
+ type=float,
1925
+ default=None,
1926
+ show_default=False,
1927
+ help="Distance cutoff (Å) from ML region for movable MM atoms. "
1928
+ "MM atoms beyond this are frozen. "
1929
+ "Providing --movable-cutoff disables --detect-layer.",
1930
+ )
1931
+ @click.option(
1932
+ "--hessian-calc-mode",
1933
+ type=click.Choice(["Analytical", "FiniteDifference"], case_sensitive=False),
1934
+ default=None,
1935
+ help="How the ML backend builds the Hessian (Analytical or FiniteDifference); "
1936
+ "overrides calc.hessian_calc_mode from YAML. "
1937
+ "Default: 'FiniteDifference'. Use 'Analytical' when VRAM is sufficient.",
1938
+ )
1939
+ @click.option("--max-cycles", type=int, default=10000, show_default=True, help="Maximum total optimization cycles.")
1940
+ @click.option(
1941
+ "--dump/--no-dump",
1942
+ default=False,
1943
+ show_default=True,
1944
+ help="Write concatenated trajectory 'optimization_all_trj.xyz'.",
1945
+ )
1946
+ @click.option("-o", "--out-dir", type=str, default=OUT_DIR_TSOPT, show_default=True, help="Output directory.")
1947
+ @click.option(
1948
+ "--thresh",
1949
+ type=click.Choice(["gau_loose", "gau", "gau_tight", "gau_vtight", "baker", "never"], case_sensitive=False),
1950
+ default=None,
1951
+ help="Convergence preset.",
1952
+ )
1953
+ @click.option(
1954
+ "--opt-mode",
1955
+ type=click.Choice(["grad", "hess", "light", "heavy", "dimer", "rsirfo"], case_sensitive=False),
1956
+ default="hess",
1957
+ show_default=True,
1958
+ help="grad (dimer) or hess (rsirfo). Aliases light/heavy and dimer/rsirfo are accepted.",
1959
+ )
1960
+ @click.option(
1961
+ "--microiter/--no-microiter",
1962
+ "microiter",
1963
+ default=True,
1964
+ show_default=True,
1965
+ help="Enable microiteration: alternate ML 1-step (RS-I-RFO) and MM relaxation (LBFGS with MM-only forces). "
1966
+ "Only effective in --opt-mode hess. Ignored in grad mode.",
1967
+ )
1968
+ @click.option(
1969
+ "--partial-hessian-flatten/--full-hessian-flatten",
1970
+ "partial_hessian_flatten",
1971
+ default=True,
1972
+ show_default=True,
1973
+ help="Use partial Hessian (ML region only) for imaginary mode detection in flatten loop.",
1974
+ )
1975
+ @click.option(
1976
+ "--flatten/--no-flatten",
1977
+ "flatten",
1978
+ default=None,
1979
+ show_default=False,
1980
+ help="Enable/disable extra imaginary-mode flattening loop. "
1981
+ "--flatten uses the default flatten_max_iter (50); --no-flatten forces it to 0. "
1982
+ "When not provided, the value is determined by the YAML config or defaults.",
1983
+ )
1984
+ @click.option(
1985
+ "--ml-only-hessian-dimer/--no-ml-only-hessian-dimer",
1986
+ "ml_only_hessian_dimer",
1987
+ default=False,
1988
+ show_default=True,
1989
+ help="Use ML-region-only Hessian (no MM Hessian contribution) for dimer orientation "
1990
+ "in grad mode. Faster but less accurate for mode direction.",
1991
+ )
1992
+ @click.option(
1993
+ "--active-dof-mode",
1994
+ type=click.Choice(["all", "ml-only", "partial", "unfrozen"], case_sensitive=False),
1995
+ default="partial",
1996
+ show_default=True,
1997
+ help="Active DOF selection for final frequency analysis: "
1998
+ "all (all atoms), ml-only (ML only), partial (ML + MovableMM, default), "
1999
+ "unfrozen (all except frozen layer).",
2000
+ )
2001
+ @click.option(
2002
+ "--config",
2003
+ "config_yaml",
2004
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
2005
+ default=None,
2006
+ help="Base YAML configuration file applied before explicit CLI options.",
2007
+ )
2008
+ @click.option(
2009
+ "--show-config/--no-show-config",
2010
+ "show_config",
2011
+ default=False,
2012
+ show_default=True,
2013
+ help="Print resolved configuration and continue execution.",
2014
+ )
2015
+ @click.option(
2016
+ "--dry-run/--no-dry-run",
2017
+ "dry_run",
2018
+ default=False,
2019
+ show_default=True,
2020
+ help="Validate options and print the execution plan without running TS optimization.",
2021
+ )
2022
+ @click.option(
2023
+ "--convert-files/--no-convert-files",
2024
+ "convert_files",
2025
+ default=True,
2026
+ show_default=True,
2027
+ help="Convert XYZ/TRJ outputs into PDB companions based on the input format.",
2028
+ )
2029
+ @click.option(
2030
+ "-b", "--backend",
2031
+ type=click.Choice(["uma", "orb", "mace", "aimnet2"], case_sensitive=False),
2032
+ default=None,
2033
+ show_default=False,
2034
+ help="ML backend for the ONIOM high-level region (default: uma).",
2035
+ )
2036
+ @click.option(
2037
+ "--embedcharge/--no-embedcharge",
2038
+ "embedcharge",
2039
+ default=False,
2040
+ show_default=True,
2041
+ help="Enable xTB point-charge embedding correction for MM→ML environmental effects.",
2042
+ )
2043
+ @click.option(
2044
+ "--embedcharge-cutoff",
2045
+ "embedcharge_cutoff",
2046
+ type=float,
2047
+ default=None,
2048
+ show_default=False,
2049
+ help="Distance cutoff (Å) from ML region for MM point charges in xTB embedding. "
2050
+ "Default: 12.0 Å when --embedcharge is enabled.",
2051
+ )
2052
+ @click.pass_context
2053
+ def cli(
2054
+ ctx: click.Context,
2055
+ input_path: Path,
2056
+ ref_pdb: Optional[Path],
2057
+ real_parm7: Path,
2058
+ model_pdb: Optional[Path],
2059
+ model_indices_str: Optional[str],
2060
+ model_indices_one_based: bool,
2061
+ detect_layer: bool,
2062
+ charge: Optional[int],
2063
+ ligand_charge: Optional[str],
2064
+ spin: Optional[int],
2065
+ freeze_atoms_text: Optional[str],
2066
+ hess_cutoff: Optional[float],
2067
+ movable_cutoff: Optional[float],
2068
+ hessian_calc_mode: Optional[str],
2069
+ max_cycles: int,
2070
+ dump: bool,
2071
+ out_dir: str,
2072
+ thresh: Optional[str],
2073
+ opt_mode: str,
2074
+ microiter: bool,
2075
+ partial_hessian_flatten: bool,
2076
+ flatten: Optional[bool],
2077
+ ml_only_hessian_dimer: bool,
2078
+ active_dof_mode: str,
2079
+ config_yaml: Optional[Path],
2080
+ show_config: bool,
2081
+ dry_run: bool,
2082
+ convert_files: bool,
2083
+ backend: Optional[str],
2084
+ embedcharge: bool,
2085
+ embedcharge_cutoff: Optional[float],
2086
+ ) -> None:
2087
+ set_convert_file_enabled(convert_files)
2088
+ _is_param_explicit = make_is_param_explicit(ctx)
2089
+
2090
+ config_yaml, override_yaml, used_legacy_yaml = resolve_yaml_sources(
2091
+ config_yaml=config_yaml,
2092
+ override_yaml=None,
2093
+ args_yaml_legacy=None,
2094
+ )
2095
+ merged_yaml_cfg, _, _ = load_merged_yaml_cfg(
2096
+ config_yaml=config_yaml,
2097
+ override_yaml=None,
2098
+ )
2099
+
2100
+ # Handle input: PDB directly, or XYZ with --ref-pdb for topology
2101
+ suffix = input_path.suffix.lower()
2102
+ if suffix == ".pdb":
2103
+ # PDB input: use directly
2104
+ prepared_input = prepare_input_structure(input_path)
2105
+ elif suffix == ".xyz":
2106
+ # XYZ input: require --ref-pdb for topology
2107
+ if ref_pdb is None:
2108
+ click.echo("ERROR: XYZ/TRJ input requires --ref-pdb to specify PDB topology.", err=True)
2109
+ sys.exit(1)
2110
+ prepared_input = prepare_input_structure(input_path)
2111
+ apply_ref_pdb_override(prepared_input, ref_pdb)
2112
+ click.echo(f"[input] Using XYZ coordinates from {input_path.name}, PDB topology from {ref_pdb.name}")
2113
+ else:
2114
+ click.echo(f"ERROR: Unsupported input format: {suffix}. Use .pdb or .xyz (with --ref-pdb).", err=True)
2115
+ sys.exit(1)
2116
+
2117
+ geom_input_path = prepared_input.geom_path
2118
+ source_path = prepared_input.source_path
2119
+ charge, spin = resolve_charge_spin_or_raise(
2120
+ prepared_input, charge, spin,
2121
+ ligand_charge=ligand_charge, prefix="[tsopt]",
2122
+ )
2123
+
2124
+ try:
2125
+ freeze_atoms_cli = _parse_freeze_atoms_opt(freeze_atoms_text)
2126
+ except click.BadParameter as e:
2127
+ click.echo(f"ERROR: {e}", err=True)
2128
+ prepared_input.cleanup()
2129
+ sys.exit(1)
2130
+
2131
+ model_indices: Optional[List[int]] = None
2132
+ if model_indices_str:
2133
+ try:
2134
+ model_indices = parse_indices_string(model_indices_str, one_based=model_indices_one_based)
2135
+ except click.BadParameter as e:
2136
+ click.echo(f"ERROR: {e}", err=True)
2137
+ prepared_input.cleanup()
2138
+ sys.exit(1)
2139
+
2140
+ time_start = time.perf_counter()
2141
+
2142
+ # Resolve optimizer mode (default is now hess/RS-I-RFO)
2143
+ mode_resolved = normalize_choice(
2144
+ opt_mode,
2145
+ param="--opt-mode",
2146
+ alias_groups=TSOPT_MODE_ALIASES,
2147
+ allowed_hint="grad|hess|dimer|rsirfo",
2148
+ )
2149
+ use_heavy = (mode_resolved == "rsirfo")
2150
+
2151
+ config_layer_cfg = load_yaml_dict(config_yaml)
2152
+ override_layer_cfg = load_yaml_dict(override_yaml)
2153
+ geom_cfg: Dict[str, Any] = deepcopy(GEOM_KW)
2154
+ calc_cfg: Dict[str, Any] = deepcopy(CALC_KW)
2155
+ opt_cfg: Dict[str, Any] = dict(OPT_BASE_KW)
2156
+ lbfgs_cfg: Dict[str, Any] = dict(LBFGS_KW)
2157
+ simple_cfg: Dict[str, Any] = dict(hessian_dimer_KW)
2158
+ rsirfo_cfg: Dict[str, Any] = dict(RSIRFO_KW)
2159
+
2160
+ apply_yaml_overrides(
2161
+ config_layer_cfg,
2162
+ [
2163
+ (geom_cfg, (("geom",),)),
2164
+ (calc_cfg, (("calc",), ("mlmm",))),
2165
+ (opt_cfg, (("opt",),)),
2166
+ (simple_cfg, (("hessian_dimer",),)),
2167
+ (rsirfo_cfg, (("rsirfo",),)),
2168
+ ],
2169
+ )
2170
+ if _is_param_explicit("hessian_calc_mode") and hessian_calc_mode is not None:
2171
+ calc_cfg["hessian_calc_mode"] = str(hessian_calc_mode)
2172
+ if _is_param_explicit("max_cycles"):
2173
+ opt_cfg["max_cycles"] = int(max_cycles)
2174
+ if _is_param_explicit("dump"):
2175
+ opt_cfg["dump"] = bool(dump)
2176
+ if _is_param_explicit("out_dir"):
2177
+ opt_cfg["out_dir"] = out_dir
2178
+ if _is_param_explicit("thresh") and thresh is not None:
2179
+ opt_cfg["thresh"] = str(thresh)
2180
+ simple_cfg["thresh"] = str(thresh)
2181
+ rsirfo_cfg["thresh"] = str(thresh)
2182
+ # Handle --flatten/--no-flatten CLI toggle
2183
+ if flatten is not None:
2184
+ if flatten:
2185
+ # Use default from HESSIAN_DIMER_KW if not already set
2186
+ simple_cfg.setdefault("flatten_max_iter", HESSIAN_DIMER_KW["flatten_max_iter"])
2187
+ else:
2188
+ simple_cfg["flatten_max_iter"] = 0
2189
+ if _is_param_explicit("detect_layer"):
2190
+ calc_cfg["use_bfactor_layers"] = bool(detect_layer)
2191
+ if _is_param_explicit("hess_cutoff") and hess_cutoff is not None:
2192
+ calc_cfg["hess_cutoff"] = float(hess_cutoff)
2193
+ if _is_param_explicit("movable_cutoff") and movable_cutoff is not None:
2194
+ calc_cfg["movable_cutoff"] = float(movable_cutoff)
2195
+ calc_cfg["use_bfactor_layers"] = False
2196
+
2197
+ model_charge_value = calc_cfg.get("model_charge", charge)
2198
+ if model_charge_value is None:
2199
+ model_charge_value = charge
2200
+ calc_cfg["model_charge"] = int(model_charge_value)
2201
+ if _is_param_explicit("charge"):
2202
+ calc_cfg["model_charge"] = int(charge)
2203
+
2204
+ model_mult_value = calc_cfg.get("model_mult", spin)
2205
+ if model_mult_value is None:
2206
+ model_mult_value = spin
2207
+ calc_cfg["model_mult"] = int(model_mult_value)
2208
+ if _is_param_explicit("spin"):
2209
+ calc_cfg["model_mult"] = int(spin)
2210
+
2211
+ if model_pdb is not None:
2212
+ calc_cfg["model_pdb"] = str(model_pdb)
2213
+ calc_cfg["input_pdb"] = str(source_path)
2214
+ calc_cfg["real_parm7"] = str(real_parm7)
2215
+
2216
+ if backend is not None:
2217
+ calc_cfg["backend"] = str(backend).lower()
2218
+ if _is_param_explicit("embedcharge"):
2219
+ calc_cfg["embedcharge"] = bool(embedcharge)
2220
+ if _is_param_explicit("embedcharge_cutoff"):
2221
+ calc_cfg["embedcharge_cutoff"] = embedcharge_cutoff
2222
+
2223
+ apply_yaml_overrides(
2224
+ override_layer_cfg,
2225
+ [
2226
+ (geom_cfg, (("geom",),)),
2227
+ (calc_cfg, (("calc",), ("mlmm",))),
2228
+ (opt_cfg, (("opt",),)),
2229
+ (simple_cfg, (("hessian_dimer",),)),
2230
+ (rsirfo_cfg, (("rsirfo",),)),
2231
+ ],
2232
+ )
2233
+ calc_paths = (("calc",), ("mlmm",))
2234
+ partial_explicit = (
2235
+ yaml_section_has_key(config_layer_cfg, calc_paths, "return_partial_hessian")
2236
+ or yaml_section_has_key(override_layer_cfg, calc_paths, "return_partial_hessian")
2237
+ )
2238
+ if not partial_explicit:
2239
+ calc_cfg["return_partial_hessian"] = True
2240
+
2241
+ # Resolve microiteration config from YAML
2242
+ microiter_cfg = dict(MICROITER_KW)
2243
+ apply_yaml_overrides(
2244
+ config_layer_cfg,
2245
+ [(microiter_cfg, (("microiter",),))],
2246
+ )
2247
+ apply_yaml_overrides(
2248
+ override_layer_cfg,
2249
+ [(microiter_cfg, (("microiter",),))],
2250
+ )
2251
+
2252
+ use_microiter = bool(microiter) and use_heavy
2253
+ if bool(microiter) and not use_heavy:
2254
+ click.echo("[microiter] --microiter is only effective with --opt-mode hess (RS-I-RFO). Ignoring.")
2255
+
2256
+ try:
2257
+ geom_freeze = _normalize_geom_freeze_opt(geom_cfg.get("freeze_atoms"))
2258
+ except click.BadParameter as e:
2259
+ click.echo(f"ERROR: {e}", err=True)
2260
+ prepared_input.cleanup()
2261
+ sys.exit(1)
2262
+ geom_cfg["freeze_atoms"] = geom_freeze
2263
+ if freeze_atoms_cli:
2264
+ merge_freeze_atom_indices(geom_cfg, freeze_atoms_cli)
2265
+ freeze_atoms_final = list(geom_cfg.get("freeze_atoms") or [])
2266
+ calc_cfg["freeze_atoms"] = freeze_atoms_final
2267
+
2268
+ # Propagate opt.print_every only when it is explicitly different from the
2269
+ # base default. This avoids clobbering optimizer-specific YAML settings
2270
+ # (e.g. hessian_dimer.lbfgs.print_every / rsirfo.print_every) with the
2271
+ # inherited OPT_BASE default value.
2272
+ try:
2273
+ pe_opt = int(opt_cfg.get("print_every", OPT_BASE_KW.get("print_every", 100)))
2274
+ pe_base = int(OPT_BASE_KW.get("print_every", 100))
2275
+ if pe_opt >= 1 and pe_opt != pe_base:
2276
+ simple_cfg.setdefault("lbfgs", {})
2277
+ simple_cfg["lbfgs"]["print_every"] = pe_opt
2278
+ rsirfo_cfg["print_every"] = pe_opt
2279
+ except Exception:
2280
+ logger.debug("Failed to configure print_every", exc_info=True)
2281
+
2282
+ out_dir_path = Path(opt_cfg["out_dir"]).resolve()
2283
+
2284
+ # movable_cutoff implies full distance-based layer assignment.
2285
+ # hess_cutoff alone is allowed with detect-layer and is applied on movable MM atoms.
2286
+ detect_layer_enabled = bool(calc_cfg.get("use_bfactor_layers", True))
2287
+ model_pdb_cfg = calc_cfg.get("model_pdb")
2288
+ if calc_cfg.get("movable_cutoff") is not None:
2289
+ if detect_layer_enabled:
2290
+ click.echo("[layer] movable_cutoff is set; disabling --detect-layer.", err=True)
2291
+ detect_layer_enabled = False
2292
+ calc_cfg["use_bfactor_layers"] = False
2293
+
2294
+ layer_source_pdb = source_path
2295
+ if detect_layer_enabled and layer_source_pdb.suffix.lower() != ".pdb":
2296
+ click.echo("ERROR: --detect-layer requires a PDB input (or --ref-pdb).", err=True)
2297
+ prepared_input.cleanup()
2298
+ sys.exit(1)
2299
+
2300
+ if show_config:
2301
+ click.echo(
2302
+ pretty_block(
2303
+ "yaml_layers",
2304
+ {
2305
+ "config": None if config_yaml is None else str(config_yaml),
2306
+ "override_yaml": None if override_yaml is None else str(override_yaml),
2307
+ "merged_keys": sorted(merged_yaml_cfg.keys()),
2308
+ },
2309
+ )
2310
+ )
2311
+
2312
+ if dry_run:
2313
+ model_region_source = "bfactor"
2314
+ if not detect_layer_enabled:
2315
+ if model_pdb_cfg is not None:
2316
+ model_region_source = "model_pdb"
2317
+ elif model_indices:
2318
+ model_region_source = "model_indices"
2319
+ else:
2320
+ click.echo("ERROR: Provide --model-pdb or --model-indices when --no-detect-layer.", err=True)
2321
+ prepared_input.cleanup()
2322
+ sys.exit(1)
2323
+ if (
2324
+ not detect_layer_enabled
2325
+ and model_pdb_cfg is None
2326
+ and model_indices
2327
+ and layer_source_pdb.suffix.lower() != ".pdb"
2328
+ ):
2329
+ click.echo("ERROR: --model-indices requires a PDB input (or --ref-pdb).", err=True)
2330
+ prepared_input.cleanup()
2331
+ sys.exit(1)
2332
+ click.echo(
2333
+ pretty_block(
2334
+ "dry_run_plan",
2335
+ {
2336
+ "input_geometry": str(geom_input_path),
2337
+ "output_dir": str(out_dir_path),
2338
+ "optimizer_mode": ("hess-rsirfo" if use_heavy else "grad-dimer"),
2339
+ "detect_layer": bool(detect_layer_enabled),
2340
+ "model_region_source": model_region_source,
2341
+ "model_indices_count": 0 if not model_indices else len(model_indices),
2342
+ "hessian_calc_mode": calc_cfg.get("hessian_calc_mode"),
2343
+ "partial_hessian_flatten": bool(partial_hessian_flatten),
2344
+ "active_dof_mode": str(active_dof_mode),
2345
+ "will_run_tsopt": True,
2346
+ "will_write_summary": True,
2347
+ "backend": calc_cfg.get("backend", "uma"),
2348
+ "embedcharge": bool(calc_cfg.get("embedcharge", False)),
2349
+ },
2350
+ )
2351
+ )
2352
+ click.echo("[dry-run] Validation complete. TS optimization execution was skipped.")
2353
+ prepared_input.cleanup()
2354
+ return
2355
+
2356
+ model_pdb_path: Optional[Path] = None
2357
+ layer_info: Optional[Dict[str, List[int]]] = None
2358
+
2359
+ if detect_layer_enabled:
2360
+ try:
2361
+ model_pdb_path, layer_info = build_model_pdb_from_bfactors(layer_source_pdb, out_dir_path)
2362
+ calc_cfg["use_bfactor_layers"] = True
2363
+ click.echo(
2364
+ f"[layer] Detected B-factor layers: ML={len(layer_info.get('ml_indices', []))}, "
2365
+ f"MovableMM={len(layer_info.get('movable_mm_indices', []))}, "
2366
+ f"FrozenMM={len(layer_info.get('frozen_indices', []))}"
2367
+ )
2368
+ except Exception as e:
2369
+ if model_pdb_cfg is None and not model_indices:
2370
+ click.echo(f"ERROR: {e}", err=True)
2371
+ prepared_input.cleanup()
2372
+ sys.exit(1)
2373
+ click.echo(f"[layer] WARNING: {e} Falling back to explicit ML region.", err=True)
2374
+ detect_layer_enabled = False
2375
+
2376
+ if not detect_layer_enabled:
2377
+ if model_pdb_cfg is None and not model_indices:
2378
+ click.echo("ERROR: Provide --model-pdb or --model-indices when --no-detect-layer.", err=True)
2379
+ prepared_input.cleanup()
2380
+ sys.exit(1)
2381
+ if model_pdb_cfg is not None:
2382
+ model_pdb_path = Path(model_pdb_cfg)
2383
+ else:
2384
+ if layer_source_pdb.suffix.lower() != ".pdb":
2385
+ click.echo("ERROR: --model-indices requires a PDB input (or --ref-pdb).", err=True)
2386
+ prepared_input.cleanup()
2387
+ sys.exit(1)
2388
+ try:
2389
+ model_pdb_path = build_model_pdb_from_indices(layer_source_pdb, out_dir_path, model_indices or [])
2390
+ except Exception as e:
2391
+ click.echo(f"ERROR: {e}", err=True)
2392
+ prepared_input.cleanup()
2393
+ sys.exit(1)
2394
+ calc_cfg["use_bfactor_layers"] = False
2395
+
2396
+ if model_pdb_path is None:
2397
+ click.echo("ERROR: Failed to resolve model PDB for the ML region.", err=True)
2398
+ prepared_input.cleanup()
2399
+ sys.exit(1)
2400
+
2401
+ calc_cfg["model_pdb"] = str(model_pdb_path)
2402
+ freeze_atoms_final = apply_layer_freeze_constraints(
2403
+ geom_cfg,
2404
+ calc_cfg,
2405
+ layer_info,
2406
+ echo_fn=click.echo,
2407
+ )
2408
+ _align_three_layer_hessian_targets(calc_cfg, echo_fn=click.echo)
2409
+
2410
+ for key in ("input_pdb", "real_parm7", "model_pdb", "mm_fd_dir"):
2411
+ val = calc_cfg.get(key)
2412
+ if val:
2413
+ calc_cfg[key] = str(Path(val).expanduser().resolve())
2414
+
2415
+ # Pretty-print config summary (only non-default values for concise logging)
2416
+ mode_desc = "RS-I-RFO (hess)" if use_heavy else "Dimer (grad)"
2417
+ if use_microiter:
2418
+ mode_desc += " + Microiteration"
2419
+ click.echo(f"\n[mode] TS Optimizer: {mode_desc}\n")
2420
+ click.echo(pretty_block("geom", format_freeze_atoms_for_echo(geom_cfg, key="freeze_atoms")))
2421
+ echo_calc = format_freeze_atoms_for_echo(filter_calc_for_echo(calc_cfg), key="freeze_atoms")
2422
+ click.echo(pretty_block("calc", echo_calc))
2423
+ echo_opt = strip_inherited_keys({**opt_cfg, "out_dir": str(out_dir_path)}, OPT_BASE_KW, mode="same")
2424
+ click.echo(pretty_block("opt", echo_opt))
2425
+ # Show only optimizer-specific settings, not inherited from opt_cfg
2426
+ if use_heavy:
2427
+ echo_rsirfo = strip_inherited_keys(rsirfo_cfg, opt_cfg)
2428
+ click.echo(pretty_block("rsirfo", echo_rsirfo))
2429
+ else:
2430
+ sd_cfg_for_echo: Dict[str, Any] = {}
2431
+ sd_cfg_for_echo["dimer"] = dict(simple_cfg.get("dimer", {}))
2432
+ sd_cfg_for_echo["lbfgs"] = strip_inherited_keys(
2433
+ dict(simple_cfg.get("lbfgs", {})), opt_cfg
2434
+ )
2435
+ click.echo(pretty_block("hessian_dimer", sd_cfg_for_echo))
2436
+
2437
+ # --------------------------
2438
+ # 2) Prepare geometry dir
2439
+ # --------------------------
2440
+ out_dir_path.mkdir(parents=True, exist_ok=True)
2441
+
2442
+ # --------------------------
2443
+ # 3) Run
2444
+ # --------------------------
2445
+ try:
2446
+ if use_heavy:
2447
+ # Heavy mode: RS-I-RFO with full Hessian
2448
+ rsirfo_label = "RS-I-RFO heavy mode"
2449
+ if use_microiter:
2450
+ rsirfo_label += " + Microiteration"
2451
+ optim_all_path = out_dir_path / "optimization_all_trj.xyz"
2452
+ if bool(opt_cfg["dump"]) and optim_all_path.exists():
2453
+ optim_all_path.unlink()
2454
+
2455
+ coord_type = geom_cfg.get("coord_type", "cart")
2456
+ coord_kwargs = dict(geom_cfg)
2457
+ coord_kwargs.pop("coord_type", None)
2458
+ geometry = geom_loader(
2459
+ geom_input_path,
2460
+ coord_type=coord_type,
2461
+ **coord_kwargs,
2462
+ )
2463
+
2464
+ if use_microiter:
2465
+ # --- Microiteration path ---
2466
+ click.echo(f"\n=== TS optimization ({rsirfo_label}) started ===\n")
2467
+ _run_microiter_tsopt(
2468
+ geometry,
2469
+ calc_cfg,
2470
+ rsirfo_cfg,
2471
+ lbfgs_cfg,
2472
+ opt_cfg,
2473
+ microiter_cfg,
2474
+ out_dir_path,
2475
+ dump=bool(opt_cfg["dump"]),
2476
+ thresh=thresh,
2477
+ )
2478
+ click.echo(f"\n=== TS optimization ({rsirfo_label}) finished ===\n")
2479
+
2480
+ # Write final geometry
2481
+ final_xyz = out_dir_path / "final_geometry.xyz"
2482
+ final_xyz.write_text(geometry.as_xyz(), encoding="utf-8")
2483
+
2484
+ # For post-analysis, get hess_active_atoms from a fresh calc
2485
+ _temp_calc = mlmm(**calc_cfg)
2486
+ _temp_core = _temp_calc.core if hasattr(_temp_calc, "core") else _temp_calc
2487
+ hess_active_atoms = list(getattr(_temp_core, "hess_active_atoms", []))
2488
+ del _temp_calc, _temp_core
2489
+ _clear_cuda_cache()
2490
+ _rsirfo_cycles_spent = int(opt_cfg.get("max_cycles", 10000)) # budget consumed
2491
+ else:
2492
+ # --- Standard RS-I-RFO path ---
2493
+ click.echo(f"\n=== TS optimization ({rsirfo_label}) started ===\n")
2494
+
2495
+ base_calc = mlmm(**calc_cfg)
2496
+ geometry.set_calculator(base_calc)
2497
+
2498
+ click.echo("[tsopt] Seeding initial Hessian via shared freq backend.")
2499
+ hess_device = _torch_device(simple_cfg.get("device", calc_cfg.get("ml_device", "auto")))
2500
+ h_init = _calc_full_hessian_torch(geometry, calc_cfg, hess_device)
2501
+ geometry.cart_hessian = h_init
2502
+ click.echo(
2503
+ f"[tsopt] Initial Hessian seeded (shape={h_init.shape[0]}x{h_init.shape[1]})."
2504
+ )
2505
+ del h_init
2506
+
2507
+ rsirfo_args = {**rsirfo_cfg}
2508
+ rsirfo_args["out_dir"] = str(out_dir_path)
2509
+ rsirfo_args["max_cycles"] = int(opt_cfg["max_cycles"])
2510
+ rsirfo_args["dump"] = bool(opt_cfg["dump"])
2511
+ if thresh is not None:
2512
+ rsirfo_args["thresh"] = str(thresh)
2513
+ # RSIRFOptimizer does not accept RFOptimizer-specific DIIS knobs; strip them.
2514
+ for _diis_kw in ("gediis", "gdiis", "gdiis_thresh", "gediis_thresh", "gdiis_test_direction", "adapt_step_func"):
2515
+ rsirfo_args.pop(_diis_kw, None)
2516
+
2517
+ calc_core = base_calc.core if hasattr(base_calc, "core") else base_calc
2518
+ hess_active_atoms = list(getattr(calc_core, "hess_active_atoms", []))
2519
+ optimizer = RSIRFOptimizer(geometry, **rsirfo_args)
2520
+ optimizer.run()
2521
+ if bool(opt_cfg["dump"]):
2522
+ _append_xyz_trajectory(optim_all_path, out_dir_path / "optimization_trj.xyz")
2523
+
2524
+ click.echo(f"\n=== TS optimization ({rsirfo_label}) finished ===\n")
2525
+
2526
+ # --- Post-RSIRFO: count imaginary modes and optional flatten loop ---
2527
+ # Save cycle count before deleting optimizer for budget check.
2528
+ _rsirfo_cycles_spent = getattr(optimizer, "cur_cycle", 0) + 1
2529
+ geometry.set_calculator(None)
2530
+ del optimizer
2531
+ del calc_core
2532
+ del base_calc
2533
+ _clear_cuda_cache()
2534
+ mlmm_kwargs_for_heavy = dict(calc_cfg)
2535
+ mlmm_kwargs_for_heavy["out_hess_torch"] = True
2536
+ device = _torch_device(simple_cfg.get("device", calc_cfg.get("ml_device", "auto")))
2537
+
2538
+ # Determine active atoms for frequency analysis based on --active-dof-mode.
2539
+ active_atoms_freq = _get_active_dof_indices(
2540
+ calc_cfg, len(geometry.atomic_numbers), active_dof_mode, freeze_atoms_final
2541
+ )
2542
+ n_atoms = len(geometry.atomic_numbers)
2543
+ if active_atoms_freq is not None:
2544
+ active_set = set(active_atoms_freq)
2545
+ freeze_atoms_freq = [i for i in range(n_atoms) if i not in active_set]
2546
+ else:
2547
+ freeze_atoms_freq = freeze_atoms_final if freeze_atoms_final else None
2548
+
2549
+ def _calc_freqs_and_modes() -> Tuple[np.ndarray, torch.Tensor]:
2550
+ H, _energy_ha = _freq_calc_full_hessian_torch(
2551
+ geometry, mlmm_kwargs_for_heavy, device, refresh_geom_meta=True,
2552
+ )
2553
+ from .hessian_cache import store as _hess_store
2554
+ # Determine active_dofs for partial Hessian
2555
+ _n_full_dofs = 3 * len(geometry.atomic_numbers)
2556
+ if H.shape[0] < _n_full_dofs:
2557
+ _freeze_atoms = list(calc_cfg.get("freeze_atoms", []))
2558
+ _all_dofs = set(range(_n_full_dofs))
2559
+ _frozen_dofs = set()
2560
+ for _idx in _freeze_atoms:
2561
+ _frozen_dofs.update([3 * _idx, 3 * _idx + 1, 3 * _idx + 2])
2562
+ _active_dofs = sorted(_all_dofs - _frozen_dofs)
2563
+ else:
2564
+ _active_dofs = None
2565
+ _hess_store("ts", H, active_dofs=_active_dofs, meta={"energy_ha": _energy_ha})
2566
+ n_full = len(geometry.atomic_numbers)
2567
+ if H.shape[0] != 3 * n_full:
2568
+ # Partial Hessian: use Hessian-target atoms only and embed modes back.
2569
+ active_atoms = None
2570
+ if getattr(geometry, "within_partial_hessian", None) is not None:
2571
+ active_atoms = geometry.within_partial_hessian.get("active_atoms")
2572
+ if active_atoms is None:
2573
+ active_atoms = hess_active_atoms
2574
+ if active_atoms is None:
2575
+ active_atoms = active_atoms_freq
2576
+ if active_atoms is None:
2577
+ active_atoms = []
2578
+ else:
2579
+ active_atoms = [int(i) for i in np.asarray(active_atoms, dtype=int).reshape(-1).tolist()]
2580
+ if not active_atoms:
2581
+ raise RuntimeError(
2582
+ "No active atoms available for partial Hessian frequency analysis."
2583
+ )
2584
+ coords_act = geometry.cart_coords.reshape(-1, 3)[active_atoms]
2585
+ nums_act = np.asarray(geometry.atomic_numbers)[active_atoms]
2586
+ freqs_local, modes_act = _frequencies_cm_and_modes(
2587
+ H,
2588
+ nums_act,
2589
+ coords_act,
2590
+ device,
2591
+ freeze_idx=None,
2592
+ )
2593
+ # Embed modes into full 3N on CPU to reduce VRAM peak.
2594
+ modes_local = torch.zeros((modes_act.shape[0], 3 * n_full), dtype=modes_act.dtype, device="cpu")
2595
+ mask_dof = torch.as_tensor(_mask_dof_from_active_idx(n_full, active_atoms), dtype=torch.bool)
2596
+ modes_local[:, mask_dof] = modes_act.detach().cpu()
2597
+ del coords_act, nums_act, modes_act, mask_dof
2598
+ else:
2599
+ freqs_local, modes_gpu = _frequencies_cm_and_modes(
2600
+ H,
2601
+ geometry.atomic_numbers,
2602
+ geometry.cart_coords.reshape(-1, 3),
2603
+ device,
2604
+ freeze_idx=freeze_atoms_freq,
2605
+ )
2606
+ modes_local = modes_gpu.detach().cpu()
2607
+ del modes_gpu
2608
+ del H
2609
+ _clear_cuda_cache()
2610
+ return freqs_local, modes_local
2611
+
2612
+ try:
2613
+ freqs_cm, modes = _calc_freqs_and_modes()
2614
+ except Exception as exc:
2615
+ is_oom = isinstance(exc, torch.OutOfMemoryError) or ("cuda out of memory" in str(exc).lower())
2616
+ if is_oom:
2617
+ click.echo(
2618
+ "[tsopt] WARNING: CUDA OOM during final frequency analysis; "
2619
+ "skipping imaginary-mode analysis/flatten loop.",
2620
+ err=True,
2621
+ )
2622
+ _clear_cuda_cache()
2623
+ freqs_cm, modes = None, None
2624
+ else:
2625
+ raise
2626
+ neg_freq_thresh_cm = float(simple_cfg.get("neg_freq_thresh_cm", 5.0))
2627
+
2628
+ if freqs_cm is not None and modes is not None:
2629
+ neg_mask = freqs_cm < -abs(neg_freq_thresh_cm)
2630
+ n_imag = int(np.sum(neg_mask))
2631
+ ims = [float(x) for x in freqs_cm if x < -abs(neg_freq_thresh_cm)]
2632
+ click.echo(f"[Imaginary modes] n={n_imag} ({ims})")
2633
+
2634
+ flatten_max_iter = int(simple_cfg.get("flatten_max_iter", 0))
2635
+ user_max_cycles = int(opt_cfg.get("max_cycles", 10000))
2636
+ budget_remaining = user_max_cycles - _rsirfo_cycles_spent > 0
2637
+
2638
+ if flatten_max_iter > 0 and n_imag > 1 and not budget_remaining:
2639
+ click.echo("[tsopt] Reached --max-cycles budget; skipping flatten loop.")
2640
+ elif flatten_max_iter > 0 and n_imag > 1 and budget_remaining:
2641
+ click.echo("[flatten] Extra imaginary modes detected; starting RS-I-RFO flatten loop.")
2642
+ masses_amu = np.array([atomic_masses[z] for z in geometry.atomic_numbers])
2643
+ main_root = int(simple_cfg.get("root", 0))
2644
+
2645
+ for it in range(flatten_max_iter):
2646
+ click.echo(f"[flatten] RS-I-RFO iteration {it + 1}/{flatten_max_iter}")
2647
+ did_flatten = _flatten_once_with_modes_for_geom(
2648
+ geometry,
2649
+ masses_amu,
2650
+ mlmm_kwargs_for_heavy,
2651
+ freqs_cm,
2652
+ modes,
2653
+ neg_freq_thresh_cm,
2654
+ float(simple_cfg.get("flatten_amp_ang", 0.10)),
2655
+ float(simple_cfg.get("flatten_sep_cutoff", 0.0)),
2656
+ int(simple_cfg.get("flatten_k", 10)),
2657
+ main_root,
2658
+ )
2659
+ if not did_flatten:
2660
+ click.echo("[flatten] No eligible modes to flatten; stopping.")
2661
+ break
2662
+
2663
+ del freqs_cm, modes
2664
+ _clear_cuda_cache()
2665
+ base_calc = mlmm(**calc_cfg)
2666
+ geometry.set_calculator(base_calc)
2667
+ optimizer = RSIRFOptimizer(geometry, **rsirfo_args)
2668
+ click.echo("\n=== TS optimization (RS-I-RFO) restarted ===\n")
2669
+ optimizer.run()
2670
+ click.echo("\n=== TS optimization (RS-I-RFO) finished ===\n")
2671
+ if bool(opt_cfg["dump"]):
2672
+ _append_xyz_trajectory(optim_all_path, out_dir_path / "optimization_trj.xyz")
2673
+ geometry.set_calculator(None)
2674
+ del optimizer, base_calc
2675
+ _clear_cuda_cache()
2676
+
2677
+ try:
2678
+ freqs_cm, modes = _calc_freqs_and_modes()
2679
+ except Exception as exc:
2680
+ is_oom = isinstance(exc, torch.OutOfMemoryError) or ("cuda out of memory" in str(exc).lower())
2681
+ if is_oom:
2682
+ click.echo(
2683
+ "[tsopt] WARNING: CUDA OOM during final frequency analysis; "
2684
+ "stopping flatten loop.",
2685
+ err=True,
2686
+ )
2687
+ _clear_cuda_cache()
2688
+ freqs_cm, modes = None, None
2689
+ break
2690
+ raise
2691
+
2692
+ neg_mask = freqs_cm < -abs(neg_freq_thresh_cm)
2693
+ n_imag = int(np.sum(neg_mask))
2694
+ ims = [float(x) for x in freqs_cm if x < -abs(neg_freq_thresh_cm)]
2695
+ click.echo(f"[Imaginary modes] n={n_imag} ({ims})")
2696
+ if n_imag <= 1:
2697
+ break
2698
+
2699
+ if freqs_cm is not None and modes is not None:
2700
+ # --- Write all final imaginary modes like light mode ---
2701
+ vib_dir = out_dir_path / "vib"
2702
+ vib_dir.mkdir(parents=True, exist_ok=True)
2703
+ _ref_pdb_for_modes = source_path if source_path.suffix.lower() == ".pdb" else None
2704
+ n_written = _write_all_imag_modes(
2705
+ geometry,
2706
+ freqs_cm,
2707
+ modes,
2708
+ neg_freq_thresh_cm,
2709
+ vib_dir,
2710
+ ref_pdb=_ref_pdb_for_modes,
2711
+ )
2712
+ if n_written == 0:
2713
+ click.echo("[INFO] No imaginary mode found at the end for RS-I-RFO.")
2714
+ else:
2715
+ click.echo(f"[DONE] Wrote {n_written} final imaginary mode(s).")
2716
+ click.echo(f"[DONE] Mode files → {vib_dir}")
2717
+ else:
2718
+ click.echo("[INFO] Skipped final imaginary-mode export due to frequency-analysis fallback.")
2719
+
2720
+ if modes is not None:
2721
+ del modes
2722
+ if freqs_cm is not None:
2723
+ del freqs_cm
2724
+ _clear_cuda_cache()
2725
+
2726
+ # Ensure final_geometry.xyz exists (partial-micro path may not write it).
2727
+ final_xyz = out_dir_path / "final_geometry.xyz"
2728
+ if not final_xyz.exists():
2729
+ final_xyz.write_text(geometry.as_xyz(), encoding="utf-8")
2730
+
2731
+ else:
2732
+ # Light mode: Partial Hessian guided Dimer
2733
+ runner = HessianDimer(
2734
+ fn=str(geom_input_path),
2735
+ out_dir=str(out_dir_path),
2736
+ thresh_loose=simple_cfg.get("thresh_loose", "gau_loose"),
2737
+ thresh=simple_cfg.get("thresh", "gau"),
2738
+ update_interval_hessian=int(simple_cfg.get("update_interval_hessian", 500)),
2739
+ neg_freq_thresh_cm=float(simple_cfg.get("neg_freq_thresh_cm", 5.0)),
2740
+ flatten_amp_ang=float(simple_cfg.get("flatten_amp_ang", 0.10)),
2741
+ flatten_max_iter=int(simple_cfg.get("flatten_max_iter", 20)),
2742
+ mem=int(simple_cfg.get("mem", 100000)),
2743
+ use_lobpcg=bool(simple_cfg.get("use_lobpcg", True)),
2744
+ calc_kwargs=dict(calc_cfg),
2745
+ device=str(simple_cfg.get("device", calc_cfg.get("ml_device", "auto"))),
2746
+ dump=bool(opt_cfg["dump"]),
2747
+ root=int(simple_cfg.get("root", 0)),
2748
+ dimer_kwargs=dict(simple_cfg.get("dimer", {})),
2749
+ lbfgs_kwargs=dict(simple_cfg.get("lbfgs", {})),
2750
+ max_total_cycles=int(opt_cfg["max_cycles"]),
2751
+ geom_kwargs=dict(geom_cfg),
2752
+ partial_hessian_flatten=partial_hessian_flatten,
2753
+ flatten_sep_cutoff=float(simple_cfg.get("flatten_sep_cutoff", 0.0)),
2754
+ flatten_k=int(simple_cfg.get("flatten_k", 10)),
2755
+ flatten_loop_bofill=bool(simple_cfg.get("flatten_loop_bofill", False)),
2756
+ ml_only_hessian_dimer=bool(simple_cfg.get("ml_only_hessian_dimer", ml_only_hessian_dimer)),
2757
+ source_path=source_path,
2758
+ )
2759
+
2760
+ click.echo("\n=== TS optimization (Partial Hessian Dimer) started ===\n")
2761
+ runner.run()
2762
+ click.echo("\n=== TS optimization (Partial Hessian Dimer) finished ===\n")
2763
+
2764
+ if is_convert_file_enabled() and source_path.suffix.lower() == ".pdb":
2765
+ ref_pdb = source_path.resolve()
2766
+ final_xyz = out_dir_path / "final_geometry.xyz"
2767
+ final_pdb = out_dir_path / "final_geometry.pdb"
2768
+
2769
+ # Get layer indices for B-factor annotation
2770
+ # For heavy mode, base_calc is available; for light mode, create temporary calc
2771
+ layer_indices = None
2772
+ if use_heavy and 'base_calc' in dir():
2773
+ calc_core = base_calc.core if hasattr(base_calc, 'core') else base_calc
2774
+ layer_indices = {
2775
+ "ml": getattr(calc_core, 'ml_indices', None),
2776
+ "hess_mm": getattr(calc_core, 'hess_mm_indices', None),
2777
+ "movable_mm": getattr(calc_core, 'movable_mm_indices', None),
2778
+ "frozen": getattr(calc_core, 'frozen_layer_indices', None),
2779
+ }
2780
+ else:
2781
+ # For light mode, create a temporary calculator to get layer indices
2782
+ try:
2783
+ temp_calc = mlmm(**calc_cfg)
2784
+ calc_core = temp_calc.core if hasattr(temp_calc, 'core') else temp_calc
2785
+ layer_indices = {
2786
+ "ml": getattr(calc_core, 'ml_indices', None),
2787
+ "hess_mm": getattr(calc_core, 'hess_mm_indices', None),
2788
+ "movable_mm": getattr(calc_core, 'movable_mm_indices', None),
2789
+ "frozen": getattr(calc_core, 'frozen_layer_indices', None),
2790
+ }
2791
+ del temp_calc
2792
+ except Exception:
2793
+ layer_indices = None
2794
+
2795
+ try:
2796
+ convert_xyz_to_pdb(final_xyz, ref_pdb, final_pdb)
2797
+ click.echo(f"[convert] Wrote '{final_pdb}'.")
2798
+
2799
+ # Annotate B-factors with layer-based encoding
2800
+ if layer_indices and layer_indices.get("ml") is not None:
2801
+ update_pdb_bfactors_from_layers(
2802
+ final_pdb,
2803
+ ml_indices=layer_indices["ml"] or [],
2804
+ hess_mm_indices=layer_indices.get("hess_mm"),
2805
+ movable_mm_indices=layer_indices.get("movable_mm"),
2806
+ frozen_indices=layer_indices.get("frozen"),
2807
+ )
2808
+ click.echo(
2809
+ f"[annot] B-factors set in '{final_pdb}' "
2810
+ f"(ML={BFACTOR_ML:.0f}, MovableMM={BFACTOR_MOVABLE_MM:.0f}, "
2811
+ f"FrozenMM={BFACTOR_FROZEN:.0f})."
2812
+ )
2813
+ except Exception as e:
2814
+ click.echo(f"[convert] WARNING: Failed to convert final geometry to PDB: {e}", err=True)
2815
+
2816
+ all_trj = out_dir_path / "optimization_all_trj.xyz"
2817
+ if all_trj.exists():
2818
+ try:
2819
+ opt_pdb = out_dir_path / "optimization_all.pdb"
2820
+ convert_xyz_to_pdb(all_trj, ref_pdb, opt_pdb)
2821
+ click.echo(f"[convert] Wrote '{opt_pdb}'.")
2822
+
2823
+ # Annotate B-factors with layer-based encoding
2824
+ if layer_indices and layer_indices.get("ml") is not None:
2825
+ update_pdb_bfactors_from_layers(
2826
+ opt_pdb,
2827
+ ml_indices=layer_indices["ml"] or [],
2828
+ hess_mm_indices=layer_indices.get("hess_mm"),
2829
+ movable_mm_indices=layer_indices.get("movable_mm"),
2830
+ frozen_indices=layer_indices.get("frozen"),
2831
+ )
2832
+ click.echo(
2833
+ f"[annot] B-factors set in '{opt_pdb}' "
2834
+ f"(ML={BFACTOR_ML:.0f}, MovableMM={BFACTOR_MOVABLE_MM:.0f}, "
2835
+ f"FrozenMM={BFACTOR_FROZEN:.0f})."
2836
+ )
2837
+ except Exception as e:
2838
+ click.echo(f"[convert] WARNING: Failed to convert optimization trajectory to PDB: {e}", err=True)
2839
+ else:
2840
+ final_xyz = out_dir_path / "final_geometry.xyz"
2841
+
2842
+ # summary.md and key_* outputs are disabled.
2843
+ click.echo(format_elapsed("[time] Elapsed Time for TS Opt", time_start))
2844
+
2845
+ except ZeroStepLength:
2846
+ click.echo("ERROR: Proposed step length dropped below the minimum allowed (ZeroStepLength).", err=True)
2847
+ sys.exit(2)
2848
+ except OptimizationError as e:
2849
+ click.echo(f"ERROR: Optimization failed — {e}", err=True)
2850
+ sys.exit(3)
2851
+ except KeyboardInterrupt:
2852
+ click.echo("\nInterrupted by user.", err=True)
2853
+ sys.exit(130)
2854
+ except Exception as e:
2855
+ import traceback
2856
+ tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))
2857
+ click.echo("Unhandled error during optimization:\n" + textwrap.indent(tb, " "), err=True)
2858
+ sys.exit(1)
2859
+ finally:
2860
+ prepared_input.cleanup()
2861
+ # Release GPU memory (model + Hessian) so subsequent stages don't OOM
2862
+ base_calc = geometry = optimizer = last_optimizer = None
2863
+ macro_calc = macro_optimizer = mm_calc = None
2864
+ gc.collect() # break cyclic refs inside torch.nn.Module
2865
+ if torch.cuda.is_available():
2866
+ torch.cuda.empty_cache()
2867
+
2868
+
2869
+ # Allow `python -m mlmm.tsopt` direct execution
2870
+ if __name__ == "__main__":
2871
+ cli()