mlmm-toolkit 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (372) hide show
  1. hessian_ff/__init__.py +50 -0
  2. hessian_ff/analytical_hessian.py +609 -0
  3. hessian_ff/constants.py +46 -0
  4. hessian_ff/forcefield.py +339 -0
  5. hessian_ff/loaders.py +608 -0
  6. hessian_ff/native/Makefile +8 -0
  7. hessian_ff/native/__init__.py +28 -0
  8. hessian_ff/native/analytical_hessian.py +88 -0
  9. hessian_ff/native/analytical_hessian_ext.cpp +258 -0
  10. hessian_ff/native/bonded.py +82 -0
  11. hessian_ff/native/bonded_ext.cpp +640 -0
  12. hessian_ff/native/loader.py +349 -0
  13. hessian_ff/native/nonbonded.py +118 -0
  14. hessian_ff/native/nonbonded_ext.cpp +1150 -0
  15. hessian_ff/prmtop_parmed.py +23 -0
  16. hessian_ff/system.py +107 -0
  17. hessian_ff/terms/__init__.py +14 -0
  18. hessian_ff/terms/angle.py +73 -0
  19. hessian_ff/terms/bond.py +44 -0
  20. hessian_ff/terms/cmap.py +406 -0
  21. hessian_ff/terms/dihedral.py +141 -0
  22. hessian_ff/terms/nonbonded.py +209 -0
  23. hessian_ff/tests/__init__.py +0 -0
  24. hessian_ff/tests/conftest.py +75 -0
  25. hessian_ff/tests/data/small/complex.parm7 +1346 -0
  26. hessian_ff/tests/data/small/complex.pdb +125 -0
  27. hessian_ff/tests/data/small/complex.rst7 +63 -0
  28. hessian_ff/tests/test_coords_input.py +44 -0
  29. hessian_ff/tests/test_energy_force.py +49 -0
  30. hessian_ff/tests/test_hessian.py +137 -0
  31. hessian_ff/tests/test_smoke.py +18 -0
  32. hessian_ff/tests/test_validation.py +40 -0
  33. hessian_ff/workflows.py +889 -0
  34. mlmm/__init__.py +36 -0
  35. mlmm/__main__.py +7 -0
  36. mlmm/_version.py +34 -0
  37. mlmm/add_elem_info.py +374 -0
  38. mlmm/advanced_help.py +91 -0
  39. mlmm/align_freeze_atoms.py +601 -0
  40. mlmm/all.py +3535 -0
  41. mlmm/bond_changes.py +231 -0
  42. mlmm/bool_compat.py +223 -0
  43. mlmm/cli.py +574 -0
  44. mlmm/cli_utils.py +166 -0
  45. mlmm/default_group.py +337 -0
  46. mlmm/defaults.py +467 -0
  47. mlmm/define_layer.py +526 -0
  48. mlmm/dft.py +1041 -0
  49. mlmm/energy_diagram.py +253 -0
  50. mlmm/extract.py +2213 -0
  51. mlmm/fix_altloc.py +464 -0
  52. mlmm/freq.py +1406 -0
  53. mlmm/harmonic_constraints.py +140 -0
  54. mlmm/hessian_cache.py +44 -0
  55. mlmm/hessian_calc.py +174 -0
  56. mlmm/irc.py +638 -0
  57. mlmm/mlmm_calc.py +2262 -0
  58. mlmm/mm_parm.py +945 -0
  59. mlmm/oniom_export.py +1983 -0
  60. mlmm/oniom_import.py +457 -0
  61. mlmm/opt.py +1742 -0
  62. mlmm/path_opt.py +1353 -0
  63. mlmm/path_search.py +2299 -0
  64. mlmm/preflight.py +88 -0
  65. mlmm/py.typed +1 -0
  66. mlmm/pysis_runner.py +45 -0
  67. mlmm/scan.py +1047 -0
  68. mlmm/scan2d.py +1226 -0
  69. mlmm/scan3d.py +1265 -0
  70. mlmm/scan_common.py +184 -0
  71. mlmm/summary_log.py +736 -0
  72. mlmm/trj2fig.py +448 -0
  73. mlmm/tsopt.py +2871 -0
  74. mlmm/utils.py +2309 -0
  75. mlmm/xtb_embedcharge_correction.py +475 -0
  76. mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
  77. mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
  78. mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
  79. mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
  80. mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
  81. mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
  82. pysisyphus/Geometry.py +1667 -0
  83. pysisyphus/LICENSE +674 -0
  84. pysisyphus/TableFormatter.py +63 -0
  85. pysisyphus/TablePrinter.py +74 -0
  86. pysisyphus/__init__.py +12 -0
  87. pysisyphus/calculators/AFIR.py +452 -0
  88. pysisyphus/calculators/AnaPot.py +20 -0
  89. pysisyphus/calculators/AnaPot2.py +48 -0
  90. pysisyphus/calculators/AnaPot3.py +12 -0
  91. pysisyphus/calculators/AnaPot4.py +20 -0
  92. pysisyphus/calculators/AnaPotBase.py +337 -0
  93. pysisyphus/calculators/AnaPotCBM.py +25 -0
  94. pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
  95. pysisyphus/calculators/CFOUR.py +250 -0
  96. pysisyphus/calculators/Calculator.py +844 -0
  97. pysisyphus/calculators/CerjanMiller.py +24 -0
  98. pysisyphus/calculators/Composite.py +123 -0
  99. pysisyphus/calculators/ConicalIntersection.py +171 -0
  100. pysisyphus/calculators/DFTBp.py +430 -0
  101. pysisyphus/calculators/DFTD3.py +66 -0
  102. pysisyphus/calculators/DFTD4.py +84 -0
  103. pysisyphus/calculators/Dalton.py +61 -0
  104. pysisyphus/calculators/Dimer.py +681 -0
  105. pysisyphus/calculators/Dummy.py +20 -0
  106. pysisyphus/calculators/EGO.py +76 -0
  107. pysisyphus/calculators/EnergyMin.py +224 -0
  108. pysisyphus/calculators/ExternalPotential.py +264 -0
  109. pysisyphus/calculators/FakeASE.py +35 -0
  110. pysisyphus/calculators/FourWellAnaPot.py +28 -0
  111. pysisyphus/calculators/FreeEndNEBPot.py +39 -0
  112. pysisyphus/calculators/Gaussian09.py +18 -0
  113. pysisyphus/calculators/Gaussian16.py +726 -0
  114. pysisyphus/calculators/HardSphere.py +159 -0
  115. pysisyphus/calculators/IDPPCalculator.py +49 -0
  116. pysisyphus/calculators/IPIClient.py +133 -0
  117. pysisyphus/calculators/IPIServer.py +234 -0
  118. pysisyphus/calculators/LEPSBase.py +24 -0
  119. pysisyphus/calculators/LEPSExpr.py +139 -0
  120. pysisyphus/calculators/LennardJones.py +80 -0
  121. pysisyphus/calculators/MOPAC.py +219 -0
  122. pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
  123. pysisyphus/calculators/MultiCalc.py +85 -0
  124. pysisyphus/calculators/NFK.py +45 -0
  125. pysisyphus/calculators/OBabel.py +87 -0
  126. pysisyphus/calculators/ONIOMv2.py +1129 -0
  127. pysisyphus/calculators/ORCA.py +893 -0
  128. pysisyphus/calculators/ORCA5.py +6 -0
  129. pysisyphus/calculators/OpenMM.py +88 -0
  130. pysisyphus/calculators/OpenMolcas.py +281 -0
  131. pysisyphus/calculators/OverlapCalculator.py +908 -0
  132. pysisyphus/calculators/Psi4.py +218 -0
  133. pysisyphus/calculators/PyPsi4.py +37 -0
  134. pysisyphus/calculators/PySCF.py +341 -0
  135. pysisyphus/calculators/PyXTB.py +73 -0
  136. pysisyphus/calculators/QCEngine.py +106 -0
  137. pysisyphus/calculators/Rastrigin.py +22 -0
  138. pysisyphus/calculators/Remote.py +76 -0
  139. pysisyphus/calculators/Rosenbrock.py +15 -0
  140. pysisyphus/calculators/SocketCalc.py +97 -0
  141. pysisyphus/calculators/TIP3P.py +111 -0
  142. pysisyphus/calculators/TransTorque.py +161 -0
  143. pysisyphus/calculators/Turbomole.py +965 -0
  144. pysisyphus/calculators/VRIPot.py +37 -0
  145. pysisyphus/calculators/WFOWrapper.py +333 -0
  146. pysisyphus/calculators/WFOWrapper2.py +341 -0
  147. pysisyphus/calculators/XTB.py +418 -0
  148. pysisyphus/calculators/__init__.py +81 -0
  149. pysisyphus/calculators/cosmo_data.py +139 -0
  150. pysisyphus/calculators/parser.py +150 -0
  151. pysisyphus/color.py +19 -0
  152. pysisyphus/config.py +133 -0
  153. pysisyphus/constants.py +65 -0
  154. pysisyphus/cos/AdaptiveNEB.py +230 -0
  155. pysisyphus/cos/ChainOfStates.py +725 -0
  156. pysisyphus/cos/FreeEndNEB.py +25 -0
  157. pysisyphus/cos/FreezingString.py +103 -0
  158. pysisyphus/cos/GrowingChainOfStates.py +71 -0
  159. pysisyphus/cos/GrowingNT.py +309 -0
  160. pysisyphus/cos/GrowingString.py +508 -0
  161. pysisyphus/cos/NEB.py +189 -0
  162. pysisyphus/cos/SimpleZTS.py +64 -0
  163. pysisyphus/cos/__init__.py +22 -0
  164. pysisyphus/cos/stiffness.py +199 -0
  165. pysisyphus/drivers/__init__.py +17 -0
  166. pysisyphus/drivers/afir.py +855 -0
  167. pysisyphus/drivers/barriers.py +271 -0
  168. pysisyphus/drivers/birkholz.py +138 -0
  169. pysisyphus/drivers/cluster.py +318 -0
  170. pysisyphus/drivers/diabatization.py +133 -0
  171. pysisyphus/drivers/merge.py +368 -0
  172. pysisyphus/drivers/merge_mol2.py +322 -0
  173. pysisyphus/drivers/opt.py +375 -0
  174. pysisyphus/drivers/perf.py +91 -0
  175. pysisyphus/drivers/pka.py +52 -0
  176. pysisyphus/drivers/precon_pos_rot.py +669 -0
  177. pysisyphus/drivers/rates.py +480 -0
  178. pysisyphus/drivers/replace.py +219 -0
  179. pysisyphus/drivers/scan.py +212 -0
  180. pysisyphus/drivers/spectrum.py +166 -0
  181. pysisyphus/drivers/thermo.py +31 -0
  182. pysisyphus/dynamics/Gaussian.py +103 -0
  183. pysisyphus/dynamics/__init__.py +20 -0
  184. pysisyphus/dynamics/colvars.py +136 -0
  185. pysisyphus/dynamics/driver.py +297 -0
  186. pysisyphus/dynamics/helpers.py +256 -0
  187. pysisyphus/dynamics/lincs.py +105 -0
  188. pysisyphus/dynamics/mdp.py +364 -0
  189. pysisyphus/dynamics/rattle.py +121 -0
  190. pysisyphus/dynamics/thermostats.py +128 -0
  191. pysisyphus/dynamics/wigner.py +266 -0
  192. pysisyphus/elem_data.py +3473 -0
  193. pysisyphus/exceptions.py +2 -0
  194. pysisyphus/filtertrj.py +69 -0
  195. pysisyphus/helpers.py +623 -0
  196. pysisyphus/helpers_pure.py +649 -0
  197. pysisyphus/init_logging.py +50 -0
  198. pysisyphus/intcoords/Bend.py +69 -0
  199. pysisyphus/intcoords/Bend2.py +25 -0
  200. pysisyphus/intcoords/BondedFragment.py +32 -0
  201. pysisyphus/intcoords/Cartesian.py +41 -0
  202. pysisyphus/intcoords/CartesianCoords.py +140 -0
  203. pysisyphus/intcoords/Coords.py +56 -0
  204. pysisyphus/intcoords/DLC.py +197 -0
  205. pysisyphus/intcoords/DistanceFunction.py +34 -0
  206. pysisyphus/intcoords/DummyImproper.py +70 -0
  207. pysisyphus/intcoords/DummyTorsion.py +72 -0
  208. pysisyphus/intcoords/LinearBend.py +105 -0
  209. pysisyphus/intcoords/LinearDisplacement.py +80 -0
  210. pysisyphus/intcoords/OutOfPlane.py +59 -0
  211. pysisyphus/intcoords/PrimTypes.py +286 -0
  212. pysisyphus/intcoords/Primitive.py +137 -0
  213. pysisyphus/intcoords/RedundantCoords.py +659 -0
  214. pysisyphus/intcoords/RobustTorsion.py +59 -0
  215. pysisyphus/intcoords/Rotation.py +147 -0
  216. pysisyphus/intcoords/Stretch.py +31 -0
  217. pysisyphus/intcoords/Torsion.py +101 -0
  218. pysisyphus/intcoords/Torsion2.py +25 -0
  219. pysisyphus/intcoords/Translation.py +45 -0
  220. pysisyphus/intcoords/__init__.py +61 -0
  221. pysisyphus/intcoords/augment_bonds.py +126 -0
  222. pysisyphus/intcoords/derivatives.py +10512 -0
  223. pysisyphus/intcoords/eval.py +80 -0
  224. pysisyphus/intcoords/exceptions.py +37 -0
  225. pysisyphus/intcoords/findiffs.py +48 -0
  226. pysisyphus/intcoords/generate_derivatives.py +414 -0
  227. pysisyphus/intcoords/helpers.py +235 -0
  228. pysisyphus/intcoords/logging_conf.py +10 -0
  229. pysisyphus/intcoords/mp_derivatives.py +10836 -0
  230. pysisyphus/intcoords/setup.py +962 -0
  231. pysisyphus/intcoords/setup_fast.py +176 -0
  232. pysisyphus/intcoords/update.py +272 -0
  233. pysisyphus/intcoords/valid.py +89 -0
  234. pysisyphus/interpolate/Geodesic.py +93 -0
  235. pysisyphus/interpolate/IDPP.py +55 -0
  236. pysisyphus/interpolate/Interpolator.py +116 -0
  237. pysisyphus/interpolate/LST.py +70 -0
  238. pysisyphus/interpolate/Redund.py +152 -0
  239. pysisyphus/interpolate/__init__.py +9 -0
  240. pysisyphus/interpolate/helpers.py +34 -0
  241. pysisyphus/io/__init__.py +22 -0
  242. pysisyphus/io/aomix.py +178 -0
  243. pysisyphus/io/cjson.py +24 -0
  244. pysisyphus/io/crd.py +101 -0
  245. pysisyphus/io/cube.py +220 -0
  246. pysisyphus/io/fchk.py +184 -0
  247. pysisyphus/io/hdf5.py +49 -0
  248. pysisyphus/io/hessian.py +72 -0
  249. pysisyphus/io/mol2.py +146 -0
  250. pysisyphus/io/molden.py +293 -0
  251. pysisyphus/io/orca.py +189 -0
  252. pysisyphus/io/pdb.py +269 -0
  253. pysisyphus/io/psf.py +79 -0
  254. pysisyphus/io/pubchem.py +31 -0
  255. pysisyphus/io/qcschema.py +34 -0
  256. pysisyphus/io/sdf.py +29 -0
  257. pysisyphus/io/xyz.py +61 -0
  258. pysisyphus/io/zmat.py +175 -0
  259. pysisyphus/irc/DWI.py +108 -0
  260. pysisyphus/irc/DampedVelocityVerlet.py +134 -0
  261. pysisyphus/irc/Euler.py +22 -0
  262. pysisyphus/irc/EulerPC.py +345 -0
  263. pysisyphus/irc/GonzalezSchlegel.py +187 -0
  264. pysisyphus/irc/IMKMod.py +164 -0
  265. pysisyphus/irc/IRC.py +878 -0
  266. pysisyphus/irc/IRCDummy.py +10 -0
  267. pysisyphus/irc/Instanton.py +307 -0
  268. pysisyphus/irc/LQA.py +53 -0
  269. pysisyphus/irc/ModeKill.py +136 -0
  270. pysisyphus/irc/ParamPlot.py +53 -0
  271. pysisyphus/irc/RK4.py +36 -0
  272. pysisyphus/irc/__init__.py +31 -0
  273. pysisyphus/irc/initial_displ.py +219 -0
  274. pysisyphus/linalg.py +411 -0
  275. pysisyphus/line_searches/Backtracking.py +88 -0
  276. pysisyphus/line_searches/HagerZhang.py +184 -0
  277. pysisyphus/line_searches/LineSearch.py +232 -0
  278. pysisyphus/line_searches/StrongWolfe.py +108 -0
  279. pysisyphus/line_searches/__init__.py +9 -0
  280. pysisyphus/line_searches/interpol.py +15 -0
  281. pysisyphus/modefollow/NormalMode.py +40 -0
  282. pysisyphus/modefollow/__init__.py +10 -0
  283. pysisyphus/modefollow/davidson.py +199 -0
  284. pysisyphus/modefollow/lanczos.py +95 -0
  285. pysisyphus/optimizers/BFGS.py +99 -0
  286. pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
  287. pysisyphus/optimizers/ConjugateGradient.py +98 -0
  288. pysisyphus/optimizers/CubicNewton.py +75 -0
  289. pysisyphus/optimizers/FIRE.py +113 -0
  290. pysisyphus/optimizers/HessianOptimizer.py +1176 -0
  291. pysisyphus/optimizers/LBFGS.py +228 -0
  292. pysisyphus/optimizers/LayerOpt.py +411 -0
  293. pysisyphus/optimizers/MicroOptimizer.py +169 -0
  294. pysisyphus/optimizers/NCOptimizer.py +90 -0
  295. pysisyphus/optimizers/Optimizer.py +1084 -0
  296. pysisyphus/optimizers/PreconLBFGS.py +260 -0
  297. pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
  298. pysisyphus/optimizers/QuickMin.py +74 -0
  299. pysisyphus/optimizers/RFOptimizer.py +181 -0
  300. pysisyphus/optimizers/RSA.py +99 -0
  301. pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
  302. pysisyphus/optimizers/SteepestDescent.py +23 -0
  303. pysisyphus/optimizers/StringOptimizer.py +173 -0
  304. pysisyphus/optimizers/__init__.py +41 -0
  305. pysisyphus/optimizers/closures.py +301 -0
  306. pysisyphus/optimizers/cls_map.py +58 -0
  307. pysisyphus/optimizers/exceptions.py +6 -0
  308. pysisyphus/optimizers/gdiis.py +280 -0
  309. pysisyphus/optimizers/guess_hessians.py +311 -0
  310. pysisyphus/optimizers/hessian_updates.py +355 -0
  311. pysisyphus/optimizers/poly_fit.py +285 -0
  312. pysisyphus/optimizers/precon.py +153 -0
  313. pysisyphus/optimizers/restrict_step.py +24 -0
  314. pysisyphus/pack.py +172 -0
  315. pysisyphus/peakdetect.py +948 -0
  316. pysisyphus/plot.py +1031 -0
  317. pysisyphus/run.py +2106 -0
  318. pysisyphus/socket_helper.py +74 -0
  319. pysisyphus/stocastic/FragmentKick.py +132 -0
  320. pysisyphus/stocastic/Kick.py +81 -0
  321. pysisyphus/stocastic/Pipeline.py +303 -0
  322. pysisyphus/stocastic/__init__.py +21 -0
  323. pysisyphus/stocastic/align.py +127 -0
  324. pysisyphus/testing.py +96 -0
  325. pysisyphus/thermo.py +156 -0
  326. pysisyphus/trj.py +824 -0
  327. pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
  328. pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
  329. pysisyphus/tsoptimizers/TRIM.py +59 -0
  330. pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
  331. pysisyphus/tsoptimizers/__init__.py +23 -0
  332. pysisyphus/wavefunction/Basis.py +239 -0
  333. pysisyphus/wavefunction/DIIS.py +76 -0
  334. pysisyphus/wavefunction/__init__.py +25 -0
  335. pysisyphus/wavefunction/build_ext.py +42 -0
  336. pysisyphus/wavefunction/cart2sph.py +190 -0
  337. pysisyphus/wavefunction/diabatization.py +304 -0
  338. pysisyphus/wavefunction/excited_states.py +435 -0
  339. pysisyphus/wavefunction/gen_ints.py +1811 -0
  340. pysisyphus/wavefunction/helpers.py +104 -0
  341. pysisyphus/wavefunction/ints/__init__.py +0 -0
  342. pysisyphus/wavefunction/ints/boys.py +193 -0
  343. pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
  344. pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
  345. pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
  346. pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
  347. pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
  348. pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
  349. pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
  350. pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
  351. pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
  352. pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
  353. pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
  354. pysisyphus/wavefunction/localization.py +458 -0
  355. pysisyphus/wavefunction/multipole.py +159 -0
  356. pysisyphus/wavefunction/normalization.py +36 -0
  357. pysisyphus/wavefunction/pop_analysis.py +134 -0
  358. pysisyphus/wavefunction/shells.py +1171 -0
  359. pysisyphus/wavefunction/wavefunction.py +504 -0
  360. pysisyphus/wrapper/__init__.py +11 -0
  361. pysisyphus/wrapper/exceptions.py +2 -0
  362. pysisyphus/wrapper/jmol.py +120 -0
  363. pysisyphus/wrapper/mwfn.py +169 -0
  364. pysisyphus/wrapper/packmol.py +71 -0
  365. pysisyphus/xyzloader.py +168 -0
  366. pysisyphus/yaml_mods.py +45 -0
  367. thermoanalysis/LICENSE +674 -0
  368. thermoanalysis/QCData.py +244 -0
  369. thermoanalysis/__init__.py +0 -0
  370. thermoanalysis/config.py +3 -0
  371. thermoanalysis/constants.py +20 -0
  372. thermoanalysis/thermo.py +1011 -0
mlmm/all.py ADDED
@@ -0,0 +1,3535 @@
1
+ # mlmm/all.py
2
+
3
+ """
4
+ End-to-end enzymatic reaction workflow: extract, scan, MEP, TS, IRC, freq, DFT.
5
+
6
+ Example:
7
+ mlmm all -i R.pdb P.pdb -c 'GPP,MMT' -l 'GPP:-3,MMT:-1'
8
+
9
+ For detailed documentation, see: docs/all.md
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from collections import defaultdict
15
+ from pathlib import Path
16
+ from typing import List, Sequence, Optional, Tuple, Dict, Any
17
+ import shutil
18
+ import tempfile
19
+
20
+ import gc
21
+ import logging
22
+ import sys
23
+ import math
24
+ import click
25
+ from .cli_utils import make_is_param_explicit
26
+ import time # timing
27
+ import yaml
28
+ import numpy as np
29
+ import torch
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # Biopython for PDB parsing (post-processing helpers)
34
+ from Bio import PDB
35
+
36
+ # pysisyphus helpers/constants
37
+ from pysisyphus.helpers import geom_loader
38
+ from pysisyphus.constants import BOHR2ANG, AU2KCALPERMOL
39
+
40
+ # Local imports from the package
41
+ from .extract import extract_api, compute_charge_summary, log_charge_summary
42
+ from . import path_search as _path_search
43
+ from . import path_opt as _path_opt
44
+ from . import opt as _opt_cli
45
+ from . import tsopt as _ts_opt
46
+ from . import freq as _freq_cli
47
+ from . import dft as _dft_cli
48
+ from . import irc as _irc_cli
49
+
50
+ from .trj2fig import run_trj2fig
51
+ from .summary_log import write_summary_log
52
+ from .utils import (
53
+ apply_ref_pdb_override,
54
+ build_energy_diagram,
55
+ close_matplotlib_figures,
56
+ deep_update,
57
+ ensure_dir,
58
+ format_elapsed,
59
+ parse_xyz_block,
60
+ prepare_input_structure,
61
+ collect_single_option_values,
62
+ load_yaml_dict,
63
+ load_pdb_atom_metadata,
64
+ parse_scan_list_triples,
65
+ read_xyz_as_blocks,
66
+ read_xyz_first_last,
67
+ xyz_blocks_first_last,
68
+ )
69
+ from .cli_utils import resolve_yaml_sources, load_merged_yaml_cfg
70
+ from .preflight import validate_existing_files, ensure_commands_available
71
+ from . import scan as _scan_cli
72
+ from .add_elem_info import assign_elements as _assign_elem_info
73
+ from .define_layer import define_layers as _define_layers
74
+ from .mlmm_calc import mlmm as _mlmm_calc
75
+ from .mm_parm import (
76
+ Args as _AutoMMArgs,
77
+ parse_ligand_charge as _mm_parse_ligand_charge,
78
+ parse_ligand_mult as _mm_parse_ligand_mult,
79
+ run_pipeline as _mm_run,
80
+ )
81
+
82
+ AtomKey = Tuple[str, str, str, str, str, str]
83
+
84
+ class _EchoState:
85
+ """Encapsulate CLI output state for section-spacing logic."""
86
+
87
+ def __init__(self) -> None:
88
+ self._started = False
89
+
90
+ def reset(self) -> None:
91
+ self._started = False
92
+
93
+ def echo(self, *args, **kwargs) -> None:
94
+ click.echo(*args, **kwargs)
95
+ self._started = True
96
+
97
+ def section(self, message: str, **kwargs) -> None:
98
+ if self._started:
99
+ click.echo()
100
+ click.echo(message, **kwargs)
101
+ self._started = True
102
+
103
+
104
+ _echo_state = _EchoState()
105
+
106
+
107
+ def _echo(*args, **kwargs) -> None:
108
+ """Echo with local output tracking for section spacing."""
109
+ _echo_state.echo(*args, **kwargs)
110
+
111
+
112
+ def _echo_section(message: str, **kwargs) -> None:
113
+ """Echo a section header with a leading blank line unless it's the first log."""
114
+ _echo_state.section(message, **kwargs)
115
+
116
+
117
+ def _run_cli_main(
118
+ cmd_name: str,
119
+ cli_obj,
120
+ args: Sequence[str],
121
+ *,
122
+ on_nonzero: str = "warn",
123
+ on_exception: str = "raise",
124
+ prefix: Optional[str] = None,
125
+ ) -> None:
126
+ """Run a Click command with temporary argv and consistent error handling."""
127
+ saved = list(sys.argv)
128
+ label = prefix or cmd_name
129
+ try:
130
+ sys.argv = ["mlmm", cmd_name] + list(args)
131
+ _echo("\n")
132
+ cli_obj.main(args=list(args), standalone_mode=False)
133
+ except SystemExit as e:
134
+ code = getattr(e, "code", 1)
135
+ if code not in (None, 0):
136
+ if on_nonzero == "raise":
137
+ raise click.ClickException(f"[{label}] {cmd_name} exit code {code}.")
138
+ _echo(f"[{label}] WARNING: {cmd_name} exited with code {code}")
139
+ except Exception as e:
140
+ if on_exception == "raise":
141
+ raise click.ClickException(f"[{label}] {cmd_name} failed: {e}")
142
+ _echo(f"[{label}] WARNING: {cmd_name} failed: {e}")
143
+ finally:
144
+ sys.argv = saved
145
+ # Release GPU memory between pipeline stages to prevent OOM.
146
+ # Subcommand finally blocks unbind their heavy locals (= None).
147
+ # gc.collect() is needed to break cyclic refs inside torch.nn.Module,
148
+ # then empty_cache() reclaims the CUDA allocator cache.
149
+ gc.collect()
150
+ if torch.cuda.is_available():
151
+ torch.cuda.empty_cache()
152
+ _echo("\n")
153
+
154
+
155
+ # -----------------------------
156
+ # Helpers
157
+ # -----------------------------
158
+
159
+ def _append_cli_arg(args: List[str], flag: str, value: Any | None) -> None:
160
+ """Append ``flag`` and ``value`` (converted to string) to ``args`` when ``value`` is not ``None``."""
161
+ if value is None:
162
+ return
163
+ if isinstance(value, bool):
164
+ args.extend([flag, "True" if value else "False"])
165
+ else:
166
+ args.extend([flag, str(value)])
167
+
168
+
169
+ def _append_toggle_arg(args: List[str], flag: str, value: Any | None) -> None:
170
+ """Append Click bool-toggle option as ``--flag`` / ``--no-flag`` when value is not ``None``."""
171
+ if value is None:
172
+ return
173
+ if not isinstance(value, bool):
174
+ raise TypeError(f"Toggle flag '{flag}' requires bool value, got {type(value).__name__}.")
175
+ base = flag if not flag.startswith("--no-") else f"--{flag[5:]}"
176
+ neg = f"--no-{base[2:]}"
177
+ args.append(base if value else neg)
178
+
179
+
180
+ def _resolve_override_dir(default: Path, override: Path | None) -> Path:
181
+ """Return ``override`` when provided (respecting absolute paths); otherwise ``default``."""
182
+ if override is None:
183
+ return default
184
+ if override.is_absolute():
185
+ return override
186
+ return default.parent / override
187
+
188
+
189
+ def _build_effective_args_yaml(
190
+ config_yaml: Optional[Path],
191
+ override_yaml: Optional[Path],
192
+ *,
193
+ tmp_prefix: str,
194
+ ) -> Tuple[Optional[Path], Dict[str, Any]]:
195
+ """
196
+ Build an effective args-yaml file path.
197
+
198
+ Precedence for file layering:
199
+ config_yaml < override_yaml
200
+ """
201
+ merged, base_cfg, override_cfg = load_merged_yaml_cfg(config_yaml, override_yaml)
202
+
203
+ if config_yaml is None and override_yaml is None:
204
+ return None, {}
205
+ if config_yaml is None:
206
+ return override_yaml, override_cfg
207
+ if override_yaml is None:
208
+ return config_yaml, base_cfg
209
+
210
+ with tempfile.NamedTemporaryFile(
211
+ mode="w",
212
+ encoding="utf-8",
213
+ suffix=".yaml",
214
+ prefix=tmp_prefix,
215
+ delete=False,
216
+ ) as tf:
217
+ yaml.safe_dump(merged, tf, sort_keys=False, allow_unicode=True)
218
+ effective = Path(tf.name).resolve()
219
+
220
+ # Register cleanup so the temp file is removed when the process exits.
221
+ import atexit
222
+ atexit.register(lambda p=effective: p.unlink(missing_ok=True))
223
+
224
+ return effective, merged
225
+
226
+
227
+ def _write_ml_region_definition(pocket_pdb: Path, dest: Path) -> Path:
228
+ """
229
+ Copy ``pocket_pdb`` to ``dest`` for downstream ML/MM commands.
230
+
231
+ The copy preserves whatever link-hydrogen policy was used during extraction; set ``--add-linkH False``
232
+ if you need a link-free ML-region definition.
233
+ """
234
+ dest.parent.mkdir(parents=True, exist_ok=True)
235
+ try:
236
+ shutil.copyfile(pocket_pdb, dest)
237
+ except FileNotFoundError:
238
+ raise click.ClickException(f"[all] Pocket PDB not found while building ML region: {pocket_pdb}")
239
+ return dest.resolve()
240
+
241
+
242
+ def _mm_charge_mapping(expr: Optional[str]) -> Dict[str, int]:
243
+ """Return a ligand-charge mapping for mm_parm when ``expr`` uses RES=Q or RES:Q syntax."""
244
+ if not expr:
245
+ return {}
246
+ if ("=" not in expr) and (":" not in expr):
247
+ return {}
248
+ try:
249
+ return _mm_parse_ligand_charge(expr)
250
+ except Exception as exc: # pragma: no cover - defensive
251
+ raise click.ClickException(f"[all] Invalid --ligand-charge mapping for mm_parm: {exc}")
252
+
253
+
254
+ def _mm_mult_mapping(expr: Optional[str]) -> Dict[str, int]:
255
+ """Return a ligand-multiplicity mapping for mm_parm when ``expr`` uses RES=M or RES:M syntax."""
256
+ if not expr:
257
+ return {}
258
+ try:
259
+ return _mm_parse_ligand_mult(expr)
260
+ except Exception as exc: # pragma: no cover - defensive
261
+ raise click.ClickException(f"[all] Invalid --auto-mm-ligand-mult mapping for mm_parm: {exc}")
262
+
263
+
264
+ def _build_mm_parm7(
265
+ pdb: Path,
266
+ ligand_charge_expr: Optional[str],
267
+ ligand_mult_expr: Optional[str],
268
+ out_dir: Path,
269
+ ff_set: str,
270
+ add_ter: bool,
271
+ keep_temp: bool,
272
+ ) -> Tuple[Path, Path]:
273
+ """Run mm_parm on ``pdb`` and return (parm7, rst7)."""
274
+ out_dir.mkdir(parents=True, exist_ok=True)
275
+ out_prefix = (out_dir / pdb.stem).resolve()
276
+ out_prefix.parent.mkdir(parents=True, exist_ok=True)
277
+ args = _AutoMMArgs(
278
+ pdb=pdb.resolve(),
279
+ out_prefix=str(out_prefix),
280
+ ligand_charge=_mm_charge_mapping(ligand_charge_expr),
281
+ ligand_mult=_mm_mult_mapping(ligand_mult_expr),
282
+ keep_temp=bool(keep_temp),
283
+ add_ter=bool(add_ter),
284
+ add_h=False,
285
+ ph=7.0,
286
+ ff_set=str(ff_set),
287
+ out_prefix_given=True,
288
+ )
289
+ try:
290
+ _mm_run(args)
291
+ except SystemExit as exc: # pragma: no cover - click exit translation
292
+ code = getattr(exc, "code", 1)
293
+ raise click.ClickException(f"[all] mm_parm exited with code {code}.")
294
+ except Exception as exc:
295
+ raise click.ClickException(f"[all] mm_parm failed: {exc}")
296
+
297
+ parm7 = Path(f"{args.out_prefix}.parm7").resolve()
298
+ rst7 = Path(f"{args.out_prefix}.rst7").resolve()
299
+ if not parm7.exists():
300
+ raise click.ClickException(f"[all] mm_parm did not produce parm7 at {parm7}")
301
+ if not rst7.exists():
302
+ raise click.ClickException(f"[all] mm_parm did not produce rst7 at {rst7}")
303
+ return parm7, rst7
304
+
305
+
306
+ def _parse_atom_key_from_line(line: str) -> Optional[AtomKey]:
307
+ """Extract a structural identity key from a PDB ATOM/HETATM record."""
308
+ if not (line.startswith("ATOM") or line.startswith("HETATM")):
309
+ return None
310
+ atomname = line[12:16].strip()
311
+ altloc = (line[16] if len(line) > 16 else " ").strip()
312
+ resname = line[17:20].strip()
313
+ chain = (line[21] if len(line) > 21 else " ").strip()
314
+ resseq = line[22:26].strip()
315
+ icode = (line[26] if len(line) > 26 else " ").strip()
316
+ return (chain, resname, resseq, icode, atomname, altloc)
317
+
318
+
319
+ def _key_variants(key: AtomKey) -> List[AtomKey]:
320
+ """Return key variants with progressively relaxed identity fields (deduplicated)."""
321
+ chain, resn, resseq, icode, atom, alt = key
322
+ raw_variants = [
323
+ (chain, resn, resseq, icode, atom, alt),
324
+ (chain, resn, resseq, icode, atom, ""),
325
+ (chain, resn, resseq, "", atom, alt),
326
+ (chain, resn, resseq, "", atom, ""),
327
+ ]
328
+ seen: set[AtomKey] = set()
329
+ variants: List[AtomKey] = []
330
+ for variant in raw_variants:
331
+ if variant in seen:
332
+ continue
333
+ seen.add(variant)
334
+ variants.append(variant)
335
+ return variants
336
+
337
+
338
+ def _build_variant_occurrence_table(keys: Sequence[AtomKey]) -> List[Dict[AtomKey, int]]:
339
+ """Track how many times each relaxed key variant has appeared up to each atom index."""
340
+ counts: Dict[AtomKey, int] = defaultdict(int)
341
+ per_atom: List[Dict[AtomKey, int]] = []
342
+ for key in keys:
343
+ current: Dict[AtomKey, int] = {}
344
+ for variant in _key_variants(key):
345
+ counts[variant] += 1
346
+ current[variant] = counts[variant]
347
+ per_atom.append(current)
348
+ return per_atom
349
+
350
+
351
+ def _pocket_key_to_index(pocket_pdb: Path) -> Dict[AtomKey, List[int]]:
352
+ """Build mapping: structural atom key -> list of pocket indices (1-based by file order)."""
353
+ key2idx: Dict[AtomKey, List[int]] = defaultdict(list)
354
+ idx = 0
355
+ try:
356
+ with open(pocket_pdb, "r", encoding="utf-8", errors="ignore") as fh:
357
+ for line in fh:
358
+ if line.startswith("ATOM") or line.startswith("HETATM"):
359
+ key = _parse_atom_key_from_line(line)
360
+ if key is None:
361
+ continue
362
+ idx += 1
363
+ for variant in _key_variants(key):
364
+ key2idx[variant].append(idx)
365
+ except FileNotFoundError:
366
+ raise click.ClickException(f"[all] Pocket PDB not found: {pocket_pdb}")
367
+ if not key2idx:
368
+ raise click.ClickException(f"[all] Pocket PDB {pocket_pdb} has no ATOM/HETATM records.")
369
+ return dict(key2idx)
370
+
371
+
372
+ def _read_full_atom_keys_in_file_order(full_pdb: Path) -> List[AtomKey]:
373
+ """Read ATOM/HETATM lines and return keys in the original file order."""
374
+ keys: List[AtomKey] = []
375
+ try:
376
+ with open(full_pdb, "r", encoding="utf-8", errors="ignore") as fh:
377
+ for line in fh:
378
+ if line.startswith("ATOM") or line.startswith("HETATM"):
379
+ key = _parse_atom_key_from_line(line)
380
+ if key is not None:
381
+ keys.append(key)
382
+ except FileNotFoundError:
383
+ raise click.ClickException(f"[all] File not found while parsing PDB: {full_pdb}")
384
+ if not keys:
385
+ raise click.ClickException(f"[all] No ATOM/HETATM records detected in {full_pdb}.")
386
+ return keys
387
+
388
+
389
+ def _format_atom_key_for_msg(key: AtomKey) -> str:
390
+ """Pretty string for diagnostics."""
391
+ chain, resn, resseq, icode, atom, alt = key
392
+ res = f"{chain}:{resn}{resseq}{(icode if icode else '')}"
393
+ alt_sfx = f",alt={alt}" if alt else ""
394
+ return f"{res}:{atom}{alt_sfx}"
395
+
396
+
397
+ def _parse_scan_lists_literals(
398
+ scan_lists_raw: Sequence[str],
399
+ atom_meta: Optional[Sequence[Dict[str, Any]]] = None,
400
+ ) -> List[List[Tuple[int, int, float]]]:
401
+ """Parse ``--scan-lists`` literals without re-basing atom indices."""
402
+ stages: List[List[Tuple[int, int, float]]] = []
403
+ for idx_stage, literal in enumerate(scan_lists_raw, start=1):
404
+ tuples, _ = parse_scan_list_triples(
405
+ literal,
406
+ one_based=True,
407
+ atom_meta=atom_meta,
408
+ option_name=f"--scan-lists #{idx_stage}",
409
+ return_one_based=True,
410
+ )
411
+ if not tuples:
412
+ raise click.BadParameter(
413
+ f"--scan-lists #{idx_stage} must contain at least one (i,j,target) triple."
414
+ )
415
+ stages.append(tuples)
416
+ return stages
417
+
418
+
419
+ def _format_scan_stage(stage: List[Tuple[int, int, float]]) -> str:
420
+ """Serialize a scan stage back into a Python-like literal string."""
421
+ return "[" + ", ".join(f"({i},{j},{target})" for (i, j, target) in stage) + "]"
422
+
423
+
424
+ def _round_charge_with_note(q: float) -> int:
425
+ """
426
+ Cast the extractor's total charge (float) to an integer suitable for the path search.
427
+ If it is not already an integer within 1e-6, round to the nearest integer with a console note.
428
+ """
429
+ q_rounded = int(round(float(q)))
430
+ if not math.isfinite(q):
431
+ raise click.BadParameter(f"Computed total charge is non-finite: {q!r}")
432
+ if abs(float(q) - q_rounded) > 1e-6:
433
+ click.echo(f"[all] NOTE: extractor total charge = {q:g} → rounded to integer {q_rounded} for the path search.")
434
+ return q_rounded
435
+
436
+
437
+ def _derive_charge_from_ligand_charge_when_extract_skipped(
438
+ pdb_path: Path,
439
+ ligand_charge: Optional[str],
440
+ ) -> Optional[int]:
441
+ """Derive total charge from a PDB using extract-style charge summary.
442
+
443
+ *pdb_path* may be a full-complex PDB or a --model-pdb pocket.
444
+ """
445
+ if ligand_charge is None:
446
+ return None
447
+ try:
448
+ parser = PDB.PDBParser(QUIET=True)
449
+ complex_struct = parser.get_structure("complex", str(pdb_path))
450
+ selected_ids = {res.get_full_id() for res in complex_struct.get_residues()}
451
+ summary = compute_charge_summary(complex_struct, selected_ids, set(), ligand_charge)
452
+ log_charge_summary("[all]", summary)
453
+ q_total = float(summary.get("total_charge", 0.0))
454
+ click.echo(f"[all] Charge summary from {pdb_path.name} (--ligand-charge without extraction):")
455
+ click.echo(
456
+ f" Protein: {summary.get('protein_charge', 0.0):+g}, "
457
+ f"Ligand: {summary.get('ligand_total_charge', 0.0):+g}, "
458
+ f"Ions: {summary.get('ion_total_charge', 0.0):+g}, "
459
+ f"Total: {q_total:+g}"
460
+ )
461
+ return _round_charge_with_note(q_total)
462
+ except Exception as e:
463
+ click.echo(
464
+ f"[all] NOTE: failed to derive total charge from --ligand-charge: {e}",
465
+ err=True,
466
+ )
467
+ return None
468
+
469
+
470
+ def _pdb_needs_elem_fix(p: Path) -> bool:
471
+ """
472
+ Return True if the PDB has at least one ATOM/HETATM record whose element field (cols 77–78) is empty.
473
+ This is a light-weight check to decide whether to run add_elem_info.
474
+ """
475
+ try:
476
+ with p.open("r", encoding="utf-8", errors="ignore") as fh:
477
+ saw_atom = False
478
+ for line in fh:
479
+ if line.startswith("ATOM") or line.startswith("HETATM"):
480
+ saw_atom = True
481
+ if len(line) < 78 or not line[76:78].strip():
482
+ return True
483
+ # If no ATOM/HETATM was seen, fall back to "no fix"
484
+ return False
485
+ except Exception:
486
+ # On I/O errors, skip fixing (use original)
487
+ return False
488
+
489
+
490
+ # ---------- Post-processing helpers (minimal, reuse internals) ----------
491
+
492
+ def _read_summary(summary_yaml: Path) -> List[Dict[str, Any]]:
493
+ """
494
+ Read path_search/summary.yaml and return segments list (empty if not found).
495
+ """
496
+ try:
497
+ if not summary_yaml.exists():
498
+ return []
499
+ data = yaml.safe_load(summary_yaml.read_text(encoding="utf-8")) or {}
500
+ segs = data.get("segments", []) or []
501
+ if not isinstance(segs, list):
502
+ return []
503
+ return segs
504
+ except Exception:
505
+ return []
506
+
507
+
508
+ def _pdb_models_to_coords_and_elems(pdb_path: Path) -> Tuple[List[np.ndarray], List[str]]:
509
+ """
510
+ Return ([coords_model1, coords_model2, ...] in Å), [elements] from a multi-model PDB.
511
+ """
512
+ parser = PDB.PDBParser(QUIET=True)
513
+ st = parser.get_structure("seg", str(pdb_path))
514
+ models = list(st.get_models())
515
+ if not models:
516
+ raise click.ClickException(f"[post] No MODEL found in PDB: {pdb_path}")
517
+ # atom order taken from first model
518
+ atoms0 = [a for a in models[0].get_atoms()]
519
+ elems: List[str] = []
520
+ for a in atoms0:
521
+ el = (a.element or "").strip()
522
+ if not el:
523
+ # fall back: derive from atom name
524
+ nm = a.get_name().strip()
525
+ el = "".join([c for c in nm if c.isalpha()])[:2].title() or "C"
526
+ elems.append(el)
527
+ coords_list: List[np.ndarray] = []
528
+ for m in models:
529
+ atoms = [a for a in m.get_atoms()]
530
+ if len(atoms) != len(atoms0):
531
+ raise click.ClickException(f"[post] Atom count mismatch across models in {pdb_path}")
532
+ coords = np.array([a.get_coord() for a in atoms], dtype=float)
533
+ coords_list.append(coords)
534
+ return coords_list, elems
535
+
536
+
537
+ def _geom_from_angstrom(elems: Sequence[str],
538
+ coords_ang: np.ndarray,
539
+ freeze_atoms: Sequence[int]) -> Any:
540
+ """
541
+ Create a Geometry from Å coordinates using _path_search._new_geom_from_coords (expects Bohr).
542
+ """
543
+ coords_bohr = np.asarray(coords_ang, dtype=float) / BOHR2ANG
544
+ return _path_search._new_geom_from_coords(elems, coords_bohr, coord_type="cart", freeze_atoms=freeze_atoms)
545
+
546
+
547
+ def _load_segment_end_geoms(seg_pdb: Path, freeze_atoms: Sequence[int]) -> Tuple[Any, Any]:
548
+ """
549
+ Load first/last model as Geometries from a per-segment pocket PDB.
550
+ """
551
+ coords_list, elems = _pdb_models_to_coords_and_elems(seg_pdb)
552
+ gL = _geom_from_angstrom(elems, coords_list[0], freeze_atoms)
553
+ gR = _geom_from_angstrom(elems, coords_list[-1], freeze_atoms)
554
+ return gL, gR
555
+
556
+
557
+ def _irc_and_match(seg_idx: int,
558
+ seg_dir: Path,
559
+ ref_pdb_for_seg: Path,
560
+ seg_pocket_pdb: Path,
561
+ g_ts: Any,
562
+ q_int: int,
563
+ spin: int,
564
+ real_parm7: Optional[Path] = None,
565
+ model_pdb: Optional[Path] = None,
566
+ detect_layer: bool = False,
567
+ backend: Optional[str] = None,
568
+ embedcharge: bool = False,
569
+ embedcharge_cutoff: Optional[float] = None,
570
+ args_yaml: Optional[Path] = None) -> Dict[str, Any]:
571
+ """
572
+ Run EulerPC IRC from a TS geometry, then map the IRC endpoints to (left, right)
573
+ by comparing bond states with the GSM segment endpoints (when available).
574
+ Falls back to raw IRC orientation in TSOPT-only mode.
575
+ """
576
+ irc_dir = seg_dir / "irc"
577
+ ensure_dir(irc_dir)
578
+
579
+ # Build irc CLI arguments
580
+ irc_args: List[str] = [
581
+ "-i", str(ref_pdb_for_seg),
582
+ "--parm", str(real_parm7),
583
+ "--model-pdb", str(model_pdb),
584
+ "-q", str(int(q_int)),
585
+ "-m", str(int(spin)),
586
+ "--out-dir", str(irc_dir),
587
+ ]
588
+ irc_args.append("--detect-layer" if detect_layer else "--no-detect-layer")
589
+ if backend is not None:
590
+ irc_args.extend(["--backend", str(backend)])
591
+ if embedcharge:
592
+ irc_args.append("--embedcharge")
593
+ if embedcharge_cutoff is not None:
594
+ irc_args.extend(["--embedcharge-cutoff", str(embedcharge_cutoff)])
595
+ else:
596
+ irc_args.append("--no-embedcharge")
597
+ if args_yaml is not None:
598
+ irc_args.extend(["--config", str(args_yaml)])
599
+
600
+ _echo(f"[irc] Running EulerPC IRC → out={irc_dir}")
601
+ _run_cli_main("irc", _irc_cli.cli, irc_args, on_nonzero="raise", prefix="irc")
602
+
603
+ # Read IRC endpoints
604
+ finished_trj = irc_dir / "finished_irc_trj.xyz"
605
+ finished_pdb = irc_dir / "finished_irc.pdb"
606
+ irc_plot = irc_dir / "irc_plot.png"
607
+
608
+ if not finished_trj.exists():
609
+ raise click.ClickException(f"[irc] IRC trajectory not found: {finished_trj}")
610
+
611
+ # Convert to PDB if not already done
612
+ if not finished_pdb.exists():
613
+ _path_search._maybe_convert_to_pdb(finished_trj, ref_pdb_path=seg_pocket_pdb, out_path=finished_pdb)
614
+
615
+ elems, c_first, c_last = read_xyz_first_last(finished_trj)
616
+
617
+ # Create geometries from IRC endpoints
618
+ calc = _mlmm_calc(
619
+ model_charge=int(q_int),
620
+ model_mult=int(spin),
621
+ input_pdb=str(ref_pdb_for_seg),
622
+ real_parm7=str(real_parm7) if real_parm7 else None,
623
+ model_pdb=str(model_pdb) if model_pdb else None,
624
+ use_bfactor_layers=detect_layer,
625
+ backend=backend,
626
+ embedcharge=embedcharge,
627
+ )
628
+
629
+ g_left = _path_search._new_geom_from_coords(
630
+ elems, c_first / BOHR2ANG, coord_type="cart", freeze_atoms=[])
631
+ g_right = _path_search._new_geom_from_coords(
632
+ elems, c_last / BOHR2ANG, coord_type="cart", freeze_atoms=[])
633
+ g_left.set_calculator(calc)
634
+ g_right.set_calculator(calc)
635
+ _ = float(g_left.energy)
636
+ _ = float(g_right.energy)
637
+
638
+ # Reload TS geometry with energy
639
+ if g_ts.calculator is None:
640
+ g_ts.set_calculator(calc)
641
+ _ = float(g_ts.energy)
642
+
643
+ left_tag = "backward"
644
+ right_tag = "forward"
645
+ reverse_irc = False
646
+
647
+ # Try to load segment endpoints for mapping
648
+ gL_end = None
649
+ gR_end = None
650
+ seg_pocket_path = seg_dir.parent / f"mep_seg_{seg_idx:02d}.pdb"
651
+ if seg_pocket_path.exists():
652
+ try:
653
+ gL_end, gR_end = _load_segment_end_geoms(seg_pocket_path, [])
654
+ except Exception as e:
655
+ click.echo(f"[post] WARNING: failed to load segment endpoints: {e}", err=True)
656
+
657
+ # Map IRC endpoints to left/right using bond-change analysis
658
+ if gL_end is not None and gR_end is not None:
659
+ bond_cfg = dict(_path_search.BOND_KW)
660
+
661
+ def _matches(x, y) -> bool:
662
+ try:
663
+ chg, _ = _path_search._has_bond_change(x, y, bond_cfg)
664
+ return not chg
665
+ except Exception:
666
+ return False
667
+
668
+ def _rmsd_cart(g1, g2) -> float:
669
+ c1 = np.asarray(g1.coords).reshape(-1, 3)
670
+ c2 = np.asarray(g2.coords).reshape(-1, 3)
671
+ n = min(len(c1), len(c2))
672
+ return float(np.sqrt(np.mean((c1[:n] - c2[:n]) ** 2)))
673
+
674
+ # Check if IRC endpoints need swapping
675
+ match_LL = _matches(g_left, gL_end)
676
+ match_LR = _matches(g_left, gR_end)
677
+ match_RL = _matches(g_right, gL_end)
678
+ match_RR = _matches(g_right, gR_end)
679
+
680
+ if match_LR and match_RL and not (match_LL and match_RR):
681
+ # Swap: IRC backward→right, forward→left
682
+ g_left, g_right = g_right, g_left
683
+ left_tag, right_tag = right_tag, left_tag
684
+ reverse_irc = True
685
+ elif not (match_LL and match_RR):
686
+ # RMSD-based fallback
687
+ d_LL = _rmsd_cart(g_left, gL_end)
688
+ d_LR = _rmsd_cart(g_left, gR_end)
689
+ d_RL = _rmsd_cart(g_right, gL_end)
690
+ d_RR = _rmsd_cart(g_right, gR_end)
691
+ if (d_LR + d_RL) < (d_LL + d_RR):
692
+ g_left, g_right = g_right, g_left
693
+ left_tag, right_tag = right_tag, left_tag
694
+ reverse_irc = True
695
+
696
+ return {
697
+ "left_min_geom": g_left,
698
+ "right_min_geom": g_right,
699
+ "ts_geom": g_ts,
700
+ "left_tag": left_tag,
701
+ "right_tag": right_tag,
702
+ "irc_trj": str(finished_trj) if finished_trj.exists() else None,
703
+ "irc_plot": str(irc_plot) if irc_plot.exists() else None,
704
+ "reverse_irc": reverse_irc,
705
+ }
706
+
707
+
708
+ def _save_single_geom_for_tools(g: Any, ref_pdb: Path, out_dir: Path, name: str) -> Tuple[Path, Path]:
709
+ """
710
+ Write XYZ (primary, full precision) + PDB (companion) for a single geometry.
711
+ Returns (xyz_path, pdb_path).
712
+ """
713
+ out_dir.mkdir(parents=True, exist_ok=True)
714
+ # XYZ — full precision
715
+ xyz_out = out_dir / f"{name}.xyz"
716
+ with open(xyz_out, "w") as f:
717
+ f.write(g.as_xyz() + "\n")
718
+ # TRJ with energy (for PDB conversion and trajectory viewers)
719
+ xyz_trj = out_dir / f"{name}_trj.xyz"
720
+ _path_search._write_xyz_trj_with_energy([g], [float(g.energy)], xyz_trj)
721
+ # PDB companion
722
+ pdb_out = out_dir / f"{name}.pdb"
723
+ _path_search._maybe_convert_to_pdb(xyz_trj, ref_pdb_path=ref_pdb, out_path=pdb_out)
724
+ return xyz_out, pdb_out
725
+
726
+
727
+ def _save_single_geom_as_pdb_for_tools(g: Any, ref_pdb: Path, out_dir: Path, name: str) -> Path:
728
+ """Backward-compatible wrapper: returns PDB path only."""
729
+ _, pdb = _save_single_geom_for_tools(g, ref_pdb, out_dir, name)
730
+ return pdb
731
+
732
+
733
+ def _run_tsopt_on_hei(hei_pdb: Path,
734
+ charge: int,
735
+ spin: int,
736
+ real_parm7: Path,
737
+ model_pdb: Path,
738
+ detect_layer: bool,
739
+ args_yaml: Optional[Path],
740
+ out_dir: Path,
741
+ opt_mode_default: str,
742
+ overrides: Optional[Dict[str, Any]] = None,
743
+ backend: Optional[str] = None,
744
+ embedcharge: bool = False,
745
+ embedcharge_cutoff: Optional[float] = None,
746
+ ref_pdb: Optional[Path] = None) -> Tuple[Path, Any]:
747
+ """
748
+ Run tsopt CLI on a HEI structure; return (final_ts_pdb_path, ts_geom).
749
+
750
+ When *ref_pdb* (layered PDB with B-factor layer info) is given, the HEI XYZ
751
+ is used as input and *ref_pdb* is passed via ``--ref-pdb`` so that the
752
+ calculator correctly detects ML/MM layers from B-factors.
753
+ """
754
+ overrides = overrides or {}
755
+ # Prefer HEI XYZ (full precision) + layered ref-pdb (B-factor layer info)
756
+ hei_xyz = hei_pdb.with_suffix(".xyz")
757
+ if ref_pdb is not None and hei_xyz.exists():
758
+ input_file = hei_xyz
759
+ topology_pdb = ref_pdb
760
+ else:
761
+ input_file = hei_pdb
762
+ topology_pdb = hei_pdb
763
+ prepared_input = prepare_input_structure(input_file)
764
+ if input_file.suffix.lower() == ".xyz" and ref_pdb is not None:
765
+ apply_ref_pdb_override(prepared_input, ref_pdb)
766
+ try:
767
+ ts_dir = _resolve_override_dir(out_dir / "ts", overrides.get("out_dir"))
768
+ ensure_dir(ts_dir)
769
+
770
+ opt_mode = overrides.get("opt_mode", opt_mode_default)
771
+
772
+ ts_args: List[str] = [
773
+ "-i", str(prepared_input.geom_path),
774
+ ]
775
+ if input_file.suffix.lower() == ".xyz" and ref_pdb is not None:
776
+ ts_args.extend(["--ref-pdb", str(ref_pdb)])
777
+ ts_args.extend([
778
+ "--parm", str(real_parm7),
779
+ "--model-pdb", str(model_pdb),
780
+ "-q", str(int(charge)),
781
+ "-m", str(int(spin)),
782
+ "--out-dir", str(ts_dir),
783
+ ])
784
+ ts_args.append("--detect-layer" if detect_layer else "--no-detect-layer")
785
+
786
+ if opt_mode is not None:
787
+ ts_args.extend(["--opt-mode", str(opt_mode)])
788
+
789
+ _append_cli_arg(ts_args, "--max-cycles", overrides.get("max_cycles"))
790
+ _append_toggle_arg(ts_args, "--dump", overrides.get("dump"))
791
+ _append_toggle_arg(ts_args, "--convert-files", overrides.get("convert_files"))
792
+ _append_cli_arg(ts_args, "--thresh", overrides.get("thresh"))
793
+ _append_toggle_arg(ts_args, "--flatten", overrides.get("flatten"))
794
+
795
+ hess_mode = overrides.get("hessian_calc_mode")
796
+ if hess_mode:
797
+ ts_args.extend(["--hessian-calc-mode", str(hess_mode)])
798
+
799
+ if args_yaml is not None:
800
+ ts_args.extend(["--config", str(args_yaml)])
801
+
802
+ if backend is not None:
803
+ ts_args.extend(["--backend", str(backend)])
804
+ if embedcharge:
805
+ ts_args.append("--embedcharge")
806
+ if embedcharge_cutoff is not None:
807
+ ts_args.extend(["--embedcharge-cutoff", str(embedcharge_cutoff)])
808
+ else:
809
+ ts_args.append("--no-embedcharge")
810
+
811
+ _echo(f"[tsopt] Running tsopt on HEI → out={ts_dir}")
812
+ _run_cli_main("tsopt", _ts_opt.cli, ts_args, on_nonzero="raise", prefix="tsopt")
813
+
814
+ # Prefer XYZ (full precision) for geometry loading; PDB for topology
815
+ final_xyz = ts_dir / "final_geometry.xyz"
816
+ ts_pdb = ts_dir / "final_geometry.pdb"
817
+ if not ts_pdb.exists() and final_xyz.exists():
818
+ _path_search._maybe_convert_to_pdb(final_xyz, topology_pdb, ts_pdb)
819
+ if not final_xyz.exists() and not ts_pdb.exists():
820
+ raise click.ClickException("[tsopt] TS outputs not found.")
821
+ geom_src = final_xyz if final_xyz.exists() else ts_pdb
822
+ g_ts = geom_loader(geom_src, coord_type="cart")
823
+
824
+ # Ensure calculator to have energy on g_ts
825
+ calc = _mlmm_calc(
826
+ model_charge=int(charge),
827
+ model_mult=int(spin),
828
+ input_pdb=str(topology_pdb),
829
+ real_parm7=str(real_parm7),
830
+ model_pdb=str(model_pdb),
831
+ use_bfactor_layers=detect_layer,
832
+ backend=backend,
833
+ embedcharge=embedcharge,
834
+ )
835
+ g_ts.set_calculator(calc)
836
+ _ = float(g_ts.energy)
837
+
838
+ return ts_pdb, g_ts
839
+ finally:
840
+ prepared_input.cleanup()
841
+
842
+
843
+ def _pseudo_irc_and_match(seg_idx: int,
844
+ seg_dir: Path,
845
+ ref_pdb_for_seg: Path,
846
+ seg_pocket_pdb: Path,
847
+ g_ts: Any,
848
+ q_int: int,
849
+ spin: int,
850
+ real_parm7: Optional[Path] = None,
851
+ model_pdb: Optional[Path] = None,
852
+ detect_layer: bool = False,
853
+ backend: Optional[str] = None,
854
+ embedcharge: bool = False,
855
+ embedcharge_cutoff: Optional[float] = None) -> Dict[str, Any]:
856
+ """
857
+ From a TS pocket geometry, perform pseudo-IRC:
858
+ - compute imag. mode
859
+ - displace ± (0.25 Å) and optimize both to minima (LBFGS)
860
+ - map each min to left/right segment endpoint by bond-change check (if segment endpoints exist)
861
+ *If no segment endpoints are available (single-structure TSOPT-only mode), fall back to (plus, minus).*
862
+ Returns dict with paths/energies/geoms for {left, ts, right}, and small IRC plots.
863
+ """
864
+ # Mode direction
865
+ calc_kwargs = dict(
866
+ model_charge=int(q_int),
867
+ model_mult=int(spin),
868
+ input_pdb=str(ref_pdb_for_seg),
869
+ real_parm7=str(real_parm7) if real_parm7 else None,
870
+ model_pdb=str(model_pdb) if model_pdb else None,
871
+ use_bfactor_layers=detect_layer,
872
+ )
873
+ if backend is not None:
874
+ calc_kwargs["backend"] = backend
875
+ calc_kwargs["embedcharge"] = embedcharge
876
+ mode_xyz = _compute_imag_mode_direction(g_ts, calc_kwargs=calc_kwargs, freeze_atoms=[])
877
+
878
+ # Displace ± and optimize
879
+ irc_dir = seg_dir / "irc"
880
+ ensure_dir(irc_dir)
881
+ amp = 0.25 # Å; small stable displacement
882
+ g_plus0 = _displaced_geometry_along_mode(g_ts, mode_xyz, +amp, [])
883
+ g_minus0 = _displaced_geometry_along_mode(g_ts, mode_xyz, -amp, [])
884
+
885
+ # Shared ML/MM calc
886
+ shared_calc = _mlmm_calc(**calc_kwargs)
887
+ # LBFGS settings (reuse defaults)
888
+ sopt_cfg = dict(_path_search.LBFGS_KW)
889
+ sopt_cfg["dump"] = True
890
+ sopt_cfg["out_dir"] = str(irc_dir)
891
+
892
+ # Optimize
893
+ g_plus = _path_search._optimize_single(g_plus0, shared_calc, sopt_cfg, irc_dir, tag=f"seg_{seg_idx:02d}_irc_plus", ref_pdb_path=seg_pocket_pdb)
894
+ g_minus = _path_search._optimize_single(g_minus0, shared_calc, sopt_cfg, irc_dir, tag=f"seg_{seg_idx:02d}_irc_minus", ref_pdb_path=seg_pocket_pdb)
895
+
896
+ # IRC mini plots (TS→min)
897
+ try:
898
+ trj_plus = irc_dir / f"seg_{seg_idx:02d}_irc_plus_opt/optimization_trj.xyz"
899
+ trj_minus = irc_dir / f"seg_{seg_idx:02d}_irc_minus_opt/optimization_trj.xyz"
900
+ if trj_plus.exists():
901
+ run_trj2fig(trj_plus, [irc_dir / f"irc_plus_plot.png"], unit="kcal", reference="init", reverse_x=False)
902
+ if trj_minus.exists():
903
+ run_trj2fig(trj_minus, [irc_dir / f"irc_minus_plot.png"], unit="kcal", reference="init", reverse_x=False)
904
+ except Exception as e:
905
+ click.echo(f"[irc] WARNING: failed to plot IRC mini plots: {e}", err=True)
906
+
907
+ # Try to load segment endpoints (pocket-only) — available only after path_search
908
+ gL_end = None
909
+ gR_end = None
910
+ seg_pocket_path = seg_dir.parent / f"mep_seg_{seg_idx:02d}.pdb"
911
+ if seg_pocket_path.exists():
912
+ try:
913
+ gL_end, gR_end = _load_segment_end_geoms(seg_pocket_path, [])
914
+ except Exception as e:
915
+ click.echo(f"[post] WARNING: failed to load segment endpoints: {e}", err=True)
916
+
917
+ # Decide mapping
918
+ bond_cfg = dict(_path_search.BOND_KW)
919
+
920
+ def _rmsd_cart(g1, g2) -> float:
921
+ """Cartesian RMSD (Bohr) between two geometries."""
922
+ c1 = np.asarray(g1.coords).reshape(-1, 3)
923
+ c2 = np.asarray(g2.coords).reshape(-1, 3)
924
+ n = min(len(c1), len(c2))
925
+ return float(np.sqrt(np.mean((c1[:n] - c2[:n]) ** 2)))
926
+
927
+ def _matches(x, y) -> bool:
928
+ try:
929
+ chg, _ = _path_search._has_bond_change(x, y, bond_cfg)
930
+ return (not chg)
931
+ except Exception:
932
+ # fallback: small RMSD threshold
933
+ return (_rmsd_cart(x, y) < 1e-3)
934
+
935
+ candidates = [("plus", g_plus), ("minus", g_minus)]
936
+ mapping: Dict[str, Any] = {"left": None, "right": None}
937
+
938
+ if (gL_end is not None) and (gR_end is not None):
939
+ # First pass: exact match on bond changes
940
+ for tag, g in candidates:
941
+ if _matches(g, gL_end) and not _matches(g, gR_end):
942
+ mapping["left"] = (tag, g)
943
+ elif _matches(g, gR_end) and not _matches(g, gL_end):
944
+ mapping["right"] = (tag, g)
945
+ # Second pass: fill missing by RMSD
946
+ for side, g_end in (("left", gL_end), ("right", gR_end)):
947
+ if mapping[side] is None:
948
+ remain = [(t, gg) for (t, gg) in candidates if (mapping["left"] is None or mapping["left"][0] != t) and (mapping["right"] is None or mapping["right"][0] != t)]
949
+ if not remain:
950
+ remain = candidates
951
+ best = min(remain, key=lambda p: _rmsd_cart(p[1], g_end))
952
+ mapping[side] = best
953
+ else:
954
+ # Fallback (single-structure TSOPT-only mode): keep a deterministic assignment
955
+ mapping["left"] = ("minus", g_minus)
956
+ mapping["right"] = ("plus", g_plus)
957
+
958
+ # Energies (ensure calculator is attached)
959
+ for _, g in candidates:
960
+ if g.calculator is None:
961
+ g.set_calculator(shared_calc)
962
+ _ = float(g.energy)
963
+ if g_ts.calculator is None:
964
+ g_ts.set_calculator(shared_calc)
965
+ _ = float(g_ts.energy)
966
+
967
+ # Dump tiny TS↔min trj for each direction
968
+ try:
969
+ for side in ("left", "right"):
970
+ tag, gmin = mapping[side]
971
+ trj = irc_dir / f"irc_{side}_trj.xyz"
972
+ _path_search._write_xyz_trj_with_energy([g_ts, gmin], [float(g_ts.energy), float(gmin.energy)], trj)
973
+ run_trj2fig(trj, [irc_dir / f"irc_{side}_plot.png"], unit="kcal", reference="init", reverse_x=False)
974
+ except Exception:
975
+ logger.debug("Failed to write IRC mini-trajectory plot", exc_info=True)
976
+
977
+ irc_trj = None
978
+ irc_plot = None
979
+ try:
980
+ g_left = mapping["left"][1]
981
+ g_right = mapping["right"][1]
982
+ irc_trj = irc_dir / "finished_irc_trj.xyz"
983
+ _path_search._write_xyz_trj_with_energy(
984
+ [g_left, g_ts, g_right],
985
+ [float(g_left.energy), float(g_ts.energy), float(g_right.energy)],
986
+ irc_trj,
987
+ )
988
+ irc_plot = irc_dir / "irc_plot.png"
989
+ run_trj2fig(irc_trj, [irc_plot], unit="kcal", reference="init", reverse_x=False)
990
+ except Exception as e:
991
+ click.echo(f"[irc] WARNING: failed to build combined IRC plot: {e}", err=True)
992
+
993
+ return {
994
+ "left_min_geom": mapping["left"][1],
995
+ "right_min_geom": mapping["right"][1],
996
+ "ts_geom": g_ts,
997
+ "left_tag": mapping["left"][0],
998
+ "right_tag": mapping["right"][0],
999
+ "irc_trj": str(irc_trj) if irc_trj else None,
1000
+ "irc_plot": str(irc_plot) if irc_plot else None,
1001
+ }
1002
+
1003
+
1004
+ def _write_segment_energy_diagram(
1005
+ prefix: Path,
1006
+ labels: List[str],
1007
+ energies_eh: List[float],
1008
+ title_note: str,
1009
+ ylabel: str = "ΔE (kcal/mol)",
1010
+ write_html: bool = False,
1011
+ ) -> Optional[Dict[str, Any]]:
1012
+ """
1013
+ Write energy diagram (PNG only) using utils.build_energy_diagram.
1014
+ """
1015
+ if not energies_eh:
1016
+ return None
1017
+ e0 = energies_eh[0]
1018
+ energies_kcal = [(e - e0) * AU2KCALPERMOL for e in energies_eh]
1019
+ fig = build_energy_diagram(
1020
+ energies=energies_kcal,
1021
+ labels=labels,
1022
+ ylabel=ylabel,
1023
+ baseline=True,
1024
+ showgrid=False,
1025
+ )
1026
+ if title_note:
1027
+ fig.update_layout(title=title_note)
1028
+ png = prefix.with_suffix(".png")
1029
+ try:
1030
+ fig.write_image(str(png), scale=2)
1031
+ except Exception as e:
1032
+ click.echo(f"[diagram] NOTE: PNG export skipped (install 'kaleido' to enable): {e}", err=True)
1033
+ else:
1034
+ click.echo(f"[diagram] Wrote energy diagram → {png.name}")
1035
+
1036
+ payload: Dict[str, Any] = {
1037
+ "name": prefix.stem,
1038
+ "labels": labels,
1039
+ "energies_kcal": energies_kcal,
1040
+ "ylabel": ylabel,
1041
+ "energies_au": list(energies_eh),
1042
+ "image": str(png),
1043
+ }
1044
+ if title_note:
1045
+ payload["title"] = title_note
1046
+ return payload
1047
+
1048
+
1049
+ def _build_global_segment_labels(n_segments: int) -> List[str]:
1050
+ """
1051
+ Build GSM-like labels for aggregated R/TS/P diagrams over multiple segments.
1052
+
1053
+ Pattern:
1054
+ - n = 1: ["R", "TS1", "P"]
1055
+ - n >= 2: R, TS1, IM1_1, IM1_2, TS2, IM2_1, IM2_2, ..., TSN, P
1056
+ """
1057
+ if n_segments <= 0:
1058
+ return []
1059
+ if n_segments == 1:
1060
+ return ["R", "TS1", "P"]
1061
+
1062
+ labels: List[str] = []
1063
+ for seg_idx in range(1, n_segments + 1):
1064
+ if seg_idx == 1:
1065
+ labels.extend(["R", "TS1", "IM1_1"])
1066
+ elif seg_idx == n_segments:
1067
+ labels.extend([f"IM{seg_idx - 1}_2", f"TS{seg_idx}", "P"])
1068
+ else:
1069
+ labels.extend(
1070
+ [f"IM{seg_idx - 1}_2", f"TS{seg_idx}", f"IM{seg_idx}_1"]
1071
+ )
1072
+ return labels
1073
+
1074
+
1075
+ def _merge_irc_trajectories_to_single_plot(
1076
+ trj_and_flags: Sequence[Tuple[Path, bool]],
1077
+ out_png: Path,
1078
+ ) -> None:
1079
+ """
1080
+ Build a single IRC plot over all reactive segments using trj2fig.
1081
+ """
1082
+ all_blocks: List[str] = []
1083
+ for trj_path, reverse in trj_and_flags:
1084
+ if not isinstance(trj_path, Path) or not trj_path.exists():
1085
+ continue
1086
+ try:
1087
+ blocks = read_xyz_as_blocks(trj_path)
1088
+ except click.ClickException as e:
1089
+ click.echo(str(e), err=True)
1090
+ continue
1091
+ if not blocks:
1092
+ continue
1093
+ if reverse:
1094
+ blocks = list(reversed(blocks))
1095
+ all_blocks.extend("\n".join(b) for b in blocks)
1096
+
1097
+ if not all_blocks:
1098
+ return
1099
+
1100
+ tmp_trj = out_png.with_name(f"{out_png.stem}_trj.xyz")
1101
+ ensure_dir(tmp_trj.parent)
1102
+ try:
1103
+ tmp_trj.write_text("\n".join(all_blocks) + "\n", encoding="utf-8")
1104
+ except Exception as e:
1105
+ click.echo(f"[irc_all] WARNING: Failed to write concatenated IRC trajectory: {e}", err=True)
1106
+ return
1107
+
1108
+ try:
1109
+ run_trj2fig(tmp_trj, [out_png], unit="kcal", reference="init", reverse_x=False)
1110
+ click.echo(f"[irc_all] Wrote aggregated IRC plot → {out_png}")
1111
+ except Exception as e:
1112
+ click.echo(f"[irc_all] WARNING: failed to plot concatenated IRC trajectory: {e}", err=True)
1113
+ finally:
1114
+ try:
1115
+ tmp_trj.unlink()
1116
+ except Exception:
1117
+ logger.debug("Failed to unlink temp trajectory file", exc_info=True)
1118
+
1119
+
1120
+ def _run_freq_for_state(pdb_path: Path,
1121
+ q_int: int,
1122
+ spin: int,
1123
+ real_parm7: Path,
1124
+ model_pdb: Path,
1125
+ detect_layer: bool,
1126
+ out_dir: Path,
1127
+ args_yaml: Optional[Path],
1128
+ overrides: Optional[Dict[str, Any]] = None,
1129
+ backend: Optional[str] = None,
1130
+ embedcharge: bool = False,
1131
+ embedcharge_cutoff: Optional[float] = None,
1132
+ xyz_path: Optional[Path] = None) -> Dict[str, Any]:
1133
+ """
1134
+ Run freq CLI; return parsed thermo dict (may be empty).
1135
+ When *xyz_path* is given, use it for full-precision coordinates with
1136
+ *pdb_path* as topology reference (--ref-pdb).
1137
+ """
1138
+ fdir = out_dir
1139
+ ensure_dir(fdir)
1140
+ overrides = overrides or {}
1141
+
1142
+ dump_use = overrides.get("dump")
1143
+ if dump_use is None:
1144
+ dump_use = True
1145
+
1146
+ # Prefer XYZ (full precision) with --ref-pdb for topology
1147
+ if xyz_path is not None and xyz_path.exists():
1148
+ args = ["-i", str(xyz_path), "--ref-pdb", str(pdb_path)]
1149
+ else:
1150
+ args = ["-i", str(pdb_path)]
1151
+ args.extend([
1152
+ "--parm", str(real_parm7),
1153
+ "--model-pdb", str(model_pdb),
1154
+ "-q", str(int(q_int)),
1155
+ "-m", str(int(spin)),
1156
+ "--out-dir", str(fdir),
1157
+ ])
1158
+ args.append("--detect-layer" if detect_layer else "--no-detect-layer")
1159
+
1160
+ _append_cli_arg(args, "--max-write", overrides.get("max_write"))
1161
+ _append_cli_arg(args, "--amplitude-ang", overrides.get("amplitude_ang"))
1162
+ _append_cli_arg(args, "--n-frames", overrides.get("n_frames"))
1163
+ if overrides.get("sort") is not None:
1164
+ args.extend(["--sort", str(overrides.get("sort"))])
1165
+ _append_cli_arg(args, "--temperature", overrides.get("temperature"))
1166
+ _append_cli_arg(args, "--pressure", overrides.get("pressure"))
1167
+ _append_toggle_arg(args, "--dump", dump_use)
1168
+ _append_toggle_arg(args, "--convert-files", overrides.get("convert_files"))
1169
+
1170
+ hess_mode = overrides.get("hessian_calc_mode")
1171
+ if hess_mode:
1172
+ args.extend(["--hessian-calc-mode", str(hess_mode)])
1173
+
1174
+ if args_yaml is not None:
1175
+ args.extend(["--config", str(args_yaml)])
1176
+ if backend is not None:
1177
+ args.extend(["--backend", str(backend)])
1178
+ if embedcharge:
1179
+ args.append("--embedcharge")
1180
+ if embedcharge_cutoff is not None:
1181
+ args.extend(["--embedcharge-cutoff", str(embedcharge_cutoff)])
1182
+ else:
1183
+ args.append("--no-embedcharge")
1184
+ _run_cli_main("freq", _freq_cli.cli, args, on_nonzero="warn", on_exception="raise", prefix="freq")
1185
+ # parse thermoanalysis.yaml if any
1186
+ y = fdir / "thermoanalysis.yaml"
1187
+ if y.exists():
1188
+ try:
1189
+ return yaml.safe_load(y.read_text(encoding="utf-8")) or {}
1190
+ except Exception:
1191
+ return {}
1192
+ return {}
1193
+
1194
+
1195
+ def _run_opt_for_state(
1196
+ pdb_path: Path,
1197
+ q_int: int,
1198
+ spin: int,
1199
+ real_parm7: Path,
1200
+ model_pdb: Path,
1201
+ detect_layer: bool,
1202
+ out_dir: Path,
1203
+ args_yaml: Optional[Path],
1204
+ opt_mode_default: str,
1205
+ convert_files: Optional[bool] = None,
1206
+ backend: Optional[str] = None,
1207
+ embedcharge: bool = False,
1208
+ embedcharge_cutoff: Optional[float] = None,
1209
+ thresh: Optional[str] = None,
1210
+ xyz_path: Optional[Path] = None,
1211
+ ) -> Tuple[Any, Path]:
1212
+ """
1213
+ Run opt CLI for a single endpoint and return (optimized Geometry, final geometry path).
1214
+ When *xyz_path* is given, pass it as ``-i`` with ``--ref-pdb pdb_path`` to
1215
+ preserve full coordinate precision.
1216
+ """
1217
+ opt_dir = out_dir
1218
+ ensure_dir(opt_dir)
1219
+
1220
+ # Use XYZ (full precision) when available; fall back to PDB
1221
+ if xyz_path is not None and xyz_path.exists():
1222
+ prepared_input = prepare_input_structure(xyz_path)
1223
+ apply_ref_pdb_override(prepared_input, pdb_path)
1224
+ input_label = xyz_path.name
1225
+ else:
1226
+ prepared_input = prepare_input_structure(pdb_path)
1227
+ input_label = pdb_path.name
1228
+ try:
1229
+ opt_mode = str(opt_mode_default or "heavy").lower()
1230
+ args = [
1231
+ "-i", str(prepared_input.geom_path),
1232
+ ]
1233
+ # Add --ref-pdb when input is XYZ
1234
+ if prepared_input.geom_path.suffix.lower() == ".xyz":
1235
+ args.extend(["--ref-pdb", str(prepared_input.source_path)])
1236
+ args.extend([
1237
+ "--parm", str(real_parm7),
1238
+ "--model-pdb", str(model_pdb),
1239
+ "-q", str(int(q_int)),
1240
+ "-m", str(int(spin)),
1241
+ "--out-dir", str(opt_dir),
1242
+ "--opt-mode", opt_mode,
1243
+ ])
1244
+ args.append("--detect-layer" if detect_layer else "--no-detect-layer")
1245
+ _append_toggle_arg(args, "--convert-files", convert_files)
1246
+ _append_cli_arg(args, "--thresh", thresh)
1247
+
1248
+ if args_yaml is not None:
1249
+ args.extend(["--config", str(args_yaml)])
1250
+
1251
+ if backend is not None:
1252
+ args.extend(["--backend", str(backend)])
1253
+ if embedcharge:
1254
+ args.append("--embedcharge")
1255
+ if embedcharge_cutoff is not None:
1256
+ args.extend(["--embedcharge-cutoff", str(embedcharge_cutoff)])
1257
+ else:
1258
+ args.append("--no-embedcharge")
1259
+
1260
+ _echo(f"[endpoint-opt] Running opt on {input_label} (mode={opt_mode}) → out={opt_dir}")
1261
+ _run_cli_main("opt", _opt_cli.cli, args, on_nonzero="raise", on_exception="raise", prefix="endpoint-opt")
1262
+
1263
+ final_pdb = opt_dir / "final_geometry.pdb"
1264
+ final_xyz = opt_dir / "final_geometry.xyz"
1265
+ # Prefer XYZ (full precision) for geometry loading
1266
+ if final_xyz.exists():
1267
+ final_geom_path = final_xyz
1268
+ elif final_pdb.exists():
1269
+ final_geom_path = final_pdb
1270
+ else:
1271
+ raise click.ClickException(f"[endpoint-opt] opt outputs not found under {opt_dir}")
1272
+
1273
+ g_opt = geom_loader(final_geom_path, coord_type="cart")
1274
+ calc_input_pdb = final_pdb if final_pdb.exists() else pdb_path
1275
+ calc = _mlmm_calc(
1276
+ model_charge=int(q_int),
1277
+ model_mult=int(spin),
1278
+ input_pdb=str(calc_input_pdb),
1279
+ real_parm7=str(real_parm7),
1280
+ model_pdb=str(model_pdb),
1281
+ use_bfactor_layers=detect_layer,
1282
+ backend=backend,
1283
+ embedcharge=embedcharge,
1284
+ )
1285
+ g_opt.set_calculator(calc)
1286
+ _ = float(g_opt.energy)
1287
+
1288
+ return g_opt, final_geom_path
1289
+ finally:
1290
+ prepared_input.cleanup()
1291
+
1292
+
1293
+ def _run_dft_for_state(pdb_path: Path,
1294
+ q_int: int,
1295
+ spin: int,
1296
+ real_parm7: Path,
1297
+ model_pdb: Path,
1298
+ detect_layer: bool,
1299
+ out_dir: Path,
1300
+ args_yaml: Optional[Path],
1301
+ func_basis: str = "wb97m-v/def2-tzvpd",
1302
+ overrides: Optional[Dict[str, Any]] = None,
1303
+ backend: Optional[str] = None,
1304
+ embedcharge: bool = False,
1305
+ embedcharge_cutoff: Optional[float] = None,
1306
+ xyz_path: Optional[Path] = None) -> Dict[str, Any]:
1307
+ """
1308
+ Run dft CLI; return parsed result.yaml dict (may be empty).
1309
+ When *xyz_path* is given, use it for full-precision coordinates with
1310
+ *pdb_path* as topology reference (--ref-pdb).
1311
+ """
1312
+ ddir = out_dir
1313
+ ensure_dir(ddir)
1314
+ overrides = overrides or {}
1315
+
1316
+ func_basis_use = overrides.get("func_basis", func_basis)
1317
+
1318
+ # Prefer XYZ (full precision) with --ref-pdb for topology
1319
+ if xyz_path is not None and xyz_path.exists():
1320
+ args = ["-i", str(xyz_path), "--ref-pdb", str(pdb_path)]
1321
+ else:
1322
+ args = ["-i", str(pdb_path)]
1323
+ args.extend([
1324
+ "--parm", str(real_parm7),
1325
+ "--model-pdb", str(model_pdb),
1326
+ "-q", str(int(q_int)),
1327
+ "-m", str(int(spin)),
1328
+ "--func-basis", str(func_basis_use),
1329
+ "--out-dir", str(ddir),
1330
+ ])
1331
+ args.append("--detect-layer" if detect_layer else "--no-detect-layer")
1332
+
1333
+ _append_cli_arg(args, "--max-cycle", overrides.get("max_cycle"))
1334
+ _append_cli_arg(args, "--conv-tol", overrides.get("conv_tol"))
1335
+ _append_cli_arg(args, "--grid-level", overrides.get("grid_level"))
1336
+ _append_toggle_arg(args, "--convert-files", overrides.get("convert_files"))
1337
+
1338
+ if args_yaml is not None:
1339
+ args.extend(["--config", str(args_yaml)])
1340
+ if backend is not None:
1341
+ args.extend(["--backend", str(backend)])
1342
+ if embedcharge:
1343
+ args.append("--embedcharge")
1344
+ if embedcharge_cutoff is not None:
1345
+ args.extend(["--embedcharge-cutoff", str(embedcharge_cutoff)])
1346
+ else:
1347
+ args.append("--no-embedcharge")
1348
+ _run_cli_main("dft", _dft_cli.cli, args, on_nonzero="warn", on_exception="raise", prefix="dft")
1349
+ y = out_dir / "result.yaml"
1350
+ if y.exists():
1351
+ try:
1352
+ return yaml.safe_load(y.read_text(encoding="utf-8")) or {}
1353
+ except Exception:
1354
+ return {}
1355
+ return {}
1356
+
1357
+
1358
+ # -----------------------------
1359
+ # CLI
1360
+ # -----------------------------
1361
+
1362
+ _ALL_PRIMARY_HELP_OPTIONS = frozenset(
1363
+ {
1364
+ "-i",
1365
+ "--input",
1366
+ "-c",
1367
+ "--center",
1368
+ "-l",
1369
+ "--ligand-charge",
1370
+ "-q",
1371
+ "--charge",
1372
+ "--out-dir",
1373
+ "--tsopt",
1374
+ "--thermo",
1375
+ "--dft",
1376
+ "--config",
1377
+ "--dry-run",
1378
+ "--embedcharge",
1379
+ "-s",
1380
+ "--scan-lists",
1381
+ "-b",
1382
+ "--backend",
1383
+ "-o",
1384
+ "--help-advanced",
1385
+ }
1386
+ )
1387
+
1388
+
1389
+ def _show_advanced_help(
1390
+ ctx: click.Context, _param: click.Parameter, value: bool
1391
+ ) -> None:
1392
+ """Print full option help (including hidden advanced options) and exit."""
1393
+ if not value or ctx.resilient_parsing:
1394
+ return
1395
+
1396
+ hidden = getattr(ctx.command, "_advanced_hidden_options", ())
1397
+ restored: list[click.Option] = []
1398
+ for opt in hidden:
1399
+ if opt.hidden:
1400
+ opt.hidden = False
1401
+ restored.append(opt)
1402
+ try:
1403
+ click.echo(ctx.command.get_help(ctx))
1404
+ finally:
1405
+ for opt in restored:
1406
+ opt.hidden = True
1407
+ ctx.exit()
1408
+
1409
+
1410
+ def _configure_all_help_visibility(command: click.Command) -> None:
1411
+ """Hide advanced options from default --help while keeping them functional."""
1412
+ hidden_options: list[click.Option] = []
1413
+ for param in command.params:
1414
+ if not isinstance(param, click.Option):
1415
+ continue
1416
+ names = set(param.opts + param.secondary_opts)
1417
+ if names & _ALL_PRIMARY_HELP_OPTIONS:
1418
+ continue
1419
+ if param.hidden:
1420
+ continue
1421
+ param.hidden = True
1422
+ hidden_options.append(param)
1423
+ setattr(command, "_advanced_hidden_options", tuple(hidden_options))
1424
+
1425
+
1426
+ @click.command(
1427
+ help="Run pocket extraction → (optional single-structure staged scan) → MEP search in one shot.\n"
1428
+ "If exactly one input is provided: (a) with --scan-lists, stage results feed into path_search; "
1429
+ "(b) with --tsopt True and no --scan-lists, run TSOPT-only mode.",
1430
+ context_settings={
1431
+ "help_option_names": ["-h", "--help"],
1432
+ "ignore_unknown_options": True,
1433
+ "allow_extra_args": True,
1434
+ },
1435
+ )
1436
+ @click.option(
1437
+ "--help-advanced",
1438
+ is_flag=True,
1439
+ is_eager=True,
1440
+ expose_value=False,
1441
+ callback=_show_advanced_help,
1442
+ help="Show all options (including advanced settings) and exit.",
1443
+ )
1444
+ # ===== Inputs =====
1445
+ @click.option(
1446
+ "-i", "--input", "input_paths",
1447
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
1448
+ multiple=True, required=True,
1449
+ help=("Two or more **full** PDBs in reaction order (reactant [intermediates ...] product), "
1450
+ "or a single **full** PDB (with --scan-lists or with --tsopt True). "
1451
+ "You may pass a single '-i' followed by multiple space-separated files (e.g., '-i A.pdb B.pdb C.pdb').")
1452
+ )
1453
+ @click.option(
1454
+ "-c", "--center", "center_spec",
1455
+ type=str, required=False, default=None,
1456
+ help=("Substrate specification for the extractor: "
1457
+ "a PDB path, a residue-ID list like '123,124' or 'A:123,B:456' "
1458
+ "(insertion codes OK: '123A' / 'A:123A'), "
1459
+ "or a residue-name list like 'GPP,MMT'. "
1460
+ "When omitted, extraction is skipped and full structures are used directly.")
1461
+ )
1462
+ @click.option(
1463
+ "-o", "--out-dir", "out_dir",
1464
+ type=click.Path(path_type=Path, file_okay=False),
1465
+ default=Path("./result_all/"), show_default=True,
1466
+ help="Top-level output directory for the pipeline."
1467
+ )
1468
+ # ===== Extractor knobs (subset of extract.parse_args) =====
1469
+ @click.option("-r", "--radius", type=float, default=2.6, show_default=True,
1470
+ help="Inclusion cutoff (Å) around substrate atoms.")
1471
+ @click.option("--radius-het2het", type=float, default=0.0, show_default=True,
1472
+ help="Independent hetero–hetero cutoff (Å) for non‑C/H pairs.")
1473
+ @click.option("--include-H2O", "--include-h2o", "include_h2o", type=click.BOOL, default=True, show_default=True,
1474
+ help="Include waters (HOH/WAT/TIP3/SOL) in the pocket.")
1475
+ @click.option("--exclude-backbone", "exclude_backbone", type=click.BOOL, default=True, show_default=True,
1476
+ help="Remove backbone atoms on non‑substrate amino acids (with PRO/HYP safeguards).")
1477
+ @click.option("--add-linkH", "add_linkh", type=click.BOOL, default=False, show_default=True,
1478
+ help="Add link hydrogens for severed bonds (carbon-only) in pockets.")
1479
+ @click.option("--selected-resn", type=str, default="", show_default=True,
1480
+ help="Force-include residues (comma/space separated; chain/insertion codes allowed).")
1481
+ @click.option("-l", "--ligand-charge", type=str, default=None,
1482
+ help=("Either a total charge (number) to distribute across unknown residues "
1483
+ "or a mapping like 'GPP:-3,MMT:-1'."))
1484
+ @click.option(
1485
+ "-q",
1486
+ "--charge",
1487
+ "charge_override",
1488
+ type=int,
1489
+ default=None,
1490
+ help="Force total system charge. Highest priority over derived charges.",
1491
+ )
1492
+ @click.option(
1493
+ "--parm",
1494
+ "parm7_override",
1495
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
1496
+ default=None,
1497
+ help="Pre-built AMBER parm7 topology file. When provided, mm_parm generation is skipped.",
1498
+ )
1499
+ @click.option(
1500
+ "--model-pdb",
1501
+ "model_pdb_override",
1502
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
1503
+ default=None,
1504
+ help="Pre-built ML-region PDB (with B-factor layer info). When provided, ml_region generation is skipped.",
1505
+ )
1506
+ @click.option("--auto-mm-ff-set", "mm_ff_set",
1507
+ type=click.Choice(["ff19SB", "ff14SB"], case_sensitive=False),
1508
+ default="ff19SB", show_default=True,
1509
+ help="Force-field set forwarded to mm_parm (ff19SB uses OPC3; ff14SB uses TIP3P).")
1510
+ @click.option("--auto-mm-add-ter/--auto-mm-no-add-ter", "mm_add_ter",
1511
+ default=True, show_default=True,
1512
+ help="Control mm_parm TER insertion around ligand/water/ion blocks.")
1513
+ @click.option("--auto-mm-keep-temp", "mm_keep_temp", is_flag=True, default=False, show_default=True,
1514
+ help="Keep the mm_parm temporary working directory (for debugging).")
1515
+ @click.option(
1516
+ "--auto-mm-ligand-mult",
1517
+ "mm_ligand_mult",
1518
+ type=str,
1519
+ default=None,
1520
+ help=("Spin multiplicity mapping forwarded to mm_parm (e.g., 'GPP:2,SAM:1'). "
1521
+ "If omitted, mm_parm defaults to 1 for all ligands.")
1522
+ )
1523
+ @click.option("--verbose", type=click.BOOL, default=True, show_default=True, help="Enable INFO-level logging inside extractor.")
1524
+ # ===== Path search knobs (subset of path_search.cli) =====
1525
+ @click.option("-m", "--multiplicity", "spin", type=int, default=1, show_default=True, help="Multiplicity (2S+1).")
1526
+ @click.option("--max-nodes", type=int, default=_path_opt.GS_KW["max_nodes"], show_default=True,
1527
+ help="Max internal nodes for **segment** GSM (String has max_nodes+2 images including endpoints).")
1528
+ @click.option("--max-cycles", type=int, default=300, show_default=True, help="Maximum GSM optimization cycles.")
1529
+ @click.option("--climb", type=click.BOOL, default=True, show_default=True,
1530
+ help="Enable transition-state climbing after growth for the **first** segment in each pair.")
1531
+ @click.option(
1532
+ "--opt-mode",
1533
+ type=click.Choice(["grad", "hess"], case_sensitive=False),
1534
+ default="grad",
1535
+ show_default=True,
1536
+ help=(
1537
+ "Optimizer mode forwarded to scan/path-search and used for single optimizations: "
1538
+ "grad (=LBFGS/Dimer) or hess (=RFO/RSIRFO)."
1539
+ ),
1540
+ )
1541
+ @click.option(
1542
+ "--opt-mode-post",
1543
+ type=click.Choice(["grad", "hess"], case_sensitive=False),
1544
+ default="hess",
1545
+ show_default=True,
1546
+ help=(
1547
+ "Optimizer mode override for TSOPT/post-IRC endpoint optimizations. "
1548
+ "If unset, uses --opt-mode when explicitly provided; otherwise falls back to tsopt defaults."
1549
+ ),
1550
+ )
1551
+ @click.option("--dump", type=click.BOOL, default=False, show_default=True,
1552
+ help="Dump GSM / single-structure trajectories during the run, forwarding the same flag to scan/tsopt/freq.")
1553
+ @click.option(
1554
+ "--refine-path/--no-refine-path",
1555
+ "refine_path",
1556
+ default=True,
1557
+ show_default=True,
1558
+ help=(
1559
+ "If True, run recursive path_search on the full ordered series; if False, run a single-pass "
1560
+ "path-opt GSM between each adjacent pair and concatenate the segments (no path_search)."
1561
+ ),
1562
+ )
1563
+ @click.option(
1564
+ "--thresh",
1565
+ type=str,
1566
+ default=None,
1567
+ show_default=False,
1568
+ help=(
1569
+ "Convergence preset (gau_loose|gau|gau_tight|gau_vtight|baker|never). "
1570
+ "Defaults to 'gau' when not provided."
1571
+ ),
1572
+ )
1573
+ @click.option(
1574
+ "--thresh-post",
1575
+ type=str,
1576
+ default="baker",
1577
+ show_default=True,
1578
+ help=(
1579
+ "Convergence preset for post-IRC endpoint optimizations "
1580
+ "(gau_loose|gau|gau_tight|gau_vtight|baker|never)."
1581
+ ),
1582
+ )
1583
+ @click.option("--config", "config_yaml", type=click.Path(path_type=Path, exists=True, dir_okay=False),
1584
+ default=None, help="Base YAML configuration file applied before explicit CLI options.")
1585
+ @click.option("--show-config/--no-show-config", "show_config", default=False, show_default=True,
1586
+ help="Print resolved configuration and continue execution.")
1587
+ @click.option("--dry-run/--no-dry-run", "dry_run", default=False, show_default=True,
1588
+ help="Validate options and print the execution plan without running any stage.")
1589
+ @click.option("--preopt", "pre_opt", type=click.BOOL, default=True, show_default=True,
1590
+ help="If False, skip initial single-structure optimizations of the pocket inputs.")
1591
+ @click.option("--hessian-calc-mode",
1592
+ type=click.Choice(["Analytical", "FiniteDifference"], case_sensitive=False),
1593
+ default=None,
1594
+ help="Common MLIP Hessian calculation mode forwarded to tsopt and freq. Default: 'FiniteDifference'. Use 'Analytical' when VRAM is sufficient.")
1595
+ @click.option(
1596
+ "--detect-layer/--no-detect-layer",
1597
+ "detect_layer",
1598
+ default=True,
1599
+ show_default=True,
1600
+ help="Detect ML/MM layers from input PDB B-factors (B=0/10/20) in downstream tools. "
1601
+ "If disabled, downstream tools require --model-pdb or --model-indices.",
1602
+ )
1603
+ # ===== Post-processing toggles =====
1604
+ @click.option("--tsopt", "do_tsopt", type=click.BOOL, default=False, show_default=True,
1605
+ help="TS optimization + pseudo-IRC per reactive segment (or TSOPT-only mode for single-structure), and build energy diagrams.")
1606
+ @click.option("--thermo", "do_thermo", type=click.BOOL, default=False, show_default=True,
1607
+ help="Run freq on (R,TS,P) per reactive segment (or TSOPT-only mode) and build Gibbs free-energy diagram (MLIP).")
1608
+ @click.option("--dft", "do_dft", type=click.BOOL, default=False, show_default=True,
1609
+ help="Run DFT single-point on (R,TS,P) and build DFT energy diagram. With --thermo True, also generate a DFT//MLIP Gibbs diagram.")
1610
+ @click.option("--tsopt-max-cycles", type=int, default=None,
1611
+ help="Override tsopt --max-cycles value.")
1612
+ @click.option(
1613
+ "--flatten/--no-flatten",
1614
+ "flatten",
1615
+ default=False,
1616
+ show_default=True,
1617
+ help="Enable the extra-imaginary-mode flattening loop in tsopt (grad: dimer loop, hess: post-RSIRFO); --no-flatten forces flatten_max_iter=0.",
1618
+ )
1619
+ @click.option("--tsopt-out-dir", type=click.Path(path_type=Path, file_okay=False), default=None,
1620
+ help="Override tsopt output subdirectory (relative paths are resolved against the default).")
1621
+ @click.option("--freq-out-dir", type=click.Path(path_type=Path, file_okay=False), default=None,
1622
+ help="Override freq output base directory (relative paths resolved against the default).")
1623
+ @click.option("--freq-max-write", type=int, default=None,
1624
+ help="Override freq --max-write value.")
1625
+ @click.option("--freq-amplitude-ang", type=float, default=None,
1626
+ help="Override freq --amplitude-ang (Å).")
1627
+ @click.option("--freq-n-frames", type=int, default=None,
1628
+ help="Override freq --n-frames value.")
1629
+ @click.option("--freq-sort", type=click.Choice(["value", "abs"], case_sensitive=False), default=None,
1630
+ help="Override freq mode sorting.")
1631
+ @click.option("--freq-temperature", type=float, default=None,
1632
+ help="Override freq thermochemistry temperature (K).")
1633
+ @click.option("--freq-pressure", type=float, default=None,
1634
+ help="Override freq thermochemistry pressure (atm).")
1635
+ @click.option("--dft-out-dir", type=click.Path(path_type=Path, file_okay=False), default=None,
1636
+ help="Override dft output base directory (relative paths resolved against the default).")
1637
+ @click.option("--dft-func-basis", type=str, default=None,
1638
+ help="Override dft --func-basis value.")
1639
+ @click.option("--dft-max-cycle", type=int, default=None,
1640
+ help="Override dft --max-cycle value.")
1641
+ @click.option("--dft-conv-tol", type=float, default=None,
1642
+ help="Override dft --conv-tol value.")
1643
+ @click.option("--dft-grid-level", type=int, default=None,
1644
+ help="Override dft --grid-level value.")
1645
+ # ===== NEW: staged scan specification for single-structure route =====
1646
+ @click.option(
1647
+ "-s", "--scan-lists",
1648
+ "scan_lists_raw",
1649
+ type=str, multiple=True, required=False,
1650
+ help='Scan targets: inline Python literal or a YAML/JSON spec file path. '
1651
+ 'Multiple inline literals define sequential stages, e.g. '
1652
+ '"[(12,45,1.35)]" "[(10,55,2.20),(23,34,1.80)]". '
1653
+ 'Indices refer to the original full PDB (1-based) or PDB atom selectors like "TYR,285,CA"; '
1654
+ 'they are auto-mapped to the pocket after extraction.',
1655
+ )
1656
+ @click.option("--scan-out-dir", type=click.Path(path_type=Path, file_okay=False), default=None,
1657
+ help="Override the scan output directory (default: <out-dir>/scan/). Relative paths are resolved against the default parent.")
1658
+ @click.option("--scan-one-based", type=click.BOOL, default=None,
1659
+ help="Override scan indexing interpretation (True = 1-based, False = 0-based).")
1660
+ @click.option("--scan-max-step-size", type=float, default=None,
1661
+ help="Override scan --max-step-size (Å).")
1662
+ @click.option("--scan-bias-k", type=float, default=None,
1663
+ help="Override scan harmonic bias strength k (eV/Å^2).")
1664
+ @click.option("--scan-relax-max-cycles", type=int, default=None,
1665
+ help="Override scan relaxation max cycles per step.")
1666
+ @click.option("--scan-preopt", "scan_preopt_override", type=click.BOOL, default=None,
1667
+ help="Override scan --preopt flag.")
1668
+ @click.option("--scan-endopt", "scan_endopt_override", type=click.BOOL, default=None,
1669
+ help="Override scan --endopt flag.")
1670
+ @click.option("--convert-files/--no-convert-files", "convert_files", default=True, show_default=True,
1671
+ help="Convert XYZ/TRJ outputs to PDB format using reference topology; forwarded to all subcommands.")
1672
+ @click.option(
1673
+ "--ref-pdb",
1674
+ "ref_pdb_cli",
1675
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
1676
+ default=None,
1677
+ help=(
1678
+ "Reference PDB for topology/B-factor layer information when -i provides XYZ inputs. "
1679
+ "Used for define-layer, mm_parm, ml_region, and forwarded to downstream tools "
1680
+ "(tsopt, irc, freq, path_search) as --ref-pdb."
1681
+ ),
1682
+ )
1683
+ @click.option(
1684
+ "-b", "--backend",
1685
+ type=click.Choice(["uma", "orb", "mace", "aimnet2"], case_sensitive=False),
1686
+ default=None,
1687
+ show_default=False,
1688
+ help="ML backend for the ONIOM high-level region (default: uma).",
1689
+ )
1690
+ @click.option(
1691
+ "--embedcharge/--no-embedcharge",
1692
+ "embedcharge",
1693
+ default=False,
1694
+ show_default=True,
1695
+ help="Enable xTB point-charge embedding correction for MM→ML environmental effects.",
1696
+ )
1697
+ @click.option(
1698
+ "--embedcharge-cutoff",
1699
+ "embedcharge_cutoff",
1700
+ type=float,
1701
+ default=None,
1702
+ show_default=False,
1703
+ help="Distance cutoff (Å) from ML region for MM point charges in xTB embedding. "
1704
+ "Default: 12.0 Å when --embedcharge is enabled.",
1705
+ )
1706
+ @click.pass_context
1707
+ def cli(
1708
+ ctx: click.Context,
1709
+ input_paths: Sequence[Path],
1710
+ center_spec: Optional[str],
1711
+ out_dir: Path,
1712
+ radius: float,
1713
+ radius_het2het: float,
1714
+ include_h2o: bool,
1715
+ exclude_backbone: bool,
1716
+ add_linkh: bool,
1717
+ selected_resn: str,
1718
+ ligand_charge: Optional[str],
1719
+ charge_override: Optional[int],
1720
+ parm7_override: Optional[Path],
1721
+ model_pdb_override: Optional[Path],
1722
+ mm_ff_set: str,
1723
+ mm_add_ter: bool,
1724
+ mm_keep_temp: bool,
1725
+ mm_ligand_mult: Optional[str],
1726
+ verbose: bool,
1727
+ spin: int,
1728
+ max_nodes: int,
1729
+ max_cycles: int,
1730
+ climb: bool,
1731
+ opt_mode: str,
1732
+ opt_mode_post: Optional[str],
1733
+ dump: bool,
1734
+ refine_path: bool,
1735
+ thresh: Optional[str],
1736
+ thresh_post: str,
1737
+ config_yaml: Optional[Path],
1738
+ show_config: bool,
1739
+ dry_run: bool,
1740
+ pre_opt: bool,
1741
+ hessian_calc_mode: Optional[str],
1742
+ detect_layer: bool,
1743
+ do_tsopt: bool,
1744
+ do_thermo: bool,
1745
+ do_dft: bool,
1746
+ scan_lists_raw: Sequence[str],
1747
+ scan_out_dir: Optional[Path],
1748
+ scan_one_based: Optional[bool],
1749
+ scan_max_step_size: Optional[float],
1750
+ scan_bias_k: Optional[float],
1751
+ scan_relax_max_cycles: Optional[int],
1752
+ scan_preopt_override: Optional[bool],
1753
+ scan_endopt_override: Optional[bool],
1754
+ convert_files: bool,
1755
+ ref_pdb_cli: Optional[Path],
1756
+ backend: Optional[str],
1757
+ embedcharge: bool,
1758
+ embedcharge_cutoff: Optional[float],
1759
+ tsopt_max_cycles: Optional[int],
1760
+ flatten: bool,
1761
+ tsopt_out_dir: Optional[Path],
1762
+ freq_out_dir: Optional[Path],
1763
+ freq_max_write: Optional[int],
1764
+ freq_amplitude_ang: Optional[float],
1765
+ freq_n_frames: Optional[int],
1766
+ freq_sort: Optional[str],
1767
+ freq_temperature: Optional[float],
1768
+ freq_pressure: Optional[float],
1769
+ dft_out_dir: Optional[Path],
1770
+ dft_func_basis: Optional[str],
1771
+ dft_max_cycle: Optional[int],
1772
+ dft_conv_tol: Optional[float],
1773
+ dft_grid_level: Optional[int],
1774
+ ) -> None:
1775
+ """
1776
+ The **all** command composes `extract` → (optional `scan` on pocket) → `path_search` and hides ref-template bookkeeping.
1777
+ It also accepts the sloppy `-i A B C` style like `path_search` does. With single input:
1778
+ - with --scan-lists: run staged scan on the pocket and use stage results as inputs for path_search,
1779
+ - with --tsopt True and no --scan-lists: run TSOPT-only mode (no path_search).
1780
+ """
1781
+ _echo_state.reset()
1782
+
1783
+ time_start = time.perf_counter()
1784
+ command_str = "mlmm all " + " ".join(sys.argv[1:])
1785
+
1786
+ _is_param_explicit = make_is_param_explicit(ctx)
1787
+ dump_override_requested = _is_param_explicit("dump")
1788
+ opt_mode_set = _is_param_explicit("opt_mode")
1789
+ opt_mode_post_set = _is_param_explicit("opt_mode_post")
1790
+
1791
+ config_yaml, override_yaml, _ = resolve_yaml_sources(config_yaml, None, None)
1792
+ args_yaml, merged_yaml_cfg = _build_effective_args_yaml(
1793
+ config_yaml=config_yaml,
1794
+ override_yaml=None,
1795
+ tmp_prefix="mlmm_all_merged_",
1796
+ )
1797
+
1798
+ mm_ff_set = "ff14SB" if str(mm_ff_set).lower().startswith("ff14") else "ff19SB"
1799
+
1800
+ # --- Robustly accept a single "-i" followed by multiple paths (like path_search.cli) ---
1801
+ argv_all = sys.argv[1:]
1802
+ i_vals = collect_single_option_values(argv_all, ("-i", "--input"), label="-i/--input")
1803
+ if i_vals:
1804
+ i_parsed = validate_existing_files(
1805
+ i_vals,
1806
+ option_name="-i/--input",
1807
+ hint="When using '-i', list only existing file paths (multiple paths may follow a single '-i').",
1808
+ )
1809
+ input_paths = tuple(i_parsed)
1810
+
1811
+ scan_vals = collect_single_option_values(argv_all, ("-s", "--scan-lists"), "--scan-lists")
1812
+ if scan_vals:
1813
+ scan_lists_raw = tuple(scan_vals)
1814
+
1815
+ # --------------------------
1816
+ # Validate input count / single-structure modes
1817
+ # --------------------------
1818
+ is_single = (len(input_paths) == 1)
1819
+ has_scan = bool(scan_lists_raw)
1820
+ single_tsopt_mode = (is_single and (not has_scan) and do_tsopt)
1821
+
1822
+ if (len(input_paths) < 2) and (not (is_single and (has_scan or do_tsopt))):
1823
+ raise click.BadParameter(
1824
+ "Provide at least two PDBs with -i/--input in reaction order, "
1825
+ "or use a single PDB with --scan-lists, or a single PDB with --tsopt True."
1826
+ )
1827
+
1828
+ _mode_alias = {
1829
+ "grad": "grad",
1830
+ "hess": "hess",
1831
+ "light": "grad",
1832
+ "heavy": "hess",
1833
+ }
1834
+ opt_mode_norm = _mode_alias.get(str(opt_mode).strip().lower(), "grad")
1835
+ path_search_opt_mode = opt_mode_norm
1836
+ opt_mode_post_norm = (
1837
+ None
1838
+ if opt_mode_post is None
1839
+ else _mode_alias.get(str(opt_mode_post).strip().lower(), "hess")
1840
+ )
1841
+ endpoint_opt_mode_default = (
1842
+ opt_mode_post_norm if (opt_mode_post_set and opt_mode_post_norm is not None)
1843
+ else (opt_mode_norm if opt_mode_set else "hess")
1844
+ )
1845
+ if opt_mode_post_norm in {"grad", "hess"}:
1846
+ tsopt_opt_mode_default = opt_mode_post_norm
1847
+ elif opt_mode_set:
1848
+ tsopt_opt_mode_default = opt_mode_norm
1849
+ else:
1850
+ tsopt_opt_mode_default = "hess"
1851
+ tsopt_overrides: Dict[str, Any] = {}
1852
+ if tsopt_max_cycles is not None:
1853
+ tsopt_overrides["max_cycles"] = int(tsopt_max_cycles)
1854
+ if dump_override_requested:
1855
+ tsopt_overrides["dump"] = bool(dump)
1856
+ if tsopt_out_dir is not None:
1857
+ tsopt_overrides["out_dir"] = tsopt_out_dir
1858
+ if hessian_calc_mode is not None:
1859
+ tsopt_overrides["hessian_calc_mode"] = hessian_calc_mode
1860
+ if opt_mode_post_norm in {"grad", "hess"}:
1861
+ tsopt_overrides["opt_mode"] = opt_mode_post_norm
1862
+ elif opt_mode_set:
1863
+ tsopt_overrides["opt_mode"] = tsopt_opt_mode_default
1864
+ tsopt_overrides["convert_files"] = bool(convert_files)
1865
+ if thresh_post is not None:
1866
+ tsopt_overrides["thresh"] = str(thresh_post)
1867
+ if _is_param_explicit("flatten"):
1868
+ tsopt_overrides["flatten"] = bool(flatten)
1869
+
1870
+ freq_overrides: Dict[str, Any] = {}
1871
+ if freq_max_write is not None:
1872
+ freq_overrides["max_write"] = int(freq_max_write)
1873
+ if freq_amplitude_ang is not None:
1874
+ freq_overrides["amplitude_ang"] = float(freq_amplitude_ang)
1875
+ if freq_n_frames is not None:
1876
+ freq_overrides["n_frames"] = int(freq_n_frames)
1877
+ if freq_sort is not None:
1878
+ freq_overrides["sort"] = freq_sort.lower()
1879
+ if freq_temperature is not None:
1880
+ freq_overrides["temperature"] = float(freq_temperature)
1881
+ if freq_pressure is not None:
1882
+ freq_overrides["pressure"] = float(freq_pressure)
1883
+ if dump_override_requested:
1884
+ freq_overrides["dump"] = bool(dump)
1885
+ if hessian_calc_mode is not None:
1886
+ freq_overrides["hessian_calc_mode"] = hessian_calc_mode
1887
+ freq_overrides["convert_files"] = bool(convert_files)
1888
+
1889
+ dft_overrides: Dict[str, Any] = {}
1890
+ if dft_max_cycle is not None:
1891
+ dft_overrides["max_cycle"] = int(dft_max_cycle)
1892
+ if dft_conv_tol is not None:
1893
+ dft_overrides["conv_tol"] = float(dft_conv_tol)
1894
+ if dft_grid_level is not None:
1895
+ dft_overrides["grid_level"] = int(dft_grid_level)
1896
+ dft_overrides["convert_files"] = bool(convert_files)
1897
+
1898
+ dft_func_basis_use = dft_func_basis or "wb97m-v/def2-tzvpd"
1899
+ dft_method_fallback = dft_func_basis_use
1900
+
1901
+ if show_config or dry_run:
1902
+ config_payload: Dict[str, Any] = {
1903
+ "yaml": {
1904
+ "config": str(config_yaml) if config_yaml else None,
1905
+ "override_yaml": str(override_yaml) if override_yaml else None,
1906
+ "effective_args_yaml": str(args_yaml) if args_yaml else None,
1907
+ },
1908
+ "all": {
1909
+ "inputs": [str(p) for p in input_paths],
1910
+ "center": center_spec,
1911
+ "charge_override": charge_override,
1912
+ "skip_extract": bool(center_spec is None or str(center_spec).strip() == ""),
1913
+ "out_dir": str(out_dir),
1914
+ "spin": int(spin),
1915
+ "max_nodes": int(max_nodes),
1916
+ "max_cycles": int(max_cycles),
1917
+ "climb": bool(climb),
1918
+ "opt_mode": str(opt_mode),
1919
+ "opt_mode_post": (None if opt_mode_post is None else str(opt_mode_post)),
1920
+ "path_search_opt_mode": str(path_search_opt_mode),
1921
+ "endpoint_opt_mode": str(endpoint_opt_mode_default),
1922
+ "dump": bool(dump),
1923
+ "refine_path": bool(refine_path),
1924
+ "thresh": thresh,
1925
+ "thresh_post": thresh_post,
1926
+ "flatten": bool(flatten),
1927
+ "pre_opt": bool(pre_opt),
1928
+ "detect_layer": bool(detect_layer),
1929
+ "tsopt": bool(do_tsopt),
1930
+ "thermo": bool(do_thermo),
1931
+ "dft": bool(do_dft),
1932
+ },
1933
+ "overrides": {
1934
+ "tsopt": tsopt_overrides,
1935
+ "freq": freq_overrides,
1936
+ "dft": dft_overrides,
1937
+ },
1938
+ }
1939
+ if merged_yaml_cfg:
1940
+ config_payload["effective_yaml"] = merged_yaml_cfg
1941
+ _echo_section("=== [all] Effective configuration ===")
1942
+ click.echo(
1943
+ yaml.safe_dump(config_payload, sort_keys=False, allow_unicode=True).rstrip()
1944
+ )
1945
+
1946
+ if dry_run:
1947
+ _echo("[all] Dry-run mode: no extraction/search/post-processing was executed.")
1948
+ _echo(
1949
+ "[all] Planned stages: extract -> mm_parm -> optional scan -> path_search -> optional tsopt/freq/dft."
1950
+ )
1951
+ _echo(format_elapsed("[all] Elapsed for Whole Pipeline", time_start))
1952
+ return
1953
+
1954
+ # --------------------------
1955
+ # Prepare directories
1956
+ # --------------------------
1957
+ out_dir = out_dir.resolve()
1958
+ pockets_dir = out_dir / "pockets"
1959
+ path_dir = out_dir / ("path_search" if refine_path else "path_opt")
1960
+ scan_dir = _resolve_override_dir(out_dir / "scan", scan_out_dir) # for single-structure scan mode
1961
+ ensure_dir(out_dir)
1962
+ if not single_tsopt_mode:
1963
+ ensure_dir(path_dir) # path_search might be skipped only in tsopt-only mode
1964
+
1965
+ # --------------------------
1966
+ # Preflight: add_elem_info only for inputs lacking element fields
1967
+ # → Create fixed copies under a temporary folder inside out_dir (used ONLY for extraction)
1968
+ # --------------------------
1969
+ elem_tmp_dir = out_dir / "add_elem_info"
1970
+ inputs_for_extract: List[Path] = []
1971
+ elem_fix_echo=False
1972
+ for p in input_paths:
1973
+ if _pdb_needs_elem_fix(p):
1974
+ if elem_fix_echo==False:
1975
+ _echo_section("=== [all] Preflight — add_elem_info (only when element fields are missing) ===")
1976
+ elem_fix_echo=True
1977
+ ensure_dir(elem_tmp_dir)
1978
+ out_p = (elem_tmp_dir / p.name).resolve()
1979
+ try:
1980
+ _assign_elem_info(str(p), str(out_p), overwrite=False)
1981
+ _echo(f"[all] add_elem_info: fixed elements → {out_p}")
1982
+ inputs_for_extract.append(out_p)
1983
+ except SystemExit as e:
1984
+ code = getattr(e, "code", 1)
1985
+ _echo(f"[all] WARNING: add_elem_info exited with code {code} for {p}; using original.", err=True)
1986
+ inputs_for_extract.append(p.resolve())
1987
+ except Exception as e:
1988
+ _echo(f"[all] WARNING: add_elem_info failed for {p}: {e} — using original file.", err=True)
1989
+ inputs_for_extract.append(p.resolve())
1990
+ else:
1991
+ inputs_for_extract.append(p.resolve())
1992
+
1993
+ extract_inputs = tuple(inputs_for_extract)
1994
+ skip_extract = center_spec is None or str(center_spec).strip() == ""
1995
+
1996
+ # When inputs are XYZ and --ref-pdb is provided, use it for topology-requiring steps
1997
+ ref_pdb_for_topology: Optional[Path] = None
1998
+ if ref_pdb_cli is not None:
1999
+ ref_pdb_for_topology = ref_pdb_cli.resolve()
2000
+ _echo(f"[all] --ref-pdb provided: {ref_pdb_for_topology}")
2001
+
2002
+ resolved_charge: Optional[int] = None
2003
+ pocket_outputs: List[Path] = []
2004
+
2005
+ if skip_extract:
2006
+ _echo_section(
2007
+ "=== [all] Stage 1/3 — Extraction skipped (no -c/--center); using full structures as pockets ==="
2008
+ )
2009
+ pocket_outputs = [p.resolve() for p in extract_inputs]
2010
+ _echo("[all] Pocket inputs (full structures):")
2011
+ for op in pocket_outputs:
2012
+ _echo(f" - {op}")
2013
+ # Use --model-pdb for charge derivation when provided (pocket charge),
2014
+ # otherwise fall back to full input PDB.
2015
+ charge_source_pdb = model_pdb_override if model_pdb_override is not None else extract_inputs[0]
2016
+ resolved_charge = _derive_charge_from_ligand_charge_when_extract_skipped(
2017
+ charge_source_pdb, ligand_charge
2018
+ )
2019
+ else:
2020
+ _echo_section(
2021
+ "=== [all] Stage 1/3 — Active-site pocket extraction (multi-structure union when applicable) ==="
2022
+ )
2023
+ ensure_dir(pockets_dir)
2024
+ for p in extract_inputs:
2025
+ pocket_outputs.append((pockets_dir / f"pocket_{p.stem}.pdb").resolve())
2026
+
2027
+ try:
2028
+ ex_res = extract_api(
2029
+ complex_pdb=[str(p) for p in extract_inputs],
2030
+ center=center_spec,
2031
+ output=[str(p) for p in pocket_outputs],
2032
+ radius=float(radius),
2033
+ radius_het2het=float(radius_het2het),
2034
+ include_H2O=bool(include_h2o),
2035
+ exclude_backbone=bool(exclude_backbone),
2036
+ add_linkH=bool(add_linkh),
2037
+ selected_resn=selected_resn or "",
2038
+ ligand_charge=ligand_charge,
2039
+ verbose=bool(verbose),
2040
+ )
2041
+ except Exception as e:
2042
+ raise click.ClickException(f"[all] Extractor failed: {e}")
2043
+
2044
+ _echo("[all] Pocket files:")
2045
+ for op in pocket_outputs:
2046
+ _echo(f" - {op}")
2047
+
2048
+ try:
2049
+ cs = ex_res.get("charge_summary", {})
2050
+ q_total = float(cs.get("total_charge", 0.0))
2051
+ q_prot = float(cs.get("protein_charge", 0.0))
2052
+ q_lig = float(cs.get("ligand_total_charge", 0.0))
2053
+ q_ion = float(cs.get("ion_total_charge", 0.0))
2054
+ _echo("\n[all] Charge summary from extractor (model #1):")
2055
+ _echo(
2056
+ f" Protein: {q_prot:+g}, Ligand: {q_lig:+g}, Ions: {q_ion:+g}, Total: {q_total:+g}"
2057
+ )
2058
+ resolved_charge = _round_charge_with_note(q_total)
2059
+ except Exception as e:
2060
+ raise click.ClickException(f"[all] Could not obtain total charge from extractor: {e}")
2061
+
2062
+ if charge_override is not None:
2063
+ q_int = int(charge_override)
2064
+ override_msg = f"[all] WARNING: -q/--charge override supplied; forcing TOTAL system charge to {q_int:+d}"
2065
+ if resolved_charge is not None:
2066
+ override_msg += f" (would otherwise use {int(resolved_charge):+d} from workflow)"
2067
+ _echo(override_msg)
2068
+ else:
2069
+ if resolved_charge is None:
2070
+ raise click.ClickException(
2071
+ "[all] Total charge could not be resolved. Provide -q/--charge, "
2072
+ "or provide --ligand-charge when extraction is skipped."
2073
+ )
2074
+ q_int = int(resolved_charge)
2075
+
2076
+ # --------------------------
2077
+ # Stage 1b: ML-region definition (copy first pocket) and mm_parm on the first full input
2078
+ # --------------------------
2079
+ _echo_section("=== [all] ML/MM preparation — ML region + parm7 ===")
2080
+ first_pocket = pocket_outputs[0]
2081
+ first_full_input = extract_inputs[0]
2082
+ # When --ref-pdb is provided, use it for PDB-requiring topology operations
2083
+ pocket_for_ml_region = ref_pdb_for_topology if ref_pdb_for_topology is not None else first_pocket
2084
+ pdb_for_mm_parm = ref_pdb_for_topology if ref_pdb_for_topology is not None else first_full_input
2085
+
2086
+ # ML region definition: use --model-pdb if provided, otherwise generate from pocket
2087
+ if model_pdb_override is not None:
2088
+ ml_region_pdb = model_pdb_override.resolve()
2089
+ _echo(f"[all] ML region definition (--model-pdb override) → {ml_region_pdb}")
2090
+ else:
2091
+ ml_region_pdb = _write_ml_region_definition(pocket_for_ml_region, out_dir / "ml_region.pdb")
2092
+ _echo(f"[all] ML region definition → {ml_region_pdb}")
2093
+
2094
+ # mm_parm: use --parm if provided, otherwise run tleap
2095
+ if parm7_override is not None:
2096
+ real_parm7_path = parm7_override.resolve()
2097
+ _echo(f"[all] parm7 (--parm override) → {real_parm7_path}")
2098
+ else:
2099
+ _echo(f"[all] mm_parm source PDB → {pdb_for_mm_parm}")
2100
+ mm_dir = out_dir / "mm_parm"
2101
+ ensure_commands_available(
2102
+ ("tleap", "antechamber", "parmchk2"),
2103
+ context="mm_parm (AmberTools)",
2104
+ )
2105
+ real_parm7_path, real_rst7_path = _build_mm_parm7(
2106
+ pdb=pdb_for_mm_parm,
2107
+ ligand_charge_expr=ligand_charge,
2108
+ ligand_mult_expr=mm_ligand_mult,
2109
+ out_dir=mm_dir,
2110
+ ff_set=mm_ff_set,
2111
+ add_ter=mm_add_ter,
2112
+ keep_temp=mm_keep_temp,
2113
+ )
2114
+ _echo(f"[all] mm_parm outputs → parm7: {real_parm7_path.name}, rst7: {real_rst7_path.name}")
2115
+
2116
+ # --------------------------
2117
+ # define-layer: assign 3-layer B-factors to each full-system PDB
2118
+ # --------------------------
2119
+ _echo_section("=== [all] define-layer — assign 3-layer B-factors to full-system PDBs ===")
2120
+ layered_dir = out_dir / "layered"
2121
+ ensure_dir(layered_dir)
2122
+ layered_inputs: List[Path] = []
2123
+ for idx, full_pdb in enumerate(extract_inputs):
2124
+ # When --ref-pdb is given and input is not PDB, use ref_pdb for define-layer
2125
+ pdb_for_layer = full_pdb
2126
+ if ref_pdb_for_topology is not None and full_pdb.suffix.lower() != ".pdb":
2127
+ pdb_for_layer = ref_pdb_for_topology
2128
+ out_layered = layered_dir / f"{pdb_for_layer.stem}_layered.pdb"
2129
+ try:
2130
+ layer_info = _define_layers(
2131
+ input_pdb=pdb_for_layer,
2132
+ output_pdb=out_layered,
2133
+ model_pdb=ml_region_pdb,
2134
+ )
2135
+ _echo(f"[all] define-layer [{idx}]: {full_pdb.name} → {out_layered.name} "
2136
+ f"(ML={len(layer_info.get('ml_indices', []))}, "
2137
+ f"MovableMM={len(layer_info.get('movable_mm_indices', []))}, "
2138
+ f"FrozenMM={len(layer_info.get('frozen_indices', []))})")
2139
+ layered_inputs.append(out_layered)
2140
+ except Exception as e:
2141
+ _echo(f"[all] WARNING: define-layer failed for {full_pdb.name}: {e}", err=True)
2142
+ _echo(f"[all] Falling back to original PDB (no B-factor layers).", err=True)
2143
+ layered_inputs.append(full_pdb)
2144
+
2145
+ # --------------------------
2146
+ # Other path: single-structure + --tsopt True (and NO scan-lists) → TSOPT-only mode
2147
+ # --------------------------
2148
+ irc_trj_for_all: List[Tuple[Path, bool]] = []
2149
+
2150
+ if single_tsopt_mode:
2151
+ _echo_section("=== [all] TSOPT-only single-structure mode ===")
2152
+ tsroot = out_dir / "tsopt_single"
2153
+ ensure_dir(tsroot)
2154
+
2155
+ # Use the layered full-system PDB as TS initial guess
2156
+ layered_pdb = layered_inputs[0]
2157
+ # When --ref-pdb is given and input is XYZ, copy the XYZ next to the layered PDB
2158
+ # so that _run_tsopt_on_hei can use XYZ (full precision) + layered PDB (topology)
2159
+ if ref_pdb_for_topology is not None and extract_inputs[0].suffix.lower() != ".pdb":
2160
+ xyz_companion = layered_pdb.with_suffix(".xyz")
2161
+ if not xyz_companion.exists():
2162
+ shutil.copy2(extract_inputs[0], xyz_companion)
2163
+ _echo(f"[all] Copied XYZ input → {xyz_companion} (full precision for tsopt)")
2164
+ # TS optimization
2165
+ ts_pdb, g_ts = _run_tsopt_on_hei(
2166
+ layered_pdb,
2167
+ q_int,
2168
+ spin,
2169
+ real_parm7_path,
2170
+ ml_region_pdb,
2171
+ detect_layer,
2172
+ args_yaml,
2173
+ tsroot,
2174
+ tsopt_opt_mode_default,
2175
+ overrides=tsopt_overrides,
2176
+ backend=backend,
2177
+ embedcharge=embedcharge,
2178
+ embedcharge_cutoff=embedcharge_cutoff,
2179
+ ref_pdb=layered_pdb,
2180
+ )
2181
+
2182
+ # EulerPC IRC & map endpoints (no segment endpoints exist → fallback mapping)
2183
+ irc_pocket_ref = ref_pdb_for_topology if ref_pdb_for_topology is not None else first_pocket
2184
+ irc_res = _irc_and_match(seg_idx=1,
2185
+ seg_dir=tsroot,
2186
+ ref_pdb_for_seg=ts_pdb,
2187
+ seg_pocket_pdb=irc_pocket_ref,
2188
+ g_ts=g_ts,
2189
+ q_int=q_int,
2190
+ spin=spin,
2191
+ real_parm7=real_parm7_path,
2192
+ model_pdb=ml_region_pdb,
2193
+ detect_layer=detect_layer,
2194
+ backend=backend,
2195
+ embedcharge=embedcharge,
2196
+ embedcharge_cutoff=embedcharge_cutoff,
2197
+ args_yaml=args_yaml)
2198
+ gL = irc_res["left_min_geom"]
2199
+ gR = irc_res["right_min_geom"]
2200
+ gT = irc_res["ts_geom"]
2201
+ irc_plot_path = irc_res.get("irc_plot")
2202
+ irc_trj_path = irc_res.get("irc_trj")
2203
+ if irc_trj_path:
2204
+ try:
2205
+ irc_trj_for_all.append((Path(irc_trj_path), bool(irc_res.get("reverse_irc", False))))
2206
+ except Exception:
2207
+ logger.debug("Failed to append IRC trajectory path", exc_info=True)
2208
+
2209
+ # Ensure UMA energies
2210
+ eL = float(gL.energy)
2211
+ eT = float(gT.energy)
2212
+ eR = float(gR.energy)
2213
+
2214
+ # In this mode ONLY: assign Reactant/Product so that higher-energy end is the Reactant
2215
+ if eL >= eR:
2216
+ g_react, e_react = gL, eL
2217
+ g_prod, e_prod = gR, eR
2218
+ else:
2219
+ g_react, e_react = gR, eR
2220
+ g_prod, e_prod = gL, eL
2221
+
2222
+ # Save XYZ (full precision) + PDB (companion) and run endpoint-opt
2223
+ struct_dir = tsroot / "structures"
2224
+ ensure_dir(struct_dir)
2225
+ pocket_ref = ref_pdb_for_topology if ref_pdb_for_topology is not None else first_pocket
2226
+ xR_irc, pR_irc = _save_single_geom_for_tools(g_react, pocket_ref, struct_dir, "reactant_irc")
2227
+ xT, pT = _save_single_geom_for_tools(gT, pocket_ref, struct_dir, "ts")
2228
+ xP_irc, pP_irc = _save_single_geom_for_tools(g_prod, pocket_ref, struct_dir, "product_irc")
2229
+
2230
+ endpoint_opt_dir = tsroot / "endpoint_opt"
2231
+ ensure_dir(endpoint_opt_dir)
2232
+
2233
+ # Map IRC left/right Hessians → R/P endpoint (left=forward, right=backward)
2234
+ from .hessian_cache import load as _hess_load, store as _hess_store, clear as _clear_hess_cache
2235
+ _react_hk = "irc_left" if eL >= eR else "irc_right"
2236
+ _prod_hk = "irc_right" if eL >= eR else "irc_left"
2237
+
2238
+ _c = _hess_load(_react_hk)
2239
+ if _c:
2240
+ _hess_store("irc_endpoint", _c["hessian"], active_dofs=_c.get("active_dofs"), meta=_c.get("meta"))
2241
+ try:
2242
+ g_react, _ = _run_opt_for_state(
2243
+ pR_irc, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
2244
+ endpoint_opt_dir / "R", args_yaml, endpoint_opt_mode_default,
2245
+ convert_files=convert_files,
2246
+ backend=backend,
2247
+ embedcharge=embedcharge,
2248
+ embedcharge_cutoff=embedcharge_cutoff,
2249
+ thresh=thresh_post,
2250
+ xyz_path=xR_irc,
2251
+ )
2252
+ except Exception as e:
2253
+ _echo(
2254
+ f"[post] WARNING: Reactant endpoint optimization failed in TSOPT-only mode: {e}",
2255
+ err=True,
2256
+ )
2257
+
2258
+ _c = _hess_load(_prod_hk)
2259
+ if _c:
2260
+ _hess_store("irc_endpoint", _c["hessian"], active_dofs=_c.get("active_dofs"), meta=_c.get("meta"))
2261
+ try:
2262
+ g_prod, _ = _run_opt_for_state(
2263
+ pP_irc, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
2264
+ endpoint_opt_dir / "P", args_yaml, endpoint_opt_mode_default,
2265
+ convert_files=convert_files,
2266
+ backend=backend,
2267
+ embedcharge=embedcharge,
2268
+ embedcharge_cutoff=embedcharge_cutoff,
2269
+ thresh=thresh_post,
2270
+ xyz_path=xP_irc,
2271
+ )
2272
+ except Exception as e:
2273
+ _echo(
2274
+ f"[post] WARNING: Product endpoint optimization failed in TSOPT-only mode: {e}",
2275
+ err=True,
2276
+ )
2277
+ shutil.rmtree(endpoint_opt_dir, ignore_errors=True)
2278
+ _echo("[endpoint-opt] Clean endpoint-opt working dir.")
2279
+
2280
+ xR, pR = _save_single_geom_for_tools(g_react, pocket_ref, struct_dir, "reactant")
2281
+ xP, pP = _save_single_geom_for_tools(g_prod, pocket_ref, struct_dir, "product")
2282
+ e_react = float(g_react.energy)
2283
+ e_prod = float(g_prod.energy)
2284
+
2285
+ # UMA energy diagram (R, TS, P)
2286
+ uma_prefix = tsroot / "energy_diagram_UMA"
2287
+ uma_diag = _write_segment_energy_diagram(
2288
+ uma_prefix,
2289
+ labels=["R", "TS", "P"],
2290
+ energies_eh=[e_react, eT, e_prod],
2291
+ title_note="(UMA, TSOPT/IRC)",
2292
+ )
2293
+ g_uma_diag = None
2294
+ dft_diag = None
2295
+ g_dft_diag = None
2296
+
2297
+ # ── Release GPU memory before freq/thermo/DFT ──
2298
+ for _g in (gL, gR, gT, g_react, g_prod):
2299
+ if _g is not None and hasattr(_g, "calculator"):
2300
+ _g.calculator = None
2301
+ gc.collect()
2302
+ if torch.cuda.is_available():
2303
+ torch.cuda.empty_cache()
2304
+
2305
+ # Thermochemistry (UMA) Gibbs
2306
+ thermo_payloads: Dict[str, Dict[str, Any]] = {}
2307
+ GR = GT = GP = None
2308
+ eR_dft = eT_dft = eP_dft = None
2309
+ GR_dftUMA = GT_dftUMA = GP_dftUMA = None
2310
+ freq_root = _resolve_override_dir(tsroot / "freq", freq_out_dir)
2311
+ dft_root = _resolve_override_dir(tsroot / "dft", dft_out_dir)
2312
+
2313
+ if do_thermo:
2314
+ _echo(f"[thermo] Single TSOPT: freq on TS/R/P")
2315
+ tT = _run_freq_for_state(pT, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
2316
+ freq_root / "TS", args_yaml, overrides=freq_overrides,
2317
+ backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff, xyz_path=xT)
2318
+ _clear_hess_cache() # TS Hessian consumed; R/P need exact computation
2319
+ tR = _run_freq_for_state(pR, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
2320
+ freq_root / "R", args_yaml, overrides=freq_overrides,
2321
+ backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff, xyz_path=xR)
2322
+ tP = _run_freq_for_state(pP, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
2323
+ freq_root / "P", args_yaml, overrides=freq_overrides,
2324
+ backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff, xyz_path=xP)
2325
+ thermo_payloads = {"R": tR, "TS": tT, "P": tP}
2326
+ try:
2327
+ GR = float(tR.get("sum_EE_and_thermal_free_energy_ha", e_react))
2328
+ GT = float(tT.get("sum_EE_and_thermal_free_energy_ha", eT))
2329
+ GP = float(tP.get("sum_EE_and_thermal_free_energy_ha", e_prod))
2330
+ g_uma_diag = _write_segment_energy_diagram(
2331
+ tsroot / "energy_diagram_G_UMA",
2332
+ labels=["R", "TS", "P"],
2333
+ energies_eh=[GR, GT, GP],
2334
+ title_note="(Gibbs, UMA)",
2335
+ ylabel="ΔG (kcal/mol)",
2336
+ )
2337
+ except Exception as e:
2338
+ _echo(f"[thermo] WARNING: failed to build Gibbs diagram: {e}", err=True)
2339
+
2340
+ # DFT & DFT//UMA
2341
+ if do_dft:
2342
+ _echo(f"[dft] Single TSOPT: DFT on R/TS/P")
2343
+ dR = _run_dft_for_state(pR, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
2344
+ dft_root / "R", args_yaml, func_basis=dft_func_basis_use, overrides=dft_overrides,
2345
+ backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff, xyz_path=xR)
2346
+ dT = _run_dft_for_state(pT, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
2347
+ dft_root / "TS", args_yaml, func_basis=dft_func_basis_use, overrides=dft_overrides,
2348
+ backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff, xyz_path=xT)
2349
+ dP = _run_dft_for_state(pP, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
2350
+ dft_root / "P", args_yaml, func_basis=dft_func_basis_use, overrides=dft_overrides,
2351
+ backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff, xyz_path=xP)
2352
+ try:
2353
+ eR_dft = float(dR.get("energy", {}).get("hartree", e_react) if dR else e_react)
2354
+ eT_dft = float(dT.get("energy", {}).get("hartree", eT) if dT else eT)
2355
+ eP_dft = float(dP.get("energy", {}).get("hartree", e_prod) if dP else e_prod)
2356
+ dft_diag = _write_segment_energy_diagram(
2357
+ tsroot / "energy_diagram_DFT",
2358
+ labels=["R", "TS", "P"],
2359
+ energies_eh=[eR_dft, eT_dft, eP_dft],
2360
+ title_note=f"({dft_method_fallback})",
2361
+ )
2362
+ except Exception as e:
2363
+ _echo(f"[dft] WARNING: failed to build DFT diagram: {e}", err=True)
2364
+
2365
+ if do_thermo:
2366
+ try:
2367
+ dG_R = float(thermo_payloads.get("R", {}).get("thermal_correction_free_energy_ha", 0.0))
2368
+ dG_T = float(thermo_payloads.get("TS", {}).get("thermal_correction_free_energy_ha", 0.0))
2369
+ dG_P = float(thermo_payloads.get("P", {}).get("thermal_correction_free_energy_ha", 0.0))
2370
+ GR_dftUMA = eR_dft + dG_R
2371
+ GT_dftUMA = eT_dft + dG_T
2372
+ GP_dftUMA = eP_dft + dG_P
2373
+ g_dft_diag = _write_segment_energy_diagram(
2374
+ tsroot / "energy_diagram_G_DFT_plus_UMA",
2375
+ labels=["R", "TS", "P"],
2376
+ energies_eh=[GR_dftUMA, GT_dftUMA, GP_dftUMA],
2377
+ title_note="(Gibbs, DFT//UMA)",
2378
+ ylabel="ΔG (kcal/mol)",
2379
+ )
2380
+ except Exception as e:
2381
+ _echo(f"[dft//uma] WARNING: failed to build DFT//UMA Gibbs diagram: {e}", err=True)
2382
+
2383
+ # Summary.yaml / summary.log for TSOPT-only mode
2384
+ bond_cfg = dict(_path_search.BOND_KW)
2385
+ bond_summary = ""
2386
+ try:
2387
+ changed, bond_summary = _path_search._has_bond_change(g_react, g_prod, bond_cfg)
2388
+ if not changed:
2389
+ bond_summary = "(no covalent changes detected)"
2390
+ except Exception:
2391
+ bond_summary = "(no covalent changes detected)"
2392
+
2393
+ barrier = (eT - e_react) * AU2KCALPERMOL
2394
+ delta = (e_prod - e_react) * AU2KCALPERMOL
2395
+
2396
+ energy_diagrams: List[Dict[str, Any]] = []
2397
+ def _promote_all(diag: Optional[Dict[str, Any]], stem: str) -> None:
2398
+ if not diag:
2399
+ return
2400
+ diag_all = dict(diag)
2401
+ diag_all["name"] = f"{stem}_all"
2402
+ diag_all["image"] = str(out_dir / f"{stem}_all.png")
2403
+ energy_diagrams.append(diag_all)
2404
+
2405
+ _promote_all(uma_diag, "energy_diagram_UMA")
2406
+ _promote_all(g_uma_diag, "energy_diagram_G_UMA")
2407
+ _promote_all(dft_diag, "energy_diagram_DFT")
2408
+ _promote_all(g_dft_diag, "energy_diagram_G_DFT_plus_UMA")
2409
+
2410
+ summary = {
2411
+ "out_dir": str(tsroot),
2412
+ "n_images": 0,
2413
+ "n_segments": 1,
2414
+ "segments": [
2415
+ {
2416
+ "index": 1,
2417
+ "tag": "seg_01",
2418
+ "kind": "tsopt",
2419
+ "barrier_kcal": float(barrier),
2420
+ "delta_kcal": float(delta),
2421
+ "bond_changes": bond_summary,
2422
+ }
2423
+ ],
2424
+ "energy_diagrams": list(energy_diagrams),
2425
+ }
2426
+ try:
2427
+ with open(tsroot / "summary.yaml", "w") as f:
2428
+ yaml.safe_dump(summary, f, sort_keys=False, allow_unicode=True)
2429
+ shutil.copy2(tsroot / "summary.yaml", out_dir / "summary.yaml")
2430
+ except Exception as e:
2431
+ _echo(f"[write] WARNING: failed to write summary.yaml: {e}", err=True)
2432
+
2433
+ segment_log: Dict[str, Any] = {
2434
+ "index": 1,
2435
+ "tag": "seg_01",
2436
+ "kind": "tsopt",
2437
+ "bond_changes": bond_summary,
2438
+ "post_dir": str(tsroot),
2439
+ "mep_barrier_kcal": barrier,
2440
+ "mep_delta_kcal": delta,
2441
+ }
2442
+ if irc_plot_path:
2443
+ segment_log["irc_plot"] = str(irc_plot_path)
2444
+ if irc_trj_path:
2445
+ segment_log["irc_traj"] = str(irc_trj_path)
2446
+ if do_thermo:
2447
+ n_imag = None
2448
+ try:
2449
+ n_imag = int(thermo_payloads.get("TS", {}).get("num_imag_freq"))
2450
+ except Exception:
2451
+ n_imag = None
2452
+ if n_imag is not None:
2453
+ segment_log["ts_imag"] = {"n_imag": n_imag}
2454
+
2455
+ segment_log["uma"] = {
2456
+ "labels": ["R", "TS", "P"],
2457
+ "energies_au": [e_react, eT, e_prod],
2458
+ "energies_kcal": [0.0, barrier, delta],
2459
+ "diagram": str((tsroot / "energy_diagram_UMA").with_suffix(".png")),
2460
+ "structures": {"R": pR, "TS": pT, "P": pP},
2461
+ "barrier_kcal": barrier,
2462
+ "delta_kcal": delta,
2463
+ }
2464
+ if GR is not None and GT is not None and GP is not None:
2465
+ segment_log["gibbs_uma"] = {
2466
+ "labels": ["R", "TS", "P"],
2467
+ "energies_au": [GR, GT, GP],
2468
+ "energies_kcal": [
2469
+ 0.0,
2470
+ (GT - GR) * AU2KCALPERMOL,
2471
+ (GP - GR) * AU2KCALPERMOL,
2472
+ ],
2473
+ "diagram": str((tsroot / "energy_diagram_G_UMA").with_suffix(".png")),
2474
+ "structures": {"R": pR, "TS": pT, "P": pP},
2475
+ "barrier_kcal": (GT - GR) * AU2KCALPERMOL,
2476
+ "delta_kcal": (GP - GR) * AU2KCALPERMOL,
2477
+ }
2478
+ if eR_dft is not None and eT_dft is not None and eP_dft is not None:
2479
+ segment_log["dft"] = {
2480
+ "labels": ["R", "TS", "P"],
2481
+ "energies_au": [eR_dft, eT_dft, eP_dft],
2482
+ "energies_kcal": [
2483
+ 0.0,
2484
+ (eT_dft - eR_dft) * AU2KCALPERMOL,
2485
+ (eP_dft - eR_dft) * AU2KCALPERMOL,
2486
+ ],
2487
+ "diagram": str((tsroot / "energy_diagram_DFT").with_suffix(".png")),
2488
+ "structures": {"R": pR, "TS": pT, "P": pP},
2489
+ "barrier_kcal": (eT_dft - eR_dft) * AU2KCALPERMOL,
2490
+ "delta_kcal": (eP_dft - eR_dft) * AU2KCALPERMOL,
2491
+ }
2492
+ if GR_dftUMA is not None and GT_dftUMA is not None and GP_dftUMA is not None:
2493
+ segment_log["gibbs_dft_uma"] = {
2494
+ "labels": ["R", "TS", "P"],
2495
+ "energies_au": [GR_dftUMA, GT_dftUMA, GP_dftUMA],
2496
+ "energies_kcal": [
2497
+ 0.0,
2498
+ (GT_dftUMA - GR_dftUMA) * AU2KCALPERMOL,
2499
+ (GP_dftUMA - GR_dftUMA) * AU2KCALPERMOL,
2500
+ ],
2501
+ "diagram": str((tsroot / "energy_diagram_G_DFT_plus_UMA").with_suffix(".png")),
2502
+ "structures": {"R": pR, "TS": pT, "P": pP},
2503
+ "barrier_kcal": (GT_dftUMA - GR_dftUMA) * AU2KCALPERMOL,
2504
+ "delta_kcal": (GP_dftUMA - GR_dftUMA) * AU2KCALPERMOL,
2505
+ }
2506
+
2507
+ summary_payload = {
2508
+ "root_out_dir": str(out_dir),
2509
+ "path_dir": str(tsroot),
2510
+ "path_module_dir": "tsopt_single",
2511
+ "pipeline_mode": "tsopt-only",
2512
+ "refine_path": bool(refine_path),
2513
+ "thresh": thresh,
2514
+ "thresh_post": thresh_post,
2515
+ "flatten": bool(flatten),
2516
+ "tsopt": do_tsopt,
2517
+ "thermo": do_thermo,
2518
+ "dft": do_dft,
2519
+ "opt_mode": tsopt_opt_mode_default,
2520
+ "mep_mode": "tsopt-only",
2521
+ "uma_model": None,
2522
+ "command": command_str,
2523
+ "charge": q_int,
2524
+ "spin": spin,
2525
+ "mep": {"n_images": 0, "n_segments": 1},
2526
+ "segments": summary.get("segments", []),
2527
+ "energy_diagrams": list(energy_diagrams),
2528
+ "post_segments": [segment_log],
2529
+ "key_files": {},
2530
+ }
2531
+ try:
2532
+ write_summary_log(tsroot / "summary.log", summary_payload)
2533
+ shutil.copy2(tsroot / "summary.log", out_dir / "summary.log")
2534
+ except Exception as e:
2535
+ _echo(f"[write] WARNING: failed to write summary.log: {e}", err=True)
2536
+
2537
+ try:
2538
+ for stem in (
2539
+ "energy_diagram_UMA",
2540
+ "energy_diagram_G_UMA",
2541
+ "energy_diagram_DFT",
2542
+ "energy_diagram_G_DFT_plus_UMA",
2543
+ ):
2544
+ src = tsroot / f"{stem}.png"
2545
+ if src.exists():
2546
+ shutil.copy2(src, out_dir / f"{stem}_all.png")
2547
+ except Exception as e:
2548
+ _echo(f"[all] WARNING: failed to copy *_all diagrams: {e}", err=True)
2549
+
2550
+ try:
2551
+ if irc_plot_path:
2552
+ irc_plot_src = Path(irc_plot_path)
2553
+ if irc_plot_src.exists():
2554
+ shutil.copy2(irc_plot_src, out_dir / "irc_plot_all.png")
2555
+ except Exception as e:
2556
+ _echo(f"[all] WARNING: failed to copy irc_plot_all.png: {e}", err=True)
2557
+
2558
+ # summary.md and key_* outputs are disabled.
2559
+ _echo_section("=== [all] TSOPT-only pipeline finished successfully ===")
2560
+ _echo(format_elapsed("[all] Elapsed for Whole Pipeline", time_start))
2561
+ return
2562
+
2563
+ # --------------------------
2564
+ # Stage 1c: Optional scan (single-structure only) to build ordered pocket inputs
2565
+ # --------------------------
2566
+ pockets_for_path: List[Path]
2567
+ if is_single and has_scan:
2568
+ _echo_section("=== [all] Stage 1c — Staged scan on layered full-system PDB (single-structure mode) ===")
2569
+ ensure_dir(scan_dir)
2570
+ layered_pdb = Path(layered_inputs[0]).resolve()
2571
+ full_input_pdb = Path(input_paths[0]).resolve()
2572
+ # Use the layered full-system PDB for scan (no pocket index remapping needed)
2573
+ full_atom_meta = load_pdb_atom_metadata(full_input_pdb)
2574
+ converted_scan_stages = _parse_scan_lists_literals(scan_lists_raw, atom_meta=full_atom_meta)
2575
+ scan_one_based_effective = True if scan_one_based is None else bool(scan_one_based)
2576
+ scan_stage_literals: List[str] = []
2577
+ for stage in converted_scan_stages:
2578
+ if scan_one_based_effective:
2579
+ stage_use = stage
2580
+ else:
2581
+ stage_use = [(i - 1, j - 1, target) for (i, j, target) in stage]
2582
+ scan_stage_literals.append(_format_scan_stage(stage_use))
2583
+ _echo("[all] Remapped --scan-lists indices from the full PDB to the pocket ordering.")
2584
+ scan_preopt_use = pre_opt if scan_preopt_override is None else bool(scan_preopt_override)
2585
+ scan_endopt_use = False if scan_endopt_override is None else bool(scan_endopt_override)
2586
+ scan_opt_mode_use = path_search_opt_mode
2587
+
2588
+ scan_args: List[str] = [
2589
+ "-i", str(layered_pdb),
2590
+ "--parm", str(real_parm7_path),
2591
+ "-q", str(int(q_int)),
2592
+ "-m", str(int(spin)),
2593
+ "--out-dir", str(scan_dir),
2594
+ "--preopt" if scan_preopt_use else "--no-preopt",
2595
+ "--endopt" if scan_endopt_use else "--no-endopt",
2596
+ "--opt-mode", str(scan_opt_mode_use),
2597
+ ]
2598
+ scan_args.append("--detect-layer" if detect_layer else "--no-detect-layer")
2599
+
2600
+ if dump_override_requested:
2601
+ scan_args.append("--dump" if dump else "--no-dump")
2602
+
2603
+ if scan_one_based is not None:
2604
+ scan_args.append("--one-based" if scan_one_based else "--zero-based")
2605
+
2606
+ _append_cli_arg(scan_args, "--max-step-size", scan_max_step_size)
2607
+ _append_cli_arg(scan_args, "--bias-k", scan_bias_k)
2608
+ _append_cli_arg(scan_args, "--relax-max-cycles", scan_relax_max_cycles)
2609
+ scan_args.append("--convert-files" if convert_files else "--no-convert-files")
2610
+ if thresh is not None:
2611
+ scan_args.extend(["--thresh", str(thresh)])
2612
+ if args_yaml is not None:
2613
+ scan_args.extend(["--config", str(args_yaml)])
2614
+ # Forward all converted --scan-lists (aligned to the pocket atom order)
2615
+ if scan_stage_literals:
2616
+ scan_args.append("--scan-lists")
2617
+ scan_args.extend(scan_stage_literals)
2618
+
2619
+ if backend is not None:
2620
+ scan_args.extend(["--backend", str(backend)])
2621
+ if embedcharge:
2622
+ scan_args.append("--embedcharge")
2623
+ if embedcharge_cutoff is not None:
2624
+ scan_args.extend(["--embedcharge-cutoff", str(embedcharge_cutoff)])
2625
+ else:
2626
+ scan_args.append("--no-embedcharge")
2627
+
2628
+ _echo("[all] Invoking scan with arguments:")
2629
+ _echo(" " + " ".join(scan_args))
2630
+
2631
+ _run_cli_main("scan", _scan_cli.cli, scan_args, on_nonzero="raise", on_exception="raise", prefix="all")
2632
+
2633
+ # Collect stage results — prefer XYZ (full precision), keep PDB as ref for topology
2634
+ stage_results: List[Path] = []
2635
+ stage_refs: List[Path] = []
2636
+ for st in sorted(scan_dir.glob("stage_*")):
2637
+ if not st.is_dir():
2638
+ continue
2639
+ xyz = st / "result.xyz"
2640
+ pdb = st / "result.pdb"
2641
+ if xyz.exists():
2642
+ stage_results.append(xyz.resolve())
2643
+ stage_refs.append(pdb.resolve() if pdb.exists() else layered_pdb)
2644
+ elif pdb.exists():
2645
+ stage_results.append(pdb.resolve())
2646
+ stage_refs.append(pdb.resolve())
2647
+ if not stage_results:
2648
+ raise click.ClickException("[all] No stage result files found under scan/.")
2649
+ _echo("[all] Collected scan stage files:")
2650
+ for p in stage_results:
2651
+ _echo(f" - {p}")
2652
+
2653
+ # Input series to path_search: [preopt result (if available), scan stage results ...]
2654
+ # When scan ran with --preopt, its optimised reactant geometry lives in
2655
+ # scan/preopt/result.xyz (full precision) or result.pdb. Using this
2656
+ # avoids a redundant ~2000-cycle re-optimisation inside path_search.
2657
+ preopt_xyz = scan_dir / "preopt" / "result.xyz"
2658
+ preopt_pdb = scan_dir / "preopt" / "result.pdb"
2659
+ if preopt_xyz.exists():
2660
+ init0_geom = preopt_xyz.resolve()
2661
+ init0_ref = layered_pdb # layered PDB has authoritative B-factor layers
2662
+ elif preopt_pdb.exists():
2663
+ init0_geom = preopt_pdb.resolve()
2664
+ init0_ref = layered_pdb
2665
+ else:
2666
+ # No preopt output — fall back to original layered PDB
2667
+ init0_geom = layered_pdb
2668
+ init0_ref = layered_pdb
2669
+ pockets_for_path = [init0_geom] + stage_results
2670
+ refs_for_path = [init0_ref] + stage_refs
2671
+ else:
2672
+ # Multi-structure standard route: use layered full-system PDBs
2673
+ pockets_for_path = list(layered_inputs)
2674
+
2675
+ # --------------------------
2676
+ # Stage 2: Path search on full-system layered PDBs
2677
+ # --------------------------
2678
+ if refine_path:
2679
+ _echo_section("=== [all] Stage 2/3 — MEP search on full-system layered PDBs (recursive GSM) ===")
2680
+
2681
+ # Build path_search CLI args using *repeated* options (robust for Click)
2682
+ ps_args: List[str] = []
2683
+
2684
+ # Inputs: single -i followed by all layered full-system PDBs
2685
+ ps_args.append("-i")
2686
+ for p in pockets_for_path:
2687
+ ps_args.append(str(p))
2688
+
2689
+ # Charge & spin
2690
+ ps_args.extend(["-q", str(q_int)])
2691
+ ps_args.extend(["-m", str(int(spin))])
2692
+ ps_args.extend(["--parm", str(real_parm7_path)])
2693
+ # Layered PDBs have B-factors → detect-layer will auto-identify layers
2694
+ ps_args.append("--detect-layer")
2695
+
2696
+ # Nodes, cycles, climb, optimizer, dump, out-dir, preopt, args-yaml
2697
+ ps_args.extend(["--max-nodes", str(int(max_nodes))])
2698
+ ps_args.extend(["--max-cycles", str(int(max_cycles))])
2699
+ ps_args.append("--climb" if climb else "--no-climb")
2700
+ ps_args.extend(["--opt-mode", str(path_search_opt_mode)])
2701
+ ps_args.append("--dump" if dump else "--no-dump")
2702
+ ps_args.extend(["--out-dir", str(path_dir)])
2703
+ ps_args.append("--preopt" if pre_opt else "--no-preopt")
2704
+ ps_args.append("--convert-files" if convert_files else "--no-convert-files")
2705
+ if thresh is not None:
2706
+ ps_args.extend(["--thresh", str(thresh)])
2707
+ if args_yaml is not None:
2708
+ ps_args.extend(["--config", str(args_yaml)])
2709
+
2710
+ # Provide --ref-pdb for topology/B-factor info (one per input)
2711
+ # MUST use layered PDBs (with B-factor layer info) so that downstream
2712
+ # PDB conversion preserves ML/MovableMM/FrozenMM layer encoding.
2713
+ ps_args.append("--ref-pdb")
2714
+ if is_single and has_scan:
2715
+ # single+scan: use refs_for_path which maps to each pocket (XYZ→PDB ref)
2716
+ for ref in refs_for_path:
2717
+ ps_args.append(str(ref))
2718
+ else:
2719
+ for lp in layered_inputs:
2720
+ ps_args.append(str(lp))
2721
+
2722
+ if backend is not None:
2723
+ ps_args.extend(["--backend", str(backend)])
2724
+ if embedcharge:
2725
+ ps_args.append("--embedcharge")
2726
+ if embedcharge_cutoff is not None:
2727
+ ps_args.extend(["--embedcharge-cutoff", str(embedcharge_cutoff)])
2728
+ else:
2729
+ ps_args.append("--no-embedcharge")
2730
+
2731
+ _echo("[all] Invoking path_search with arguments:")
2732
+ _echo(" " + " ".join(ps_args))
2733
+
2734
+ _run_cli_main("path_search", _path_search.cli, ps_args, on_nonzero="raise", on_exception="raise", prefix="all")
2735
+ else:
2736
+ # --no-refine-path: run path-opt GSM between each adjacent pair and concatenate
2737
+ _echo_section("=== [all] Stage 2/3 — MEP path-opt on full-system layered PDBs (single-pass GSM per pair) ===")
2738
+
2739
+ if len(pockets_for_path) < 2:
2740
+ raise click.ClickException("[all] Need at least two structures for path-opt MEP concatenation.")
2741
+
2742
+ ensure_dir(path_dir)
2743
+ combined_blocks: List[str] = []
2744
+ path_opt_segments: List[Dict[str, Any]] = []
2745
+
2746
+ for pair_idx in range(len(pockets_for_path) - 1):
2747
+ p_left = pockets_for_path[pair_idx]
2748
+ p_right = pockets_for_path[pair_idx + 1]
2749
+ seg_tag = f"seg_{pair_idx:02d}"
2750
+ seg_out = path_dir / f"{seg_tag}_mep"
2751
+ ensure_dir(seg_out)
2752
+
2753
+ po_args: List[str] = [
2754
+ "-i", str(p_left), str(p_right),
2755
+ "-q", str(q_int),
2756
+ "-m", str(int(spin)),
2757
+ "--parm", str(real_parm7_path),
2758
+ "--detect-layer",
2759
+ "--max-nodes", str(int(max_nodes)),
2760
+ "--max-cycles", str(int(max_cycles)),
2761
+ ]
2762
+ po_args.append("--climb" if climb else "--no-climb")
2763
+ po_args.append("--dump" if dump else "--no-dump")
2764
+ po_args.extend(["--out-dir", str(seg_out)])
2765
+ po_args.append("--preopt" if pre_opt else "--no-preopt")
2766
+ po_args.append("--convert-files" if convert_files else "--no-convert-files")
2767
+ if thresh is not None:
2768
+ po_args.extend(["--thresh", str(thresh)])
2769
+ if args_yaml is not None:
2770
+ po_args.extend(["--config", str(args_yaml)])
2771
+ if backend is not None:
2772
+ po_args.extend(["--backend", str(backend)])
2773
+ if embedcharge:
2774
+ po_args.append("--embedcharge")
2775
+ if embedcharge_cutoff is not None:
2776
+ po_args.extend(["--embedcharge-cutoff", str(embedcharge_cutoff)])
2777
+ else:
2778
+ po_args.append("--no-embedcharge")
2779
+
2780
+ _echo(f"[all] Invoking path_opt for pair {pair_idx} with arguments:")
2781
+ _echo(" " + " ".join(po_args))
2782
+ _run_cli_main("path_opt", _path_opt.cli, po_args, on_nonzero="raise", on_exception="raise", prefix="all")
2783
+
2784
+ # --- Post-processing per segment ---
2785
+ seg_trj = seg_out / "final_geometries_trj.xyz"
2786
+ if not seg_trj.exists():
2787
+ raise click.ClickException(
2788
+ f"[all] path-opt segment {pair_idx} did not produce final_geometries_trj.xyz"
2789
+ )
2790
+
2791
+ # Copy per-segment trajectory to path_dir
2792
+ try:
2793
+ seg_mep_trj = path_dir / f"mep_seg_{pair_idx:02d}_trj.xyz"
2794
+ shutil.copy2(seg_trj, seg_mep_trj)
2795
+ if pockets_for_path[0].suffix.lower() == ".pdb":
2796
+ _path_search._maybe_convert_to_pdb(
2797
+ seg_mep_trj,
2798
+ ref_pdb_path=pockets_for_path[0],
2799
+ out_path=path_dir / f"mep_seg_{pair_idx:02d}.pdb",
2800
+ )
2801
+ except Exception as e:
2802
+ _echo(
2803
+ f"[all] WARNING: failed to emit per-segment trajectory copies for segment {pair_idx:02d}: {e}",
2804
+ err=True,
2805
+ )
2806
+
2807
+ # Mirror HEI artifacts
2808
+ hei_src = seg_out / "hei.xyz"
2809
+ if hei_src.exists():
2810
+ try:
2811
+ shutil.copy2(hei_src, path_dir / f"hei_seg_{pair_idx:02d}.xyz")
2812
+ hei_pdb_src = seg_out / "hei.pdb"
2813
+ if hei_pdb_src.exists():
2814
+ shutil.copy2(hei_pdb_src, path_dir / f"hei_seg_{pair_idx:02d}.pdb")
2815
+ except Exception as e:
2816
+ _echo(
2817
+ f"[all] WARNING: failed to prepare HEI artifacts for segment {pair_idx:02d}: {e}",
2818
+ err=True,
2819
+ )
2820
+
2821
+ # Parse trajectory blocks for concatenation and energy extraction
2822
+ raw_blocks = read_xyz_as_blocks(seg_trj, strict=True)
2823
+ blocks = ["\n".join(b) + "\n" for b in raw_blocks]
2824
+ if not blocks:
2825
+ raise click.ClickException(
2826
+ f"[all] No frames read from path-opt segment {pair_idx} trajectory: {seg_trj}"
2827
+ )
2828
+ # Skip duplicate first frame for subsequent segments
2829
+ if pair_idx > 0:
2830
+ blocks = blocks[1:]
2831
+ combined_blocks.extend(blocks)
2832
+
2833
+ # Extract energies from trajectory comment lines
2834
+ energies_seg: List[float] = []
2835
+ for blk in raw_blocks:
2836
+ E = np.nan
2837
+ if len(blk) >= 2:
2838
+ try:
2839
+ E = float(blk[1].split()[0])
2840
+ except Exception:
2841
+ E = np.nan
2842
+ energies_seg.append(E)
2843
+
2844
+ # Parse first/last frame coordinates for bond-change detection
2845
+ first_last = None
2846
+ try:
2847
+ first_last = xyz_blocks_first_last(raw_blocks, path=seg_trj)
2848
+ except Exception as e:
2849
+ _echo(
2850
+ f"[all] WARNING: failed to parse first/last frames for segment {pair_idx:02d}: {e}",
2851
+ err=True,
2852
+ )
2853
+
2854
+ path_opt_segments.append(
2855
+ {
2856
+ "tag": seg_tag,
2857
+ "energies": energies_seg,
2858
+ "traj": seg_trj,
2859
+ "inputs": (p_left, p_right),
2860
+ "first_last": first_last,
2861
+ }
2862
+ )
2863
+
2864
+ # --- Concatenated MEP trajectory ---
2865
+ final_trj = path_dir / "mep_trj.xyz"
2866
+ try:
2867
+ final_trj.write_text("".join(combined_blocks), encoding="utf-8")
2868
+ _echo(f"[all] Wrote concatenated MEP trajectory: {final_trj}")
2869
+ except Exception as e:
2870
+ raise click.ClickException(f"[all] Failed to write concatenated MEP: {e}")
2871
+
2872
+ # Energy plot for concatenated trajectory
2873
+ try:
2874
+ run_trj2fig(final_trj, [path_dir / "mep_plot.png"], unit="kcal", reference="init", reverse_x=False)
2875
+ close_matplotlib_figures()
2876
+ _echo(f"[plot] Saved energy plot → '{path_dir / 'mep_plot.png'}'")
2877
+ except Exception as e:
2878
+ _echo(f"[plot] WARNING: Failed to plot concatenated MEP: {e}", err=True)
2879
+
2880
+ # PDB conversion of concatenated trajectory
2881
+ try:
2882
+ if pockets_for_path[0].suffix.lower() == ".pdb":
2883
+ mep_pdb_path = path_dir / "mep.pdb"
2884
+ _path_search._maybe_convert_to_pdb(
2885
+ final_trj, ref_pdb_path=pockets_for_path[0], out_path=mep_pdb_path
2886
+ )
2887
+ if mep_pdb_path.exists():
2888
+ shutil.copy2(mep_pdb_path, out_dir / mep_pdb_path.name)
2889
+ _echo(f"[all] Copied concatenated MEP PDB → {out_dir / mep_pdb_path.name}")
2890
+ except Exception as e:
2891
+ _echo(
2892
+ f"[all] WARNING: Failed to convert/copy concatenated MEP to PDB: {e}",
2893
+ err=True,
2894
+ )
2895
+
2896
+ # --- Energy diagram ---
2897
+ energy_diagrams_po: List[Dict[str, Any]] = []
2898
+ try:
2899
+ labels = _build_global_segment_labels(len(path_opt_segments))
2900
+ energies_chain: List[float] = []
2901
+ for si, seg_info in enumerate(path_opt_segments):
2902
+ Es = [float(x) for x in seg_info.get("energies", [])]
2903
+ if not Es:
2904
+ continue
2905
+ if si == 0:
2906
+ energies_chain.append(Es[0])
2907
+ energies_chain.append(float(np.nanmax(Es)))
2908
+ energies_chain.append(Es[-1])
2909
+ if labels and energies_chain and len(labels) == len(energies_chain):
2910
+ title_note = "(GSM; all segments)" if len(path_opt_segments) > 1 else "(GSM)"
2911
+ diag_payload = _write_segment_energy_diagram(
2912
+ path_dir / "energy_diagram_mep",
2913
+ labels=labels,
2914
+ energies_eh=energies_chain,
2915
+ title_note=title_note,
2916
+ )
2917
+ if diag_payload:
2918
+ energy_diagrams_po.append(diag_payload)
2919
+ except Exception as e:
2920
+ _echo(f"[diagram] WARNING: Failed to build GSM diagram for path-opt branch: {e}", err=True)
2921
+
2922
+ # --- Bond change detection and summary.yaml ---
2923
+ segments_summary: List[Dict[str, Any]] = []
2924
+ bond_cfg = dict(_path_search.BOND_KW)
2925
+ for seg_idx, info in enumerate(path_opt_segments):
2926
+ Es = [float(x) for x in info.get("energies", []) if np.isfinite(x)]
2927
+ if not Es:
2928
+ continue
2929
+ barrier = (max(Es) - Es[0]) * AU2KCALPERMOL
2930
+ delta = (Es[-1] - Es[0]) * AU2KCALPERMOL
2931
+ bond_summary = ""
2932
+ try:
2933
+ first_last = info.get("first_last")
2934
+ if first_last:
2935
+ elems, c_first, c_last = first_last
2936
+ else:
2937
+ elems, c_first, c_last = read_xyz_first_last(Path(info["traj"]))
2938
+ gL = _geom_from_angstrom(elems, c_first, [])
2939
+ gR = _geom_from_angstrom(elems, c_last, [])
2940
+ changed, bond_summary = _path_search._has_bond_change(gL, gR, bond_cfg)
2941
+ if not changed:
2942
+ bond_summary = "(no covalent changes detected)"
2943
+ except Exception as e:
2944
+ _echo(
2945
+ f"[all] WARNING: Failed to detect bond changes for segment {seg_idx:02d}: {e}",
2946
+ err=True,
2947
+ )
2948
+ bond_summary = "(no covalent changes detected)"
2949
+
2950
+ segments_summary.append(
2951
+ {
2952
+ "index": seg_idx,
2953
+ "tag": info.get("tag", f"seg_{seg_idx:02d}"),
2954
+ "kind": "seg",
2955
+ "barrier_kcal": float(barrier),
2956
+ "delta_kcal": float(delta),
2957
+ "bond_changes": bond_summary,
2958
+ }
2959
+ )
2960
+
2961
+ po_summary: Dict[str, Any] = {
2962
+ "out_dir": str(path_dir),
2963
+ "n_images": len(read_xyz_as_blocks(final_trj)),
2964
+ "n_segments": len(segments_summary),
2965
+ "segments": segments_summary,
2966
+ }
2967
+ if energy_diagrams_po:
2968
+ po_summary["energy_diagrams"] = list(energy_diagrams_po)
2969
+ try:
2970
+ with open(path_dir / "summary.yaml", "w") as f:
2971
+ yaml.safe_dump(po_summary, f, sort_keys=False, allow_unicode=True)
2972
+ _echo(f"[write] Wrote '{path_dir / 'summary.yaml'}'.")
2973
+ except Exception as e:
2974
+ _echo(f"[write] WARNING: Failed to write summary.yaml for path-opt branch: {e}", err=True)
2975
+
2976
+ # Copy key outputs to out_dir root
2977
+ try:
2978
+ for name in ("mep_plot.png", "energy_diagram_mep.png", "summary.yaml"):
2979
+ src = path_dir / name
2980
+ if src.exists():
2981
+ shutil.copy2(src, out_dir / name)
2982
+ for ext in ("_trj.xyz", ".xyz"):
2983
+ src = path_dir / f"mep{ext}"
2984
+ if src.exists():
2985
+ shutil.copy2(src, out_dir / src.name)
2986
+ except Exception as e:
2987
+ _echo(f"[all] WARNING: Failed to relocate path-opt summary files: {e}", err=True)
2988
+
2989
+ # --------------------------
2990
+ # Stage 3: Merge (performed by path_search when --ref-pdb was supplied)
2991
+ # --------------------------
2992
+ _echo_section("=== [all] Stage 3/3 — Final outputs ===")
2993
+ _echo(f"[all] Final products can be found under: {path_dir}")
2994
+ _echo(" - mep_trj.xyz (concatenated MEP trajectory)")
2995
+ _echo(" - mep.pdb (PDB conversion, if input was .pdb)")
2996
+ _echo(" - mep_seg_XX_trj.xyz (per-segment trajectories)")
2997
+ _echo(" - hei_seg_XX.xyz/.pdb (HEI per segment)")
2998
+ _echo(" - summary.yaml (segment barriers, ΔE, bond changes)")
2999
+ _echo(" - mep_plot.png / energy_diagram_MEP.png / summary.log")
3000
+ _echo_section("=== [all] Pipeline finished successfully (core path) ===")
3001
+
3002
+ summary_yaml = path_dir / "summary.yaml"
3003
+ summary_loaded = {}
3004
+ if summary_yaml.exists():
3005
+ try:
3006
+ summary_loaded = yaml.safe_load(summary_yaml.read_text(encoding="utf-8")) or {}
3007
+ except Exception:
3008
+ summary_loaded = {}
3009
+ summary: Dict[str, Any] = summary_loaded if isinstance(summary_loaded, dict) else {}
3010
+ segments = _read_summary(summary_yaml)
3011
+ energy_diagrams: List[Dict[str, Any]] = []
3012
+ existing_diagrams = summary.get("energy_diagrams", [])
3013
+ if isinstance(existing_diagrams, list):
3014
+ energy_diagrams.extend(existing_diagrams)
3015
+
3016
+ def _copy_path_outputs_to_root() -> None:
3017
+ try:
3018
+ for name in (
3019
+ "mep_plot.png",
3020
+ "energy_diagram_MEP.png",
3021
+ "mep.pdb",
3022
+ "summary.yaml",
3023
+ "summary.log",
3024
+ ):
3025
+ src = path_dir / name
3026
+ if src.exists():
3027
+ shutil.copy2(src, out_dir / name)
3028
+ for stem in ("mep",):
3029
+ for ext in ("_trj.xyz", ".xyz"):
3030
+ src = path_dir / f"{stem}{ext}"
3031
+ if src.exists():
3032
+ shutil.copy2(src, out_dir / src.name)
3033
+ except Exception as e:
3034
+ _echo(f"[all] WARNING: Failed to copy path_search outputs: {e}", err=True)
3035
+
3036
+ def _write_pipeline_summary_log(post_segment_logs: Sequence[Dict[str, Any]]) -> None:
3037
+ try:
3038
+ diag_for_log: Dict[str, Any] = {}
3039
+ for diag in summary.get("energy_diagrams", []) or []:
3040
+ if isinstance(diag, dict) and str(diag.get("name", "")).lower().endswith("mep"):
3041
+ diag_for_log = diag
3042
+ break
3043
+ mep_info = {
3044
+ "n_images": summary.get("n_images"),
3045
+ "n_segments": summary.get("n_segments"),
3046
+ "traj_pdb": str(path_dir / "mep.pdb") if (path_dir / "mep.pdb").exists() else None,
3047
+ "mep_plot": str(path_dir / "mep_plot.png") if (path_dir / "mep_plot.png").exists() else None,
3048
+ "diagram": diag_for_log,
3049
+ }
3050
+ summary_payload = {
3051
+ "root_out_dir": str(out_dir),
3052
+ "path_dir": str(path_dir),
3053
+ "path_module_dir": path_dir.name,
3054
+ "pipeline_mode": "path-search" if refine_path else "path-opt",
3055
+ "refine_path": bool(refine_path),
3056
+ "thresh": thresh,
3057
+ "thresh_post": thresh_post,
3058
+ "flatten": bool(flatten),
3059
+ "tsopt": do_tsopt,
3060
+ "thermo": do_thermo,
3061
+ "dft": do_dft,
3062
+ "opt_mode": opt_mode_norm,
3063
+ "opt_mode_post": opt_mode_post.lower() if opt_mode_post else None,
3064
+ "mep_mode": "path-search" if refine_path else "path-opt",
3065
+ "uma_model": None,
3066
+ "command": command_str,
3067
+ "charge": q_int,
3068
+ "spin": spin,
3069
+ "mep": mep_info,
3070
+ "segments": summary.get("segments", []),
3071
+ "energy_diagrams": summary.get("energy_diagrams", []),
3072
+ "post_segments": list(post_segment_logs),
3073
+ "key_files": {},
3074
+ }
3075
+ write_summary_log(path_dir / "summary.log", summary_payload)
3076
+ _copy_path_outputs_to_root()
3077
+ except Exception as e:
3078
+ _echo(f"[write] WARNING: Failed to write summary.log: {e}", err=True)
3079
+
3080
+ # --------------------------
3081
+ # Optional Stage 4: TSOPT / THERMO / DFT (per reactive segment)
3082
+ # --------------------------
3083
+ if not (do_tsopt or do_thermo or do_dft):
3084
+ _write_pipeline_summary_log([])
3085
+ # summary.md and key_* outputs are disabled.
3086
+ # Elapsed time
3087
+ _echo(format_elapsed("[all] Elapsed for Whole Pipeline", time_start))
3088
+ return
3089
+
3090
+ _echo_section("=== [all] Stage 4 — Post-processing per reactive segment ===")
3091
+
3092
+ # Use segment summary from path_search / path-opt
3093
+ if not segments:
3094
+ _echo("[post] No segments found in summary; nothing to do.")
3095
+ _write_pipeline_summary_log([])
3096
+ # summary.md and key_* outputs are disabled.
3097
+ _echo(format_elapsed("[all] Elapsed for Whole Pipeline", time_start))
3098
+ return
3099
+
3100
+ # Iterate only bond-change segments (kind='seg' and bond_changes not empty and not '(no covalent...)')
3101
+ reactive = [s for s in segments if (s.get("kind", "seg") == "seg" and str(s.get("bond_changes", "")).strip() and str(s.get("bond_changes", "")).strip() != "(no covalent changes detected)")]
3102
+ if not reactive:
3103
+ _echo("[post] No bond-change segments. Skipping TS/thermo/DFT.")
3104
+ _write_pipeline_summary_log([])
3105
+ # summary.md and key_* outputs are disabled.
3106
+ _echo(format_elapsed("[all] Elapsed for Whole Pipeline", time_start))
3107
+ return
3108
+
3109
+ post_segment_logs: List[Dict[str, Any]] = []
3110
+ tsopt_seg_energies: List[Tuple[float, float, float]] = []
3111
+ g_uma_seg_energies: List[Tuple[float, float, float]] = []
3112
+ dft_seg_energies: List[Tuple[float, float, float]] = []
3113
+ g_dftuma_seg_energies: List[Tuple[float, float, float]] = []
3114
+ irc_trj_for_all: List[Tuple[Path, bool]] = []
3115
+
3116
+ # For each reactive segment
3117
+ for s in reactive:
3118
+ seg_idx = int(s.get("index", 0) or 0)
3119
+ seg_tag = s.get("tag", f"seg_{seg_idx:02d}")
3120
+ _echo_section(f"--- [post] Segment {seg_idx:02d} ({seg_tag}) ---")
3121
+
3122
+ seg_root = path_dir # base
3123
+ seg_dir = seg_root / f"post_seg_{seg_idx:02d}"
3124
+ ensure_dir(seg_dir)
3125
+
3126
+ # HEI pocket file prepared by path_search (only for bond-change segments)
3127
+ hei_pocket_pdb = seg_root / f"hei_seg_{seg_idx:02d}.pdb"
3128
+ if not hei_pocket_pdb.exists():
3129
+ _echo(f"[post] WARNING: HEI pocket PDB not found for segment {seg_idx:02d}; skipping TSOPT.", err=True)
3130
+ continue
3131
+
3132
+ # 4.1 TS optimization (optional; still needed to drive IRC & diagrams)
3133
+ if do_tsopt:
3134
+ ts_pdb, g_ts = _run_tsopt_on_hei(
3135
+ hei_pocket_pdb,
3136
+ q_int,
3137
+ spin,
3138
+ real_parm7_path,
3139
+ ml_region_pdb,
3140
+ detect_layer,
3141
+ args_yaml,
3142
+ seg_dir,
3143
+ tsopt_opt_mode_default,
3144
+ overrides=tsopt_overrides,
3145
+ backend=backend,
3146
+ embedcharge=embedcharge,
3147
+ embedcharge_cutoff=embedcharge_cutoff,
3148
+ ref_pdb=layered_inputs[0] if layered_inputs else None,
3149
+ )
3150
+ else:
3151
+ # If TSOPT off: use the GSM HEI (pocket) as TS geometry
3152
+ ts_pdb = hei_pocket_pdb
3153
+ g_ts = geom_loader(ts_pdb, coord_type="cart")
3154
+ calc = _mlmm_calc(
3155
+ model_charge=int(q_int),
3156
+ model_mult=int(spin),
3157
+ input_pdb=str(ts_pdb),
3158
+ real_parm7=str(real_parm7_path),
3159
+ model_pdb=str(ml_region_pdb),
3160
+ use_bfactor_layers=detect_layer,
3161
+ backend=backend,
3162
+ embedcharge=embedcharge,
3163
+ )
3164
+ g_ts.set_calculator(calc); _ = float(g_ts.energy)
3165
+
3166
+ # 4.2 EulerPC IRC & mapping to (left,right)
3167
+ irc_plot_path = None
3168
+ irc_trj_path = None
3169
+ irc_res = _irc_and_match(seg_idx=seg_idx,
3170
+ seg_dir=seg_dir,
3171
+ ref_pdb_for_seg=ts_pdb,
3172
+ seg_pocket_pdb=hei_pocket_pdb,
3173
+ g_ts=g_ts,
3174
+ q_int=q_int,
3175
+ spin=spin,
3176
+ real_parm7=real_parm7_path,
3177
+ model_pdb=ml_region_pdb,
3178
+ detect_layer=detect_layer,
3179
+ backend=backend,
3180
+ embedcharge=embedcharge,
3181
+ embedcharge_cutoff=embedcharge_cutoff,
3182
+ args_yaml=args_yaml)
3183
+ irc_plot_path = irc_res.get("irc_plot")
3184
+ irc_trj_path = irc_res.get("irc_trj")
3185
+ if irc_trj_path:
3186
+ try:
3187
+ irc_trj_for_all.append((Path(irc_trj_path), bool(irc_res.get("reverse_irc", False))))
3188
+ except Exception:
3189
+ logger.debug("Failed to append IRC trajectory path", exc_info=True)
3190
+
3191
+ gL = irc_res["left_min_geom"]
3192
+ gR = irc_res["right_min_geom"]
3193
+ gT = irc_res["ts_geom"]
3194
+ # Save IRC endpoints (XYZ primary), run endpoint-opt, then save optimized structures
3195
+ struct_dir = seg_dir / "structures"
3196
+ ensure_dir(struct_dir)
3197
+ xL_irc, pL_irc = _save_single_geom_for_tools(gL, hei_pocket_pdb, struct_dir, "reactant_irc")
3198
+ xT, pT = _save_single_geom_for_tools(gT, hei_pocket_pdb, struct_dir, "ts")
3199
+ xR_irc, pR_irc = _save_single_geom_for_tools(gR, hei_pocket_pdb, struct_dir, "product_irc")
3200
+
3201
+ endpoint_opt_dir = seg_dir / "endpoint_opt"
3202
+ ensure_dir(endpoint_opt_dir)
3203
+
3204
+ # Map IRC left/right Hessians → R/P endpoint
3205
+ # When reverse_irc is True, _irc_and_match swapped left/right to match GSM endpoints,
3206
+ # so "irc_left" (=forward) now corresponds to gR and "irc_right" (=backward) to gL.
3207
+ from .hessian_cache import load as _hess_load, store as _hess_store, clear as _clear_hess_cache
3208
+ _reversed = bool(irc_res.get("reverse_irc", False))
3209
+ _left_hk = "irc_right" if _reversed else "irc_left"
3210
+ _right_hk = "irc_left" if _reversed else "irc_right"
3211
+
3212
+ _c = _hess_load(_left_hk)
3213
+ if _c:
3214
+ _hess_store("irc_endpoint", _c["hessian"], active_dofs=_c.get("active_dofs"), meta=_c.get("meta"))
3215
+ try:
3216
+ gL, _ = _run_opt_for_state(
3217
+ pL_irc, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
3218
+ endpoint_opt_dir / "R", args_yaml, endpoint_opt_mode_default,
3219
+ convert_files=convert_files,
3220
+ backend=backend,
3221
+ embedcharge=embedcharge,
3222
+ embedcharge_cutoff=embedcharge_cutoff,
3223
+ thresh=thresh_post,
3224
+ xyz_path=xL_irc,
3225
+ )
3226
+ except Exception as e:
3227
+ _echo(
3228
+ f"[post] WARNING: Reactant endpoint optimization failed for segment {seg_idx:02d}: {e}",
3229
+ err=True,
3230
+ )
3231
+
3232
+ _c = _hess_load(_right_hk)
3233
+ if _c:
3234
+ _hess_store("irc_endpoint", _c["hessian"], active_dofs=_c.get("active_dofs"), meta=_c.get("meta"))
3235
+ try:
3236
+ gR, _ = _run_opt_for_state(
3237
+ pR_irc, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
3238
+ endpoint_opt_dir / "P", args_yaml, endpoint_opt_mode_default,
3239
+ convert_files=convert_files,
3240
+ backend=backend,
3241
+ embedcharge=embedcharge,
3242
+ embedcharge_cutoff=embedcharge_cutoff,
3243
+ thresh=thresh_post,
3244
+ xyz_path=xR_irc,
3245
+ )
3246
+ except Exception as e:
3247
+ _echo(
3248
+ f"[post] WARNING: Product endpoint optimization failed for segment {seg_idx:02d}: {e}",
3249
+ err=True,
3250
+ )
3251
+ shutil.rmtree(endpoint_opt_dir, ignore_errors=True)
3252
+ _echo("[endpoint-opt] Clean endpoint-opt working dir.")
3253
+
3254
+ xL, pL = _save_single_geom_for_tools(gL, hei_pocket_pdb, struct_dir, "reactant")
3255
+ xR, pR = _save_single_geom_for_tools(gR, hei_pocket_pdb, struct_dir, "product")
3256
+
3257
+ # 4.3 Segment-level energy diagram from UMA (R,TS,P)
3258
+ eR = float(gL.energy)
3259
+ eT = float(gT.energy)
3260
+ eP = float(gR.energy)
3261
+ tsopt_seg_energies.append((eR, eT, eP))
3262
+ uma_prefix = seg_dir / "energy_diagram_UMA"
3263
+ _write_segment_energy_diagram(
3264
+ uma_prefix,
3265
+ labels=["R", f"TS{seg_idx}", "P"],
3266
+ energies_eh=[eR, eT, eP],
3267
+ title_note="(UMA, TSOPT/IRC)",
3268
+ )
3269
+
3270
+ # ── Release GPU memory before freq/thermo/DFT ──
3271
+ for _g in (gL, gR, gT):
3272
+ if _g is not None and hasattr(_g, "calculator"):
3273
+ _g.calculator = None
3274
+ gc.collect()
3275
+ if torch.cuda.is_available():
3276
+ torch.cuda.empty_cache()
3277
+
3278
+ # 4.4 Thermochemistry (UMA freq) and Gibbs diagram
3279
+ thermo_payloads: Dict[str, Dict[str, Any]] = {}
3280
+ GR = GT = GP = None
3281
+ freq_seg_root = _resolve_override_dir(seg_dir / "freq", freq_out_dir)
3282
+ dft_seg_root = _resolve_override_dir(seg_dir / "dft", dft_out_dir)
3283
+
3284
+ if do_thermo:
3285
+ _echo(f"[thermo] Segment {seg_idx:02d}: freq on TS/R/P")
3286
+ tT = _run_freq_for_state(pT, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
3287
+ freq_seg_root / "TS", args_yaml, overrides=freq_overrides,
3288
+ backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff, xyz_path=xT)
3289
+ _clear_hess_cache() # TS Hessian consumed; R/P need exact computation
3290
+ tR = _run_freq_for_state(pL, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
3291
+ freq_seg_root / "R", args_yaml, overrides=freq_overrides,
3292
+ backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff, xyz_path=xL)
3293
+ tP = _run_freq_for_state(pR, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
3294
+ freq_seg_root / "P", args_yaml, overrides=freq_overrides,
3295
+ backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff, xyz_path=xR)
3296
+ thermo_payloads = {"R": tR, "TS": tT, "P": tP}
3297
+ try:
3298
+ GR = float(tR.get("sum_EE_and_thermal_free_energy_ha", eR))
3299
+ GT = float(tT.get("sum_EE_and_thermal_free_energy_ha", eT))
3300
+ GP = float(tP.get("sum_EE_and_thermal_free_energy_ha", eP))
3301
+ g_uma_seg_energies.append((GR, GT, GP))
3302
+ _write_segment_energy_diagram(
3303
+ seg_dir / "energy_diagram_G_UMA",
3304
+ labels=["R", f"TS{seg_idx}", "P"],
3305
+ energies_eh=[GR, GT, GP],
3306
+ title_note="(Gibbs, UMA)",
3307
+ ylabel="ΔG (kcal/mol)",
3308
+ )
3309
+ except Exception as e:
3310
+ _echo(f"[thermo] WARNING: failed to build Gibbs diagram: {e}", err=True)
3311
+
3312
+ # 4.5 DFT single-point and (optionally) DFT//UMA Gibbs
3313
+ eR_dft = eT_dft = eP_dft = None
3314
+ GR_dftUMA = GT_dftUMA = GP_dftUMA = None
3315
+ if do_dft:
3316
+ _echo(f"[dft] Segment {seg_idx:02d}: DFT on R/TS/P")
3317
+ dR = _run_dft_for_state(pL, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
3318
+ dft_seg_root / "R", args_yaml, func_basis=dft_func_basis_use, overrides=dft_overrides,
3319
+ backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff, xyz_path=xL)
3320
+ dT = _run_dft_for_state(pT, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
3321
+ dft_seg_root / "TS", args_yaml, func_basis=dft_func_basis_use, overrides=dft_overrides,
3322
+ backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff, xyz_path=xT)
3323
+ dP = _run_dft_for_state(pR, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
3324
+ dft_seg_root / "P", args_yaml, func_basis=dft_func_basis_use, overrides=dft_overrides,
3325
+ backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff, xyz_path=xR)
3326
+ try:
3327
+ eR_dft = float(dR.get("energy", {}).get("hartree", np.nan) if dR else np.nan)
3328
+ eT_dft = float(dT.get("energy", {}).get("hartree", np.nan) if dT else np.nan)
3329
+ eP_dft = float(dP.get("energy", {}).get("hartree", np.nan) if dP else np.nan)
3330
+ if all(map(np.isfinite, [eR_dft, eT_dft, eP_dft])):
3331
+ dft_seg_energies.append((eR_dft, eT_dft, eP_dft))
3332
+ _write_segment_energy_diagram(seg_dir / "energy_diagram_DFT",
3333
+ labels=["R", f"TS{seg_idx}", "P"],
3334
+ energies_eh=[eR_dft, eT_dft, eP_dft],
3335
+ title_note=f"({dft_method_fallback})")
3336
+ else:
3337
+ _echo("[dft] WARNING: some DFT energies missing; diagram skipped.", err=True)
3338
+ except Exception as e:
3339
+ _echo(f"[dft] WARNING: failed to build DFT diagram: {e}", err=True)
3340
+
3341
+ # DFT//UMA thermal Gibbs (E_DFT + ΔG_therm(UMA))
3342
+ if do_thermo:
3343
+ try:
3344
+ dG_R = float(thermo_payloads.get("R", {}).get("thermal_correction_free_energy_ha", 0.0))
3345
+ dG_T = float(thermo_payloads.get("TS", {}).get("thermal_correction_free_energy_ha", 0.0))
3346
+ dG_P = float(thermo_payloads.get("P", {}).get("thermal_correction_free_energy_ha", 0.0))
3347
+ eR_dft = float(dR.get("energy", {}).get("hartree", eR) if dR else eR)
3348
+ eT_dft = float(dT.get("energy", {}).get("hartree", eT) if dT else eT)
3349
+ eP_dft = float(dP.get("energy", {}).get("hartree", eP) if dP else eP)
3350
+ GR_dftUMA = eR_dft + dG_R
3351
+ GT_dftUMA = eT_dft + dG_T
3352
+ GP_dftUMA = eP_dft + dG_P
3353
+ g_dftuma_seg_energies.append((GR_dftUMA, GT_dftUMA, GP_dftUMA))
3354
+ _write_segment_energy_diagram(
3355
+ seg_dir / "energy_diagram_G_DFT_plus_UMA",
3356
+ labels=["R", f"TS{seg_idx}", "P"],
3357
+ energies_eh=[GR_dftUMA, GT_dftUMA, GP_dftUMA],
3358
+ title_note="(Gibbs, DFT//UMA)",
3359
+ ylabel="ΔG (kcal/mol)",
3360
+ )
3361
+ except Exception as e:
3362
+ _echo(f"[dft//uma] WARNING: failed to build DFT//UMA Gibbs diagram: {e}", err=True)
3363
+
3364
+ segment_log: Dict[str, Any] = {
3365
+ "index": seg_idx,
3366
+ "tag": seg_tag,
3367
+ "kind": s.get("kind", "seg"),
3368
+ "bond_changes": s.get("bond_changes", ""),
3369
+ "mep_barrier_kcal": s.get("barrier_kcal"),
3370
+ "mep_delta_kcal": s.get("delta_kcal"),
3371
+ "post_dir": str(seg_dir),
3372
+ }
3373
+ if irc_plot_path:
3374
+ segment_log["irc_plot"] = str(irc_plot_path)
3375
+ if irc_trj_path:
3376
+ segment_log["irc_traj"] = str(irc_trj_path)
3377
+ if do_thermo:
3378
+ n_imag = None
3379
+ try:
3380
+ n_imag = int(thermo_payloads.get("TS", {}).get("num_imag_freq"))
3381
+ except Exception:
3382
+ n_imag = None
3383
+ if n_imag is not None:
3384
+ segment_log["ts_imag"] = {"n_imag": n_imag}
3385
+ segment_log["uma"] = {
3386
+ "labels": ["R", "TS", "P"],
3387
+ "energies_au": [eR, eT, eP],
3388
+ "energies_kcal": [
3389
+ 0.0,
3390
+ (eT - eR) * AU2KCALPERMOL,
3391
+ (eP - eR) * AU2KCALPERMOL,
3392
+ ],
3393
+ "diagram": str((seg_dir / "energy_diagram_UMA").with_suffix(".png")),
3394
+ "structures": {"R": pL, "TS": pT, "P": pR},
3395
+ "barrier_kcal": (eT - eR) * AU2KCALPERMOL,
3396
+ "delta_kcal": (eP - eR) * AU2KCALPERMOL,
3397
+ }
3398
+ if GR is not None and GT is not None and GP is not None:
3399
+ segment_log["gibbs_uma"] = {
3400
+ "labels": ["R", "TS", "P"],
3401
+ "energies_au": [GR, GT, GP],
3402
+ "energies_kcal": [
3403
+ 0.0,
3404
+ (GT - GR) * AU2KCALPERMOL,
3405
+ (GP - GR) * AU2KCALPERMOL,
3406
+ ],
3407
+ "diagram": str((seg_dir / "energy_diagram_G_UMA").with_suffix(".png")),
3408
+ "structures": {"R": pL, "TS": pT, "P": pR},
3409
+ "barrier_kcal": (GT - GR) * AU2KCALPERMOL,
3410
+ "delta_kcal": (GP - GR) * AU2KCALPERMOL,
3411
+ }
3412
+ if eR_dft is not None and eT_dft is not None and eP_dft is not None and all(
3413
+ map(np.isfinite, [eR_dft, eT_dft, eP_dft])
3414
+ ):
3415
+ segment_log["dft"] = {
3416
+ "labels": ["R", "TS", "P"],
3417
+ "energies_au": [eR_dft, eT_dft, eP_dft],
3418
+ "energies_kcal": [
3419
+ 0.0,
3420
+ (eT_dft - eR_dft) * AU2KCALPERMOL,
3421
+ (eP_dft - eR_dft) * AU2KCALPERMOL,
3422
+ ],
3423
+ "diagram": str((seg_dir / "energy_diagram_DFT").with_suffix(".png")),
3424
+ "structures": {"R": pL, "TS": pT, "P": pR},
3425
+ "barrier_kcal": (eT_dft - eR_dft) * AU2KCALPERMOL,
3426
+ "delta_kcal": (eP_dft - eR_dft) * AU2KCALPERMOL,
3427
+ }
3428
+ if GR_dftUMA is not None and GT_dftUMA is not None and GP_dftUMA is not None:
3429
+ segment_log["gibbs_dft_uma"] = {
3430
+ "labels": ["R", "TS", "P"],
3431
+ "energies_au": [GR_dftUMA, GT_dftUMA, GP_dftUMA],
3432
+ "energies_kcal": [
3433
+ 0.0,
3434
+ (GT_dftUMA - GR_dftUMA) * AU2KCALPERMOL,
3435
+ (GP_dftUMA - GR_dftUMA) * AU2KCALPERMOL,
3436
+ ],
3437
+ "diagram": str((seg_dir / "energy_diagram_G_DFT_plus_UMA").with_suffix(".png")),
3438
+ "structures": {"R": pL, "TS": pT, "P": pR},
3439
+ "barrier_kcal": (GT_dftUMA - GR_dftUMA) * AU2KCALPERMOL,
3440
+ "delta_kcal": (GP_dftUMA - GR_dftUMA) * AU2KCALPERMOL,
3441
+ }
3442
+
3443
+ post_segment_logs.append(segment_log)
3444
+
3445
+ # -------------------------------------------------------------------------
3446
+ # Aggregate diagrams over all reactive segments
3447
+ # -------------------------------------------------------------------------
3448
+ if tsopt_seg_energies:
3449
+ tsopt_all_energies = [e for triple in tsopt_seg_energies for e in triple]
3450
+ tsopt_all_labels = _build_global_segment_labels(len(tsopt_seg_energies))
3451
+ if tsopt_all_labels and len(tsopt_all_labels) == len(tsopt_all_energies):
3452
+ diag_payload = _write_segment_energy_diagram(
3453
+ out_dir / "energy_diagram_UMA_all",
3454
+ labels=tsopt_all_labels,
3455
+ energies_eh=tsopt_all_energies,
3456
+ title_note="(UMA, TSOPT + IRC; all segments)",
3457
+ write_html=False,
3458
+ )
3459
+ if diag_payload:
3460
+ energy_diagrams.append(diag_payload)
3461
+
3462
+ if do_thermo and g_uma_seg_energies:
3463
+ g_uma_all_energies = [e for triple in g_uma_seg_energies for e in triple]
3464
+ g_uma_all_labels = _build_global_segment_labels(len(g_uma_seg_energies))
3465
+ if g_uma_all_labels and len(g_uma_all_labels) == len(g_uma_all_energies):
3466
+ diag_payload = _write_segment_energy_diagram(
3467
+ out_dir / "energy_diagram_G_UMA_all",
3468
+ labels=g_uma_all_labels,
3469
+ energies_eh=g_uma_all_energies,
3470
+ title_note="(UMA + Thermal Correction; all segments)",
3471
+ ylabel="ΔG (kcal/mol)",
3472
+ write_html=False,
3473
+ )
3474
+ if diag_payload:
3475
+ energy_diagrams.append(diag_payload)
3476
+
3477
+ if do_dft and dft_seg_energies:
3478
+ dft_all_energies = [e for triple in dft_seg_energies for e in triple]
3479
+ dft_all_labels = _build_global_segment_labels(len(dft_seg_energies))
3480
+ if dft_all_labels and len(dft_all_labels) == len(dft_all_energies):
3481
+ diag_payload = _write_segment_energy_diagram(
3482
+ out_dir / "energy_diagram_DFT_all",
3483
+ labels=dft_all_labels,
3484
+ energies_eh=dft_all_energies,
3485
+ title_note=f"({dft_method_fallback}; all segments)",
3486
+ write_html=False,
3487
+ )
3488
+ if diag_payload:
3489
+ energy_diagrams.append(diag_payload)
3490
+
3491
+ if do_dft and do_thermo and g_dftuma_seg_energies:
3492
+ g_dftuma_all_energies = [e for triple in g_dftuma_seg_energies for e in triple]
3493
+ g_dftuma_all_labels = _build_global_segment_labels(len(g_dftuma_seg_energies))
3494
+ if g_dftuma_all_labels and len(g_dftuma_all_labels) == len(g_dftuma_all_energies):
3495
+ diag_payload = _write_segment_energy_diagram(
3496
+ out_dir / "energy_diagram_G_DFT_plus_UMA_all",
3497
+ labels=g_dftuma_all_labels,
3498
+ energies_eh=g_dftuma_all_energies,
3499
+ title_note=f"({dft_method_fallback} // UMA + Thermal Correction; all segments)",
3500
+ ylabel="ΔG (kcal/mol)",
3501
+ write_html=False,
3502
+ )
3503
+ if diag_payload:
3504
+ energy_diagrams.append(diag_payload)
3505
+
3506
+ # -------------------------------------------------------------------------
3507
+ # Aggregated IRC plot over all reactive segments
3508
+ # -------------------------------------------------------------------------
3509
+ if irc_trj_for_all:
3510
+ _merge_irc_trajectories_to_single_plot(
3511
+ irc_trj_for_all, out_dir / "irc_plot_all.png"
3512
+ )
3513
+
3514
+ # Refresh summary.yaml with final energy diagram metadata
3515
+ try:
3516
+ summary["energy_diagrams"] = list(energy_diagrams)
3517
+ with open(path_dir / "summary.yaml", "w") as f:
3518
+ yaml.safe_dump(summary, f, sort_keys=False, allow_unicode=True)
3519
+ try:
3520
+ shutil.copy2(path_dir / "summary.yaml", out_dir / "summary.yaml")
3521
+ except Exception as e:
3522
+ _echo(f"[all] WARNING: Failed to mirror summary.yaml to {out_dir}: {e}", err=True)
3523
+ except Exception as e:
3524
+ _echo(f"[write] WARNING: Failed to refresh summary.yaml with energy diagram metadata: {e}", err=True)
3525
+
3526
+ _write_pipeline_summary_log(post_segment_logs)
3527
+ # summary.md and key_* outputs are disabled.
3528
+ _echo(format_elapsed("[all] Elapsed for Whole Pipeline", time_start))
3529
+
3530
+
3531
+ _configure_all_help_visibility(cli)
3532
+
3533
+
3534
+ if __name__ == "__main__":
3535
+ cli()