mlmm-toolkit 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (372) hide show
  1. hessian_ff/__init__.py +50 -0
  2. hessian_ff/analytical_hessian.py +609 -0
  3. hessian_ff/constants.py +46 -0
  4. hessian_ff/forcefield.py +339 -0
  5. hessian_ff/loaders.py +608 -0
  6. hessian_ff/native/Makefile +8 -0
  7. hessian_ff/native/__init__.py +28 -0
  8. hessian_ff/native/analytical_hessian.py +88 -0
  9. hessian_ff/native/analytical_hessian_ext.cpp +258 -0
  10. hessian_ff/native/bonded.py +82 -0
  11. hessian_ff/native/bonded_ext.cpp +640 -0
  12. hessian_ff/native/loader.py +349 -0
  13. hessian_ff/native/nonbonded.py +118 -0
  14. hessian_ff/native/nonbonded_ext.cpp +1150 -0
  15. hessian_ff/prmtop_parmed.py +23 -0
  16. hessian_ff/system.py +107 -0
  17. hessian_ff/terms/__init__.py +14 -0
  18. hessian_ff/terms/angle.py +73 -0
  19. hessian_ff/terms/bond.py +44 -0
  20. hessian_ff/terms/cmap.py +406 -0
  21. hessian_ff/terms/dihedral.py +141 -0
  22. hessian_ff/terms/nonbonded.py +209 -0
  23. hessian_ff/tests/__init__.py +0 -0
  24. hessian_ff/tests/conftest.py +75 -0
  25. hessian_ff/tests/data/small/complex.parm7 +1346 -0
  26. hessian_ff/tests/data/small/complex.pdb +125 -0
  27. hessian_ff/tests/data/small/complex.rst7 +63 -0
  28. hessian_ff/tests/test_coords_input.py +44 -0
  29. hessian_ff/tests/test_energy_force.py +49 -0
  30. hessian_ff/tests/test_hessian.py +137 -0
  31. hessian_ff/tests/test_smoke.py +18 -0
  32. hessian_ff/tests/test_validation.py +40 -0
  33. hessian_ff/workflows.py +889 -0
  34. mlmm/__init__.py +36 -0
  35. mlmm/__main__.py +7 -0
  36. mlmm/_version.py +34 -0
  37. mlmm/add_elem_info.py +374 -0
  38. mlmm/advanced_help.py +91 -0
  39. mlmm/align_freeze_atoms.py +601 -0
  40. mlmm/all.py +3535 -0
  41. mlmm/bond_changes.py +231 -0
  42. mlmm/bool_compat.py +223 -0
  43. mlmm/cli.py +574 -0
  44. mlmm/cli_utils.py +166 -0
  45. mlmm/default_group.py +337 -0
  46. mlmm/defaults.py +467 -0
  47. mlmm/define_layer.py +526 -0
  48. mlmm/dft.py +1041 -0
  49. mlmm/energy_diagram.py +253 -0
  50. mlmm/extract.py +2213 -0
  51. mlmm/fix_altloc.py +464 -0
  52. mlmm/freq.py +1406 -0
  53. mlmm/harmonic_constraints.py +140 -0
  54. mlmm/hessian_cache.py +44 -0
  55. mlmm/hessian_calc.py +174 -0
  56. mlmm/irc.py +638 -0
  57. mlmm/mlmm_calc.py +2262 -0
  58. mlmm/mm_parm.py +945 -0
  59. mlmm/oniom_export.py +1983 -0
  60. mlmm/oniom_import.py +457 -0
  61. mlmm/opt.py +1742 -0
  62. mlmm/path_opt.py +1353 -0
  63. mlmm/path_search.py +2299 -0
  64. mlmm/preflight.py +88 -0
  65. mlmm/py.typed +1 -0
  66. mlmm/pysis_runner.py +45 -0
  67. mlmm/scan.py +1047 -0
  68. mlmm/scan2d.py +1226 -0
  69. mlmm/scan3d.py +1265 -0
  70. mlmm/scan_common.py +184 -0
  71. mlmm/summary_log.py +736 -0
  72. mlmm/trj2fig.py +448 -0
  73. mlmm/tsopt.py +2871 -0
  74. mlmm/utils.py +2309 -0
  75. mlmm/xtb_embedcharge_correction.py +475 -0
  76. mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
  77. mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
  78. mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
  79. mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
  80. mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
  81. mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
  82. pysisyphus/Geometry.py +1667 -0
  83. pysisyphus/LICENSE +674 -0
  84. pysisyphus/TableFormatter.py +63 -0
  85. pysisyphus/TablePrinter.py +74 -0
  86. pysisyphus/__init__.py +12 -0
  87. pysisyphus/calculators/AFIR.py +452 -0
  88. pysisyphus/calculators/AnaPot.py +20 -0
  89. pysisyphus/calculators/AnaPot2.py +48 -0
  90. pysisyphus/calculators/AnaPot3.py +12 -0
  91. pysisyphus/calculators/AnaPot4.py +20 -0
  92. pysisyphus/calculators/AnaPotBase.py +337 -0
  93. pysisyphus/calculators/AnaPotCBM.py +25 -0
  94. pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
  95. pysisyphus/calculators/CFOUR.py +250 -0
  96. pysisyphus/calculators/Calculator.py +844 -0
  97. pysisyphus/calculators/CerjanMiller.py +24 -0
  98. pysisyphus/calculators/Composite.py +123 -0
  99. pysisyphus/calculators/ConicalIntersection.py +171 -0
  100. pysisyphus/calculators/DFTBp.py +430 -0
  101. pysisyphus/calculators/DFTD3.py +66 -0
  102. pysisyphus/calculators/DFTD4.py +84 -0
  103. pysisyphus/calculators/Dalton.py +61 -0
  104. pysisyphus/calculators/Dimer.py +681 -0
  105. pysisyphus/calculators/Dummy.py +20 -0
  106. pysisyphus/calculators/EGO.py +76 -0
  107. pysisyphus/calculators/EnergyMin.py +224 -0
  108. pysisyphus/calculators/ExternalPotential.py +264 -0
  109. pysisyphus/calculators/FakeASE.py +35 -0
  110. pysisyphus/calculators/FourWellAnaPot.py +28 -0
  111. pysisyphus/calculators/FreeEndNEBPot.py +39 -0
  112. pysisyphus/calculators/Gaussian09.py +18 -0
  113. pysisyphus/calculators/Gaussian16.py +726 -0
  114. pysisyphus/calculators/HardSphere.py +159 -0
  115. pysisyphus/calculators/IDPPCalculator.py +49 -0
  116. pysisyphus/calculators/IPIClient.py +133 -0
  117. pysisyphus/calculators/IPIServer.py +234 -0
  118. pysisyphus/calculators/LEPSBase.py +24 -0
  119. pysisyphus/calculators/LEPSExpr.py +139 -0
  120. pysisyphus/calculators/LennardJones.py +80 -0
  121. pysisyphus/calculators/MOPAC.py +219 -0
  122. pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
  123. pysisyphus/calculators/MultiCalc.py +85 -0
  124. pysisyphus/calculators/NFK.py +45 -0
  125. pysisyphus/calculators/OBabel.py +87 -0
  126. pysisyphus/calculators/ONIOMv2.py +1129 -0
  127. pysisyphus/calculators/ORCA.py +893 -0
  128. pysisyphus/calculators/ORCA5.py +6 -0
  129. pysisyphus/calculators/OpenMM.py +88 -0
  130. pysisyphus/calculators/OpenMolcas.py +281 -0
  131. pysisyphus/calculators/OverlapCalculator.py +908 -0
  132. pysisyphus/calculators/Psi4.py +218 -0
  133. pysisyphus/calculators/PyPsi4.py +37 -0
  134. pysisyphus/calculators/PySCF.py +341 -0
  135. pysisyphus/calculators/PyXTB.py +73 -0
  136. pysisyphus/calculators/QCEngine.py +106 -0
  137. pysisyphus/calculators/Rastrigin.py +22 -0
  138. pysisyphus/calculators/Remote.py +76 -0
  139. pysisyphus/calculators/Rosenbrock.py +15 -0
  140. pysisyphus/calculators/SocketCalc.py +97 -0
  141. pysisyphus/calculators/TIP3P.py +111 -0
  142. pysisyphus/calculators/TransTorque.py +161 -0
  143. pysisyphus/calculators/Turbomole.py +965 -0
  144. pysisyphus/calculators/VRIPot.py +37 -0
  145. pysisyphus/calculators/WFOWrapper.py +333 -0
  146. pysisyphus/calculators/WFOWrapper2.py +341 -0
  147. pysisyphus/calculators/XTB.py +418 -0
  148. pysisyphus/calculators/__init__.py +81 -0
  149. pysisyphus/calculators/cosmo_data.py +139 -0
  150. pysisyphus/calculators/parser.py +150 -0
  151. pysisyphus/color.py +19 -0
  152. pysisyphus/config.py +133 -0
  153. pysisyphus/constants.py +65 -0
  154. pysisyphus/cos/AdaptiveNEB.py +230 -0
  155. pysisyphus/cos/ChainOfStates.py +725 -0
  156. pysisyphus/cos/FreeEndNEB.py +25 -0
  157. pysisyphus/cos/FreezingString.py +103 -0
  158. pysisyphus/cos/GrowingChainOfStates.py +71 -0
  159. pysisyphus/cos/GrowingNT.py +309 -0
  160. pysisyphus/cos/GrowingString.py +508 -0
  161. pysisyphus/cos/NEB.py +189 -0
  162. pysisyphus/cos/SimpleZTS.py +64 -0
  163. pysisyphus/cos/__init__.py +22 -0
  164. pysisyphus/cos/stiffness.py +199 -0
  165. pysisyphus/drivers/__init__.py +17 -0
  166. pysisyphus/drivers/afir.py +855 -0
  167. pysisyphus/drivers/barriers.py +271 -0
  168. pysisyphus/drivers/birkholz.py +138 -0
  169. pysisyphus/drivers/cluster.py +318 -0
  170. pysisyphus/drivers/diabatization.py +133 -0
  171. pysisyphus/drivers/merge.py +368 -0
  172. pysisyphus/drivers/merge_mol2.py +322 -0
  173. pysisyphus/drivers/opt.py +375 -0
  174. pysisyphus/drivers/perf.py +91 -0
  175. pysisyphus/drivers/pka.py +52 -0
  176. pysisyphus/drivers/precon_pos_rot.py +669 -0
  177. pysisyphus/drivers/rates.py +480 -0
  178. pysisyphus/drivers/replace.py +219 -0
  179. pysisyphus/drivers/scan.py +212 -0
  180. pysisyphus/drivers/spectrum.py +166 -0
  181. pysisyphus/drivers/thermo.py +31 -0
  182. pysisyphus/dynamics/Gaussian.py +103 -0
  183. pysisyphus/dynamics/__init__.py +20 -0
  184. pysisyphus/dynamics/colvars.py +136 -0
  185. pysisyphus/dynamics/driver.py +297 -0
  186. pysisyphus/dynamics/helpers.py +256 -0
  187. pysisyphus/dynamics/lincs.py +105 -0
  188. pysisyphus/dynamics/mdp.py +364 -0
  189. pysisyphus/dynamics/rattle.py +121 -0
  190. pysisyphus/dynamics/thermostats.py +128 -0
  191. pysisyphus/dynamics/wigner.py +266 -0
  192. pysisyphus/elem_data.py +3473 -0
  193. pysisyphus/exceptions.py +2 -0
  194. pysisyphus/filtertrj.py +69 -0
  195. pysisyphus/helpers.py +623 -0
  196. pysisyphus/helpers_pure.py +649 -0
  197. pysisyphus/init_logging.py +50 -0
  198. pysisyphus/intcoords/Bend.py +69 -0
  199. pysisyphus/intcoords/Bend2.py +25 -0
  200. pysisyphus/intcoords/BondedFragment.py +32 -0
  201. pysisyphus/intcoords/Cartesian.py +41 -0
  202. pysisyphus/intcoords/CartesianCoords.py +140 -0
  203. pysisyphus/intcoords/Coords.py +56 -0
  204. pysisyphus/intcoords/DLC.py +197 -0
  205. pysisyphus/intcoords/DistanceFunction.py +34 -0
  206. pysisyphus/intcoords/DummyImproper.py +70 -0
  207. pysisyphus/intcoords/DummyTorsion.py +72 -0
  208. pysisyphus/intcoords/LinearBend.py +105 -0
  209. pysisyphus/intcoords/LinearDisplacement.py +80 -0
  210. pysisyphus/intcoords/OutOfPlane.py +59 -0
  211. pysisyphus/intcoords/PrimTypes.py +286 -0
  212. pysisyphus/intcoords/Primitive.py +137 -0
  213. pysisyphus/intcoords/RedundantCoords.py +659 -0
  214. pysisyphus/intcoords/RobustTorsion.py +59 -0
  215. pysisyphus/intcoords/Rotation.py +147 -0
  216. pysisyphus/intcoords/Stretch.py +31 -0
  217. pysisyphus/intcoords/Torsion.py +101 -0
  218. pysisyphus/intcoords/Torsion2.py +25 -0
  219. pysisyphus/intcoords/Translation.py +45 -0
  220. pysisyphus/intcoords/__init__.py +61 -0
  221. pysisyphus/intcoords/augment_bonds.py +126 -0
  222. pysisyphus/intcoords/derivatives.py +10512 -0
  223. pysisyphus/intcoords/eval.py +80 -0
  224. pysisyphus/intcoords/exceptions.py +37 -0
  225. pysisyphus/intcoords/findiffs.py +48 -0
  226. pysisyphus/intcoords/generate_derivatives.py +414 -0
  227. pysisyphus/intcoords/helpers.py +235 -0
  228. pysisyphus/intcoords/logging_conf.py +10 -0
  229. pysisyphus/intcoords/mp_derivatives.py +10836 -0
  230. pysisyphus/intcoords/setup.py +962 -0
  231. pysisyphus/intcoords/setup_fast.py +176 -0
  232. pysisyphus/intcoords/update.py +272 -0
  233. pysisyphus/intcoords/valid.py +89 -0
  234. pysisyphus/interpolate/Geodesic.py +93 -0
  235. pysisyphus/interpolate/IDPP.py +55 -0
  236. pysisyphus/interpolate/Interpolator.py +116 -0
  237. pysisyphus/interpolate/LST.py +70 -0
  238. pysisyphus/interpolate/Redund.py +152 -0
  239. pysisyphus/interpolate/__init__.py +9 -0
  240. pysisyphus/interpolate/helpers.py +34 -0
  241. pysisyphus/io/__init__.py +22 -0
  242. pysisyphus/io/aomix.py +178 -0
  243. pysisyphus/io/cjson.py +24 -0
  244. pysisyphus/io/crd.py +101 -0
  245. pysisyphus/io/cube.py +220 -0
  246. pysisyphus/io/fchk.py +184 -0
  247. pysisyphus/io/hdf5.py +49 -0
  248. pysisyphus/io/hessian.py +72 -0
  249. pysisyphus/io/mol2.py +146 -0
  250. pysisyphus/io/molden.py +293 -0
  251. pysisyphus/io/orca.py +189 -0
  252. pysisyphus/io/pdb.py +269 -0
  253. pysisyphus/io/psf.py +79 -0
  254. pysisyphus/io/pubchem.py +31 -0
  255. pysisyphus/io/qcschema.py +34 -0
  256. pysisyphus/io/sdf.py +29 -0
  257. pysisyphus/io/xyz.py +61 -0
  258. pysisyphus/io/zmat.py +175 -0
  259. pysisyphus/irc/DWI.py +108 -0
  260. pysisyphus/irc/DampedVelocityVerlet.py +134 -0
  261. pysisyphus/irc/Euler.py +22 -0
  262. pysisyphus/irc/EulerPC.py +345 -0
  263. pysisyphus/irc/GonzalezSchlegel.py +187 -0
  264. pysisyphus/irc/IMKMod.py +164 -0
  265. pysisyphus/irc/IRC.py +878 -0
  266. pysisyphus/irc/IRCDummy.py +10 -0
  267. pysisyphus/irc/Instanton.py +307 -0
  268. pysisyphus/irc/LQA.py +53 -0
  269. pysisyphus/irc/ModeKill.py +136 -0
  270. pysisyphus/irc/ParamPlot.py +53 -0
  271. pysisyphus/irc/RK4.py +36 -0
  272. pysisyphus/irc/__init__.py +31 -0
  273. pysisyphus/irc/initial_displ.py +219 -0
  274. pysisyphus/linalg.py +411 -0
  275. pysisyphus/line_searches/Backtracking.py +88 -0
  276. pysisyphus/line_searches/HagerZhang.py +184 -0
  277. pysisyphus/line_searches/LineSearch.py +232 -0
  278. pysisyphus/line_searches/StrongWolfe.py +108 -0
  279. pysisyphus/line_searches/__init__.py +9 -0
  280. pysisyphus/line_searches/interpol.py +15 -0
  281. pysisyphus/modefollow/NormalMode.py +40 -0
  282. pysisyphus/modefollow/__init__.py +10 -0
  283. pysisyphus/modefollow/davidson.py +199 -0
  284. pysisyphus/modefollow/lanczos.py +95 -0
  285. pysisyphus/optimizers/BFGS.py +99 -0
  286. pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
  287. pysisyphus/optimizers/ConjugateGradient.py +98 -0
  288. pysisyphus/optimizers/CubicNewton.py +75 -0
  289. pysisyphus/optimizers/FIRE.py +113 -0
  290. pysisyphus/optimizers/HessianOptimizer.py +1176 -0
  291. pysisyphus/optimizers/LBFGS.py +228 -0
  292. pysisyphus/optimizers/LayerOpt.py +411 -0
  293. pysisyphus/optimizers/MicroOptimizer.py +169 -0
  294. pysisyphus/optimizers/NCOptimizer.py +90 -0
  295. pysisyphus/optimizers/Optimizer.py +1084 -0
  296. pysisyphus/optimizers/PreconLBFGS.py +260 -0
  297. pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
  298. pysisyphus/optimizers/QuickMin.py +74 -0
  299. pysisyphus/optimizers/RFOptimizer.py +181 -0
  300. pysisyphus/optimizers/RSA.py +99 -0
  301. pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
  302. pysisyphus/optimizers/SteepestDescent.py +23 -0
  303. pysisyphus/optimizers/StringOptimizer.py +173 -0
  304. pysisyphus/optimizers/__init__.py +41 -0
  305. pysisyphus/optimizers/closures.py +301 -0
  306. pysisyphus/optimizers/cls_map.py +58 -0
  307. pysisyphus/optimizers/exceptions.py +6 -0
  308. pysisyphus/optimizers/gdiis.py +280 -0
  309. pysisyphus/optimizers/guess_hessians.py +311 -0
  310. pysisyphus/optimizers/hessian_updates.py +355 -0
  311. pysisyphus/optimizers/poly_fit.py +285 -0
  312. pysisyphus/optimizers/precon.py +153 -0
  313. pysisyphus/optimizers/restrict_step.py +24 -0
  314. pysisyphus/pack.py +172 -0
  315. pysisyphus/peakdetect.py +948 -0
  316. pysisyphus/plot.py +1031 -0
  317. pysisyphus/run.py +2106 -0
  318. pysisyphus/socket_helper.py +74 -0
  319. pysisyphus/stocastic/FragmentKick.py +132 -0
  320. pysisyphus/stocastic/Kick.py +81 -0
  321. pysisyphus/stocastic/Pipeline.py +303 -0
  322. pysisyphus/stocastic/__init__.py +21 -0
  323. pysisyphus/stocastic/align.py +127 -0
  324. pysisyphus/testing.py +96 -0
  325. pysisyphus/thermo.py +156 -0
  326. pysisyphus/trj.py +824 -0
  327. pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
  328. pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
  329. pysisyphus/tsoptimizers/TRIM.py +59 -0
  330. pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
  331. pysisyphus/tsoptimizers/__init__.py +23 -0
  332. pysisyphus/wavefunction/Basis.py +239 -0
  333. pysisyphus/wavefunction/DIIS.py +76 -0
  334. pysisyphus/wavefunction/__init__.py +25 -0
  335. pysisyphus/wavefunction/build_ext.py +42 -0
  336. pysisyphus/wavefunction/cart2sph.py +190 -0
  337. pysisyphus/wavefunction/diabatization.py +304 -0
  338. pysisyphus/wavefunction/excited_states.py +435 -0
  339. pysisyphus/wavefunction/gen_ints.py +1811 -0
  340. pysisyphus/wavefunction/helpers.py +104 -0
  341. pysisyphus/wavefunction/ints/__init__.py +0 -0
  342. pysisyphus/wavefunction/ints/boys.py +193 -0
  343. pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
  344. pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
  345. pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
  346. pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
  347. pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
  348. pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
  349. pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
  350. pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
  351. pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
  352. pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
  353. pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
  354. pysisyphus/wavefunction/localization.py +458 -0
  355. pysisyphus/wavefunction/multipole.py +159 -0
  356. pysisyphus/wavefunction/normalization.py +36 -0
  357. pysisyphus/wavefunction/pop_analysis.py +134 -0
  358. pysisyphus/wavefunction/shells.py +1171 -0
  359. pysisyphus/wavefunction/wavefunction.py +504 -0
  360. pysisyphus/wrapper/__init__.py +11 -0
  361. pysisyphus/wrapper/exceptions.py +2 -0
  362. pysisyphus/wrapper/jmol.py +120 -0
  363. pysisyphus/wrapper/mwfn.py +169 -0
  364. pysisyphus/wrapper/packmol.py +71 -0
  365. pysisyphus/xyzloader.py +168 -0
  366. pysisyphus/yaml_mods.py +45 -0
  367. thermoanalysis/LICENSE +674 -0
  368. thermoanalysis/QCData.py +244 -0
  369. thermoanalysis/__init__.py +0 -0
  370. thermoanalysis/config.py +3 -0
  371. thermoanalysis/constants.py +20 -0
  372. thermoanalysis/thermo.py +1011 -0
mlmm/path_search.py ADDED
@@ -0,0 +1,2299 @@
1
+ # mlmm/path_search.py
2
+
3
+ """
4
+ ML/MM recursive GSM segmentation for multistep minimum-energy paths.
5
+
6
+ Example:
7
+ mlmm path-search -i R.pdb P.pdb --parm real.parm7 --model-pdb ml_region.pdb -q 0
8
+
9
+ For detailed documentation, see: docs/path_search.md
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from copy import deepcopy
15
+ from dataclasses import dataclass
16
+ from pathlib import Path
17
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
18
+
19
+ import gc
20
+ import logging
21
+ import sys
22
+ import traceback
23
+ import textwrap
24
+ import tempfile
25
+ import os
26
+
27
+ logger = logging.getLogger(__name__)
28
+ import time # timing
29
+ import re # used in _segment_base_id
30
+
31
+ import click
32
+ import numpy as np
33
+ import torch
34
+ import yaml
35
+
36
+ from pysisyphus.helpers import geom_loader
37
+ from pysisyphus.cos.GrowingString import GrowingString
38
+ from pysisyphus.optimizers.StringOptimizer import StringOptimizer
39
+ from pysisyphus.optimizers.LBFGS import LBFGS
40
+ from pysisyphus.optimizers.exceptions import OptimizationError, ZeroStepLength
41
+ from pysisyphus.constants import AU2KCALPERMOL, BOHR2ANG
42
+
43
+
44
+ from .mlmm_calc import mlmm, MLMMASECalculator
45
+ from .defaults import (
46
+ BOND_KW as _BOND_KW_DEFAULT,
47
+ SEARCH_KW as _SEARCH_KW_DEFAULT,
48
+ )
49
+ from .path_opt import GS_KW as _PATH_GS_KW, STOPT_KW as _PATH_STOPT_KW, DMF_KW as _PATH_DMF_KW, _select_hei_index
50
+ from .opt import (
51
+ GEOM_KW as _OPT_GEOM_KW,
52
+ CALC_KW as _OPT_CALC_KW,
53
+ LBFGS_KW as _OPT_LBFGS_KW,
54
+ _parse_freeze_atoms as _parse_freeze_atoms_opt,
55
+ _normalize_geom_freeze as _normalize_geom_freeze_opt,
56
+ )
57
+ from .utils import (
58
+ apply_layer_freeze_constraints,
59
+ apply_ref_pdb_override,
60
+ convert_xyz_to_pdb,
61
+ set_convert_file_enabled,
62
+ load_yaml_dict,
63
+ deep_update,
64
+ apply_yaml_overrides,
65
+ pretty_block,
66
+ strip_inherited_keys,
67
+ filter_calc_for_echo,
68
+ format_freeze_atoms_for_echo,
69
+ format_elapsed,
70
+ merge_freeze_atom_indices,
71
+ build_energy_diagram,
72
+ prepare_input_structure,
73
+ resolve_charge_spin_or_raise,
74
+ PreparedInputStructure,
75
+ parse_indices_string,
76
+ build_model_pdb_from_bfactors,
77
+ build_model_pdb_from_indices,
78
+ read_bfactors_from_pdb,
79
+ has_valid_layer_bfactors,
80
+ parse_layer_indices_from_bfactors,
81
+ collect_single_option_values,
82
+ )
83
+ from .cli_utils import resolve_yaml_sources, load_merged_yaml_cfg, make_is_param_explicit
84
+ from .preflight import validate_existing_files
85
+ from .trj2fig import run_trj2fig # auto-generate an energy plot when a _trj.xyz is produced
86
+ from .summary_log import write_summary_log
87
+ from .bond_changes import compare_structures, summarize_changes
88
+ from .align_freeze_atoms import align_and_refine_sequence_inplace
89
+
90
+ # -----------------------------------------------
91
+ # Configuration defaults
92
+ # -----------------------------------------------
93
+
94
+ # Geometry (input handling) — reuse opt.py defaults
95
+ GEOM_KW: Dict[str, Any] = deepcopy(_OPT_GEOM_KW)
96
+
97
+ # ML/MM calculator settings — reuse opt.py defaults
98
+ CALC_KW: Dict[str, Any] = deepcopy(_OPT_CALC_KW)
99
+
100
+ # GrowingString (path representation)
101
+ GS_KW: Dict[str, Any] = deepcopy(_PATH_GS_KW)
102
+
103
+ # StringOptimizer (GSM optimization control)
104
+ STOPT_KW: Dict[str, Any] = deepcopy(_PATH_STOPT_KW)
105
+ STOPT_KW.update({
106
+ "out_dir": "./result_path_search/",
107
+ })
108
+
109
+ # LBFGS settings
110
+ LBFGS_KW: Dict[str, Any] = deepcopy(_OPT_LBFGS_KW)
111
+ LBFGS_KW.update({
112
+ "out_dir": "./result_path_search/",
113
+ })
114
+
115
+ # Covalent-bond change detection
116
+ BOND_KW: Dict[str, Any] = deepcopy(_BOND_KW_DEFAULT)
117
+
118
+ # DMF (Direct Max Flux) defaults
119
+ DMF_KW: Dict[str, Any] = deepcopy(_PATH_DMF_KW)
120
+
121
+ # Global search control
122
+ SEARCH_KW: Dict[str, Any] = deepcopy(_SEARCH_KW_DEFAULT)
123
+
124
+ # Multi-structure loader
125
+ def _load_structures(
126
+ inputs: Sequence[PreparedInputStructure],
127
+ coord_type: str,
128
+ base_freeze: Sequence[int],
129
+ ) -> List[Any]:
130
+ """
131
+ Load multiple geometries and assign `freeze_atoms`; return a list of geometries.
132
+ """
133
+ geoms: List[Any] = []
134
+ for prepared in inputs:
135
+ geom_path = prepared.geom_path
136
+ g = geom_loader(geom_path, coord_type=coord_type)
137
+ cfg: Dict[str, Any] = {"freeze_atoms": list(base_freeze)}
138
+ freeze = merge_freeze_atom_indices(cfg)
139
+ g.freeze_atoms = np.array(freeze, dtype=int)
140
+ geoms.append(g)
141
+ return geoms
142
+
143
+
144
+ # Helpers shared with opt.py for freeze parsing/normalization
145
+ _parse_freeze_atoms = _parse_freeze_atoms_opt
146
+ _normalize_geom_freeze = _normalize_geom_freeze_opt
147
+
148
+
149
+ def _write_xyz_trj_with_energy(images: Sequence, energies: Sequence[float], path: Path) -> None:
150
+ """
151
+ Write an XYZ `_trj.xyz` with the energy on line 2 of each block.
152
+ """
153
+ blocks: List[str] = []
154
+ E = np.array(energies, dtype=float)
155
+ for geom, e in zip(images, E):
156
+ s = geom.as_xyz()
157
+ lines = s.splitlines()
158
+ if len(lines) >= 2 and lines[0].strip().isdigit():
159
+ lines[1] = f"{e:.12f}"
160
+ s_mod = "\n".join(lines)
161
+ if not s_mod.endswith("\n"):
162
+ s_mod += "\n"
163
+ blocks.append(s_mod)
164
+ with open(path, "w") as f:
165
+ f.write("".join(blocks))
166
+
167
+
168
+ def _maybe_convert_to_pdb(in_path: Path, ref_pdb_path: Optional[Path], out_path: Optional[Path] = None) -> Optional[Path]:
169
+ """
170
+ If any input is PDB, convert the given `.xyz/_trj.xyz` to PDB using `ref_pdb_path`.
171
+ Return the output path on success, else None.
172
+ """
173
+ try:
174
+ if ref_pdb_path is None or (not in_path.exists()) or in_path.suffix.lower() not in (".xyz", "_trj.xyz"):
175
+ return None
176
+ out_pdb = out_path if out_path is not None else in_path.with_suffix(".pdb")
177
+ convert_xyz_to_pdb(in_path, ref_pdb_path, out_pdb)
178
+ click.echo(f"[convert] Wrote '{out_pdb}'.")
179
+ return out_pdb
180
+ except Exception as e:
181
+ click.echo(f"[convert] WARNING: Failed to convert '{in_path.name}' to PDB: {e}", err=True)
182
+ return None
183
+
184
+
185
+ def _kabsch_rmsd(A: np.ndarray, B: np.ndarray, align: bool = True, indices: Optional[Sequence[int]] = None) -> float:
186
+ """
187
+ RMSD between A and B (no rigid alignment; `align` is ignored). Optional subset selection via `indices`.
188
+ """
189
+ assert A.shape == B.shape and A.shape[1] == 3
190
+ if indices is not None and len(indices) > 0:
191
+ idx = np.array(sorted({int(i) for i in indices if 0 <= int(i) < A.shape[0]}), dtype=int)
192
+ if idx.size == 0:
193
+ idx = np.arange(A.shape[0], dtype=int)
194
+ A = A[idx]
195
+ B = B[idx]
196
+ diff = A - B
197
+ return float(np.sqrt((diff * diff).sum() / A.shape[0]))
198
+
199
+
200
+
201
+
202
+ def _has_bond_change(x, y, bond_cfg: Dict[str, Any]) -> Tuple[bool, str]:
203
+ """
204
+ Determine whether covalent bonds are forming or breaking between `x` and `y`.
205
+ """
206
+ res = compare_structures(
207
+ x, y,
208
+ device=bond_cfg.get("device", "cuda"),
209
+ bond_factor=float(bond_cfg.get("bond_factor", 1.20)),
210
+ margin_fraction=float(bond_cfg.get("margin_fraction", 0.05)),
211
+ delta_fraction=float(bond_cfg.get("delta_fraction", 0.05)),
212
+ )
213
+ formed = len(res.formed_covalent) > 0
214
+ broken = len(res.broken_covalent) > 0
215
+ summary = summarize_changes(x, res, one_based=True)
216
+ return (formed or broken), summary
217
+
218
+
219
+ # ---------- Minimal GS configuration helper ----------
220
+
221
+
222
+ # -----------------------------------------------
223
+ # Kink detection & interpolation helpers
224
+ # -----------------------------------------------
225
+
226
+ def _new_geom_from_coords(atoms: Sequence[str], coords: np.ndarray, coord_type: str, freeze_atoms: Sequence[int]) -> Any:
227
+ """
228
+ Create a pysisyphus Geometry from Bohr coords via temporary XYZ; attach `freeze_atoms`.
229
+ """
230
+ lines = [str(len(atoms)), ""]
231
+ coords_ang = np.asarray(coords, dtype=float) * BOHR2ANG
232
+ for sym, (x, y, z) in zip(atoms, coords_ang):
233
+ lines.append(f"{sym} {x:.15f} {y:.15f} {z:.15f}")
234
+ s = "\n".join(lines) + "\n"
235
+ tmp = tempfile.NamedTemporaryFile("w+", suffix=".xyz", delete=False)
236
+ try:
237
+ tmp.write(s)
238
+ tmp.flush()
239
+ tmp.close()
240
+ g = geom_loader(Path(tmp.name), coord_type=coord_type)
241
+ g.freeze_atoms = np.array(sorted(set(map(int, freeze_atoms))), dtype=int)
242
+ return g
243
+ finally:
244
+ try:
245
+ os.unlink(tmp.name)
246
+ except Exception:
247
+ logger.debug("Failed to unlink temp file %s", tmp.name, exc_info=True)
248
+
249
+
250
+ def _make_linear_interpolations(gL, gR, n_internal: int) -> List[Any]:
251
+ """
252
+ Return `n_internal` linearly interpolated structures between gL → gR (excluding endpoints).
253
+ Atom order follows `gL`.
254
+ """
255
+ A = np.asarray(gL.coords3d, dtype=float)
256
+ B = np.asarray(gR.coords3d, dtype=float)
257
+ assert A.shape == B.shape and A.shape[1] == 3, "Atom counts must match for interpolation."
258
+ atoms = [a for a in gL.atoms]
259
+ coord_type = gL.coord_type
260
+ faL = getattr(gL, "freeze_atoms", np.array([], dtype=int))
261
+ faR = getattr(gR, "freeze_atoms", np.array([], dtype=int))
262
+ freeze_union = sorted(set(map(int, faL)) | set(map(int, faR)))
263
+ interps: List[Any] = []
264
+ for k in range(1, n_internal + 1):
265
+ t = k / (n_internal + 1.0)
266
+ C = (1.0 - t) * A + t * B
267
+ interps.append(_new_geom_from_coords(atoms, C, coord_type, freeze_union))
268
+ return interps
269
+
270
+
271
+ # ---- Segment/bridge tagging helpers ----
272
+
273
+ def _tag_images(images: Sequence[Any], **attrs: Any) -> None:
274
+ """
275
+ Attach arbitrary attributes to Geometry images.
276
+ """
277
+ for im in images:
278
+ for k, v in attrs.items():
279
+ try:
280
+ setattr(im, k, v)
281
+ except Exception:
282
+ logger.debug("Failed to set attribute %s on image", k, exc_info=True)
283
+
284
+
285
+ def _segment_base_id(tag: str) -> str:
286
+ """
287
+ Extract base id 'seg_XXX' from a tag like 'seg_000_refine'; fallback to `tag` or 'seg'.
288
+ """
289
+ m = re.search(r"(seg_\d{3})", tag or "")
290
+ return m.group(1) if m else (tag or "seg")
291
+
292
+
293
+ def _is_local_minimum(idx: int, energies: Sequence[float]) -> bool:
294
+ if idx < 0 or idx >= len(energies):
295
+ return False
296
+ if idx == 0:
297
+ return len(energies) > 1 and energies[1] > energies[0]
298
+ if idx == len(energies) - 1:
299
+ return energies[-2] > energies[-1]
300
+ return energies[idx - 1] > energies[idx] and energies[idx + 1] > energies[idx]
301
+
302
+
303
+ def _find_nearest_local_minimum(
304
+ hei_idx: int,
305
+ direction: int,
306
+ energies: Sequence[float],
307
+ ) -> Optional[int]:
308
+ i = hei_idx + direction
309
+ while 0 <= i < len(energies):
310
+ if _is_local_minimum(i, energies):
311
+ return i
312
+ i += direction
313
+ return None
314
+
315
+
316
+ @dataclass
317
+ class GSMResult:
318
+ images: List[Any]
319
+ energies: List[float]
320
+ hei_idx: int
321
+
322
+
323
+ # ---- Per-segment summary for the console report ----
324
+ @dataclass
325
+ class SegmentReport:
326
+ tag: str
327
+ barrier_kcal: float
328
+ delta_kcal: float
329
+ summary: str # summarize_changes string (empty for bridges)
330
+ kind: str = "seg" # "seg" or "bridge"
331
+ seg_index: int = 0 # 1-based index along final MEP (assigned later)
332
+
333
+
334
+ def _run_gsm_between(
335
+ gA,
336
+ gB,
337
+ shared_calc,
338
+ gs_cfg: Dict[str, Any],
339
+ stopt_cfg: Dict[str, Any],
340
+ out_dir: Path,
341
+ tag: str,
342
+ ref_pdb_path: Optional[Path], # reference PDB for conversion
343
+ ) -> GSMResult:
344
+ """
345
+ Run GSM between `gA`–`gB`, save segment outputs, and return images/energies/HEI index.
346
+ """
347
+ # Attach calculator to endpoints
348
+ for g in (gA, gB):
349
+ g.set_calculator(shared_calc)
350
+
351
+ gs = GrowingString(
352
+ images=[gA, gB],
353
+ calc_getter=(lambda: shared_calc),
354
+ **gs_cfg,
355
+ )
356
+
357
+ _opt_args = dict(stopt_cfg)
358
+ seg_dir = out_dir / f"{tag}_mep"
359
+ seg_dir.mkdir(parents=True, exist_ok=True)
360
+ _opt_args["out_dir"] = str(seg_dir)
361
+
362
+ optimizer = StringOptimizer(
363
+ geometry=gs,
364
+ **{k: v for k, v in _opt_args.items() if k != "type"}
365
+ )
366
+
367
+ click.echo(f"\n=== [{tag}] GSM started ===\n")
368
+ optimizer.run()
369
+ click.echo(f"\n=== [{tag}] GSM finished ===\n")
370
+
371
+ energies = list(map(float, np.array(gs.energy, dtype=float)))
372
+ images = list(gs.images)
373
+
374
+ # Choose HEI: prefer internal local maxima; fallback to highest internal node
375
+ E = np.array(energies, dtype=float)
376
+ nE = len(E)
377
+ local_max_candidates = [i for i in range(1, nE - 1) if (E[i] > E[i - 1] and E[i] > E[i + 1])]
378
+ if local_max_candidates:
379
+ hei_idx = int(max(local_max_candidates, key=lambda i: E[i]))
380
+ else:
381
+ hei_idx = int(np.argmax(E[1:-1])) + 1 if nE >= 3 else int(np.argmax(E))
382
+
383
+ # Write trajectory
384
+ final_trj = seg_dir / "final_geometries_trj.xyz"
385
+ wrote_with_energy = True
386
+ try:
387
+ _write_xyz_trj_with_energy(images, energies, final_trj)
388
+ click.echo(f"[{tag}] Wrote '{final_trj}'.")
389
+ except Exception:
390
+ wrote_with_energy = False
391
+ with open(final_trj, "w") as f:
392
+ f.write(gs.as_xyz())
393
+ click.echo(f"[{tag}] Wrote '{final_trj}'.")
394
+
395
+ # Energy plot for the segment
396
+ try:
397
+ if wrote_with_energy:
398
+ run_trj2fig(final_trj, [seg_dir / "mep_plot.png"], unit="kcal", reference="init", reverse_x=False)
399
+ click.echo(f"[{tag}] Saved energy plot → '{seg_dir / 'mep_plot.png'}'")
400
+ else:
401
+ click.echo(f"[{tag}] WARNING: Energies missing; skipping plot.", err=True)
402
+ except Exception as e:
403
+ click.echo(f"[{tag}] WARNING: Failed to plot energy: {e}", err=True)
404
+
405
+ # If PDB input exists, convert intermediate _trj.xyz to PDB
406
+ _maybe_convert_to_pdb(final_trj, ref_pdb_path, seg_dir / "final_geometries.pdb")
407
+
408
+ # Write HEI structure (XYZ with energy in line 2)
409
+ try:
410
+ hei_geom = images[hei_idx]
411
+ hei_E = float(E[hei_idx])
412
+ hei_xyz = seg_dir / "hei.xyz"
413
+ s = hei_geom.as_xyz()
414
+ lines = s.splitlines()
415
+ if len(lines) >= 2 and lines[0].strip().isdigit():
416
+ lines[1] = f"{hei_E:.12f}"
417
+ s_out = "\n".join(lines)
418
+ if not s_out.endswith("\n"):
419
+ s_out += "\n"
420
+ else:
421
+ s_out = s if s.endswith("\n") else (s + "\n")
422
+ with open(hei_xyz, "w") as f:
423
+ f.write(s_out)
424
+ click.echo(f"[{tag}] Wrote '{hei_xyz}'.")
425
+ _maybe_convert_to_pdb(hei_xyz, ref_pdb_path, seg_dir / "hei.pdb")
426
+ except Exception as e:
427
+ click.echo(f"[{tag}] WARNING: Failed to write HEI structure: {e}", err=True)
428
+
429
+ return GSMResult(images=images, energies=energies, hei_idx=hei_idx)
430
+
431
+
432
+ def _run_dmf_between(
433
+ gA,
434
+ gB,
435
+ shared_calc,
436
+ calc_cfg: Dict[str, Any],
437
+ out_dir: Path,
438
+ tag: str,
439
+ ref_pdb_path: Optional[Path],
440
+ max_nodes: int,
441
+ dmf_cfg: Dict[str, Any],
442
+ ) -> GSMResult:
443
+ """Run DMF for a segment and convert outputs to pysisyphus Geometries."""
444
+ from pysisyphus.constants import ANG2BOHR
445
+ from ase.io import read as ase_read, write as ase_write
446
+ from io import StringIO
447
+
448
+ seg_dir = out_dir / f"{tag}_mep"
449
+ seg_dir.mkdir(parents=True, exist_ok=True)
450
+
451
+ fix_atoms: List[int] = []
452
+ try:
453
+ fix_atoms = sorted(
454
+ {int(i) for g in [gA, gB] for i in getattr(g, "freeze_atoms", [])}
455
+ )
456
+ except Exception:
457
+ logger.debug("Failed to extract freeze_atoms from endpoints", exc_info=True)
458
+
459
+ # Convert pysisyphus geometries to ASE Atoms for DMF
460
+ def _geom_to_ase(g):
461
+ return ase_read(StringIO(g.as_xyz()), format="xyz")
462
+
463
+ geoms_for_dmf = [gA, gB]
464
+
465
+ try:
466
+ from ase.calculators.mixing import SumCalculator
467
+ from dmf import DirectMaxFlux, interpolate_fbenm
468
+ except Exception as e:
469
+ raise RuntimeError(f"DMF mode requires pydmf and cyipopt: {e}") from e
470
+
471
+ from .harmonic_constraints import HarmonicFixAtoms
472
+ from .utils import deep_update, convert_xyz_to_pdb
473
+
474
+ ref_images = [_geom_to_ase(g) for g in geoms_for_dmf]
475
+ charge = int(calc_cfg.get("model_charge", 0))
476
+ spin = int(calc_cfg.get("model_mult", 1))
477
+ for img in ref_images:
478
+ img.info["charge"] = charge
479
+ img.info["spin"] = spin
480
+
481
+ # Build ASE calculator from the shared PySisyphus calculator
482
+ ase_calc = MLMMASECalculator(core=shared_calc.core)
483
+
484
+ dmf_cfg_local = deep_update(dict(DMF_KW), dmf_cfg)
485
+ fbenm_opts = dict(dmf_cfg_local.get("fbenm_options", {}))
486
+ cfbenm_opts = dict(dmf_cfg_local.get("cfbenm_options", {}))
487
+ dmf_opts = dict(dmf_cfg_local.get("dmf_options", {}))
488
+ update_teval = bool(dmf_opts.pop("update_teval", False))
489
+ k_fix = float(dmf_cfg_local.get("k_fix", 300.0))
490
+
491
+ mxflx_fbenm = interpolate_fbenm(
492
+ ref_images,
493
+ nmove=max(1, int(max_nodes)),
494
+ fbenm_only_endpoints=bool(dmf_cfg_local.get("fbenm_only_endpoints", False)),
495
+ correlated=bool(dmf_cfg_local.get("correlated", False)),
496
+ sequential=bool(dmf_cfg_local.get("sequential", False)),
497
+ output_file=str(seg_dir / "dmf_fbenm_ipopt.out"),
498
+ fbenm_options=fbenm_opts,
499
+ cfbenm_options=cfbenm_opts,
500
+ dmf_options=dmf_opts,
501
+ )
502
+ coefs = mxflx_fbenm.coefs.copy()
503
+
504
+ mxflx = DirectMaxFlux(
505
+ ref_images,
506
+ coefs=coefs,
507
+ nmove=max(1, int(max_nodes)),
508
+ update_teval=update_teval,
509
+ remove_rotation_and_translation=bool(dmf_opts.get("remove_rotation_and_translation", False)),
510
+ mass_weighted=bool(dmf_opts.get("mass_weighted", False)),
511
+ parallel=bool(dmf_opts.get("parallel", False)),
512
+ eps_vel=float(dmf_opts.get("eps_vel", 0.01)),
513
+ eps_rot=float(dmf_opts.get("eps_rot", 0.01)),
514
+ beta=float(dmf_opts.get("beta", 10.0)),
515
+ )
516
+
517
+ for image in mxflx.images:
518
+ if "charge" not in image.info:
519
+ image.info["charge"] = charge
520
+ if "spin" not in image.info:
521
+ image.info["spin"] = spin
522
+ if fix_atoms:
523
+ ref_positions = image.get_positions()[fix_atoms]
524
+ harmonic_calc = HarmonicFixAtoms(indices=fix_atoms, ref_positions=ref_positions, k_fix=k_fix)
525
+ image.calc = SumCalculator([ase_calc, harmonic_calc])
526
+ else:
527
+ image.calc = ase_calc
528
+
529
+ mxflx.add_ipopt_options({"output_file": str(seg_dir / "dmf_ipopt.out")})
530
+ max_cycles_dmf = dmf_cfg_local.get("max_cycles")
531
+ if max_cycles_dmf is not None:
532
+ try:
533
+ max_iter = int(max_cycles_dmf)
534
+ if max_iter > 0:
535
+ mxflx.add_ipopt_options({"max_iter": max_iter})
536
+ except Exception:
537
+ logger.debug("Failed to set ipopt max_iter option", exc_info=True)
538
+
539
+ click.echo(f"\n=== [{tag}] DMF started ===\n")
540
+ mxflx.solve(tol="tight")
541
+ click.echo(f"\n=== [{tag}] DMF finished ===\n")
542
+
543
+ # Evaluate energies using PySisyphus calculator
544
+ energies = []
545
+ for image in mxflx.images:
546
+ elems = image.get_chemical_symbols()
547
+ coords_bohr = np.asarray(image.get_positions(), dtype=float).reshape(-1, 3) * ANG2BOHR
548
+ energies.append(float(shared_calc.get_energy(elems, coords_bohr)["energy"]))
549
+ hei_idx = _select_hei_index(energies)
550
+
551
+ # Write trajectory
552
+ final_trj = seg_dir / "final_geometries_trj.xyz"
553
+ _write_xyz_trj_with_energy_from_ase(mxflx.images, energies, final_trj)
554
+ click.echo(f"[{tag}] Wrote '{final_trj}'.")
555
+ _maybe_convert_to_pdb(final_trj, ref_pdb_path, seg_dir / "final_geometries.pdb")
556
+
557
+ try:
558
+ run_trj2fig(final_trj, [seg_dir / "mep_plot.png"], unit="kcal", reference="init", reverse_x=False)
559
+ click.echo(f"[{tag}] Saved energy plot → '{seg_dir / 'mep_plot.png'}'")
560
+ except Exception as e:
561
+ click.echo(f"[{tag}] WARNING: Failed to plot energy: {e}", err=True)
562
+
563
+ # Convert ASE images back to pysisyphus Geometries
564
+ from pysisyphus.helpers import geom_loader as gl
565
+ imgs = []
566
+ for atoms in mxflx.images:
567
+ buf = StringIO()
568
+ ase_write(buf, atoms, format="xyz")
569
+ buf.seek(0)
570
+ # Write temp xyz and load as geom
571
+ tmp_xyz = seg_dir / f"_tmp_dmf_{len(imgs)}.xyz"
572
+ with open(tmp_xyz, "w") as f:
573
+ f.write(buf.getvalue())
574
+ g = gl(tmp_xyz, coord_type=gA.coord_type)
575
+ try:
576
+ g.freeze_atoms = np.array(getattr(gA, "freeze_atoms", []), dtype=int)
577
+ except Exception:
578
+ logger.debug("Failed to set freeze_atoms on interpolated image", exc_info=True)
579
+ g.set_calculator(shared_calc)
580
+ imgs.append(g)
581
+ tmp_xyz.unlink(missing_ok=True)
582
+
583
+ return GSMResult(images=imgs, energies=energies, hei_idx=hei_idx)
584
+
585
+
586
+ def _write_xyz_trj_with_energy_from_ase(images, energies, path: Path) -> None:
587
+ """Write an ASE Atoms list with energies as an XYZ trajectory."""
588
+ from ase.io import write as ase_write
589
+ from io import StringIO
590
+ blocks = []
591
+ for atoms, E in zip(images, energies):
592
+ buf = StringIO()
593
+ ase_write(buf, atoms, format="xyz")
594
+ s = buf.getvalue()
595
+ lines = s.splitlines()
596
+ if len(lines) >= 2 and lines[0].strip().isdigit():
597
+ lines[1] = f"{E:.12f}"
598
+ blocks.append("\n".join(lines) + "\n")
599
+ with open(path, "w") as f:
600
+ f.write("".join(blocks))
601
+
602
+
603
+ def _run_mep_between(
604
+ gA,
605
+ gB,
606
+ shared_calc,
607
+ gs_cfg: Dict[str, Any],
608
+ stopt_cfg: Dict[str, Any],
609
+ out_dir: Path,
610
+ tag: str,
611
+ ref_pdb_path: Optional[Path],
612
+ mep_mode_kind: str = "gsm",
613
+ calc_cfg: Optional[Dict[str, Any]] = None,
614
+ max_nodes: int = 10,
615
+ dmf_cfg: Optional[Dict[str, Any]] = None,
616
+ ) -> GSMResult:
617
+ """Dispatcher: run GSM or DMF between two geometries."""
618
+ if mep_mode_kind == "dmf":
619
+ return _run_dmf_between(
620
+ gA, gB, shared_calc,
621
+ calc_cfg=calc_cfg or {},
622
+ out_dir=out_dir, tag=tag,
623
+ ref_pdb_path=ref_pdb_path,
624
+ max_nodes=max_nodes,
625
+ dmf_cfg=dmf_cfg or dict(DMF_KW),
626
+ )
627
+ return _run_gsm_between(gA, gB, shared_calc, gs_cfg, stopt_cfg, out_dir, tag=tag, ref_pdb_path=ref_pdb_path)
628
+
629
+
630
+ def _optimize_single(
631
+ g,
632
+ shared_calc,
633
+ lbfgs_cfg: Dict[str, Any],
634
+ out_dir: Path,
635
+ tag: str,
636
+ ref_pdb_path: Optional[Path], # for PDB conversion
637
+ ):
638
+ """
639
+ Run single-structure optimization (LBFGS) and return the final Geometry.
640
+ """
641
+ g.set_calculator(shared_calc)
642
+
643
+ seg_dir = out_dir / f"{tag}_lbfgs_opt"
644
+ seg_dir.mkdir(parents=True, exist_ok=True)
645
+ args = dict(lbfgs_cfg)
646
+ args["out_dir"] = str(seg_dir)
647
+
648
+ opt = LBFGS(g, **args)
649
+
650
+ click.echo(f"\n=== [{tag}] Single-structure LBFGS started ===\n")
651
+ opt.run()
652
+ click.echo(f"\n=== [{tag}] Single-structure LBFGS finished ===\n")
653
+
654
+ try:
655
+ final_xyz = Path(opt.final_fn) if isinstance(opt.final_fn, (str, Path)) else Path(opt.final_fn)
656
+ _maybe_convert_to_pdb(final_xyz, ref_pdb_path)
657
+ g_final = geom_loader(final_xyz, coord_type=g.coord_type)
658
+ try:
659
+ g_final.freeze_atoms = np.array(getattr(g, "freeze_atoms", []), dtype=int)
660
+ except Exception:
661
+ logger.debug("Failed to set freeze_atoms on final geometry", exc_info=True)
662
+ g_final.set_calculator(shared_calc)
663
+ return g_final
664
+ except Exception:
665
+ return g
666
+
667
+
668
+ def _refine_between(
669
+ gL,
670
+ gR,
671
+ shared_calc,
672
+ gs_cfg: Dict[str, Any],
673
+ stopt_cfg: Dict[str, Any],
674
+ out_dir: Path,
675
+ tag: str,
676
+ ref_pdb_path: Optional[Path], # for PDB conversion
677
+ mep_mode_kind: str = "gsm",
678
+ calc_cfg: Optional[Dict[str, Any]] = None,
679
+ max_nodes: int = 10,
680
+ dmf_cfg: Optional[Dict[str, Any]] = None,
681
+ ) -> GSMResult:
682
+ """
683
+ Refine End1–End2 via GSM or DMF (force climb=True for GSM).
684
+ """
685
+ gs_refine_cfg = {**gs_cfg, "climb": True, "climb_lanczos": True}
686
+ return _run_mep_between(
687
+ gL, gR, shared_calc, gs_refine_cfg, stopt_cfg, out_dir, tag=f"{tag}_refine",
688
+ ref_pdb_path=ref_pdb_path, mep_mode_kind=mep_mode_kind,
689
+ calc_cfg=calc_cfg, max_nodes=max_nodes, dmf_cfg=dmf_cfg,
690
+ )
691
+
692
+
693
+ def _maybe_bridge_segments(
694
+ tail_g,
695
+ head_g,
696
+ shared_calc,
697
+ gs_cfg: Dict[str, Any], # bridge-specific GS config
698
+ stopt_cfg: Dict[str, Any],
699
+ out_dir: Path,
700
+ tag: str,
701
+ rmsd_thresh: float,
702
+ ref_pdb_path: Optional[Path], # for PDB conversion
703
+ mep_mode_kind: str = "gsm",
704
+ calc_cfg: Optional[Dict[str, Any]] = None,
705
+ max_nodes: int = 5,
706
+ dmf_cfg: Optional[Dict[str, Any]] = None,
707
+ ) -> Optional[GSMResult]:
708
+ """
709
+ Run a bridge GSM/DMF if two segment endpoints are farther than the threshold.
710
+ """
711
+ rmsd = _kabsch_rmsd(np.array(tail_g.coords3d), np.array(head_g.coords3d), align=False)
712
+ if rmsd <= rmsd_thresh:
713
+ return None
714
+ click.echo(f"[{tag}] Gap detected between segments (RMSD={rmsd:.4e} Å) — bridging via {mep_mode_kind.upper()}.")
715
+ return _run_mep_between(
716
+ tail_g, head_g, shared_calc, gs_cfg, stopt_cfg, out_dir, tag=f"{tag}_bridge",
717
+ ref_pdb_path=ref_pdb_path, mep_mode_kind=mep_mode_kind,
718
+ calc_cfg=calc_cfg, max_nodes=max_nodes, dmf_cfg=dmf_cfg,
719
+ )
720
+
721
+
722
+ def _stitch_paths(
723
+ parts: List[Tuple[List[Any], List[float]]],
724
+ stitch_rmsd_thresh: float,
725
+ bridge_rmsd_thresh: float,
726
+ shared_calc,
727
+ gs_cfg, # GS config for bridges (climb=False, max_nodes=search.max_nodes_bridge)
728
+ stopt_cfg,
729
+ out_dir: Path,
730
+ tag: str,
731
+ ref_pdb_path: Optional[Path], # for PDB conversion
732
+ bond_cfg: Optional[Dict[str, Any]] = None, # detect bond changes between adjacent parts
733
+ segment_builder: Optional[Callable[[Any, Any, str], "CombinedPath"]] = None, # builds a recursive segment
734
+ segments_out: Optional[List["SegmentReport"]] = None, # append inserted segment summaries in order
735
+ bridge_pair_index: Optional[int] = None, # pair index to tag bridge frames across pairs
736
+ mep_mode_kind: str = "gsm",
737
+ calc_cfg: Optional[Dict[str, Any]] = None,
738
+ dmf_cfg: Optional[Dict[str, Any]] = None,
739
+ ) -> Tuple[List[Any], List[float]]:
740
+ """
741
+ Concatenate path parts (images, energies). Insert bridge GSMs when needed.
742
+ If covalent changes are detected across an interface, build and insert a *new* recursive segment
743
+ using `segment_builder` instead of bridging. Update `segments_out` accordingly.
744
+ """
745
+ all_imgs: List[Any] = []
746
+ all_E: List[float] = []
747
+
748
+ def _last_known_seg_tag_from_images(imgs: List[Any]) -> Optional[str]:
749
+ for im in reversed(imgs):
750
+ t = getattr(im, "mep_seg_tag", None)
751
+ if t:
752
+ return t
753
+ return None
754
+
755
+ def _first_known_seg_tag_from_images(imgs: List[Any]) -> Optional[str]:
756
+ for im in imgs:
757
+ t = getattr(im, "mep_seg_tag", None)
758
+ if t:
759
+ return t
760
+ return None
761
+
762
+ def append_part(imgs: List[Any], Es: List[float]) -> None:
763
+ nonlocal all_imgs, all_E
764
+ if not imgs:
765
+ return
766
+ if not all_imgs:
767
+ all_imgs.extend(imgs)
768
+ all_E.extend(Es)
769
+ return
770
+ tail = all_imgs[-1]
771
+ head = imgs[0]
772
+
773
+ adj_changed, adj_summary = False, ""
774
+ if segment_builder is not None and bond_cfg is not None:
775
+ try:
776
+ adj_changed, adj_summary = _has_bond_change(tail, head, bond_cfg)
777
+ except Exception:
778
+ adj_changed, adj_summary = False, ""
779
+
780
+ if adj_changed and segment_builder is not None:
781
+ click.echo(f"[{tag}] Covalent changes detected at interface — inserting a new recursive segment.")
782
+ if adj_summary:
783
+ click.echo(textwrap.indent(adj_summary, prefix=" "))
784
+ sub = segment_builder(tail, head, f"{tag}_mid")
785
+ seg_imgs, seg_E = sub.images, sub.energies
786
+ if segments_out is not None and getattr(sub, "segments", None):
787
+ segments_out.extend(sub.segments)
788
+ if seg_imgs:
789
+ if _kabsch_rmsd(np.array(all_imgs[-1].coords3d), np.array(seg_imgs[0].coords3d), align=False) <= stitch_rmsd_thresh:
790
+ seg_imgs = seg_imgs[1:]
791
+ seg_E = seg_E[1:]
792
+ all_imgs.extend(seg_imgs)
793
+ all_E.extend(seg_E)
794
+ if _kabsch_rmsd(np.array(all_imgs[-1].coords3d), np.array(imgs[0].coords3d), align=False) <= stitch_rmsd_thresh:
795
+ imgs = imgs[1:]
796
+ Es = Es[1:]
797
+ all_imgs.extend(imgs)
798
+ all_E.extend(Es)
799
+ return
800
+
801
+ rmsd = _kabsch_rmsd(np.array(tail.coords3d), np.array(head.coords3d), align=False)
802
+ if rmsd <= stitch_rmsd_thresh:
803
+ all_imgs.extend(imgs[1:])
804
+ all_E.extend(Es[1:])
805
+ elif rmsd > bridge_rmsd_thresh:
806
+ left_tag_recent = _last_known_seg_tag_from_images(all_imgs) or "segL"
807
+ right_tag_upcoming = _first_known_seg_tag_from_images(imgs) or "segR"
808
+ left_base = _segment_base_id(left_tag_recent)
809
+ right_base = _segment_base_id(right_tag_upcoming)
810
+ bridge_name_base = f"{left_base}_{right_base}"
811
+
812
+ br = _maybe_bridge_segments(
813
+ tail, head, shared_calc, gs_cfg, stopt_cfg, out_dir, tag=bridge_name_base,
814
+ rmsd_thresh=bridge_rmsd_thresh, ref_pdb_path=ref_pdb_path,
815
+ mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
816
+ )
817
+ if br is not None:
818
+ _tag_images(br.images, mep_seg_tag=f"{bridge_name_base}_bridge", mep_seg_kind="bridge",
819
+ mep_has_bond_changes=False, pair_index=bridge_pair_index)
820
+ b_imgs, b_E = br.images, br.energies
821
+ if _kabsch_rmsd(np.array(all_imgs[-1].coords3d), np.array(b_imgs[0].coords3d), align=False) <= stitch_rmsd_thresh:
822
+ b_imgs = b_imgs[1:]
823
+ b_E = b_E[1:]
824
+ if b_imgs:
825
+ all_imgs.extend(b_imgs)
826
+ all_E.extend(b_E)
827
+
828
+ if segments_out is not None:
829
+ try:
830
+ barrier_kcal = (max(br.energies) - br.energies[0]) * AU2KCALPERMOL
831
+ delta_kcal = (br.energies[-1] - br.energies[0]) * AU2KCALPERMOL
832
+ except Exception:
833
+ barrier_kcal = float("nan")
834
+ delta_kcal = float("nan")
835
+ bridge_report = SegmentReport(
836
+ tag=f"{bridge_name_base}_bridge",
837
+ barrier_kcal=float(barrier_kcal),
838
+ delta_kcal=float(delta_kcal),
839
+ summary="",
840
+ kind="bridge"
841
+ )
842
+ insert_pos: Optional[int] = None
843
+ try:
844
+ for j, sr in enumerate(segments_out):
845
+ if sr.tag == right_tag_upcoming:
846
+ insert_pos = j
847
+ break
848
+ except Exception:
849
+ insert_pos = None
850
+ if insert_pos is None:
851
+ segments_out.append(bridge_report)
852
+ else:
853
+ segments_out.insert(insert_pos, bridge_report)
854
+
855
+ if _kabsch_rmsd(np.array(all_imgs[-1].coords3d), np.array(imgs[0].coords3d), align=False) <= stitch_rmsd_thresh:
856
+ imgs = imgs[1:]
857
+ Es = Es[1:]
858
+ all_imgs.extend(imgs)
859
+ all_E.extend(Es)
860
+ else:
861
+ all_imgs.extend(imgs)
862
+ all_E.extend(Es)
863
+
864
+ for (imgs, Es) in parts:
865
+ append_part(imgs, Es)
866
+
867
+ return all_imgs, all_E
868
+
869
+
870
+ # -----------------------------------------------
871
+ # Recursive search (core)
872
+ # -----------------------------------------------
873
+
874
+ @dataclass
875
+ class CombinedPath:
876
+ images: List[Any]
877
+ energies: List[float]
878
+ segments: List[SegmentReport] # segment summaries in final output order
879
+
880
+
881
+ def _trailing_kink_count(segments: Sequence[SegmentReport]) -> int:
882
+ """Return the number of consecutive kink segments at the end of *segments*."""
883
+ count = 0
884
+ for seg in reversed(segments):
885
+ if seg.tag and "kink" in seg.tag:
886
+ count += 1
887
+ else:
888
+ break
889
+ return count
890
+
891
+
892
+ def _build_multistep_path(
893
+ gA,
894
+ gB,
895
+ shared_calc,
896
+ geom_cfg: Dict[str, Any],
897
+ gs_cfg: Dict[str, Any],
898
+ stopt_cfg: Dict[str, Any],
899
+ single_opt_cfg: Dict[str, Any],
900
+ bond_cfg: Dict[str, Any],
901
+ search_cfg: Dict[str, Any],
902
+ refine_mode_kind: str,
903
+ out_dir: Path,
904
+ ref_pdb_path: Optional[Path],
905
+ depth: int,
906
+ seg_counter: List[int],
907
+ branch_tag: str,
908
+ pair_index: Optional[int] = None,
909
+ mep_mode_kind: str = "gsm",
910
+ calc_cfg: Optional[Dict[str, Any]] = None,
911
+ dmf_cfg: Optional[Dict[str, Any]] = None,
912
+ kink_seq_count: int = 0,
913
+ ) -> CombinedPath:
914
+ """
915
+ Recursively construct a multistep MEP from A–B and return it (A→B order).
916
+ """
917
+ seg_max_nodes = int(search_cfg.get("max_nodes_segment", gs_cfg.get("max_nodes", 10)))
918
+ gs_seg_cfg = {**gs_cfg, "max_nodes": seg_max_nodes}
919
+ max_seq_kink = int(search_cfg.get("max_seq_kink", 2))
920
+
921
+ if depth > int(search_cfg.get("max_depth", 10)):
922
+ click.echo(f"[{branch_tag}] Reached maximum recursion depth. Returning current endpoints only.")
923
+ gsm = _run_mep_between(
924
+ gA, gB, shared_calc, gs_seg_cfg, stopt_cfg, out_dir, tag=f"seg_{seg_counter[0]:03d}_maxdepth",
925
+ ref_pdb_path=ref_pdb_path, mep_mode_kind=mep_mode_kind,
926
+ calc_cfg=calc_cfg, max_nodes=seg_max_nodes, dmf_cfg=dmf_cfg,
927
+ )
928
+ seg_counter[0] += 1
929
+ _tag_images(gsm.images, pair_index=pair_index)
930
+ return CombinedPath(images=gsm.images, energies=gsm.energies, segments=[])
931
+
932
+ seg_id = seg_counter[0]
933
+ seg_counter[0] += 1
934
+ tag0 = f"seg_{seg_id:03d}"
935
+
936
+ gs_seg_cfg_first = {**gs_seg_cfg, "climb": True, "climb_lanczos": True}
937
+ gsm0 = _run_mep_between(
938
+ gA, gB, shared_calc, gs_seg_cfg_first, stopt_cfg, out_dir, tag=tag0,
939
+ ref_pdb_path=ref_pdb_path, mep_mode_kind=mep_mode_kind,
940
+ calc_cfg=calc_cfg, max_nodes=seg_max_nodes, dmf_cfg=dmf_cfg,
941
+ )
942
+
943
+ hei = int(gsm0.hei_idx)
944
+ if not (1 <= hei <= len(gsm0.images) - 2):
945
+ click.echo(f"[{tag0}] WARNING: HEI is at an endpoint (idx={hei}). Returning the raw GSM path.")
946
+ _tag_images(gsm0.images, pair_index=pair_index)
947
+ return CombinedPath(images=gsm0.images, energies=gsm0.energies, segments=[])
948
+
949
+ if refine_mode_kind == "minima":
950
+ left_idx = _find_nearest_local_minimum(hei_idx=hei, direction=-1, energies=gsm0.energies)
951
+ right_idx = _find_nearest_local_minimum(hei_idx=hei, direction=1, energies=gsm0.energies)
952
+ if left_idx is None:
953
+ left_idx = hei - 1
954
+ if right_idx is None:
955
+ right_idx = hei + 1
956
+ click.echo(f"[{tag0}] Using nearest local minima around HEI (left idx={left_idx}, right idx={right_idx}).")
957
+ left_img = gsm0.images[left_idx]
958
+ right_img = gsm0.images[right_idx]
959
+ else:
960
+ left_img = gsm0.images[hei - 1]
961
+ right_img = gsm0.images[hei + 1]
962
+ click.echo(f"[{tag0}] Refining HEI±1 (peak mode).")
963
+
964
+ left_end = _optimize_single(left_img, shared_calc, single_opt_cfg, out_dir, tag=f"{tag0}_left", ref_pdb_path=ref_pdb_path)
965
+ right_end = _optimize_single(right_img, shared_calc, single_opt_cfg, out_dir, tag=f"{tag0}_right", ref_pdb_path=ref_pdb_path)
966
+
967
+ try:
968
+ lr_changed, lr_summary = _has_bond_change(left_end, right_end, bond_cfg)
969
+ except Exception as e:
970
+ click.echo(f"[{tag0}] WARNING: Failed to evaluate bond changes for kink detection: {e}", err=True)
971
+ lr_changed, lr_summary = True, ""
972
+ use_kink = (not lr_changed)
973
+
974
+ if use_kink:
975
+ n_inter = int(search_cfg.get("kink_max_nodes", 3))
976
+ click.echo(f"[{tag0}] Kink detected (no covalent changes between End1 and End2). "
977
+ f"Using {n_inter} linear interpolation nodes + single-structure optimizations instead of GSM.")
978
+ inter_geoms = _make_linear_interpolations(left_end, right_end, n_inter)
979
+ opt_inters: List[Any] = []
980
+ for i, g_int in enumerate(inter_geoms, 1):
981
+ g_int.set_calculator(shared_calc)
982
+ g_opt = _optimize_single(g_int, shared_calc, single_opt_cfg, out_dir, tag=f"{tag0}_kink_int{i}", ref_pdb_path=ref_pdb_path)
983
+ opt_inters.append(g_opt)
984
+ step_imgs = [left_end] + opt_inters + [right_end]
985
+ step_E = [float(img.energy) for img in step_imgs]
986
+ ref1 = GSMResult(images=step_imgs, energies=step_E, hei_idx=int(np.argmax(step_E)))
987
+ step_tag_for_report = f"{tag0}_kink"
988
+ else:
989
+ click.echo(f"[{tag0}] Kink not detected (covalent changes present between End1 and End2).")
990
+ if lr_summary:
991
+ click.echo(textwrap.indent(lr_summary, prefix=" "))
992
+ ref1 = _refine_between(left_end, right_end, shared_calc, gs_seg_cfg, stopt_cfg, out_dir, tag=tag0,
993
+ ref_pdb_path=ref_pdb_path, mep_mode_kind=mep_mode_kind,
994
+ calc_cfg=calc_cfg, max_nodes=seg_max_nodes, dmf_cfg=dmf_cfg)
995
+ step_tag_for_report = f"{tag0}_refine"
996
+
997
+ step_imgs, step_E = ref1.images, ref1.energies
998
+
999
+ _changed, step_summary = _has_bond_change(step_imgs[0], step_imgs[-1], bond_cfg)
1000
+ _tag_images(step_imgs, mep_seg_tag=step_tag_for_report, mep_seg_kind="seg",
1001
+ mep_has_bond_changes=bool(_changed), pair_index=pair_index)
1002
+
1003
+ left_changed, left_summary = _has_bond_change(gA, left_end, bond_cfg)
1004
+ right_changed, right_summary = _has_bond_change(right_end, gB, bond_cfg)
1005
+
1006
+ click.echo(f"[{tag0}] Covalent changes (A vs left_end): {'Yes' if left_changed else 'No'}")
1007
+ if left_changed:
1008
+ click.echo(textwrap.indent(left_summary, prefix=" "))
1009
+ click.echo(f"[{tag0}] Covalent changes (right_end vs B): {'Yes' if right_changed else 'No'}")
1010
+ if right_changed:
1011
+ click.echo(textwrap.indent(right_summary, prefix=" "))
1012
+
1013
+ try:
1014
+ barrier_kcal = (max(step_E) - step_E[0]) * AU2KCALPERMOL
1015
+ delta_kcal = (step_E[-1] - step_E[0]) * AU2KCALPERMOL
1016
+ except Exception:
1017
+ barrier_kcal = float("nan")
1018
+ delta_kcal = float("nan")
1019
+
1020
+ seg_report = SegmentReport(
1021
+ tag=step_tag_for_report,
1022
+ barrier_kcal=float(barrier_kcal),
1023
+ delta_kcal=float(delta_kcal),
1024
+ summary=step_summary if _changed else "(no covalent changes detected)",
1025
+ kind="seg"
1026
+ )
1027
+
1028
+ parts: List[Tuple[List[Any], List[float]]] = []
1029
+ seg_reports: List[SegmentReport] = []
1030
+
1031
+ trailing_kink_run = kink_seq_count
1032
+ if left_changed:
1033
+ subL = _build_multistep_path(
1034
+ gA, left_end, shared_calc, geom_cfg, gs_cfg, stopt_cfg,
1035
+ single_opt_cfg, bond_cfg, search_cfg, refine_mode_kind,
1036
+ out_dir, ref_pdb_path, depth + 1, seg_counter, branch_tag=f"{branch_tag}L",
1037
+ pair_index=pair_index,
1038
+ mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
1039
+ kink_seq_count=kink_seq_count,
1040
+ )
1041
+ _tag_images(subL.images, pair_index=pair_index)
1042
+ parts.append((subL.images, subL.energies))
1043
+ seg_reports.extend(subL.segments)
1044
+ trailing_kink_run = _trailing_kink_count(seg_reports)
1045
+
1046
+ current_kink_run = trailing_kink_run + 1 if use_kink else 0
1047
+ if use_kink and current_kink_run >= max_seq_kink:
1048
+ warning_msg = (
1049
+ f"[{tag0}] Consecutive kink segments were detected. Something seems wrong. "
1050
+ "Please check the initial structure and the generated intermediate structures. "
1051
+ "Alternatively, try switching the mep-mode. If that still fails, try including intermediate structures in the inputs."
1052
+ )
1053
+ click.echo(warning_msg)
1054
+ gsm = _run_mep_between(
1055
+ gA, gB, shared_calc, gs_seg_cfg, stopt_cfg, out_dir, tag=f"seg_{seg_counter[0]:03d}_maxdepth",
1056
+ ref_pdb_path=ref_pdb_path, mep_mode_kind=mep_mode_kind,
1057
+ calc_cfg=calc_cfg, max_nodes=seg_max_nodes, dmf_cfg=dmf_cfg,
1058
+ )
1059
+ seg_counter[0] += 1
1060
+ _tag_images(gsm.images, pair_index=pair_index)
1061
+ return CombinedPath(images=gsm.images, energies=gsm.energies, segments=[])
1062
+
1063
+ parts.append((step_imgs, step_E))
1064
+ seg_reports.append(seg_report)
1065
+
1066
+ if right_changed:
1067
+ subR = _build_multistep_path(
1068
+ right_end, gB, shared_calc, geom_cfg, gs_cfg, stopt_cfg,
1069
+ single_opt_cfg, bond_cfg, search_cfg, refine_mode_kind,
1070
+ out_dir, ref_pdb_path, depth + 1, seg_counter, branch_tag=f"{branch_tag}R",
1071
+ pair_index=pair_index,
1072
+ mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
1073
+ kink_seq_count=current_kink_run,
1074
+ )
1075
+ _tag_images(subR.images, pair_index=pair_index)
1076
+ parts.append((subR.images, subR.energies))
1077
+ seg_reports.extend(subR.segments)
1078
+
1079
+ bridge_max_nodes = int(search_cfg.get("max_nodes_bridge", 5))
1080
+ gs_bridge_cfg = {**gs_cfg, "max_nodes": bridge_max_nodes, "climb": False, "climb_lanczos": False}
1081
+
1082
+ def _segment_builder(tail_g, head_g, _tag: str) -> CombinedPath:
1083
+ sub = _build_multistep_path(
1084
+ tail_g, head_g,
1085
+ shared_calc,
1086
+ geom_cfg, gs_cfg, stopt_cfg,
1087
+ single_opt_cfg,
1088
+ bond_cfg, search_cfg, refine_mode_kind,
1089
+ out_dir=out_dir,
1090
+ ref_pdb_path=ref_pdb_path,
1091
+ depth=depth + 1,
1092
+ seg_counter=seg_counter,
1093
+ branch_tag=f"{branch_tag}B",
1094
+ pair_index=pair_index,
1095
+ mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
1096
+ kink_seq_count=_trailing_kink_count(seg_reports),
1097
+ )
1098
+ _tag_images(sub.images, pair_index=pair_index)
1099
+ return sub
1100
+
1101
+ stitched_imgs, stitched_E = _stitch_paths(
1102
+ parts,
1103
+ stitch_rmsd_thresh=float(search_cfg["stitch_rmsd_thresh"]),
1104
+ bridge_rmsd_thresh=float(search_cfg["bridge_rmsd_thresh"]),
1105
+ shared_calc=shared_calc,
1106
+ gs_cfg=gs_bridge_cfg,
1107
+ stopt_cfg=stopt_cfg,
1108
+ out_dir=out_dir,
1109
+ tag=tag0,
1110
+ ref_pdb_path=ref_pdb_path,
1111
+ bond_cfg=bond_cfg,
1112
+ segment_builder=_segment_builder,
1113
+ segments_out=seg_reports,
1114
+ bridge_pair_index=pair_index,
1115
+ mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
1116
+ )
1117
+
1118
+ _tag_images(stitched_imgs, pair_index=pair_index)
1119
+
1120
+ return CombinedPath(images=stitched_imgs, energies=stitched_E, segments=seg_reports)
1121
+
1122
+
1123
+ # -----------------------------------------------
1124
+ # CLI
1125
+ # -----------------------------------------------
1126
+
1127
+ @click.command(
1128
+ help="Multistep MEP search via recursive GSM segmentation.",
1129
+ context_settings={
1130
+ "help_option_names": ["-h", "--help"],
1131
+ "ignore_unknown_options": True,
1132
+ "allow_extra_args": True,
1133
+ },
1134
+ )
1135
+ @click.option(
1136
+ "-i", "--input",
1137
+ "input_paths",
1138
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
1139
+ multiple=True, # allow: -i A -i B -i C or -i A B C
1140
+ required=True,
1141
+ help=("Two or more structures in reaction order. "
1142
+ "Either repeat '-i' (e.g., '-i A -i B -i C') or use a single '-i' "
1143
+ "followed by multiple space-separated paths (e.g., '-i A B C').")
1144
+ )
1145
+ @click.option(
1146
+ "--parm",
1147
+ "real_parm7",
1148
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
1149
+ required=True,
1150
+ help="Amber parm7 topology covering the full enzyme complex.",
1151
+ )
1152
+ @click.option(
1153
+ "--model-pdb",
1154
+ "model_pdb",
1155
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
1156
+ required=False,
1157
+ help="PDB describing atoms that belong to the ML (high-level) region. "
1158
+ "Optional when --detect-layer is enabled.",
1159
+ )
1160
+ @click.option(
1161
+ "--model-indices",
1162
+ "model_indices_str",
1163
+ type=str,
1164
+ default=None,
1165
+ show_default=False,
1166
+ help="Comma-separated atom indices for the ML region (ranges allowed like 1-5). "
1167
+ "Used when --model-pdb is omitted.",
1168
+ )
1169
+ @click.option(
1170
+ "--model-indices-one-based/--model-indices-zero-based",
1171
+ "model_indices_one_based",
1172
+ default=True,
1173
+ show_default=True,
1174
+ help="Interpret --model-indices as 1-based (default) or 0-based.",
1175
+ )
1176
+ @click.option(
1177
+ "--detect-layer/--no-detect-layer",
1178
+ "detect_layer",
1179
+ default=True,
1180
+ show_default=True,
1181
+ help="Detect ML/MM layers from input PDB B-factors (B=0/10/20). "
1182
+ "If disabled, you must provide --model-pdb or --model-indices.",
1183
+ )
1184
+ @click.option(
1185
+ "-q",
1186
+ "--charge",
1187
+ type=int,
1188
+ required=False,
1189
+ help="Total system charge. Required unless --ligand-charge is provided.",
1190
+ )
1191
+ @click.option("-l", "--ligand-charge", type=str, default=None, show_default=False,
1192
+ help="Total charge or per-resname mapping (e.g., GPP:-3,SAM:1) used to derive "
1193
+ "charge when -q is omitted (requires PDB input or --ref-pdb).")
1194
+ @click.option(
1195
+ "-m",
1196
+ "--multiplicity",
1197
+ "spin",
1198
+ type=int,
1199
+ default=None,
1200
+ show_default=False,
1201
+ help="Spin multiplicity (2S+1). Defaults to 1 when omitted.",
1202
+ )
1203
+ @click.option(
1204
+ "--mep-mode",
1205
+ "mep_mode",
1206
+ type=click.Choice(["gsm", "dmf"], case_sensitive=False),
1207
+ default="gsm",
1208
+ show_default=True,
1209
+ help="MEP method: gsm (GrowingString) or dmf (Direct Max Flux).",
1210
+ )
1211
+ @click.option(
1212
+ "--refine-mode",
1213
+ type=click.Choice(["peak", "minima"], case_sensitive=False),
1214
+ default=None,
1215
+ show_default=True,
1216
+ help=(
1217
+ "Refinement seed around the highest-energy image: "
1218
+ "'peak' uses HEI±1, 'minima' uses nearest local minima. "
1219
+ "Defaults to peak for gsm and minima for dmf."
1220
+ ),
1221
+ )
1222
+ @click.option(
1223
+ "--freeze-atoms",
1224
+ "freeze_atoms_text",
1225
+ type=str,
1226
+ default=None,
1227
+ show_default=False,
1228
+ help="Comma-separated 1-based atom indices to freeze (e.g., '1,3,5').",
1229
+ )
1230
+ @click.option(
1231
+ "--hess-cutoff",
1232
+ "hess_cutoff",
1233
+ type=float,
1234
+ default=None,
1235
+ show_default=False,
1236
+ help="Distance cutoff (Å) from ML region for MM atoms to include in Hessian calculation. "
1237
+ "Applied to movable MM atoms and can be combined with --detect-layer.",
1238
+ )
1239
+ @click.option(
1240
+ "--movable-cutoff",
1241
+ "movable_cutoff",
1242
+ type=float,
1243
+ default=None,
1244
+ show_default=False,
1245
+ help="Distance cutoff (Å) from ML region for movable MM atoms. MM atoms beyond this are frozen. "
1246
+ "Providing --movable-cutoff disables --detect-layer.",
1247
+ )
1248
+ @click.option("--max-nodes", type=int, default=10, show_default=True,
1249
+ help=("Number of internal nodes (string has max_nodes+2 images including endpoints). "
1250
+ "Used for *segment* GSM unless overridden by YAML search.max_nodes_segment."))
1251
+ @click.option("--max-cycles", type=int, default=300, show_default=True, help="Maximum GSM optimization cycles.")
1252
+ @click.option(
1253
+ "--climb/--no-climb",
1254
+ default=True,
1255
+ show_default=True,
1256
+ help="Enable transition-state search after path growth.",
1257
+ )
1258
+ @click.option(
1259
+ "--dump/--no-dump",
1260
+ default=False,
1261
+ show_default=True,
1262
+ help="Dump GSM/single-optimization trajectories during the run.",
1263
+ )
1264
+ @click.option(
1265
+ "--opt-mode",
1266
+ "opt_mode",
1267
+ type=click.Choice(["grad", "hess"], case_sensitive=False),
1268
+ default="grad",
1269
+ show_default=True,
1270
+ help="Single-structure optimizer: grad (=LBFGS) or hess (=RFO).",
1271
+ )
1272
+ @click.option("-o", "--out-dir", "out_dir", type=str, default="./result_path_search/", show_default=True, help="Output directory.")
1273
+ @click.option(
1274
+ "--thresh",
1275
+ type=click.Choice(["gau_loose", "gau", "gau_tight", "gau_vtight", "baker", "never"], case_sensitive=False),
1276
+ default=None,
1277
+ help="Convergence preset for GSM/StringOptimizer and single LBFGS runs.",
1278
+ )
1279
+ @click.option(
1280
+ "--config",
1281
+ "config_yaml",
1282
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
1283
+ default=None,
1284
+ help="Base YAML configuration file applied before explicit CLI options.",
1285
+ )
1286
+ @click.option(
1287
+ "--show-config/--no-show-config",
1288
+ "show_config",
1289
+ default=False,
1290
+ show_default=True,
1291
+ help="Print resolved configuration and continue execution.",
1292
+ )
1293
+ @click.option(
1294
+ "--dry-run/--no-dry-run",
1295
+ "dry_run",
1296
+ default=False,
1297
+ show_default=True,
1298
+ help="Validate options and print the execution plan without running path search.",
1299
+ )
1300
+ @click.option(
1301
+ "--preopt/--no-preopt",
1302
+ "pre_opt",
1303
+ default=True,
1304
+ show_default=True,
1305
+ help="If False, skip initial single-structure optimizations of inputs."
1306
+ )
1307
+ # Input alignment switch (default True)
1308
+ @click.option(
1309
+ "--align/--no-align",
1310
+ "align",
1311
+ default=True,
1312
+ show_default=True,
1313
+ help=("After pre-optimization, align all inputs to the *first* input and match freeze_atoms "
1314
+ "using the align_freeze_atoms API.")
1315
+ )
1316
+ # Full template PDBs for XYZ→PDB conversion and topology reference
1317
+ @click.option(
1318
+ "--ref-pdb",
1319
+ "ref_pdb_paths",
1320
+ type=click.Path(path_type=Path, exists=True, dir_okay=False),
1321
+ multiple=True,
1322
+ default=None,
1323
+ help=("Full-size template PDBs in the same reaction order as --input. "
1324
+ "Required when using XYZ inputs to provide topology and B-factor information.")
1325
+ )
1326
+ @click.option(
1327
+ "--convert-files/--no-convert-files",
1328
+ "convert_files",
1329
+ default=True,
1330
+ show_default=True,
1331
+ help="Convert XYZ/TRJ outputs into PDB companions based on the input format.",
1332
+ )
1333
+ @click.option(
1334
+ "-b", "--backend",
1335
+ type=click.Choice(["uma", "orb", "mace", "aimnet2"], case_sensitive=False),
1336
+ default=None,
1337
+ show_default=False,
1338
+ help="ML backend for the ONIOM high-level region (default: uma).",
1339
+ )
1340
+ @click.option(
1341
+ "--embedcharge/--no-embedcharge",
1342
+ "embedcharge",
1343
+ default=False,
1344
+ show_default=True,
1345
+ help="Enable xTB point-charge embedding correction for MM→ML environmental effects.",
1346
+ )
1347
+ @click.option(
1348
+ "--embedcharge-cutoff",
1349
+ "embedcharge_cutoff",
1350
+ type=float,
1351
+ default=None,
1352
+ show_default=False,
1353
+ help="Distance cutoff (Å) from ML region for MM point charges in xTB embedding. "
1354
+ "Default: 12.0 Å when --embedcharge is enabled.",
1355
+ )
1356
+ @click.pass_context
1357
+ def cli(
1358
+ ctx: click.Context,
1359
+ input_paths: Sequence[Path],
1360
+ real_parm7: Path,
1361
+ model_pdb: Optional[Path],
1362
+ model_indices_str: Optional[str],
1363
+ model_indices_one_based: bool,
1364
+ detect_layer: bool,
1365
+ charge: Optional[int],
1366
+ ligand_charge: Optional[str],
1367
+ spin: Optional[int],
1368
+ mep_mode: str,
1369
+ refine_mode: Optional[str],
1370
+ freeze_atoms_text: Optional[str],
1371
+ hess_cutoff: Optional[float],
1372
+ movable_cutoff: Optional[float],
1373
+ max_nodes: int,
1374
+ max_cycles: int,
1375
+ climb: bool,
1376
+ dump: bool,
1377
+ opt_mode: str,
1378
+ out_dir: str,
1379
+ thresh: Optional[str],
1380
+ config_yaml: Optional[Path],
1381
+ show_config: bool,
1382
+ dry_run: bool,
1383
+ pre_opt: bool,
1384
+ align: bool,
1385
+ ref_pdb_paths: Optional[Sequence[Path]],
1386
+ convert_files: bool,
1387
+ backend: Optional[str],
1388
+ embedcharge: bool,
1389
+ embedcharge_cutoff: Optional[float],
1390
+ ) -> None:
1391
+ set_convert_file_enabled(convert_files)
1392
+ prepared_inputs: List[PreparedInputStructure] = []
1393
+ # --- Robustly accept both styles for -i/--input and --ref-pdb ---
1394
+ argv_all = sys.argv[1:] # drop program name
1395
+ i_vals = collect_single_option_values(argv_all, ("-i", "--input"), label="-i/--input")
1396
+ if i_vals:
1397
+ i_parsed = validate_existing_files(
1398
+ i_vals,
1399
+ option_name="-i/--input",
1400
+ hint="When using '-i', list only existing file paths (multiple paths may follow a single '-i').",
1401
+ )
1402
+ input_paths = tuple(i_parsed)
1403
+
1404
+ ref_vals = collect_single_option_values(argv_all, ("--ref-pdb",), label="--ref-pdb")
1405
+ if ref_vals:
1406
+ ref_parsed = validate_existing_files(
1407
+ ref_vals,
1408
+ option_name="--ref-pdb",
1409
+ hint="When using '--ref-pdb', multiple files may follow a single option.",
1410
+ )
1411
+ ref_pdb_paths = tuple(ref_parsed)
1412
+ # --- end of robust parsing fix ---
1413
+
1414
+ _is_param_explicit = make_is_param_explicit(ctx)
1415
+
1416
+ config_yaml, override_yaml, used_legacy_yaml = resolve_yaml_sources(
1417
+ config_yaml=config_yaml,
1418
+ override_yaml=None,
1419
+ args_yaml_legacy=None,
1420
+ )
1421
+ merged_yaml_cfg, _, _ = load_merged_yaml_cfg(
1422
+ config_yaml=config_yaml,
1423
+ override_yaml=None,
1424
+ )
1425
+
1426
+ time_start = time.perf_counter() # start timing
1427
+ command_str = "mlmm path-search " + " ".join(sys.argv[1:])
1428
+ try:
1429
+ # --------------------------
1430
+ # 0) Input validation (multi-structure)
1431
+ # --------------------------
1432
+ if len(input_paths) < 2:
1433
+ raise click.BadParameter("Provide at least two structures for --input in reaction order (reactant [intermediates ...] product).")
1434
+
1435
+ p_list = [Path(p) for p in input_paths]
1436
+ ref_list = list(ref_pdb_paths) if ref_pdb_paths else []
1437
+ prepared_inputs = []
1438
+ for i, p in enumerate(p_list):
1439
+ pi = prepare_input_structure(p)
1440
+ if p.suffix.lower() == ".xyz":
1441
+ if i < len(ref_list):
1442
+ apply_ref_pdb_override(pi, ref_list[i])
1443
+ else:
1444
+ raise click.BadParameter(
1445
+ f"XYZ input '{p}' requires a corresponding --ref-pdb for topology/B-factor info."
1446
+ )
1447
+ elif p.suffix.lower() != ".pdb":
1448
+ raise click.BadParameter(
1449
+ f"'{p}': unsupported format. Use .pdb or .xyz (with --ref-pdb)."
1450
+ )
1451
+ prepared_inputs.append(pi)
1452
+ # --------------------------
1453
+ # 1) Resolve settings (defaults < config < CLI(explicit) < override)
1454
+ # --------------------------
1455
+ config_layer_cfg = load_yaml_dict(config_yaml)
1456
+ override_layer_cfg = load_yaml_dict(override_yaml)
1457
+
1458
+ mep_mode_kind = mep_mode.lower().strip()
1459
+ refine_mode_kind = refine_mode.strip().lower() if refine_mode else None
1460
+
1461
+ geom_cfg = dict(GEOM_KW)
1462
+ calc_cfg = dict(CALC_KW)
1463
+ gs_cfg = dict(GS_KW)
1464
+ stopt_cfg = dict(STOPT_KW)
1465
+ lbfgs_cfg = dict(LBFGS_KW)
1466
+ bond_cfg = dict(BOND_KW)
1467
+ search_cfg = dict(SEARCH_KW)
1468
+ dmf_cfg = dict(DMF_KW)
1469
+
1470
+ apply_yaml_overrides(
1471
+ config_layer_cfg,
1472
+ [
1473
+ (geom_cfg, (("geom",),)),
1474
+ (calc_cfg, (("calc",), ("mlmm",))),
1475
+ (gs_cfg, (("gs",),)),
1476
+ (stopt_cfg, (("stopt",), ("opt",))),
1477
+ (lbfgs_cfg, (("stopt", "lbfgs"), ("lbfgs",))),
1478
+ (bond_cfg, (("bond",),)),
1479
+ (search_cfg, (("search",),)),
1480
+ (dmf_cfg, (("dmf",),)),
1481
+ ],
1482
+ )
1483
+
1484
+ # CLI explicit overrides (after config YAML, before override YAML)
1485
+ if backend is not None:
1486
+ calc_cfg["backend"] = str(backend).lower()
1487
+ if _is_param_explicit("embedcharge"):
1488
+ calc_cfg["embedcharge"] = bool(embedcharge)
1489
+ if _is_param_explicit("embedcharge_cutoff"):
1490
+ calc_cfg["embedcharge_cutoff"] = embedcharge_cutoff
1491
+
1492
+ try:
1493
+ geom_freeze = _normalize_geom_freeze(geom_cfg.get("freeze_atoms"))
1494
+ except click.BadParameter as e:
1495
+ click.echo(f"ERROR: {e}", err=True)
1496
+ sys.exit(1)
1497
+ geom_cfg["freeze_atoms"] = geom_freeze
1498
+
1499
+ try:
1500
+ cli_freeze = _parse_freeze_atoms(freeze_atoms_text)
1501
+ except click.BadParameter as e:
1502
+ click.echo(f"ERROR: {e}", err=True)
1503
+ sys.exit(1)
1504
+
1505
+ model_indices: Optional[List[int]] = None
1506
+ if model_indices_str:
1507
+ try:
1508
+ model_indices = parse_indices_string(model_indices_str, one_based=model_indices_one_based)
1509
+ except click.BadParameter as e:
1510
+ click.echo(f"ERROR: {e}", err=True)
1511
+ sys.exit(1)
1512
+ if cli_freeze:
1513
+ merge_freeze_atom_indices(geom_cfg, cli_freeze)
1514
+
1515
+ freeze_atoms_final = list(geom_cfg.get("freeze_atoms") or [])
1516
+ calc_cfg["freeze_atoms"] = freeze_atoms_final
1517
+
1518
+ resolved_charge = charge
1519
+ resolved_spin = spin
1520
+ for prepared in prepared_inputs:
1521
+ resolved_charge, resolved_spin = resolve_charge_spin_or_raise(
1522
+ prepared,
1523
+ resolved_charge,
1524
+ resolved_spin,
1525
+ ligand_charge=ligand_charge,
1526
+ prefix="[path-search]",
1527
+ )
1528
+ charge_value = calc_cfg.get("model_charge", resolved_charge)
1529
+ if charge_value is None:
1530
+ charge_value = resolved_charge
1531
+ calc_cfg["model_charge"] = int(charge_value)
1532
+ if _is_param_explicit("charge"):
1533
+ calc_cfg["model_charge"] = int(resolved_charge)
1534
+
1535
+ spin_value = calc_cfg.get("model_mult", resolved_spin)
1536
+ if spin_value is None:
1537
+ spin_value = resolved_spin
1538
+ calc_cfg["model_mult"] = int(spin_value)
1539
+ if _is_param_explicit("spin"):
1540
+ calc_cfg["model_mult"] = int(resolved_spin)
1541
+
1542
+ first_input = p_list[0]
1543
+ # input_pdb must be a PDB (parmed requirement); use --ref-pdb when input is XYZ
1544
+ if first_input.suffix.lower() != ".pdb" and ref_list:
1545
+ calc_cfg["input_pdb"] = str(Path(ref_list[0]).resolve())
1546
+ else:
1547
+ calc_cfg["input_pdb"] = str(first_input)
1548
+ calc_cfg["real_parm7"] = str(real_parm7)
1549
+
1550
+ detect_layer_effective = bool(calc_cfg.get("use_bfactor_layers", detect_layer))
1551
+ if _is_param_explicit("detect_layer"):
1552
+ detect_layer_effective = bool(detect_layer)
1553
+
1554
+ if _is_param_explicit("max_nodes"):
1555
+ gs_cfg["max_nodes"] = int(max_nodes)
1556
+ search_cfg["max_nodes_segment"] = int(max_nodes)
1557
+ if _is_param_explicit("max_cycles"):
1558
+ stopt_cfg["max_cycles"] = int(max_cycles)
1559
+ stopt_cfg["stop_in_when_full"] = int(max_cycles)
1560
+ dmf_cfg["max_cycles"] = int(max_cycles)
1561
+ if _is_param_explicit("climb"):
1562
+ gs_cfg["climb"] = bool(climb)
1563
+ gs_cfg["climb_lanczos"] = bool(climb)
1564
+ if _is_param_explicit("dump"):
1565
+ stopt_cfg["dump"] = bool(dump)
1566
+ lbfgs_cfg["dump"] = bool(dump)
1567
+ if _is_param_explicit("out_dir"):
1568
+ stopt_cfg["out_dir"] = out_dir
1569
+ lbfgs_cfg["out_dir"] = out_dir
1570
+ if _is_param_explicit("thresh") and thresh is not None:
1571
+ stopt_cfg["thresh"] = str(thresh)
1572
+ lbfgs_cfg["thresh"] = str(thresh)
1573
+ if _is_param_explicit("hess_cutoff") and hess_cutoff is not None:
1574
+ calc_cfg["hess_cutoff"] = float(hess_cutoff)
1575
+ if _is_param_explicit("movable_cutoff") and movable_cutoff is not None:
1576
+ calc_cfg["movable_cutoff"] = float(movable_cutoff)
1577
+ detect_layer_effective = False
1578
+ if _is_param_explicit("refine_mode"):
1579
+ search_cfg["refine_mode"] = refine_mode_kind
1580
+
1581
+ apply_yaml_overrides(
1582
+ override_layer_cfg,
1583
+ [
1584
+ (geom_cfg, (("geom",),)),
1585
+ (calc_cfg, (("calc",), ("mlmm",))),
1586
+ (gs_cfg, (("gs",),)),
1587
+ (stopt_cfg, (("stopt",), ("opt",))),
1588
+ (lbfgs_cfg, (("stopt", "lbfgs"), ("lbfgs",))),
1589
+ (bond_cfg, (("bond",),)),
1590
+ (search_cfg, (("search",),)),
1591
+ (dmf_cfg, (("dmf",),)),
1592
+ ],
1593
+ )
1594
+
1595
+ refine_mode_kind = search_cfg.get("refine_mode")
1596
+ if refine_mode_kind is None:
1597
+ refine_mode_kind = "peak" if mep_mode_kind == "gsm" else "minima"
1598
+ else:
1599
+ refine_mode_kind = str(refine_mode_kind).strip().lower()
1600
+ if refine_mode_kind not in {"peak", "minima"}:
1601
+ raise click.BadParameter(f"Unknown --refine-mode '{refine_mode_kind}'.")
1602
+ search_cfg["refine_mode"] = refine_mode_kind
1603
+
1604
+ out_dir_path = Path(stopt_cfg.get("out_dir", out_dir)).resolve()
1605
+ detect_layer_effective = bool(calc_cfg.get("use_bfactor_layers", detect_layer_effective))
1606
+
1607
+ model_pdb_effective: Optional[Path] = None
1608
+ if _is_param_explicit("model_pdb") and model_pdb is not None:
1609
+ model_pdb_effective = Path(model_pdb)
1610
+ else:
1611
+ model_pdb_cfg = calc_cfg.get("model_pdb")
1612
+ if isinstance(model_pdb_cfg, (str, Path)) and str(model_pdb_cfg).strip():
1613
+ model_pdb_effective = Path(model_pdb_cfg)
1614
+
1615
+ hess_cutoff_effective = calc_cfg.get("hess_cutoff")
1616
+ movable_cutoff_effective = calc_cfg.get("movable_cutoff")
1617
+ if movable_cutoff_effective is not None:
1618
+ if detect_layer_effective:
1619
+ click.echo("[layer] movable_cutoff is set; disabling detect-layer.", err=True)
1620
+ detect_layer_effective = False
1621
+
1622
+ # For layer detection, prefer --ref-pdb (which carries B-factor layers)
1623
+ # over the first input (which may be XYZ).
1624
+ if ref_list and ref_list[0]:
1625
+ layer_source_pdb = Path(ref_list[0]).resolve()
1626
+ else:
1627
+ layer_source_pdb = first_input
1628
+ if detect_layer_effective and layer_source_pdb.suffix.lower() != ".pdb":
1629
+ click.echo("ERROR: --detect-layer requires a PDB input (or --ref-pdb).", err=True)
1630
+ sys.exit(1)
1631
+
1632
+ if dry_run:
1633
+ layer_info_preview: Optional[Dict[str, List[int]]] = None
1634
+ model_region_source = "bfactor"
1635
+
1636
+ if detect_layer_effective:
1637
+ try:
1638
+ bfactors = read_bfactors_from_pdb(layer_source_pdb)
1639
+ if not bfactors:
1640
+ raise ValueError(f"No ATOM/HETATM records found in {layer_source_pdb}.")
1641
+ if not has_valid_layer_bfactors(bfactors):
1642
+ raise ValueError(
1643
+ "Invalid or missing layer B-factors (expected ~0/10/20). "
1644
+ "Provide --no-detect-layer with --model-pdb/--model-indices."
1645
+ )
1646
+ layer_info_preview = parse_layer_indices_from_bfactors(bfactors)
1647
+ if not layer_info_preview.get("ml_indices"):
1648
+ raise ValueError("No ML atoms detected from B-factors (value ~0).")
1649
+ except Exception as e:
1650
+ if model_pdb_effective is None and not model_indices:
1651
+ click.echo(f"ERROR: {e}", err=True)
1652
+ sys.exit(1)
1653
+ click.echo(f"[layer] WARNING: {e} Falling back to explicit ML region.", err=True)
1654
+ detect_layer_effective = False
1655
+
1656
+ if not detect_layer_effective:
1657
+ if model_pdb_effective is not None:
1658
+ model_region_source = "model_pdb"
1659
+ elif model_indices:
1660
+ model_region_source = "model_indices"
1661
+ if layer_source_pdb.suffix.lower() != ".pdb":
1662
+ click.echo("ERROR: --model-indices requires a PDB input.", err=True)
1663
+ sys.exit(1)
1664
+ n_atoms = 0
1665
+ with layer_source_pdb.open("r", encoding="utf-8", errors="ignore") as fh:
1666
+ for line in fh:
1667
+ if line.startswith(("ATOM ", "HETATM")):
1668
+ n_atoms += 1
1669
+ bad_idx = [i for i in model_indices if i < 0 or i >= n_atoms]
1670
+ if bad_idx:
1671
+ click.echo(
1672
+ f"ERROR: model index out of range: {bad_idx[0]} (valid: 0 <= idx < {n_atoms})",
1673
+ err=True,
1674
+ )
1675
+ sys.exit(1)
1676
+ else:
1677
+ click.echo("ERROR: Provide --model-pdb or --model-indices when --no-detect-layer.", err=True)
1678
+ sys.exit(1)
1679
+
1680
+ if show_config:
1681
+ click.echo(
1682
+ pretty_block(
1683
+ "yaml_layers",
1684
+ {
1685
+ "config": None if config_yaml is None else str(config_yaml),
1686
+ "override": None if override_yaml is None else str(override_yaml),
1687
+ "merged_keys": sorted(merged_yaml_cfg.keys()),
1688
+ },
1689
+ )
1690
+ )
1691
+
1692
+ dry_payload: Dict[str, Any] = {
1693
+ "input_count": len(p_list),
1694
+ "input_first": str(p_list[0]) if p_list else None,
1695
+ "input_last": str(p_list[-1]) if p_list else None,
1696
+ "output_dir": str(out_dir_path),
1697
+ "mep_mode": mep_mode_kind,
1698
+ "refine_mode": refine_mode_kind,
1699
+ "opt_mode": str(opt_mode),
1700
+ "detect_layer": bool(detect_layer_effective),
1701
+ "model_region_source": model_region_source,
1702
+ "model_indices_count": 0 if not model_indices else len(model_indices),
1703
+ "pre_opt": bool(pre_opt),
1704
+ "align": bool(align),
1705
+ "max_depth": int(search_cfg.get("max_depth", SEARCH_KW["max_depth"])),
1706
+ "max_nodes_segment": int(search_cfg.get("max_nodes_segment", gs_cfg.get("max_nodes", 0))),
1707
+ "will_run_path_search": True,
1708
+ "will_write_summary": True,
1709
+ "backend": calc_cfg.get("backend", "uma"),
1710
+ "embedcharge": bool(calc_cfg.get("embedcharge", False)),
1711
+ }
1712
+ if layer_info_preview is not None:
1713
+ dry_payload["layer_counts"] = {
1714
+ "ml": len(layer_info_preview.get("ml_indices", [])),
1715
+ "movable_mm": len(layer_info_preview.get("movable_mm_indices", [])),
1716
+ "frozen": len(layer_info_preview.get("frozen_indices", [])),
1717
+ "unassigned": len(layer_info_preview.get("unassigned_indices", [])),
1718
+ }
1719
+
1720
+ click.echo(pretty_block("dry_run_plan", dry_payload))
1721
+ click.echo("[dry-run] Validation complete. Path search execution was skipped.")
1722
+ return
1723
+
1724
+ model_pdb_path: Optional[Path] = None
1725
+ layer_info: Optional[Dict[str, List[int]]] = None
1726
+
1727
+ if detect_layer_effective:
1728
+ try:
1729
+ model_pdb_path, layer_info = build_model_pdb_from_bfactors(layer_source_pdb, out_dir_path)
1730
+ calc_cfg["use_bfactor_layers"] = True
1731
+ click.echo(
1732
+ f"[layer] Detected B-factor layers: ML={len(layer_info.get('ml_indices', []))}, "
1733
+ f"MovableMM={len(layer_info.get('movable_mm_indices', []))}, "
1734
+ f"FrozenMM={len(layer_info.get('frozen_indices', []))}"
1735
+ )
1736
+ except Exception as e:
1737
+ if model_pdb_effective is None and not model_indices:
1738
+ click.echo(f"ERROR: {e}", err=True)
1739
+ sys.exit(1)
1740
+ click.echo(f"[layer] WARNING: {e} Falling back to explicit ML region.", err=True)
1741
+ detect_layer_effective = False
1742
+
1743
+ if not detect_layer_effective:
1744
+ if model_pdb_effective is None and not model_indices:
1745
+ click.echo("ERROR: Provide --model-pdb or --model-indices when --no-detect-layer.", err=True)
1746
+ sys.exit(1)
1747
+ if model_pdb_effective is not None:
1748
+ model_pdb_path = Path(model_pdb_effective)
1749
+ else:
1750
+ if layer_source_pdb.suffix.lower() != ".pdb":
1751
+ click.echo("ERROR: --model-indices requires a PDB input.", err=True)
1752
+ sys.exit(1)
1753
+ try:
1754
+ model_pdb_path = build_model_pdb_from_indices(layer_source_pdb, out_dir_path, model_indices or [])
1755
+ except Exception as e:
1756
+ click.echo(f"ERROR: {e}", err=True)
1757
+ sys.exit(1)
1758
+ calc_cfg["use_bfactor_layers"] = False
1759
+
1760
+ if model_pdb_path is None:
1761
+ click.echo("ERROR: Failed to resolve model PDB for the ML region.", err=True)
1762
+ sys.exit(1)
1763
+
1764
+ calc_cfg["model_pdb"] = str(model_pdb_path)
1765
+ freeze_atoms_final = apply_layer_freeze_constraints(
1766
+ geom_cfg,
1767
+ calc_cfg,
1768
+ layer_info,
1769
+ echo_fn=click.echo,
1770
+ )
1771
+
1772
+ # Distance-based overrides for Hessian-target and movable MM selection.
1773
+ if hess_cutoff_effective is not None:
1774
+ calc_cfg["hess_cutoff"] = float(hess_cutoff_effective)
1775
+ if movable_cutoff_effective is not None:
1776
+ calc_cfg["movable_cutoff"] = float(movable_cutoff_effective)
1777
+ calc_cfg["use_bfactor_layers"] = False
1778
+
1779
+ for key in ("input_pdb", "real_parm7", "model_pdb", "mm_fd_dir"):
1780
+ val = calc_cfg.get(key)
1781
+ if isinstance(val, (str, Path)):
1782
+ calc_cfg[key] = str(Path(val).expanduser().resolve())
1783
+
1784
+ stopt_cfg["stop_in_when_full"] = int(stopt_cfg.get("max_cycles", STOPT_KW["max_cycles"]))
1785
+ out_dir_path = Path(stopt_cfg.get("out_dir", out_dir)).resolve()
1786
+ echo_geom = format_freeze_atoms_for_echo(geom_cfg, key="freeze_atoms")
1787
+ echo_calc = format_freeze_atoms_for_echo(filter_calc_for_echo(calc_cfg), key="freeze_atoms")
1788
+ echo_gs = strip_inherited_keys(gs_cfg, GS_KW, mode="same")
1789
+ echo_stopt = strip_inherited_keys({**stopt_cfg, "out_dir": str(out_dir_path)}, STOPT_KW, mode="same")
1790
+ echo_lbfgs = strip_inherited_keys(lbfgs_cfg, LBFGS_KW, mode="same")
1791
+ echo_bond = strip_inherited_keys(bond_cfg, BOND_KW, mode="same")
1792
+ echo_search = strip_inherited_keys(search_cfg, SEARCH_KW, mode="same")
1793
+
1794
+ click.echo(pretty_block("geom", echo_geom))
1795
+ click.echo(pretty_block("calc", echo_calc))
1796
+ click.echo(pretty_block("gs", echo_gs))
1797
+ click.echo(pretty_block("stopt", echo_stopt))
1798
+ click.echo(pretty_block("lbfgs", echo_lbfgs))
1799
+ click.echo(pretty_block("bond", echo_bond))
1800
+ click.echo(pretty_block("search", echo_search))
1801
+ # Echo pre-optimization and alignment flags
1802
+ click.echo(
1803
+ pretty_block(
1804
+ "run_flags",
1805
+ {
1806
+ "pre_opt": bool(pre_opt),
1807
+ "align": bool(align),
1808
+ "mep_mode": mep_mode_kind,
1809
+ "refine_mode": refine_mode_kind,
1810
+ "opt_mode": str(opt_mode),
1811
+ },
1812
+ )
1813
+ )
1814
+
1815
+ if show_config:
1816
+ click.echo(
1817
+ pretty_block(
1818
+ "yaml_layers",
1819
+ {
1820
+ "config": None if config_yaml is None else str(config_yaml),
1821
+ "override": None if override_yaml is None else str(override_yaml),
1822
+ "merged_keys": sorted(merged_yaml_cfg.keys()),
1823
+ },
1824
+ )
1825
+ )
1826
+
1827
+ if int(stopt_cfg.get("max_cycles", 0)) <= 0:
1828
+ click.echo("[INFO] max_cycles <= 0: skipping path search.")
1829
+ return
1830
+
1831
+ # --------------------------
1832
+ # 2) Prepare inputs
1833
+ # --------------------------
1834
+ out_dir_path.mkdir(parents=True, exist_ok=True)
1835
+
1836
+ geoms = _load_structures(
1837
+ inputs=prepared_inputs,
1838
+ coord_type=geom_cfg.get("coord_type", "cart"),
1839
+ base_freeze=geom_cfg.get("freeze_atoms", []),
1840
+ )
1841
+
1842
+ shared_calc = mlmm(**calc_cfg)
1843
+ for g in geoms:
1844
+ g.set_calculator(shared_calc)
1845
+
1846
+ # Reference PDB for output conversion: prefer --ref-pdb, fall back to input PDBs
1847
+ ref_pdb_for_segments: Optional[Path] = None
1848
+ if ref_list:
1849
+ ref_pdb_for_segments = Path(ref_list[0]).resolve()
1850
+ else:
1851
+ for p in p_list:
1852
+ if p.suffix.lower() == ".pdb":
1853
+ ref_pdb_for_segments = p.resolve()
1854
+ break
1855
+
1856
+ if pre_opt:
1857
+ new_geoms: List[Any] = []
1858
+ for i, g in enumerate(geoms):
1859
+ tag = f"init{i:02d}"
1860
+ g_opt = _optimize_single(g, shared_calc, lbfgs_cfg, out_dir_path, tag=tag, ref_pdb_path=ref_pdb_for_segments)
1861
+ new_geoms.append(g_opt)
1862
+ geoms = new_geoms
1863
+ else:
1864
+ click.echo("[init] Skipping endpoint pre-optimization as requested by --no-preopt.")
1865
+
1866
+ # Align all inputs to the first structure, guided by freeze constraints, when requested
1867
+ align_thresh = str(stopt_cfg.get("thresh", "gau"))
1868
+ if align:
1869
+ try:
1870
+ click.echo("\n=== Aligning all inputs to the first structure (freeze-guided scan + relaxation) ===\n")
1871
+ _ = align_and_refine_sequence_inplace(
1872
+ geoms,
1873
+ thresh=align_thresh,
1874
+ shared_calc=shared_calc,
1875
+ out_dir=out_dir_path / "align_refine",
1876
+ verbose=True,
1877
+ )
1878
+ click.echo("[align] Completed input alignment.")
1879
+ except Exception as e:
1880
+ click.echo(f"[align] WARNING: Alignment failed; continuing without alignment: {e}", err=True)
1881
+ else:
1882
+ click.echo("[align] Skipping input alignment as requested by --no-align.")
1883
+
1884
+ # --------------------------
1885
+ # 3) Run recursive search for each adjacent pair and stitch
1886
+ # --------------------------
1887
+ click.echo("\n=== Multistep MEP search (multi-structure) started ===\n")
1888
+ seg_counter = [0]
1889
+
1890
+ bridge_max_nodes = int(search_cfg.get("max_nodes_bridge", 5))
1891
+ gs_bridge_cfg = {**gs_cfg, "max_nodes": bridge_max_nodes, "climb": False, "climb_lanczos": False}
1892
+
1893
+ combined_imgs: List[Any] = []
1894
+ combined_Es: List[float] = []
1895
+ seg_reports_all: List[SegmentReport] = []
1896
+
1897
+ def _segment_builder_for_pairs(tail_g, head_g, _tag: str) -> CombinedPath:
1898
+ sub = _build_multistep_path(
1899
+ tail_g, head_g,
1900
+ shared_calc,
1901
+ geom_cfg, gs_cfg, stopt_cfg,
1902
+ lbfgs_cfg,
1903
+ bond_cfg, search_cfg, refine_mode_kind,
1904
+ out_dir=out_dir_path,
1905
+ ref_pdb_path=ref_pdb_for_segments,
1906
+ depth=0,
1907
+ seg_counter=seg_counter,
1908
+ branch_tag="B",
1909
+ pair_index=None,
1910
+ mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
1911
+ kink_seq_count=_trailing_kink_count(seg_reports_all),
1912
+ )
1913
+ return sub
1914
+
1915
+ for i in range(len(geoms) - 1):
1916
+ gA, gB = geoms[i], geoms[i + 1]
1917
+ pair_tag = f"pair_{i:02d}"
1918
+ click.echo(f"\n--- Processing pair {i:02d}: image {i} → {i+1} ---")
1919
+ pair_path = _build_multistep_path(
1920
+ gA, gB,
1921
+ shared_calc,
1922
+ geom_cfg, gs_cfg, stopt_cfg,
1923
+ lbfgs_cfg,
1924
+ bond_cfg, search_cfg, refine_mode_kind,
1925
+ out_dir=out_dir_path,
1926
+ ref_pdb_path=ref_pdb_for_segments,
1927
+ depth=0,
1928
+ seg_counter=seg_counter,
1929
+ branch_tag=pair_tag,
1930
+ pair_index=i,
1931
+ mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
1932
+ )
1933
+
1934
+ if i == 0:
1935
+ combined_imgs = list(pair_path.images)
1936
+ combined_Es = list(pair_path.energies)
1937
+ seg_reports_all.extend(pair_path.segments)
1938
+ else:
1939
+ parts = [(combined_imgs, combined_Es), (pair_path.images, pair_path.energies)]
1940
+ combined_imgs, combined_Es = _stitch_paths(
1941
+ parts=parts,
1942
+ stitch_rmsd_thresh=float(search_cfg["stitch_rmsd_thresh"]),
1943
+ bridge_rmsd_thresh=float(search_cfg["bridge_rmsd_thresh"]),
1944
+ shared_calc=shared_calc,
1945
+ gs_cfg=gs_bridge_cfg,
1946
+ stopt_cfg=stopt_cfg,
1947
+ out_dir=out_dir_path,
1948
+ tag=pair_tag,
1949
+ ref_pdb_path=ref_pdb_for_segments,
1950
+ bond_cfg=bond_cfg,
1951
+ segment_builder=_segment_builder_for_pairs,
1952
+ segments_out=seg_reports_all,
1953
+ bridge_pair_index=i,
1954
+ mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
1955
+ )
1956
+ seg_reports_all.extend(pair_path.segments)
1957
+
1958
+ click.echo("\n=== Multistep MEP search (multi-structure) finished ===\n")
1959
+
1960
+ combined_all = CombinedPath(images=combined_imgs, energies=combined_Es, segments=seg_reports_all)
1961
+
1962
+ # --------------------------
1963
+ # 4) Outputs
1964
+ # --------------------------
1965
+ for idx, srep in enumerate(combined_all.segments, 1):
1966
+ srep.seg_index = idx
1967
+ tag_to_index = {s.tag: int(s.seg_index) for s in combined_all.segments}
1968
+ for im in combined_all.images:
1969
+ tag = getattr(im, "mep_seg_tag", None)
1970
+ if tag and tag in tag_to_index:
1971
+ try:
1972
+ setattr(im, "mep_seg_index", int(tag_to_index[tag]))
1973
+ except Exception:
1974
+ logger.debug("Failed to set mep_seg_index on image", exc_info=True)
1975
+
1976
+ # Always write mep_trj.xyz for downstream compatibility; convert to PDB when possible.
1977
+ pdb_input = ref_pdb_for_segments is not None
1978
+ final_trj = out_dir_path / "mep_trj.xyz"
1979
+ _write_xyz_trj_with_energy(combined_all.images, combined_all.energies, final_trj)
1980
+ click.echo(f"[write] Wrote '{final_trj}'.")
1981
+ try:
1982
+ run_trj2fig(final_trj, [out_dir_path / "mep_plot.png"], unit="kcal", reference="init", reverse_x=False)
1983
+ click.echo(f"[plot] Saved energy plot → '{out_dir_path / 'mep_plot.png'}'")
1984
+ except Exception as e:
1985
+ click.echo(f"[plot] WARNING: Failed to plot final energy: {e}", err=True)
1986
+
1987
+ if pdb_input:
1988
+ try:
1989
+ final_pdb = out_dir_path / "mep.pdb"
1990
+ convert_xyz_to_pdb(final_trj, ref_pdb_for_segments, final_pdb)
1991
+ click.echo(f"[convert] Wrote '{final_pdb}'.")
1992
+ except Exception as e:
1993
+ click.echo(f"[convert] WARNING: Failed to convert final MEP to PDB: {e}", err=True)
1994
+
1995
+ # ---- Pocket-only per-segment trajectories & HEIs ----
1996
+ try:
1997
+ # Map frames → segment indices
1998
+ frame_seg_indices: List[int] = [int(getattr(im, "mep_seg_index", 0) or 0) for im in combined_all.images]
1999
+ seg_to_frames: Dict[int, List[int]] = {}
2000
+ for ii, sidx in enumerate(frame_seg_indices):
2001
+ if sidx <= 0:
2002
+ continue
2003
+ seg_to_frames.setdefault(int(sidx), []).append(ii)
2004
+
2005
+ for s in combined_all.segments:
2006
+ seg_idx = int(s.seg_index)
2007
+ idxs = seg_to_frames.get(seg_idx, [])
2008
+ if not idxs:
2009
+ continue
2010
+
2011
+ # (A) Only for bond-change segments: pocket-only per-segment path
2012
+ if s.kind != "bridge" and s.summary and s.summary.strip() != "(no covalent changes detected)":
2013
+ seg_imgs = [combined_all.images[j] for j in idxs]
2014
+ seg_Es = [combined_all.energies[j] for j in idxs]
2015
+ seg_trj = out_dir_path / f"mep_seg_{seg_idx:02d}_trj.xyz"
2016
+ _write_xyz_trj_with_energy(seg_imgs, seg_Es, seg_trj)
2017
+ click.echo(f"[write] Wrote per-segment pocket trajectory → '{seg_trj}'")
2018
+ if ref_pdb_for_segments is not None:
2019
+ _maybe_convert_to_pdb(seg_trj, ref_pdb_for_segments, out_path=out_dir_path / f"mep_seg_{seg_idx:02d}.pdb")
2020
+
2021
+ # (B) HEI pocket files only for bond-change segments
2022
+ if s.kind != "bridge" and s.summary and s.summary.strip() != "(no covalent changes detected)":
2023
+ energies_seg = [combined_all.energies[j] for j in idxs]
2024
+ imax_rel = int(np.argmax(np.array(energies_seg, dtype=float)))
2025
+ imax_abs = idxs[imax_rel]
2026
+ hei_img = combined_all.images[imax_abs]
2027
+ hei_E = [combined_all.energies[imax_abs]]
2028
+ hei_trj = out_dir_path / f"hei_seg_{seg_idx:02d}.xyz"
2029
+ _write_xyz_trj_with_energy([hei_img], hei_E, hei_trj)
2030
+ click.echo(f"[write] Wrote segment HEI (pocket) → '{hei_trj}'")
2031
+ if ref_pdb_for_segments is not None:
2032
+ _maybe_convert_to_pdb(hei_trj, ref_pdb_for_segments, out_path=out_dir_path / f"hei_seg_{seg_idx:02d}.pdb")
2033
+ except Exception as e:
2034
+ click.echo(f"[write] WARNING: Failed to emit per-segment pocket outputs: {e}", err=True)
2035
+ # ---- END ----
2036
+
2037
+ summary = {
2038
+ "out_dir": str(out_dir_path),
2039
+ "n_images": len(combined_all.images),
2040
+ "n_segments": len(combined_all.segments),
2041
+ "segments": [
2042
+ {
2043
+ "index": int(s.seg_index),
2044
+ "tag": s.tag,
2045
+ "kind": s.kind,
2046
+ "barrier_kcal": float(s.barrier_kcal),
2047
+ "delta_kcal": float(s.delta_kcal),
2048
+ "bond_changes": (s.summary if (s.kind != "bridge") else "")
2049
+ } for s in combined_all.segments
2050
+ ],
2051
+ }
2052
+
2053
+ # --------------------------
2054
+ # 5) Console summary
2055
+ # --------------------------
2056
+ try:
2057
+ overall_changed, overall_summary = _has_bond_change(combined_all.images[0], combined_all.images[-1], bond_cfg)
2058
+ except Exception:
2059
+ overall_changed, overall_summary = False, ""
2060
+
2061
+ click.echo("\n=== MEP Summary ===\n")
2062
+
2063
+ click.echo("\n[overall] Covalent-bond changes between first and last image:")
2064
+ if overall_changed and overall_summary.strip():
2065
+ click.echo(textwrap.indent(overall_summary.strip(), prefix=" "))
2066
+ else:
2067
+ click.echo(" (no covalent changes detected)")
2068
+
2069
+ if combined_all.segments:
2070
+ click.echo("\n[segments] Along the final MEP order (ΔE‡, ΔE). Bridges are shown between connected segments:")
2071
+ for i, seg in enumerate(combined_all.segments, 1):
2072
+ kind_label = "BRIDGE" if seg.kind == "bridge" else "SEG"
2073
+ click.echo(f" [{i:02d}] ({kind_label}) {seg.tag} | ΔE‡ = {seg.barrier_kcal:.2f} kcal/mol, ΔE = {seg.delta_kcal:.2f} kcal/mol")
2074
+ if seg.kind != "bridge" and seg.summary.strip():
2075
+ click.echo(textwrap.indent(seg.summary.strip(), prefix=" "))
2076
+ else:
2077
+ click.echo("\n[segments] (no segment reports)")
2078
+
2079
+ # --------------------------
2080
+ # 6) Energy diagram from bond-change segments (state labeling; compressed)
2081
+ # --------------------------
2082
+ diagram_payload: Optional[Dict[str, Any]] = None
2083
+ try:
2084
+ # Map each segment index → list of frame indices
2085
+ frame_seg_indices: List[int] = [int(getattr(im, "mep_seg_index", 0) or 0) for im in combined_all.images]
2086
+ seg_to_frames: Dict[int, List[int]] = {}
2087
+ for ii, sidx in enumerate(frame_seg_indices):
2088
+ if sidx <= 0:
2089
+ continue
2090
+ seg_to_frames.setdefault(int(sidx), []).append(ii)
2091
+
2092
+ # Build TS groups (each bond-change segment starts a group)
2093
+ ts_groups: List[Dict[str, Any]] = []
2094
+ ts_count = 0
2095
+ current: Optional[Dict[str, Any]] = None
2096
+
2097
+ for s in combined_all.segments:
2098
+ idxs = seg_to_frames.get(int(s.seg_index), [])
2099
+ if not idxs:
2100
+ continue
2101
+
2102
+ if s.kind == "seg" and s.summary and s.summary.strip() != "(no covalent changes detected)":
2103
+ # New TS group
2104
+ ts_count += 1
2105
+ imax = max(idxs, key=lambda j: combined_all.energies[j])
2106
+ ts_e = float(combined_all.energies[imax])
2107
+ first_im_e = float(combined_all.energies[idxs[-1]])
2108
+ current = {
2109
+ "ts_label": f"TS{ts_count}",
2110
+ "ts_energy": ts_e,
2111
+ "first_im_energy": first_im_e,
2112
+ "tail_im_energy": first_im_e,
2113
+ "has_extra": False,
2114
+ "index": ts_count,
2115
+ }
2116
+ ts_groups.append(current)
2117
+ else:
2118
+ # Kink/bridge: fold into current group as "extra" and update tail energy
2119
+ if current is not None:
2120
+ current["tail_im_energy"] = float(combined_all.energies[idxs[-1]])
2121
+ current["has_extra"] = True
2122
+ else:
2123
+ # pre-TS region without bond change → ignore
2124
+ pass
2125
+
2126
+ # Clip endpoints to first/last bond-change segment edges
2127
+ start_idx_for_diag = 0
2128
+ end_idx_for_diag = len(combined_all.energies) - 1
2129
+ bc_segments_in_order: List[SegmentReport] = [
2130
+ s for s in combined_all.segments
2131
+ if (s.kind == "seg" and s.summary and s.summary.strip() != "(no covalent changes detected)")
2132
+ ]
2133
+ if bc_segments_in_order:
2134
+ first_bc = bc_segments_in_order[0]
2135
+ last_bc = bc_segments_in_order[-1]
2136
+ idxs_first_bc = seg_to_frames.get(int(first_bc.seg_index), [])
2137
+ idxs_last_bc = seg_to_frames.get(int(last_bc.seg_index), [])
2138
+ if idxs_first_bc:
2139
+ start_idx_for_diag = int(idxs_first_bc[0])
2140
+ if idxs_last_bc:
2141
+ end_idx_for_diag = int(idxs_last_bc[-1])
2142
+
2143
+ # Compose compressed labels/energies & human-readable chain
2144
+ labels: List[str] = ["R"]
2145
+ energies_eh: List[float] = [float(combined_all.energies[start_idx_for_diag])]
2146
+ chain_tokens: List[str] = ["R"]
2147
+
2148
+ for i, g in enumerate(ts_groups, start=1):
2149
+ last_group = (i == len(ts_groups))
2150
+
2151
+ # TS
2152
+ labels.append(g["ts_label"])
2153
+ energies_eh.append(g["ts_energy"])
2154
+ chain_tokens.extend(["-->", g["ts_label"]])
2155
+
2156
+ # For the last TS group: compress directly to P (no IMs)
2157
+ if last_group:
2158
+ continue
2159
+
2160
+ # IM1 (always keep)
2161
+ labels.append(f"IM{i}_1")
2162
+ energies_eh.append(g["first_im_energy"])
2163
+ chain_tokens.extend(["-->", f"IM{i}_1"])
2164
+
2165
+ # IM2 (represent all extra kink/bridge before next TS)
2166
+ if g["has_extra"]:
2167
+ labels.append(f"IM{i}_2")
2168
+ energies_eh.append(g["tail_im_energy"])
2169
+ chain_tokens.extend(["-|-->", f"IM{i}_2"])
2170
+
2171
+ # Product
2172
+ labels.append("P")
2173
+ energies_eh.append(float(combined_all.energies[end_idx_for_diag]))
2174
+ chain_tokens.extend(["-->", "P"])
2175
+
2176
+ # Convert to kcal/mol relative to R
2177
+ e0 = energies_eh[0]
2178
+ energies_kcal = [(e - e0) * AU2KCALPERMOL for e in energies_eh]
2179
+ energies_au = list(energies_eh)
2180
+ diagram_payload = {
2181
+ "name": "energy_diagram_MEP",
2182
+ "labels": list(labels),
2183
+ "energies_kcal": energies_kcal,
2184
+ "ylabel": "ΔE (kcal/mol)",
2185
+ "energies_au": energies_au,
2186
+ "image": str(out_dir_path / "energy_diagram_MEP.png"),
2187
+ }
2188
+
2189
+ # Log exact inputs to build_energy_diagram, and the human-readable chain
2190
+ labels_repr = "[" + ", ".join(f'"{lab}"' for lab in labels) + "]"
2191
+ energies_repr = "[" + ", ".join(f"{val:.6f}" for val in energies_kcal) + "]"
2192
+ click.echo(f"[diagram] build_energy_diagram.labels = {labels_repr}")
2193
+ click.echo(f"[diagram] build_energy_diagram.energies_kcal = {energies_repr}")
2194
+
2195
+ fig = build_energy_diagram(
2196
+ energies=energies_kcal,
2197
+ labels=labels,
2198
+ ylabel="ΔE (kcal/mol)",
2199
+ baseline=True,
2200
+ showgrid=False,
2201
+ )
2202
+
2203
+ try:
2204
+ png_path = out_dir_path / "energy_diagram_MEP.png"
2205
+ fig.write_image(str(png_path), scale=2)
2206
+ click.echo(f"[diagram] Wrote energy diagram (PNG) → '{png_path}'")
2207
+ except Exception as e:
2208
+ click.echo(f"[diagram] NOTE: PNG export skipped (install 'kaleido' to enable): {e}", err=True)
2209
+
2210
+ chain_text = " ".join(chain_tokens)
2211
+ click.echo(f"[diagram] State label sequence: {chain_text}")
2212
+
2213
+ except Exception as e:
2214
+ click.echo(f"[diagram] WARNING: Failed to build energy diagram: {e}", err=True)
2215
+
2216
+ # --------------------------
2217
+ # 7) Summary (YAML + log)
2218
+ # --------------------------
2219
+ if diagram_payload is not None:
2220
+ summary["energy_diagrams"] = [diagram_payload]
2221
+
2222
+ with open(out_dir_path / "summary.yaml", "w") as f:
2223
+ yaml.safe_dump(summary, f, sort_keys=False, allow_unicode=True)
2224
+ click.echo(f"[write] Wrote '{out_dir_path / 'summary.yaml'}'.")
2225
+
2226
+ try:
2227
+ freeze_atoms_for_log: List[int] = []
2228
+ try:
2229
+ freeze_atoms_for_log = sorted(
2230
+ {
2231
+ int(i)
2232
+ for g in getattr(combined_all, "images", [])
2233
+ for i in getattr(g, "freeze_atoms", [])
2234
+ }
2235
+ )
2236
+ except Exception:
2237
+ freeze_atoms_for_log = []
2238
+
2239
+ diag_for_log: Dict[str, Any] = diagram_payload or {}
2240
+ mep_info = {
2241
+ "n_images": len(combined_all.images),
2242
+ "n_segments": len(combined_all.segments),
2243
+ "traj_pdb": str(out_dir_path / "mep.pdb") if (out_dir_path / "mep.pdb").exists() else None,
2244
+ "mep_plot": str(out_dir_path / "mep_plot.png") if (out_dir_path / "mep_plot.png").exists() else None,
2245
+ "diagram": diag_for_log,
2246
+ }
2247
+ summary_payload = {
2248
+ "root_out_dir": str(out_dir_path),
2249
+ "path_dir": str(out_dir_path),
2250
+ "path_module_dir": "path_search",
2251
+ "pipeline_mode": "path-search",
2252
+ "refine_path": True,
2253
+ "tsopt": False,
2254
+ "thermo": False,
2255
+ "dft": False,
2256
+ "opt_mode": opt_mode,
2257
+ "mep_mode": "path-search",
2258
+ "uma_model": calc_cfg.get("uma_model"),
2259
+ "command": command_str,
2260
+ "charge": calc_cfg.get("model_charge"),
2261
+ "spin": calc_cfg.get("model_mult"),
2262
+ "freeze_atoms": freeze_atoms_for_log,
2263
+ "mep": mep_info,
2264
+ "segments": summary.get("segments", []),
2265
+ "energy_diagrams": summary.get("energy_diagrams", []),
2266
+ "key_files": {},
2267
+ }
2268
+ write_summary_log(out_dir_path / "summary.log", summary_payload)
2269
+ click.echo(f"[write] Wrote '{out_dir_path / 'summary.log'}'.")
2270
+ except Exception as e:
2271
+ click.echo(f"[write] WARNING: Failed to write summary.log: {e}", err=True)
2272
+
2273
+ # summary.md and key_* outputs are disabled.
2274
+ # --------------------------
2275
+ # 8) Elapsed time
2276
+ # --------------------------
2277
+ click.echo(format_elapsed("[time] Elapsed for Path Search", time_start))
2278
+
2279
+ except ZeroStepLength:
2280
+ click.echo("ERROR: Proposed step length dropped below the minimum allowed (ZeroStepLength).", err=True)
2281
+ sys.exit(2)
2282
+ except OptimizationError as e:
2283
+ click.echo(f"ERROR: Path search failed — {e}", err=True)
2284
+ sys.exit(3)
2285
+ except KeyboardInterrupt:
2286
+ click.echo("\nInterrupted by user.", err=True)
2287
+ sys.exit(130)
2288
+ except Exception as e:
2289
+ tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))
2290
+ click.echo("Unhandled error during path search:\n" + textwrap.indent(tb, " "), err=True)
2291
+ sys.exit(1)
2292
+ finally:
2293
+ for prepared in prepared_inputs:
2294
+ prepared.cleanup()
2295
+ # Release GPU memory so subsequent pipeline stages don't OOM
2296
+ shared_calc = geoms = None
2297
+ gc.collect() # break cyclic refs inside torch.nn.Module
2298
+ if torch.cuda.is_available():
2299
+ torch.cuda.empty_cache()