mlmm-toolkit 0.2.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hessian_ff/__init__.py +50 -0
- hessian_ff/analytical_hessian.py +609 -0
- hessian_ff/constants.py +46 -0
- hessian_ff/forcefield.py +339 -0
- hessian_ff/loaders.py +608 -0
- hessian_ff/native/Makefile +8 -0
- hessian_ff/native/__init__.py +28 -0
- hessian_ff/native/analytical_hessian.py +88 -0
- hessian_ff/native/analytical_hessian_ext.cpp +258 -0
- hessian_ff/native/bonded.py +82 -0
- hessian_ff/native/bonded_ext.cpp +640 -0
- hessian_ff/native/loader.py +349 -0
- hessian_ff/native/nonbonded.py +118 -0
- hessian_ff/native/nonbonded_ext.cpp +1150 -0
- hessian_ff/prmtop_parmed.py +23 -0
- hessian_ff/system.py +107 -0
- hessian_ff/terms/__init__.py +14 -0
- hessian_ff/terms/angle.py +73 -0
- hessian_ff/terms/bond.py +44 -0
- hessian_ff/terms/cmap.py +406 -0
- hessian_ff/terms/dihedral.py +141 -0
- hessian_ff/terms/nonbonded.py +209 -0
- hessian_ff/tests/__init__.py +0 -0
- hessian_ff/tests/conftest.py +75 -0
- hessian_ff/tests/data/small/complex.parm7 +1346 -0
- hessian_ff/tests/data/small/complex.pdb +125 -0
- hessian_ff/tests/data/small/complex.rst7 +63 -0
- hessian_ff/tests/test_coords_input.py +44 -0
- hessian_ff/tests/test_energy_force.py +49 -0
- hessian_ff/tests/test_hessian.py +137 -0
- hessian_ff/tests/test_smoke.py +18 -0
- hessian_ff/tests/test_validation.py +40 -0
- hessian_ff/workflows.py +889 -0
- mlmm/__init__.py +36 -0
- mlmm/__main__.py +7 -0
- mlmm/_version.py +34 -0
- mlmm/add_elem_info.py +374 -0
- mlmm/advanced_help.py +91 -0
- mlmm/align_freeze_atoms.py +601 -0
- mlmm/all.py +3535 -0
- mlmm/bond_changes.py +231 -0
- mlmm/bool_compat.py +223 -0
- mlmm/cli.py +574 -0
- mlmm/cli_utils.py +166 -0
- mlmm/default_group.py +337 -0
- mlmm/defaults.py +467 -0
- mlmm/define_layer.py +526 -0
- mlmm/dft.py +1041 -0
- mlmm/energy_diagram.py +253 -0
- mlmm/extract.py +2213 -0
- mlmm/fix_altloc.py +464 -0
- mlmm/freq.py +1406 -0
- mlmm/harmonic_constraints.py +140 -0
- mlmm/hessian_cache.py +44 -0
- mlmm/hessian_calc.py +174 -0
- mlmm/irc.py +638 -0
- mlmm/mlmm_calc.py +2262 -0
- mlmm/mm_parm.py +945 -0
- mlmm/oniom_export.py +1983 -0
- mlmm/oniom_import.py +457 -0
- mlmm/opt.py +1742 -0
- mlmm/path_opt.py +1353 -0
- mlmm/path_search.py +2299 -0
- mlmm/preflight.py +88 -0
- mlmm/py.typed +1 -0
- mlmm/pysis_runner.py +45 -0
- mlmm/scan.py +1047 -0
- mlmm/scan2d.py +1226 -0
- mlmm/scan3d.py +1265 -0
- mlmm/scan_common.py +184 -0
- mlmm/summary_log.py +736 -0
- mlmm/trj2fig.py +448 -0
- mlmm/tsopt.py +2871 -0
- mlmm/utils.py +2309 -0
- mlmm/xtb_embedcharge_correction.py +475 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
- pysisyphus/Geometry.py +1667 -0
- pysisyphus/LICENSE +674 -0
- pysisyphus/TableFormatter.py +63 -0
- pysisyphus/TablePrinter.py +74 -0
- pysisyphus/__init__.py +12 -0
- pysisyphus/calculators/AFIR.py +452 -0
- pysisyphus/calculators/AnaPot.py +20 -0
- pysisyphus/calculators/AnaPot2.py +48 -0
- pysisyphus/calculators/AnaPot3.py +12 -0
- pysisyphus/calculators/AnaPot4.py +20 -0
- pysisyphus/calculators/AnaPotBase.py +337 -0
- pysisyphus/calculators/AnaPotCBM.py +25 -0
- pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
- pysisyphus/calculators/CFOUR.py +250 -0
- pysisyphus/calculators/Calculator.py +844 -0
- pysisyphus/calculators/CerjanMiller.py +24 -0
- pysisyphus/calculators/Composite.py +123 -0
- pysisyphus/calculators/ConicalIntersection.py +171 -0
- pysisyphus/calculators/DFTBp.py +430 -0
- pysisyphus/calculators/DFTD3.py +66 -0
- pysisyphus/calculators/DFTD4.py +84 -0
- pysisyphus/calculators/Dalton.py +61 -0
- pysisyphus/calculators/Dimer.py +681 -0
- pysisyphus/calculators/Dummy.py +20 -0
- pysisyphus/calculators/EGO.py +76 -0
- pysisyphus/calculators/EnergyMin.py +224 -0
- pysisyphus/calculators/ExternalPotential.py +264 -0
- pysisyphus/calculators/FakeASE.py +35 -0
- pysisyphus/calculators/FourWellAnaPot.py +28 -0
- pysisyphus/calculators/FreeEndNEBPot.py +39 -0
- pysisyphus/calculators/Gaussian09.py +18 -0
- pysisyphus/calculators/Gaussian16.py +726 -0
- pysisyphus/calculators/HardSphere.py +159 -0
- pysisyphus/calculators/IDPPCalculator.py +49 -0
- pysisyphus/calculators/IPIClient.py +133 -0
- pysisyphus/calculators/IPIServer.py +234 -0
- pysisyphus/calculators/LEPSBase.py +24 -0
- pysisyphus/calculators/LEPSExpr.py +139 -0
- pysisyphus/calculators/LennardJones.py +80 -0
- pysisyphus/calculators/MOPAC.py +219 -0
- pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
- pysisyphus/calculators/MultiCalc.py +85 -0
- pysisyphus/calculators/NFK.py +45 -0
- pysisyphus/calculators/OBabel.py +87 -0
- pysisyphus/calculators/ONIOMv2.py +1129 -0
- pysisyphus/calculators/ORCA.py +893 -0
- pysisyphus/calculators/ORCA5.py +6 -0
- pysisyphus/calculators/OpenMM.py +88 -0
- pysisyphus/calculators/OpenMolcas.py +281 -0
- pysisyphus/calculators/OverlapCalculator.py +908 -0
- pysisyphus/calculators/Psi4.py +218 -0
- pysisyphus/calculators/PyPsi4.py +37 -0
- pysisyphus/calculators/PySCF.py +341 -0
- pysisyphus/calculators/PyXTB.py +73 -0
- pysisyphus/calculators/QCEngine.py +106 -0
- pysisyphus/calculators/Rastrigin.py +22 -0
- pysisyphus/calculators/Remote.py +76 -0
- pysisyphus/calculators/Rosenbrock.py +15 -0
- pysisyphus/calculators/SocketCalc.py +97 -0
- pysisyphus/calculators/TIP3P.py +111 -0
- pysisyphus/calculators/TransTorque.py +161 -0
- pysisyphus/calculators/Turbomole.py +965 -0
- pysisyphus/calculators/VRIPot.py +37 -0
- pysisyphus/calculators/WFOWrapper.py +333 -0
- pysisyphus/calculators/WFOWrapper2.py +341 -0
- pysisyphus/calculators/XTB.py +418 -0
- pysisyphus/calculators/__init__.py +81 -0
- pysisyphus/calculators/cosmo_data.py +139 -0
- pysisyphus/calculators/parser.py +150 -0
- pysisyphus/color.py +19 -0
- pysisyphus/config.py +133 -0
- pysisyphus/constants.py +65 -0
- pysisyphus/cos/AdaptiveNEB.py +230 -0
- pysisyphus/cos/ChainOfStates.py +725 -0
- pysisyphus/cos/FreeEndNEB.py +25 -0
- pysisyphus/cos/FreezingString.py +103 -0
- pysisyphus/cos/GrowingChainOfStates.py +71 -0
- pysisyphus/cos/GrowingNT.py +309 -0
- pysisyphus/cos/GrowingString.py +508 -0
- pysisyphus/cos/NEB.py +189 -0
- pysisyphus/cos/SimpleZTS.py +64 -0
- pysisyphus/cos/__init__.py +22 -0
- pysisyphus/cos/stiffness.py +199 -0
- pysisyphus/drivers/__init__.py +17 -0
- pysisyphus/drivers/afir.py +855 -0
- pysisyphus/drivers/barriers.py +271 -0
- pysisyphus/drivers/birkholz.py +138 -0
- pysisyphus/drivers/cluster.py +318 -0
- pysisyphus/drivers/diabatization.py +133 -0
- pysisyphus/drivers/merge.py +368 -0
- pysisyphus/drivers/merge_mol2.py +322 -0
- pysisyphus/drivers/opt.py +375 -0
- pysisyphus/drivers/perf.py +91 -0
- pysisyphus/drivers/pka.py +52 -0
- pysisyphus/drivers/precon_pos_rot.py +669 -0
- pysisyphus/drivers/rates.py +480 -0
- pysisyphus/drivers/replace.py +219 -0
- pysisyphus/drivers/scan.py +212 -0
- pysisyphus/drivers/spectrum.py +166 -0
- pysisyphus/drivers/thermo.py +31 -0
- pysisyphus/dynamics/Gaussian.py +103 -0
- pysisyphus/dynamics/__init__.py +20 -0
- pysisyphus/dynamics/colvars.py +136 -0
- pysisyphus/dynamics/driver.py +297 -0
- pysisyphus/dynamics/helpers.py +256 -0
- pysisyphus/dynamics/lincs.py +105 -0
- pysisyphus/dynamics/mdp.py +364 -0
- pysisyphus/dynamics/rattle.py +121 -0
- pysisyphus/dynamics/thermostats.py +128 -0
- pysisyphus/dynamics/wigner.py +266 -0
- pysisyphus/elem_data.py +3473 -0
- pysisyphus/exceptions.py +2 -0
- pysisyphus/filtertrj.py +69 -0
- pysisyphus/helpers.py +623 -0
- pysisyphus/helpers_pure.py +649 -0
- pysisyphus/init_logging.py +50 -0
- pysisyphus/intcoords/Bend.py +69 -0
- pysisyphus/intcoords/Bend2.py +25 -0
- pysisyphus/intcoords/BondedFragment.py +32 -0
- pysisyphus/intcoords/Cartesian.py +41 -0
- pysisyphus/intcoords/CartesianCoords.py +140 -0
- pysisyphus/intcoords/Coords.py +56 -0
- pysisyphus/intcoords/DLC.py +197 -0
- pysisyphus/intcoords/DistanceFunction.py +34 -0
- pysisyphus/intcoords/DummyImproper.py +70 -0
- pysisyphus/intcoords/DummyTorsion.py +72 -0
- pysisyphus/intcoords/LinearBend.py +105 -0
- pysisyphus/intcoords/LinearDisplacement.py +80 -0
- pysisyphus/intcoords/OutOfPlane.py +59 -0
- pysisyphus/intcoords/PrimTypes.py +286 -0
- pysisyphus/intcoords/Primitive.py +137 -0
- pysisyphus/intcoords/RedundantCoords.py +659 -0
- pysisyphus/intcoords/RobustTorsion.py +59 -0
- pysisyphus/intcoords/Rotation.py +147 -0
- pysisyphus/intcoords/Stretch.py +31 -0
- pysisyphus/intcoords/Torsion.py +101 -0
- pysisyphus/intcoords/Torsion2.py +25 -0
- pysisyphus/intcoords/Translation.py +45 -0
- pysisyphus/intcoords/__init__.py +61 -0
- pysisyphus/intcoords/augment_bonds.py +126 -0
- pysisyphus/intcoords/derivatives.py +10512 -0
- pysisyphus/intcoords/eval.py +80 -0
- pysisyphus/intcoords/exceptions.py +37 -0
- pysisyphus/intcoords/findiffs.py +48 -0
- pysisyphus/intcoords/generate_derivatives.py +414 -0
- pysisyphus/intcoords/helpers.py +235 -0
- pysisyphus/intcoords/logging_conf.py +10 -0
- pysisyphus/intcoords/mp_derivatives.py +10836 -0
- pysisyphus/intcoords/setup.py +962 -0
- pysisyphus/intcoords/setup_fast.py +176 -0
- pysisyphus/intcoords/update.py +272 -0
- pysisyphus/intcoords/valid.py +89 -0
- pysisyphus/interpolate/Geodesic.py +93 -0
- pysisyphus/interpolate/IDPP.py +55 -0
- pysisyphus/interpolate/Interpolator.py +116 -0
- pysisyphus/interpolate/LST.py +70 -0
- pysisyphus/interpolate/Redund.py +152 -0
- pysisyphus/interpolate/__init__.py +9 -0
- pysisyphus/interpolate/helpers.py +34 -0
- pysisyphus/io/__init__.py +22 -0
- pysisyphus/io/aomix.py +178 -0
- pysisyphus/io/cjson.py +24 -0
- pysisyphus/io/crd.py +101 -0
- pysisyphus/io/cube.py +220 -0
- pysisyphus/io/fchk.py +184 -0
- pysisyphus/io/hdf5.py +49 -0
- pysisyphus/io/hessian.py +72 -0
- pysisyphus/io/mol2.py +146 -0
- pysisyphus/io/molden.py +293 -0
- pysisyphus/io/orca.py +189 -0
- pysisyphus/io/pdb.py +269 -0
- pysisyphus/io/psf.py +79 -0
- pysisyphus/io/pubchem.py +31 -0
- pysisyphus/io/qcschema.py +34 -0
- pysisyphus/io/sdf.py +29 -0
- pysisyphus/io/xyz.py +61 -0
- pysisyphus/io/zmat.py +175 -0
- pysisyphus/irc/DWI.py +108 -0
- pysisyphus/irc/DampedVelocityVerlet.py +134 -0
- pysisyphus/irc/Euler.py +22 -0
- pysisyphus/irc/EulerPC.py +345 -0
- pysisyphus/irc/GonzalezSchlegel.py +187 -0
- pysisyphus/irc/IMKMod.py +164 -0
- pysisyphus/irc/IRC.py +878 -0
- pysisyphus/irc/IRCDummy.py +10 -0
- pysisyphus/irc/Instanton.py +307 -0
- pysisyphus/irc/LQA.py +53 -0
- pysisyphus/irc/ModeKill.py +136 -0
- pysisyphus/irc/ParamPlot.py +53 -0
- pysisyphus/irc/RK4.py +36 -0
- pysisyphus/irc/__init__.py +31 -0
- pysisyphus/irc/initial_displ.py +219 -0
- pysisyphus/linalg.py +411 -0
- pysisyphus/line_searches/Backtracking.py +88 -0
- pysisyphus/line_searches/HagerZhang.py +184 -0
- pysisyphus/line_searches/LineSearch.py +232 -0
- pysisyphus/line_searches/StrongWolfe.py +108 -0
- pysisyphus/line_searches/__init__.py +9 -0
- pysisyphus/line_searches/interpol.py +15 -0
- pysisyphus/modefollow/NormalMode.py +40 -0
- pysisyphus/modefollow/__init__.py +10 -0
- pysisyphus/modefollow/davidson.py +199 -0
- pysisyphus/modefollow/lanczos.py +95 -0
- pysisyphus/optimizers/BFGS.py +99 -0
- pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
- pysisyphus/optimizers/ConjugateGradient.py +98 -0
- pysisyphus/optimizers/CubicNewton.py +75 -0
- pysisyphus/optimizers/FIRE.py +113 -0
- pysisyphus/optimizers/HessianOptimizer.py +1176 -0
- pysisyphus/optimizers/LBFGS.py +228 -0
- pysisyphus/optimizers/LayerOpt.py +411 -0
- pysisyphus/optimizers/MicroOptimizer.py +169 -0
- pysisyphus/optimizers/NCOptimizer.py +90 -0
- pysisyphus/optimizers/Optimizer.py +1084 -0
- pysisyphus/optimizers/PreconLBFGS.py +260 -0
- pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
- pysisyphus/optimizers/QuickMin.py +74 -0
- pysisyphus/optimizers/RFOptimizer.py +181 -0
- pysisyphus/optimizers/RSA.py +99 -0
- pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
- pysisyphus/optimizers/SteepestDescent.py +23 -0
- pysisyphus/optimizers/StringOptimizer.py +173 -0
- pysisyphus/optimizers/__init__.py +41 -0
- pysisyphus/optimizers/closures.py +301 -0
- pysisyphus/optimizers/cls_map.py +58 -0
- pysisyphus/optimizers/exceptions.py +6 -0
- pysisyphus/optimizers/gdiis.py +280 -0
- pysisyphus/optimizers/guess_hessians.py +311 -0
- pysisyphus/optimizers/hessian_updates.py +355 -0
- pysisyphus/optimizers/poly_fit.py +285 -0
- pysisyphus/optimizers/precon.py +153 -0
- pysisyphus/optimizers/restrict_step.py +24 -0
- pysisyphus/pack.py +172 -0
- pysisyphus/peakdetect.py +948 -0
- pysisyphus/plot.py +1031 -0
- pysisyphus/run.py +2106 -0
- pysisyphus/socket_helper.py +74 -0
- pysisyphus/stocastic/FragmentKick.py +132 -0
- pysisyphus/stocastic/Kick.py +81 -0
- pysisyphus/stocastic/Pipeline.py +303 -0
- pysisyphus/stocastic/__init__.py +21 -0
- pysisyphus/stocastic/align.py +127 -0
- pysisyphus/testing.py +96 -0
- pysisyphus/thermo.py +156 -0
- pysisyphus/trj.py +824 -0
- pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
- pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
- pysisyphus/tsoptimizers/TRIM.py +59 -0
- pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
- pysisyphus/tsoptimizers/__init__.py +23 -0
- pysisyphus/wavefunction/Basis.py +239 -0
- pysisyphus/wavefunction/DIIS.py +76 -0
- pysisyphus/wavefunction/__init__.py +25 -0
- pysisyphus/wavefunction/build_ext.py +42 -0
- pysisyphus/wavefunction/cart2sph.py +190 -0
- pysisyphus/wavefunction/diabatization.py +304 -0
- pysisyphus/wavefunction/excited_states.py +435 -0
- pysisyphus/wavefunction/gen_ints.py +1811 -0
- pysisyphus/wavefunction/helpers.py +104 -0
- pysisyphus/wavefunction/ints/__init__.py +0 -0
- pysisyphus/wavefunction/ints/boys.py +193 -0
- pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
- pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
- pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
- pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
- pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
- pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
- pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
- pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
- pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
- pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
- pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
- pysisyphus/wavefunction/localization.py +458 -0
- pysisyphus/wavefunction/multipole.py +159 -0
- pysisyphus/wavefunction/normalization.py +36 -0
- pysisyphus/wavefunction/pop_analysis.py +134 -0
- pysisyphus/wavefunction/shells.py +1171 -0
- pysisyphus/wavefunction/wavefunction.py +504 -0
- pysisyphus/wrapper/__init__.py +11 -0
- pysisyphus/wrapper/exceptions.py +2 -0
- pysisyphus/wrapper/jmol.py +120 -0
- pysisyphus/wrapper/mwfn.py +169 -0
- pysisyphus/wrapper/packmol.py +71 -0
- pysisyphus/xyzloader.py +168 -0
- pysisyphus/yaml_mods.py +45 -0
- thermoanalysis/LICENSE +674 -0
- thermoanalysis/QCData.py +244 -0
- thermoanalysis/__init__.py +0 -0
- thermoanalysis/config.py +3 -0
- thermoanalysis/constants.py +20 -0
- thermoanalysis/thermo.py +1011 -0
mlmm/path_search.py
ADDED
|
@@ -0,0 +1,2299 @@
|
|
|
1
|
+
# mlmm/path_search.py
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
ML/MM recursive GSM segmentation for multistep minimum-energy paths.
|
|
5
|
+
|
|
6
|
+
Example:
|
|
7
|
+
mlmm path-search -i R.pdb P.pdb --parm real.parm7 --model-pdb ml_region.pdb -q 0
|
|
8
|
+
|
|
9
|
+
For detailed documentation, see: docs/path_search.md
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from copy import deepcopy
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
|
|
18
|
+
|
|
19
|
+
import gc
|
|
20
|
+
import logging
|
|
21
|
+
import sys
|
|
22
|
+
import traceback
|
|
23
|
+
import textwrap
|
|
24
|
+
import tempfile
|
|
25
|
+
import os
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
import time # timing
|
|
29
|
+
import re # used in _segment_base_id
|
|
30
|
+
|
|
31
|
+
import click
|
|
32
|
+
import numpy as np
|
|
33
|
+
import torch
|
|
34
|
+
import yaml
|
|
35
|
+
|
|
36
|
+
from pysisyphus.helpers import geom_loader
|
|
37
|
+
from pysisyphus.cos.GrowingString import GrowingString
|
|
38
|
+
from pysisyphus.optimizers.StringOptimizer import StringOptimizer
|
|
39
|
+
from pysisyphus.optimizers.LBFGS import LBFGS
|
|
40
|
+
from pysisyphus.optimizers.exceptions import OptimizationError, ZeroStepLength
|
|
41
|
+
from pysisyphus.constants import AU2KCALPERMOL, BOHR2ANG
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
from .mlmm_calc import mlmm, MLMMASECalculator
|
|
45
|
+
from .defaults import (
|
|
46
|
+
BOND_KW as _BOND_KW_DEFAULT,
|
|
47
|
+
SEARCH_KW as _SEARCH_KW_DEFAULT,
|
|
48
|
+
)
|
|
49
|
+
from .path_opt import GS_KW as _PATH_GS_KW, STOPT_KW as _PATH_STOPT_KW, DMF_KW as _PATH_DMF_KW, _select_hei_index
|
|
50
|
+
from .opt import (
|
|
51
|
+
GEOM_KW as _OPT_GEOM_KW,
|
|
52
|
+
CALC_KW as _OPT_CALC_KW,
|
|
53
|
+
LBFGS_KW as _OPT_LBFGS_KW,
|
|
54
|
+
_parse_freeze_atoms as _parse_freeze_atoms_opt,
|
|
55
|
+
_normalize_geom_freeze as _normalize_geom_freeze_opt,
|
|
56
|
+
)
|
|
57
|
+
from .utils import (
|
|
58
|
+
apply_layer_freeze_constraints,
|
|
59
|
+
apply_ref_pdb_override,
|
|
60
|
+
convert_xyz_to_pdb,
|
|
61
|
+
set_convert_file_enabled,
|
|
62
|
+
load_yaml_dict,
|
|
63
|
+
deep_update,
|
|
64
|
+
apply_yaml_overrides,
|
|
65
|
+
pretty_block,
|
|
66
|
+
strip_inherited_keys,
|
|
67
|
+
filter_calc_for_echo,
|
|
68
|
+
format_freeze_atoms_for_echo,
|
|
69
|
+
format_elapsed,
|
|
70
|
+
merge_freeze_atom_indices,
|
|
71
|
+
build_energy_diagram,
|
|
72
|
+
prepare_input_structure,
|
|
73
|
+
resolve_charge_spin_or_raise,
|
|
74
|
+
PreparedInputStructure,
|
|
75
|
+
parse_indices_string,
|
|
76
|
+
build_model_pdb_from_bfactors,
|
|
77
|
+
build_model_pdb_from_indices,
|
|
78
|
+
read_bfactors_from_pdb,
|
|
79
|
+
has_valid_layer_bfactors,
|
|
80
|
+
parse_layer_indices_from_bfactors,
|
|
81
|
+
collect_single_option_values,
|
|
82
|
+
)
|
|
83
|
+
from .cli_utils import resolve_yaml_sources, load_merged_yaml_cfg, make_is_param_explicit
|
|
84
|
+
from .preflight import validate_existing_files
|
|
85
|
+
from .trj2fig import run_trj2fig # auto-generate an energy plot when a _trj.xyz is produced
|
|
86
|
+
from .summary_log import write_summary_log
|
|
87
|
+
from .bond_changes import compare_structures, summarize_changes
|
|
88
|
+
from .align_freeze_atoms import align_and_refine_sequence_inplace
|
|
89
|
+
|
|
90
|
+
# -----------------------------------------------
|
|
91
|
+
# Configuration defaults
|
|
92
|
+
# -----------------------------------------------
|
|
93
|
+
|
|
94
|
+
# Geometry (input handling) — reuse opt.py defaults
|
|
95
|
+
GEOM_KW: Dict[str, Any] = deepcopy(_OPT_GEOM_KW)
|
|
96
|
+
|
|
97
|
+
# ML/MM calculator settings — reuse opt.py defaults
|
|
98
|
+
CALC_KW: Dict[str, Any] = deepcopy(_OPT_CALC_KW)
|
|
99
|
+
|
|
100
|
+
# GrowingString (path representation)
|
|
101
|
+
GS_KW: Dict[str, Any] = deepcopy(_PATH_GS_KW)
|
|
102
|
+
|
|
103
|
+
# StringOptimizer (GSM optimization control)
|
|
104
|
+
STOPT_KW: Dict[str, Any] = deepcopy(_PATH_STOPT_KW)
|
|
105
|
+
STOPT_KW.update({
|
|
106
|
+
"out_dir": "./result_path_search/",
|
|
107
|
+
})
|
|
108
|
+
|
|
109
|
+
# LBFGS settings
|
|
110
|
+
LBFGS_KW: Dict[str, Any] = deepcopy(_OPT_LBFGS_KW)
|
|
111
|
+
LBFGS_KW.update({
|
|
112
|
+
"out_dir": "./result_path_search/",
|
|
113
|
+
})
|
|
114
|
+
|
|
115
|
+
# Covalent-bond change detection
|
|
116
|
+
BOND_KW: Dict[str, Any] = deepcopy(_BOND_KW_DEFAULT)
|
|
117
|
+
|
|
118
|
+
# DMF (Direct Max Flux) defaults
|
|
119
|
+
DMF_KW: Dict[str, Any] = deepcopy(_PATH_DMF_KW)
|
|
120
|
+
|
|
121
|
+
# Global search control
|
|
122
|
+
SEARCH_KW: Dict[str, Any] = deepcopy(_SEARCH_KW_DEFAULT)
|
|
123
|
+
|
|
124
|
+
# Multi-structure loader
|
|
125
|
+
def _load_structures(
|
|
126
|
+
inputs: Sequence[PreparedInputStructure],
|
|
127
|
+
coord_type: str,
|
|
128
|
+
base_freeze: Sequence[int],
|
|
129
|
+
) -> List[Any]:
|
|
130
|
+
"""
|
|
131
|
+
Load multiple geometries and assign `freeze_atoms`; return a list of geometries.
|
|
132
|
+
"""
|
|
133
|
+
geoms: List[Any] = []
|
|
134
|
+
for prepared in inputs:
|
|
135
|
+
geom_path = prepared.geom_path
|
|
136
|
+
g = geom_loader(geom_path, coord_type=coord_type)
|
|
137
|
+
cfg: Dict[str, Any] = {"freeze_atoms": list(base_freeze)}
|
|
138
|
+
freeze = merge_freeze_atom_indices(cfg)
|
|
139
|
+
g.freeze_atoms = np.array(freeze, dtype=int)
|
|
140
|
+
geoms.append(g)
|
|
141
|
+
return geoms
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# Helpers shared with opt.py for freeze parsing/normalization
|
|
145
|
+
_parse_freeze_atoms = _parse_freeze_atoms_opt
|
|
146
|
+
_normalize_geom_freeze = _normalize_geom_freeze_opt
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _write_xyz_trj_with_energy(images: Sequence, energies: Sequence[float], path: Path) -> None:
|
|
150
|
+
"""
|
|
151
|
+
Write an XYZ `_trj.xyz` with the energy on line 2 of each block.
|
|
152
|
+
"""
|
|
153
|
+
blocks: List[str] = []
|
|
154
|
+
E = np.array(energies, dtype=float)
|
|
155
|
+
for geom, e in zip(images, E):
|
|
156
|
+
s = geom.as_xyz()
|
|
157
|
+
lines = s.splitlines()
|
|
158
|
+
if len(lines) >= 2 and lines[0].strip().isdigit():
|
|
159
|
+
lines[1] = f"{e:.12f}"
|
|
160
|
+
s_mod = "\n".join(lines)
|
|
161
|
+
if not s_mod.endswith("\n"):
|
|
162
|
+
s_mod += "\n"
|
|
163
|
+
blocks.append(s_mod)
|
|
164
|
+
with open(path, "w") as f:
|
|
165
|
+
f.write("".join(blocks))
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _maybe_convert_to_pdb(in_path: Path, ref_pdb_path: Optional[Path], out_path: Optional[Path] = None) -> Optional[Path]:
|
|
169
|
+
"""
|
|
170
|
+
If any input is PDB, convert the given `.xyz/_trj.xyz` to PDB using `ref_pdb_path`.
|
|
171
|
+
Return the output path on success, else None.
|
|
172
|
+
"""
|
|
173
|
+
try:
|
|
174
|
+
if ref_pdb_path is None or (not in_path.exists()) or in_path.suffix.lower() not in (".xyz", "_trj.xyz"):
|
|
175
|
+
return None
|
|
176
|
+
out_pdb = out_path if out_path is not None else in_path.with_suffix(".pdb")
|
|
177
|
+
convert_xyz_to_pdb(in_path, ref_pdb_path, out_pdb)
|
|
178
|
+
click.echo(f"[convert] Wrote '{out_pdb}'.")
|
|
179
|
+
return out_pdb
|
|
180
|
+
except Exception as e:
|
|
181
|
+
click.echo(f"[convert] WARNING: Failed to convert '{in_path.name}' to PDB: {e}", err=True)
|
|
182
|
+
return None
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _kabsch_rmsd(A: np.ndarray, B: np.ndarray, align: bool = True, indices: Optional[Sequence[int]] = None) -> float:
|
|
186
|
+
"""
|
|
187
|
+
RMSD between A and B (no rigid alignment; `align` is ignored). Optional subset selection via `indices`.
|
|
188
|
+
"""
|
|
189
|
+
assert A.shape == B.shape and A.shape[1] == 3
|
|
190
|
+
if indices is not None and len(indices) > 0:
|
|
191
|
+
idx = np.array(sorted({int(i) for i in indices if 0 <= int(i) < A.shape[0]}), dtype=int)
|
|
192
|
+
if idx.size == 0:
|
|
193
|
+
idx = np.arange(A.shape[0], dtype=int)
|
|
194
|
+
A = A[idx]
|
|
195
|
+
B = B[idx]
|
|
196
|
+
diff = A - B
|
|
197
|
+
return float(np.sqrt((diff * diff).sum() / A.shape[0]))
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _has_bond_change(x, y, bond_cfg: Dict[str, Any]) -> Tuple[bool, str]:
|
|
203
|
+
"""
|
|
204
|
+
Determine whether covalent bonds are forming or breaking between `x` and `y`.
|
|
205
|
+
"""
|
|
206
|
+
res = compare_structures(
|
|
207
|
+
x, y,
|
|
208
|
+
device=bond_cfg.get("device", "cuda"),
|
|
209
|
+
bond_factor=float(bond_cfg.get("bond_factor", 1.20)),
|
|
210
|
+
margin_fraction=float(bond_cfg.get("margin_fraction", 0.05)),
|
|
211
|
+
delta_fraction=float(bond_cfg.get("delta_fraction", 0.05)),
|
|
212
|
+
)
|
|
213
|
+
formed = len(res.formed_covalent) > 0
|
|
214
|
+
broken = len(res.broken_covalent) > 0
|
|
215
|
+
summary = summarize_changes(x, res, one_based=True)
|
|
216
|
+
return (formed or broken), summary
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# ---------- Minimal GS configuration helper ----------
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# -----------------------------------------------
|
|
223
|
+
# Kink detection & interpolation helpers
|
|
224
|
+
# -----------------------------------------------
|
|
225
|
+
|
|
226
|
+
def _new_geom_from_coords(atoms: Sequence[str], coords: np.ndarray, coord_type: str, freeze_atoms: Sequence[int]) -> Any:
|
|
227
|
+
"""
|
|
228
|
+
Create a pysisyphus Geometry from Bohr coords via temporary XYZ; attach `freeze_atoms`.
|
|
229
|
+
"""
|
|
230
|
+
lines = [str(len(atoms)), ""]
|
|
231
|
+
coords_ang = np.asarray(coords, dtype=float) * BOHR2ANG
|
|
232
|
+
for sym, (x, y, z) in zip(atoms, coords_ang):
|
|
233
|
+
lines.append(f"{sym} {x:.15f} {y:.15f} {z:.15f}")
|
|
234
|
+
s = "\n".join(lines) + "\n"
|
|
235
|
+
tmp = tempfile.NamedTemporaryFile("w+", suffix=".xyz", delete=False)
|
|
236
|
+
try:
|
|
237
|
+
tmp.write(s)
|
|
238
|
+
tmp.flush()
|
|
239
|
+
tmp.close()
|
|
240
|
+
g = geom_loader(Path(tmp.name), coord_type=coord_type)
|
|
241
|
+
g.freeze_atoms = np.array(sorted(set(map(int, freeze_atoms))), dtype=int)
|
|
242
|
+
return g
|
|
243
|
+
finally:
|
|
244
|
+
try:
|
|
245
|
+
os.unlink(tmp.name)
|
|
246
|
+
except Exception:
|
|
247
|
+
logger.debug("Failed to unlink temp file %s", tmp.name, exc_info=True)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _make_linear_interpolations(gL, gR, n_internal: int) -> List[Any]:
|
|
251
|
+
"""
|
|
252
|
+
Return `n_internal` linearly interpolated structures between gL → gR (excluding endpoints).
|
|
253
|
+
Atom order follows `gL`.
|
|
254
|
+
"""
|
|
255
|
+
A = np.asarray(gL.coords3d, dtype=float)
|
|
256
|
+
B = np.asarray(gR.coords3d, dtype=float)
|
|
257
|
+
assert A.shape == B.shape and A.shape[1] == 3, "Atom counts must match for interpolation."
|
|
258
|
+
atoms = [a for a in gL.atoms]
|
|
259
|
+
coord_type = gL.coord_type
|
|
260
|
+
faL = getattr(gL, "freeze_atoms", np.array([], dtype=int))
|
|
261
|
+
faR = getattr(gR, "freeze_atoms", np.array([], dtype=int))
|
|
262
|
+
freeze_union = sorted(set(map(int, faL)) | set(map(int, faR)))
|
|
263
|
+
interps: List[Any] = []
|
|
264
|
+
for k in range(1, n_internal + 1):
|
|
265
|
+
t = k / (n_internal + 1.0)
|
|
266
|
+
C = (1.0 - t) * A + t * B
|
|
267
|
+
interps.append(_new_geom_from_coords(atoms, C, coord_type, freeze_union))
|
|
268
|
+
return interps
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
# ---- Segment/bridge tagging helpers ----
|
|
272
|
+
|
|
273
|
+
def _tag_images(images: Sequence[Any], **attrs: Any) -> None:
|
|
274
|
+
"""
|
|
275
|
+
Attach arbitrary attributes to Geometry images.
|
|
276
|
+
"""
|
|
277
|
+
for im in images:
|
|
278
|
+
for k, v in attrs.items():
|
|
279
|
+
try:
|
|
280
|
+
setattr(im, k, v)
|
|
281
|
+
except Exception:
|
|
282
|
+
logger.debug("Failed to set attribute %s on image", k, exc_info=True)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def _segment_base_id(tag: str) -> str:
|
|
286
|
+
"""
|
|
287
|
+
Extract base id 'seg_XXX' from a tag like 'seg_000_refine'; fallback to `tag` or 'seg'.
|
|
288
|
+
"""
|
|
289
|
+
m = re.search(r"(seg_\d{3})", tag or "")
|
|
290
|
+
return m.group(1) if m else (tag or "seg")
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def _is_local_minimum(idx: int, energies: Sequence[float]) -> bool:
|
|
294
|
+
if idx < 0 or idx >= len(energies):
|
|
295
|
+
return False
|
|
296
|
+
if idx == 0:
|
|
297
|
+
return len(energies) > 1 and energies[1] > energies[0]
|
|
298
|
+
if idx == len(energies) - 1:
|
|
299
|
+
return energies[-2] > energies[-1]
|
|
300
|
+
return energies[idx - 1] > energies[idx] and energies[idx + 1] > energies[idx]
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _find_nearest_local_minimum(
|
|
304
|
+
hei_idx: int,
|
|
305
|
+
direction: int,
|
|
306
|
+
energies: Sequence[float],
|
|
307
|
+
) -> Optional[int]:
|
|
308
|
+
i = hei_idx + direction
|
|
309
|
+
while 0 <= i < len(energies):
|
|
310
|
+
if _is_local_minimum(i, energies):
|
|
311
|
+
return i
|
|
312
|
+
i += direction
|
|
313
|
+
return None
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
@dataclass
|
|
317
|
+
class GSMResult:
|
|
318
|
+
images: List[Any]
|
|
319
|
+
energies: List[float]
|
|
320
|
+
hei_idx: int
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
# ---- Per-segment summary for the console report ----
|
|
324
|
+
@dataclass
|
|
325
|
+
class SegmentReport:
|
|
326
|
+
tag: str
|
|
327
|
+
barrier_kcal: float
|
|
328
|
+
delta_kcal: float
|
|
329
|
+
summary: str # summarize_changes string (empty for bridges)
|
|
330
|
+
kind: str = "seg" # "seg" or "bridge"
|
|
331
|
+
seg_index: int = 0 # 1-based index along final MEP (assigned later)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def _run_gsm_between(
|
|
335
|
+
gA,
|
|
336
|
+
gB,
|
|
337
|
+
shared_calc,
|
|
338
|
+
gs_cfg: Dict[str, Any],
|
|
339
|
+
stopt_cfg: Dict[str, Any],
|
|
340
|
+
out_dir: Path,
|
|
341
|
+
tag: str,
|
|
342
|
+
ref_pdb_path: Optional[Path], # reference PDB for conversion
|
|
343
|
+
) -> GSMResult:
|
|
344
|
+
"""
|
|
345
|
+
Run GSM between `gA`–`gB`, save segment outputs, and return images/energies/HEI index.
|
|
346
|
+
"""
|
|
347
|
+
# Attach calculator to endpoints
|
|
348
|
+
for g in (gA, gB):
|
|
349
|
+
g.set_calculator(shared_calc)
|
|
350
|
+
|
|
351
|
+
gs = GrowingString(
|
|
352
|
+
images=[gA, gB],
|
|
353
|
+
calc_getter=(lambda: shared_calc),
|
|
354
|
+
**gs_cfg,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
_opt_args = dict(stopt_cfg)
|
|
358
|
+
seg_dir = out_dir / f"{tag}_mep"
|
|
359
|
+
seg_dir.mkdir(parents=True, exist_ok=True)
|
|
360
|
+
_opt_args["out_dir"] = str(seg_dir)
|
|
361
|
+
|
|
362
|
+
optimizer = StringOptimizer(
|
|
363
|
+
geometry=gs,
|
|
364
|
+
**{k: v for k, v in _opt_args.items() if k != "type"}
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
click.echo(f"\n=== [{tag}] GSM started ===\n")
|
|
368
|
+
optimizer.run()
|
|
369
|
+
click.echo(f"\n=== [{tag}] GSM finished ===\n")
|
|
370
|
+
|
|
371
|
+
energies = list(map(float, np.array(gs.energy, dtype=float)))
|
|
372
|
+
images = list(gs.images)
|
|
373
|
+
|
|
374
|
+
# Choose HEI: prefer internal local maxima; fallback to highest internal node
|
|
375
|
+
E = np.array(energies, dtype=float)
|
|
376
|
+
nE = len(E)
|
|
377
|
+
local_max_candidates = [i for i in range(1, nE - 1) if (E[i] > E[i - 1] and E[i] > E[i + 1])]
|
|
378
|
+
if local_max_candidates:
|
|
379
|
+
hei_idx = int(max(local_max_candidates, key=lambda i: E[i]))
|
|
380
|
+
else:
|
|
381
|
+
hei_idx = int(np.argmax(E[1:-1])) + 1 if nE >= 3 else int(np.argmax(E))
|
|
382
|
+
|
|
383
|
+
# Write trajectory
|
|
384
|
+
final_trj = seg_dir / "final_geometries_trj.xyz"
|
|
385
|
+
wrote_with_energy = True
|
|
386
|
+
try:
|
|
387
|
+
_write_xyz_trj_with_energy(images, energies, final_trj)
|
|
388
|
+
click.echo(f"[{tag}] Wrote '{final_trj}'.")
|
|
389
|
+
except Exception:
|
|
390
|
+
wrote_with_energy = False
|
|
391
|
+
with open(final_trj, "w") as f:
|
|
392
|
+
f.write(gs.as_xyz())
|
|
393
|
+
click.echo(f"[{tag}] Wrote '{final_trj}'.")
|
|
394
|
+
|
|
395
|
+
# Energy plot for the segment
|
|
396
|
+
try:
|
|
397
|
+
if wrote_with_energy:
|
|
398
|
+
run_trj2fig(final_trj, [seg_dir / "mep_plot.png"], unit="kcal", reference="init", reverse_x=False)
|
|
399
|
+
click.echo(f"[{tag}] Saved energy plot → '{seg_dir / 'mep_plot.png'}'")
|
|
400
|
+
else:
|
|
401
|
+
click.echo(f"[{tag}] WARNING: Energies missing; skipping plot.", err=True)
|
|
402
|
+
except Exception as e:
|
|
403
|
+
click.echo(f"[{tag}] WARNING: Failed to plot energy: {e}", err=True)
|
|
404
|
+
|
|
405
|
+
# If PDB input exists, convert intermediate _trj.xyz to PDB
|
|
406
|
+
_maybe_convert_to_pdb(final_trj, ref_pdb_path, seg_dir / "final_geometries.pdb")
|
|
407
|
+
|
|
408
|
+
# Write HEI structure (XYZ with energy in line 2)
|
|
409
|
+
try:
|
|
410
|
+
hei_geom = images[hei_idx]
|
|
411
|
+
hei_E = float(E[hei_idx])
|
|
412
|
+
hei_xyz = seg_dir / "hei.xyz"
|
|
413
|
+
s = hei_geom.as_xyz()
|
|
414
|
+
lines = s.splitlines()
|
|
415
|
+
if len(lines) >= 2 and lines[0].strip().isdigit():
|
|
416
|
+
lines[1] = f"{hei_E:.12f}"
|
|
417
|
+
s_out = "\n".join(lines)
|
|
418
|
+
if not s_out.endswith("\n"):
|
|
419
|
+
s_out += "\n"
|
|
420
|
+
else:
|
|
421
|
+
s_out = s if s.endswith("\n") else (s + "\n")
|
|
422
|
+
with open(hei_xyz, "w") as f:
|
|
423
|
+
f.write(s_out)
|
|
424
|
+
click.echo(f"[{tag}] Wrote '{hei_xyz}'.")
|
|
425
|
+
_maybe_convert_to_pdb(hei_xyz, ref_pdb_path, seg_dir / "hei.pdb")
|
|
426
|
+
except Exception as e:
|
|
427
|
+
click.echo(f"[{tag}] WARNING: Failed to write HEI structure: {e}", err=True)
|
|
428
|
+
|
|
429
|
+
return GSMResult(images=images, energies=energies, hei_idx=hei_idx)
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def _run_dmf_between(
|
|
433
|
+
gA,
|
|
434
|
+
gB,
|
|
435
|
+
shared_calc,
|
|
436
|
+
calc_cfg: Dict[str, Any],
|
|
437
|
+
out_dir: Path,
|
|
438
|
+
tag: str,
|
|
439
|
+
ref_pdb_path: Optional[Path],
|
|
440
|
+
max_nodes: int,
|
|
441
|
+
dmf_cfg: Dict[str, Any],
|
|
442
|
+
) -> GSMResult:
|
|
443
|
+
"""Run DMF for a segment and convert outputs to pysisyphus Geometries."""
|
|
444
|
+
from pysisyphus.constants import ANG2BOHR
|
|
445
|
+
from ase.io import read as ase_read, write as ase_write
|
|
446
|
+
from io import StringIO
|
|
447
|
+
|
|
448
|
+
seg_dir = out_dir / f"{tag}_mep"
|
|
449
|
+
seg_dir.mkdir(parents=True, exist_ok=True)
|
|
450
|
+
|
|
451
|
+
fix_atoms: List[int] = []
|
|
452
|
+
try:
|
|
453
|
+
fix_atoms = sorted(
|
|
454
|
+
{int(i) for g in [gA, gB] for i in getattr(g, "freeze_atoms", [])}
|
|
455
|
+
)
|
|
456
|
+
except Exception:
|
|
457
|
+
logger.debug("Failed to extract freeze_atoms from endpoints", exc_info=True)
|
|
458
|
+
|
|
459
|
+
# Convert pysisyphus geometries to ASE Atoms for DMF
|
|
460
|
+
def _geom_to_ase(g):
|
|
461
|
+
return ase_read(StringIO(g.as_xyz()), format="xyz")
|
|
462
|
+
|
|
463
|
+
geoms_for_dmf = [gA, gB]
|
|
464
|
+
|
|
465
|
+
try:
|
|
466
|
+
from ase.calculators.mixing import SumCalculator
|
|
467
|
+
from dmf import DirectMaxFlux, interpolate_fbenm
|
|
468
|
+
except Exception as e:
|
|
469
|
+
raise RuntimeError(f"DMF mode requires pydmf and cyipopt: {e}") from e
|
|
470
|
+
|
|
471
|
+
from .harmonic_constraints import HarmonicFixAtoms
|
|
472
|
+
from .utils import deep_update, convert_xyz_to_pdb
|
|
473
|
+
|
|
474
|
+
ref_images = [_geom_to_ase(g) for g in geoms_for_dmf]
|
|
475
|
+
charge = int(calc_cfg.get("model_charge", 0))
|
|
476
|
+
spin = int(calc_cfg.get("model_mult", 1))
|
|
477
|
+
for img in ref_images:
|
|
478
|
+
img.info["charge"] = charge
|
|
479
|
+
img.info["spin"] = spin
|
|
480
|
+
|
|
481
|
+
# Build ASE calculator from the shared PySisyphus calculator
|
|
482
|
+
ase_calc = MLMMASECalculator(core=shared_calc.core)
|
|
483
|
+
|
|
484
|
+
dmf_cfg_local = deep_update(dict(DMF_KW), dmf_cfg)
|
|
485
|
+
fbenm_opts = dict(dmf_cfg_local.get("fbenm_options", {}))
|
|
486
|
+
cfbenm_opts = dict(dmf_cfg_local.get("cfbenm_options", {}))
|
|
487
|
+
dmf_opts = dict(dmf_cfg_local.get("dmf_options", {}))
|
|
488
|
+
update_teval = bool(dmf_opts.pop("update_teval", False))
|
|
489
|
+
k_fix = float(dmf_cfg_local.get("k_fix", 300.0))
|
|
490
|
+
|
|
491
|
+
mxflx_fbenm = interpolate_fbenm(
|
|
492
|
+
ref_images,
|
|
493
|
+
nmove=max(1, int(max_nodes)),
|
|
494
|
+
fbenm_only_endpoints=bool(dmf_cfg_local.get("fbenm_only_endpoints", False)),
|
|
495
|
+
correlated=bool(dmf_cfg_local.get("correlated", False)),
|
|
496
|
+
sequential=bool(dmf_cfg_local.get("sequential", False)),
|
|
497
|
+
output_file=str(seg_dir / "dmf_fbenm_ipopt.out"),
|
|
498
|
+
fbenm_options=fbenm_opts,
|
|
499
|
+
cfbenm_options=cfbenm_opts,
|
|
500
|
+
dmf_options=dmf_opts,
|
|
501
|
+
)
|
|
502
|
+
coefs = mxflx_fbenm.coefs.copy()
|
|
503
|
+
|
|
504
|
+
mxflx = DirectMaxFlux(
|
|
505
|
+
ref_images,
|
|
506
|
+
coefs=coefs,
|
|
507
|
+
nmove=max(1, int(max_nodes)),
|
|
508
|
+
update_teval=update_teval,
|
|
509
|
+
remove_rotation_and_translation=bool(dmf_opts.get("remove_rotation_and_translation", False)),
|
|
510
|
+
mass_weighted=bool(dmf_opts.get("mass_weighted", False)),
|
|
511
|
+
parallel=bool(dmf_opts.get("parallel", False)),
|
|
512
|
+
eps_vel=float(dmf_opts.get("eps_vel", 0.01)),
|
|
513
|
+
eps_rot=float(dmf_opts.get("eps_rot", 0.01)),
|
|
514
|
+
beta=float(dmf_opts.get("beta", 10.0)),
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
for image in mxflx.images:
|
|
518
|
+
if "charge" not in image.info:
|
|
519
|
+
image.info["charge"] = charge
|
|
520
|
+
if "spin" not in image.info:
|
|
521
|
+
image.info["spin"] = spin
|
|
522
|
+
if fix_atoms:
|
|
523
|
+
ref_positions = image.get_positions()[fix_atoms]
|
|
524
|
+
harmonic_calc = HarmonicFixAtoms(indices=fix_atoms, ref_positions=ref_positions, k_fix=k_fix)
|
|
525
|
+
image.calc = SumCalculator([ase_calc, harmonic_calc])
|
|
526
|
+
else:
|
|
527
|
+
image.calc = ase_calc
|
|
528
|
+
|
|
529
|
+
mxflx.add_ipopt_options({"output_file": str(seg_dir / "dmf_ipopt.out")})
|
|
530
|
+
max_cycles_dmf = dmf_cfg_local.get("max_cycles")
|
|
531
|
+
if max_cycles_dmf is not None:
|
|
532
|
+
try:
|
|
533
|
+
max_iter = int(max_cycles_dmf)
|
|
534
|
+
if max_iter > 0:
|
|
535
|
+
mxflx.add_ipopt_options({"max_iter": max_iter})
|
|
536
|
+
except Exception:
|
|
537
|
+
logger.debug("Failed to set ipopt max_iter option", exc_info=True)
|
|
538
|
+
|
|
539
|
+
click.echo(f"\n=== [{tag}] DMF started ===\n")
|
|
540
|
+
mxflx.solve(tol="tight")
|
|
541
|
+
click.echo(f"\n=== [{tag}] DMF finished ===\n")
|
|
542
|
+
|
|
543
|
+
# Evaluate energies using PySisyphus calculator
|
|
544
|
+
energies = []
|
|
545
|
+
for image in mxflx.images:
|
|
546
|
+
elems = image.get_chemical_symbols()
|
|
547
|
+
coords_bohr = np.asarray(image.get_positions(), dtype=float).reshape(-1, 3) * ANG2BOHR
|
|
548
|
+
energies.append(float(shared_calc.get_energy(elems, coords_bohr)["energy"]))
|
|
549
|
+
hei_idx = _select_hei_index(energies)
|
|
550
|
+
|
|
551
|
+
# Write trajectory
|
|
552
|
+
final_trj = seg_dir / "final_geometries_trj.xyz"
|
|
553
|
+
_write_xyz_trj_with_energy_from_ase(mxflx.images, energies, final_trj)
|
|
554
|
+
click.echo(f"[{tag}] Wrote '{final_trj}'.")
|
|
555
|
+
_maybe_convert_to_pdb(final_trj, ref_pdb_path, seg_dir / "final_geometries.pdb")
|
|
556
|
+
|
|
557
|
+
try:
|
|
558
|
+
run_trj2fig(final_trj, [seg_dir / "mep_plot.png"], unit="kcal", reference="init", reverse_x=False)
|
|
559
|
+
click.echo(f"[{tag}] Saved energy plot → '{seg_dir / 'mep_plot.png'}'")
|
|
560
|
+
except Exception as e:
|
|
561
|
+
click.echo(f"[{tag}] WARNING: Failed to plot energy: {e}", err=True)
|
|
562
|
+
|
|
563
|
+
# Convert ASE images back to pysisyphus Geometries
|
|
564
|
+
from pysisyphus.helpers import geom_loader as gl
|
|
565
|
+
imgs = []
|
|
566
|
+
for atoms in mxflx.images:
|
|
567
|
+
buf = StringIO()
|
|
568
|
+
ase_write(buf, atoms, format="xyz")
|
|
569
|
+
buf.seek(0)
|
|
570
|
+
# Write temp xyz and load as geom
|
|
571
|
+
tmp_xyz = seg_dir / f"_tmp_dmf_{len(imgs)}.xyz"
|
|
572
|
+
with open(tmp_xyz, "w") as f:
|
|
573
|
+
f.write(buf.getvalue())
|
|
574
|
+
g = gl(tmp_xyz, coord_type=gA.coord_type)
|
|
575
|
+
try:
|
|
576
|
+
g.freeze_atoms = np.array(getattr(gA, "freeze_atoms", []), dtype=int)
|
|
577
|
+
except Exception:
|
|
578
|
+
logger.debug("Failed to set freeze_atoms on interpolated image", exc_info=True)
|
|
579
|
+
g.set_calculator(shared_calc)
|
|
580
|
+
imgs.append(g)
|
|
581
|
+
tmp_xyz.unlink(missing_ok=True)
|
|
582
|
+
|
|
583
|
+
return GSMResult(images=imgs, energies=energies, hei_idx=hei_idx)
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
def _write_xyz_trj_with_energy_from_ase(images, energies, path: Path) -> None:
|
|
587
|
+
"""Write an ASE Atoms list with energies as an XYZ trajectory."""
|
|
588
|
+
from ase.io import write as ase_write
|
|
589
|
+
from io import StringIO
|
|
590
|
+
blocks = []
|
|
591
|
+
for atoms, E in zip(images, energies):
|
|
592
|
+
buf = StringIO()
|
|
593
|
+
ase_write(buf, atoms, format="xyz")
|
|
594
|
+
s = buf.getvalue()
|
|
595
|
+
lines = s.splitlines()
|
|
596
|
+
if len(lines) >= 2 and lines[0].strip().isdigit():
|
|
597
|
+
lines[1] = f"{E:.12f}"
|
|
598
|
+
blocks.append("\n".join(lines) + "\n")
|
|
599
|
+
with open(path, "w") as f:
|
|
600
|
+
f.write("".join(blocks))
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def _run_mep_between(
|
|
604
|
+
gA,
|
|
605
|
+
gB,
|
|
606
|
+
shared_calc,
|
|
607
|
+
gs_cfg: Dict[str, Any],
|
|
608
|
+
stopt_cfg: Dict[str, Any],
|
|
609
|
+
out_dir: Path,
|
|
610
|
+
tag: str,
|
|
611
|
+
ref_pdb_path: Optional[Path],
|
|
612
|
+
mep_mode_kind: str = "gsm",
|
|
613
|
+
calc_cfg: Optional[Dict[str, Any]] = None,
|
|
614
|
+
max_nodes: int = 10,
|
|
615
|
+
dmf_cfg: Optional[Dict[str, Any]] = None,
|
|
616
|
+
) -> GSMResult:
|
|
617
|
+
"""Dispatcher: run GSM or DMF between two geometries."""
|
|
618
|
+
if mep_mode_kind == "dmf":
|
|
619
|
+
return _run_dmf_between(
|
|
620
|
+
gA, gB, shared_calc,
|
|
621
|
+
calc_cfg=calc_cfg or {},
|
|
622
|
+
out_dir=out_dir, tag=tag,
|
|
623
|
+
ref_pdb_path=ref_pdb_path,
|
|
624
|
+
max_nodes=max_nodes,
|
|
625
|
+
dmf_cfg=dmf_cfg or dict(DMF_KW),
|
|
626
|
+
)
|
|
627
|
+
return _run_gsm_between(gA, gB, shared_calc, gs_cfg, stopt_cfg, out_dir, tag=tag, ref_pdb_path=ref_pdb_path)
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
def _optimize_single(
|
|
631
|
+
g,
|
|
632
|
+
shared_calc,
|
|
633
|
+
lbfgs_cfg: Dict[str, Any],
|
|
634
|
+
out_dir: Path,
|
|
635
|
+
tag: str,
|
|
636
|
+
ref_pdb_path: Optional[Path], # for PDB conversion
|
|
637
|
+
):
|
|
638
|
+
"""
|
|
639
|
+
Run single-structure optimization (LBFGS) and return the final Geometry.
|
|
640
|
+
"""
|
|
641
|
+
g.set_calculator(shared_calc)
|
|
642
|
+
|
|
643
|
+
seg_dir = out_dir / f"{tag}_lbfgs_opt"
|
|
644
|
+
seg_dir.mkdir(parents=True, exist_ok=True)
|
|
645
|
+
args = dict(lbfgs_cfg)
|
|
646
|
+
args["out_dir"] = str(seg_dir)
|
|
647
|
+
|
|
648
|
+
opt = LBFGS(g, **args)
|
|
649
|
+
|
|
650
|
+
click.echo(f"\n=== [{tag}] Single-structure LBFGS started ===\n")
|
|
651
|
+
opt.run()
|
|
652
|
+
click.echo(f"\n=== [{tag}] Single-structure LBFGS finished ===\n")
|
|
653
|
+
|
|
654
|
+
try:
|
|
655
|
+
final_xyz = Path(opt.final_fn) if isinstance(opt.final_fn, (str, Path)) else Path(opt.final_fn)
|
|
656
|
+
_maybe_convert_to_pdb(final_xyz, ref_pdb_path)
|
|
657
|
+
g_final = geom_loader(final_xyz, coord_type=g.coord_type)
|
|
658
|
+
try:
|
|
659
|
+
g_final.freeze_atoms = np.array(getattr(g, "freeze_atoms", []), dtype=int)
|
|
660
|
+
except Exception:
|
|
661
|
+
logger.debug("Failed to set freeze_atoms on final geometry", exc_info=True)
|
|
662
|
+
g_final.set_calculator(shared_calc)
|
|
663
|
+
return g_final
|
|
664
|
+
except Exception:
|
|
665
|
+
return g
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
def _refine_between(
|
|
669
|
+
gL,
|
|
670
|
+
gR,
|
|
671
|
+
shared_calc,
|
|
672
|
+
gs_cfg: Dict[str, Any],
|
|
673
|
+
stopt_cfg: Dict[str, Any],
|
|
674
|
+
out_dir: Path,
|
|
675
|
+
tag: str,
|
|
676
|
+
ref_pdb_path: Optional[Path], # for PDB conversion
|
|
677
|
+
mep_mode_kind: str = "gsm",
|
|
678
|
+
calc_cfg: Optional[Dict[str, Any]] = None,
|
|
679
|
+
max_nodes: int = 10,
|
|
680
|
+
dmf_cfg: Optional[Dict[str, Any]] = None,
|
|
681
|
+
) -> GSMResult:
|
|
682
|
+
"""
|
|
683
|
+
Refine End1–End2 via GSM or DMF (force climb=True for GSM).
|
|
684
|
+
"""
|
|
685
|
+
gs_refine_cfg = {**gs_cfg, "climb": True, "climb_lanczos": True}
|
|
686
|
+
return _run_mep_between(
|
|
687
|
+
gL, gR, shared_calc, gs_refine_cfg, stopt_cfg, out_dir, tag=f"{tag}_refine",
|
|
688
|
+
ref_pdb_path=ref_pdb_path, mep_mode_kind=mep_mode_kind,
|
|
689
|
+
calc_cfg=calc_cfg, max_nodes=max_nodes, dmf_cfg=dmf_cfg,
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def _maybe_bridge_segments(
|
|
694
|
+
tail_g,
|
|
695
|
+
head_g,
|
|
696
|
+
shared_calc,
|
|
697
|
+
gs_cfg: Dict[str, Any], # bridge-specific GS config
|
|
698
|
+
stopt_cfg: Dict[str, Any],
|
|
699
|
+
out_dir: Path,
|
|
700
|
+
tag: str,
|
|
701
|
+
rmsd_thresh: float,
|
|
702
|
+
ref_pdb_path: Optional[Path], # for PDB conversion
|
|
703
|
+
mep_mode_kind: str = "gsm",
|
|
704
|
+
calc_cfg: Optional[Dict[str, Any]] = None,
|
|
705
|
+
max_nodes: int = 5,
|
|
706
|
+
dmf_cfg: Optional[Dict[str, Any]] = None,
|
|
707
|
+
) -> Optional[GSMResult]:
|
|
708
|
+
"""
|
|
709
|
+
Run a bridge GSM/DMF if two segment endpoints are farther than the threshold.
|
|
710
|
+
"""
|
|
711
|
+
rmsd = _kabsch_rmsd(np.array(tail_g.coords3d), np.array(head_g.coords3d), align=False)
|
|
712
|
+
if rmsd <= rmsd_thresh:
|
|
713
|
+
return None
|
|
714
|
+
click.echo(f"[{tag}] Gap detected between segments (RMSD={rmsd:.4e} Å) — bridging via {mep_mode_kind.upper()}.")
|
|
715
|
+
return _run_mep_between(
|
|
716
|
+
tail_g, head_g, shared_calc, gs_cfg, stopt_cfg, out_dir, tag=f"{tag}_bridge",
|
|
717
|
+
ref_pdb_path=ref_pdb_path, mep_mode_kind=mep_mode_kind,
|
|
718
|
+
calc_cfg=calc_cfg, max_nodes=max_nodes, dmf_cfg=dmf_cfg,
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def _stitch_paths(
|
|
723
|
+
parts: List[Tuple[List[Any], List[float]]],
|
|
724
|
+
stitch_rmsd_thresh: float,
|
|
725
|
+
bridge_rmsd_thresh: float,
|
|
726
|
+
shared_calc,
|
|
727
|
+
gs_cfg, # GS config for bridges (climb=False, max_nodes=search.max_nodes_bridge)
|
|
728
|
+
stopt_cfg,
|
|
729
|
+
out_dir: Path,
|
|
730
|
+
tag: str,
|
|
731
|
+
ref_pdb_path: Optional[Path], # for PDB conversion
|
|
732
|
+
bond_cfg: Optional[Dict[str, Any]] = None, # detect bond changes between adjacent parts
|
|
733
|
+
segment_builder: Optional[Callable[[Any, Any, str], "CombinedPath"]] = None, # builds a recursive segment
|
|
734
|
+
segments_out: Optional[List["SegmentReport"]] = None, # append inserted segment summaries in order
|
|
735
|
+
bridge_pair_index: Optional[int] = None, # pair index to tag bridge frames across pairs
|
|
736
|
+
mep_mode_kind: str = "gsm",
|
|
737
|
+
calc_cfg: Optional[Dict[str, Any]] = None,
|
|
738
|
+
dmf_cfg: Optional[Dict[str, Any]] = None,
|
|
739
|
+
) -> Tuple[List[Any], List[float]]:
|
|
740
|
+
"""
|
|
741
|
+
Concatenate path parts (images, energies). Insert bridge GSMs when needed.
|
|
742
|
+
If covalent changes are detected across an interface, build and insert a *new* recursive segment
|
|
743
|
+
using `segment_builder` instead of bridging. Update `segments_out` accordingly.
|
|
744
|
+
"""
|
|
745
|
+
all_imgs: List[Any] = []
|
|
746
|
+
all_E: List[float] = []
|
|
747
|
+
|
|
748
|
+
def _last_known_seg_tag_from_images(imgs: List[Any]) -> Optional[str]:
|
|
749
|
+
for im in reversed(imgs):
|
|
750
|
+
t = getattr(im, "mep_seg_tag", None)
|
|
751
|
+
if t:
|
|
752
|
+
return t
|
|
753
|
+
return None
|
|
754
|
+
|
|
755
|
+
def _first_known_seg_tag_from_images(imgs: List[Any]) -> Optional[str]:
|
|
756
|
+
for im in imgs:
|
|
757
|
+
t = getattr(im, "mep_seg_tag", None)
|
|
758
|
+
if t:
|
|
759
|
+
return t
|
|
760
|
+
return None
|
|
761
|
+
|
|
762
|
+
def append_part(imgs: List[Any], Es: List[float]) -> None:
|
|
763
|
+
nonlocal all_imgs, all_E
|
|
764
|
+
if not imgs:
|
|
765
|
+
return
|
|
766
|
+
if not all_imgs:
|
|
767
|
+
all_imgs.extend(imgs)
|
|
768
|
+
all_E.extend(Es)
|
|
769
|
+
return
|
|
770
|
+
tail = all_imgs[-1]
|
|
771
|
+
head = imgs[0]
|
|
772
|
+
|
|
773
|
+
adj_changed, adj_summary = False, ""
|
|
774
|
+
if segment_builder is not None and bond_cfg is not None:
|
|
775
|
+
try:
|
|
776
|
+
adj_changed, adj_summary = _has_bond_change(tail, head, bond_cfg)
|
|
777
|
+
except Exception:
|
|
778
|
+
adj_changed, adj_summary = False, ""
|
|
779
|
+
|
|
780
|
+
if adj_changed and segment_builder is not None:
|
|
781
|
+
click.echo(f"[{tag}] Covalent changes detected at interface — inserting a new recursive segment.")
|
|
782
|
+
if adj_summary:
|
|
783
|
+
click.echo(textwrap.indent(adj_summary, prefix=" "))
|
|
784
|
+
sub = segment_builder(tail, head, f"{tag}_mid")
|
|
785
|
+
seg_imgs, seg_E = sub.images, sub.energies
|
|
786
|
+
if segments_out is not None and getattr(sub, "segments", None):
|
|
787
|
+
segments_out.extend(sub.segments)
|
|
788
|
+
if seg_imgs:
|
|
789
|
+
if _kabsch_rmsd(np.array(all_imgs[-1].coords3d), np.array(seg_imgs[0].coords3d), align=False) <= stitch_rmsd_thresh:
|
|
790
|
+
seg_imgs = seg_imgs[1:]
|
|
791
|
+
seg_E = seg_E[1:]
|
|
792
|
+
all_imgs.extend(seg_imgs)
|
|
793
|
+
all_E.extend(seg_E)
|
|
794
|
+
if _kabsch_rmsd(np.array(all_imgs[-1].coords3d), np.array(imgs[0].coords3d), align=False) <= stitch_rmsd_thresh:
|
|
795
|
+
imgs = imgs[1:]
|
|
796
|
+
Es = Es[1:]
|
|
797
|
+
all_imgs.extend(imgs)
|
|
798
|
+
all_E.extend(Es)
|
|
799
|
+
return
|
|
800
|
+
|
|
801
|
+
rmsd = _kabsch_rmsd(np.array(tail.coords3d), np.array(head.coords3d), align=False)
|
|
802
|
+
if rmsd <= stitch_rmsd_thresh:
|
|
803
|
+
all_imgs.extend(imgs[1:])
|
|
804
|
+
all_E.extend(Es[1:])
|
|
805
|
+
elif rmsd > bridge_rmsd_thresh:
|
|
806
|
+
left_tag_recent = _last_known_seg_tag_from_images(all_imgs) or "segL"
|
|
807
|
+
right_tag_upcoming = _first_known_seg_tag_from_images(imgs) or "segR"
|
|
808
|
+
left_base = _segment_base_id(left_tag_recent)
|
|
809
|
+
right_base = _segment_base_id(right_tag_upcoming)
|
|
810
|
+
bridge_name_base = f"{left_base}_{right_base}"
|
|
811
|
+
|
|
812
|
+
br = _maybe_bridge_segments(
|
|
813
|
+
tail, head, shared_calc, gs_cfg, stopt_cfg, out_dir, tag=bridge_name_base,
|
|
814
|
+
rmsd_thresh=bridge_rmsd_thresh, ref_pdb_path=ref_pdb_path,
|
|
815
|
+
mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
|
|
816
|
+
)
|
|
817
|
+
if br is not None:
|
|
818
|
+
_tag_images(br.images, mep_seg_tag=f"{bridge_name_base}_bridge", mep_seg_kind="bridge",
|
|
819
|
+
mep_has_bond_changes=False, pair_index=bridge_pair_index)
|
|
820
|
+
b_imgs, b_E = br.images, br.energies
|
|
821
|
+
if _kabsch_rmsd(np.array(all_imgs[-1].coords3d), np.array(b_imgs[0].coords3d), align=False) <= stitch_rmsd_thresh:
|
|
822
|
+
b_imgs = b_imgs[1:]
|
|
823
|
+
b_E = b_E[1:]
|
|
824
|
+
if b_imgs:
|
|
825
|
+
all_imgs.extend(b_imgs)
|
|
826
|
+
all_E.extend(b_E)
|
|
827
|
+
|
|
828
|
+
if segments_out is not None:
|
|
829
|
+
try:
|
|
830
|
+
barrier_kcal = (max(br.energies) - br.energies[0]) * AU2KCALPERMOL
|
|
831
|
+
delta_kcal = (br.energies[-1] - br.energies[0]) * AU2KCALPERMOL
|
|
832
|
+
except Exception:
|
|
833
|
+
barrier_kcal = float("nan")
|
|
834
|
+
delta_kcal = float("nan")
|
|
835
|
+
bridge_report = SegmentReport(
|
|
836
|
+
tag=f"{bridge_name_base}_bridge",
|
|
837
|
+
barrier_kcal=float(barrier_kcal),
|
|
838
|
+
delta_kcal=float(delta_kcal),
|
|
839
|
+
summary="",
|
|
840
|
+
kind="bridge"
|
|
841
|
+
)
|
|
842
|
+
insert_pos: Optional[int] = None
|
|
843
|
+
try:
|
|
844
|
+
for j, sr in enumerate(segments_out):
|
|
845
|
+
if sr.tag == right_tag_upcoming:
|
|
846
|
+
insert_pos = j
|
|
847
|
+
break
|
|
848
|
+
except Exception:
|
|
849
|
+
insert_pos = None
|
|
850
|
+
if insert_pos is None:
|
|
851
|
+
segments_out.append(bridge_report)
|
|
852
|
+
else:
|
|
853
|
+
segments_out.insert(insert_pos, bridge_report)
|
|
854
|
+
|
|
855
|
+
if _kabsch_rmsd(np.array(all_imgs[-1].coords3d), np.array(imgs[0].coords3d), align=False) <= stitch_rmsd_thresh:
|
|
856
|
+
imgs = imgs[1:]
|
|
857
|
+
Es = Es[1:]
|
|
858
|
+
all_imgs.extend(imgs)
|
|
859
|
+
all_E.extend(Es)
|
|
860
|
+
else:
|
|
861
|
+
all_imgs.extend(imgs)
|
|
862
|
+
all_E.extend(Es)
|
|
863
|
+
|
|
864
|
+
for (imgs, Es) in parts:
|
|
865
|
+
append_part(imgs, Es)
|
|
866
|
+
|
|
867
|
+
return all_imgs, all_E
|
|
868
|
+
|
|
869
|
+
|
|
870
|
+
# -----------------------------------------------
|
|
871
|
+
# Recursive search (core)
|
|
872
|
+
# -----------------------------------------------
|
|
873
|
+
|
|
874
|
+
@dataclass
|
|
875
|
+
class CombinedPath:
|
|
876
|
+
images: List[Any]
|
|
877
|
+
energies: List[float]
|
|
878
|
+
segments: List[SegmentReport] # segment summaries in final output order
|
|
879
|
+
|
|
880
|
+
|
|
881
|
+
def _trailing_kink_count(segments: Sequence[SegmentReport]) -> int:
|
|
882
|
+
"""Return the number of consecutive kink segments at the end of *segments*."""
|
|
883
|
+
count = 0
|
|
884
|
+
for seg in reversed(segments):
|
|
885
|
+
if seg.tag and "kink" in seg.tag:
|
|
886
|
+
count += 1
|
|
887
|
+
else:
|
|
888
|
+
break
|
|
889
|
+
return count
|
|
890
|
+
|
|
891
|
+
|
|
892
|
+
def _build_multistep_path(
|
|
893
|
+
gA,
|
|
894
|
+
gB,
|
|
895
|
+
shared_calc,
|
|
896
|
+
geom_cfg: Dict[str, Any],
|
|
897
|
+
gs_cfg: Dict[str, Any],
|
|
898
|
+
stopt_cfg: Dict[str, Any],
|
|
899
|
+
single_opt_cfg: Dict[str, Any],
|
|
900
|
+
bond_cfg: Dict[str, Any],
|
|
901
|
+
search_cfg: Dict[str, Any],
|
|
902
|
+
refine_mode_kind: str,
|
|
903
|
+
out_dir: Path,
|
|
904
|
+
ref_pdb_path: Optional[Path],
|
|
905
|
+
depth: int,
|
|
906
|
+
seg_counter: List[int],
|
|
907
|
+
branch_tag: str,
|
|
908
|
+
pair_index: Optional[int] = None,
|
|
909
|
+
mep_mode_kind: str = "gsm",
|
|
910
|
+
calc_cfg: Optional[Dict[str, Any]] = None,
|
|
911
|
+
dmf_cfg: Optional[Dict[str, Any]] = None,
|
|
912
|
+
kink_seq_count: int = 0,
|
|
913
|
+
) -> CombinedPath:
|
|
914
|
+
"""
|
|
915
|
+
Recursively construct a multistep MEP from A–B and return it (A→B order).
|
|
916
|
+
"""
|
|
917
|
+
seg_max_nodes = int(search_cfg.get("max_nodes_segment", gs_cfg.get("max_nodes", 10)))
|
|
918
|
+
gs_seg_cfg = {**gs_cfg, "max_nodes": seg_max_nodes}
|
|
919
|
+
max_seq_kink = int(search_cfg.get("max_seq_kink", 2))
|
|
920
|
+
|
|
921
|
+
if depth > int(search_cfg.get("max_depth", 10)):
|
|
922
|
+
click.echo(f"[{branch_tag}] Reached maximum recursion depth. Returning current endpoints only.")
|
|
923
|
+
gsm = _run_mep_between(
|
|
924
|
+
gA, gB, shared_calc, gs_seg_cfg, stopt_cfg, out_dir, tag=f"seg_{seg_counter[0]:03d}_maxdepth",
|
|
925
|
+
ref_pdb_path=ref_pdb_path, mep_mode_kind=mep_mode_kind,
|
|
926
|
+
calc_cfg=calc_cfg, max_nodes=seg_max_nodes, dmf_cfg=dmf_cfg,
|
|
927
|
+
)
|
|
928
|
+
seg_counter[0] += 1
|
|
929
|
+
_tag_images(gsm.images, pair_index=pair_index)
|
|
930
|
+
return CombinedPath(images=gsm.images, energies=gsm.energies, segments=[])
|
|
931
|
+
|
|
932
|
+
seg_id = seg_counter[0]
|
|
933
|
+
seg_counter[0] += 1
|
|
934
|
+
tag0 = f"seg_{seg_id:03d}"
|
|
935
|
+
|
|
936
|
+
gs_seg_cfg_first = {**gs_seg_cfg, "climb": True, "climb_lanczos": True}
|
|
937
|
+
gsm0 = _run_mep_between(
|
|
938
|
+
gA, gB, shared_calc, gs_seg_cfg_first, stopt_cfg, out_dir, tag=tag0,
|
|
939
|
+
ref_pdb_path=ref_pdb_path, mep_mode_kind=mep_mode_kind,
|
|
940
|
+
calc_cfg=calc_cfg, max_nodes=seg_max_nodes, dmf_cfg=dmf_cfg,
|
|
941
|
+
)
|
|
942
|
+
|
|
943
|
+
hei = int(gsm0.hei_idx)
|
|
944
|
+
if not (1 <= hei <= len(gsm0.images) - 2):
|
|
945
|
+
click.echo(f"[{tag0}] WARNING: HEI is at an endpoint (idx={hei}). Returning the raw GSM path.")
|
|
946
|
+
_tag_images(gsm0.images, pair_index=pair_index)
|
|
947
|
+
return CombinedPath(images=gsm0.images, energies=gsm0.energies, segments=[])
|
|
948
|
+
|
|
949
|
+
if refine_mode_kind == "minima":
|
|
950
|
+
left_idx = _find_nearest_local_minimum(hei_idx=hei, direction=-1, energies=gsm0.energies)
|
|
951
|
+
right_idx = _find_nearest_local_minimum(hei_idx=hei, direction=1, energies=gsm0.energies)
|
|
952
|
+
if left_idx is None:
|
|
953
|
+
left_idx = hei - 1
|
|
954
|
+
if right_idx is None:
|
|
955
|
+
right_idx = hei + 1
|
|
956
|
+
click.echo(f"[{tag0}] Using nearest local minima around HEI (left idx={left_idx}, right idx={right_idx}).")
|
|
957
|
+
left_img = gsm0.images[left_idx]
|
|
958
|
+
right_img = gsm0.images[right_idx]
|
|
959
|
+
else:
|
|
960
|
+
left_img = gsm0.images[hei - 1]
|
|
961
|
+
right_img = gsm0.images[hei + 1]
|
|
962
|
+
click.echo(f"[{tag0}] Refining HEI±1 (peak mode).")
|
|
963
|
+
|
|
964
|
+
left_end = _optimize_single(left_img, shared_calc, single_opt_cfg, out_dir, tag=f"{tag0}_left", ref_pdb_path=ref_pdb_path)
|
|
965
|
+
right_end = _optimize_single(right_img, shared_calc, single_opt_cfg, out_dir, tag=f"{tag0}_right", ref_pdb_path=ref_pdb_path)
|
|
966
|
+
|
|
967
|
+
try:
|
|
968
|
+
lr_changed, lr_summary = _has_bond_change(left_end, right_end, bond_cfg)
|
|
969
|
+
except Exception as e:
|
|
970
|
+
click.echo(f"[{tag0}] WARNING: Failed to evaluate bond changes for kink detection: {e}", err=True)
|
|
971
|
+
lr_changed, lr_summary = True, ""
|
|
972
|
+
use_kink = (not lr_changed)
|
|
973
|
+
|
|
974
|
+
if use_kink:
|
|
975
|
+
n_inter = int(search_cfg.get("kink_max_nodes", 3))
|
|
976
|
+
click.echo(f"[{tag0}] Kink detected (no covalent changes between End1 and End2). "
|
|
977
|
+
f"Using {n_inter} linear interpolation nodes + single-structure optimizations instead of GSM.")
|
|
978
|
+
inter_geoms = _make_linear_interpolations(left_end, right_end, n_inter)
|
|
979
|
+
opt_inters: List[Any] = []
|
|
980
|
+
for i, g_int in enumerate(inter_geoms, 1):
|
|
981
|
+
g_int.set_calculator(shared_calc)
|
|
982
|
+
g_opt = _optimize_single(g_int, shared_calc, single_opt_cfg, out_dir, tag=f"{tag0}_kink_int{i}", ref_pdb_path=ref_pdb_path)
|
|
983
|
+
opt_inters.append(g_opt)
|
|
984
|
+
step_imgs = [left_end] + opt_inters + [right_end]
|
|
985
|
+
step_E = [float(img.energy) for img in step_imgs]
|
|
986
|
+
ref1 = GSMResult(images=step_imgs, energies=step_E, hei_idx=int(np.argmax(step_E)))
|
|
987
|
+
step_tag_for_report = f"{tag0}_kink"
|
|
988
|
+
else:
|
|
989
|
+
click.echo(f"[{tag0}] Kink not detected (covalent changes present between End1 and End2).")
|
|
990
|
+
if lr_summary:
|
|
991
|
+
click.echo(textwrap.indent(lr_summary, prefix=" "))
|
|
992
|
+
ref1 = _refine_between(left_end, right_end, shared_calc, gs_seg_cfg, stopt_cfg, out_dir, tag=tag0,
|
|
993
|
+
ref_pdb_path=ref_pdb_path, mep_mode_kind=mep_mode_kind,
|
|
994
|
+
calc_cfg=calc_cfg, max_nodes=seg_max_nodes, dmf_cfg=dmf_cfg)
|
|
995
|
+
step_tag_for_report = f"{tag0}_refine"
|
|
996
|
+
|
|
997
|
+
step_imgs, step_E = ref1.images, ref1.energies
|
|
998
|
+
|
|
999
|
+
_changed, step_summary = _has_bond_change(step_imgs[0], step_imgs[-1], bond_cfg)
|
|
1000
|
+
_tag_images(step_imgs, mep_seg_tag=step_tag_for_report, mep_seg_kind="seg",
|
|
1001
|
+
mep_has_bond_changes=bool(_changed), pair_index=pair_index)
|
|
1002
|
+
|
|
1003
|
+
left_changed, left_summary = _has_bond_change(gA, left_end, bond_cfg)
|
|
1004
|
+
right_changed, right_summary = _has_bond_change(right_end, gB, bond_cfg)
|
|
1005
|
+
|
|
1006
|
+
click.echo(f"[{tag0}] Covalent changes (A vs left_end): {'Yes' if left_changed else 'No'}")
|
|
1007
|
+
if left_changed:
|
|
1008
|
+
click.echo(textwrap.indent(left_summary, prefix=" "))
|
|
1009
|
+
click.echo(f"[{tag0}] Covalent changes (right_end vs B): {'Yes' if right_changed else 'No'}")
|
|
1010
|
+
if right_changed:
|
|
1011
|
+
click.echo(textwrap.indent(right_summary, prefix=" "))
|
|
1012
|
+
|
|
1013
|
+
try:
|
|
1014
|
+
barrier_kcal = (max(step_E) - step_E[0]) * AU2KCALPERMOL
|
|
1015
|
+
delta_kcal = (step_E[-1] - step_E[0]) * AU2KCALPERMOL
|
|
1016
|
+
except Exception:
|
|
1017
|
+
barrier_kcal = float("nan")
|
|
1018
|
+
delta_kcal = float("nan")
|
|
1019
|
+
|
|
1020
|
+
seg_report = SegmentReport(
|
|
1021
|
+
tag=step_tag_for_report,
|
|
1022
|
+
barrier_kcal=float(barrier_kcal),
|
|
1023
|
+
delta_kcal=float(delta_kcal),
|
|
1024
|
+
summary=step_summary if _changed else "(no covalent changes detected)",
|
|
1025
|
+
kind="seg"
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
parts: List[Tuple[List[Any], List[float]]] = []
|
|
1029
|
+
seg_reports: List[SegmentReport] = []
|
|
1030
|
+
|
|
1031
|
+
trailing_kink_run = kink_seq_count
|
|
1032
|
+
if left_changed:
|
|
1033
|
+
subL = _build_multistep_path(
|
|
1034
|
+
gA, left_end, shared_calc, geom_cfg, gs_cfg, stopt_cfg,
|
|
1035
|
+
single_opt_cfg, bond_cfg, search_cfg, refine_mode_kind,
|
|
1036
|
+
out_dir, ref_pdb_path, depth + 1, seg_counter, branch_tag=f"{branch_tag}L",
|
|
1037
|
+
pair_index=pair_index,
|
|
1038
|
+
mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
|
|
1039
|
+
kink_seq_count=kink_seq_count,
|
|
1040
|
+
)
|
|
1041
|
+
_tag_images(subL.images, pair_index=pair_index)
|
|
1042
|
+
parts.append((subL.images, subL.energies))
|
|
1043
|
+
seg_reports.extend(subL.segments)
|
|
1044
|
+
trailing_kink_run = _trailing_kink_count(seg_reports)
|
|
1045
|
+
|
|
1046
|
+
current_kink_run = trailing_kink_run + 1 if use_kink else 0
|
|
1047
|
+
if use_kink and current_kink_run >= max_seq_kink:
|
|
1048
|
+
warning_msg = (
|
|
1049
|
+
f"[{tag0}] Consecutive kink segments were detected. Something seems wrong. "
|
|
1050
|
+
"Please check the initial structure and the generated intermediate structures. "
|
|
1051
|
+
"Alternatively, try switching the mep-mode. If that still fails, try including intermediate structures in the inputs."
|
|
1052
|
+
)
|
|
1053
|
+
click.echo(warning_msg)
|
|
1054
|
+
gsm = _run_mep_between(
|
|
1055
|
+
gA, gB, shared_calc, gs_seg_cfg, stopt_cfg, out_dir, tag=f"seg_{seg_counter[0]:03d}_maxdepth",
|
|
1056
|
+
ref_pdb_path=ref_pdb_path, mep_mode_kind=mep_mode_kind,
|
|
1057
|
+
calc_cfg=calc_cfg, max_nodes=seg_max_nodes, dmf_cfg=dmf_cfg,
|
|
1058
|
+
)
|
|
1059
|
+
seg_counter[0] += 1
|
|
1060
|
+
_tag_images(gsm.images, pair_index=pair_index)
|
|
1061
|
+
return CombinedPath(images=gsm.images, energies=gsm.energies, segments=[])
|
|
1062
|
+
|
|
1063
|
+
parts.append((step_imgs, step_E))
|
|
1064
|
+
seg_reports.append(seg_report)
|
|
1065
|
+
|
|
1066
|
+
if right_changed:
|
|
1067
|
+
subR = _build_multistep_path(
|
|
1068
|
+
right_end, gB, shared_calc, geom_cfg, gs_cfg, stopt_cfg,
|
|
1069
|
+
single_opt_cfg, bond_cfg, search_cfg, refine_mode_kind,
|
|
1070
|
+
out_dir, ref_pdb_path, depth + 1, seg_counter, branch_tag=f"{branch_tag}R",
|
|
1071
|
+
pair_index=pair_index,
|
|
1072
|
+
mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
|
|
1073
|
+
kink_seq_count=current_kink_run,
|
|
1074
|
+
)
|
|
1075
|
+
_tag_images(subR.images, pair_index=pair_index)
|
|
1076
|
+
parts.append((subR.images, subR.energies))
|
|
1077
|
+
seg_reports.extend(subR.segments)
|
|
1078
|
+
|
|
1079
|
+
bridge_max_nodes = int(search_cfg.get("max_nodes_bridge", 5))
|
|
1080
|
+
gs_bridge_cfg = {**gs_cfg, "max_nodes": bridge_max_nodes, "climb": False, "climb_lanczos": False}
|
|
1081
|
+
|
|
1082
|
+
def _segment_builder(tail_g, head_g, _tag: str) -> CombinedPath:
|
|
1083
|
+
sub = _build_multistep_path(
|
|
1084
|
+
tail_g, head_g,
|
|
1085
|
+
shared_calc,
|
|
1086
|
+
geom_cfg, gs_cfg, stopt_cfg,
|
|
1087
|
+
single_opt_cfg,
|
|
1088
|
+
bond_cfg, search_cfg, refine_mode_kind,
|
|
1089
|
+
out_dir=out_dir,
|
|
1090
|
+
ref_pdb_path=ref_pdb_path,
|
|
1091
|
+
depth=depth + 1,
|
|
1092
|
+
seg_counter=seg_counter,
|
|
1093
|
+
branch_tag=f"{branch_tag}B",
|
|
1094
|
+
pair_index=pair_index,
|
|
1095
|
+
mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
|
|
1096
|
+
kink_seq_count=_trailing_kink_count(seg_reports),
|
|
1097
|
+
)
|
|
1098
|
+
_tag_images(sub.images, pair_index=pair_index)
|
|
1099
|
+
return sub
|
|
1100
|
+
|
|
1101
|
+
stitched_imgs, stitched_E = _stitch_paths(
|
|
1102
|
+
parts,
|
|
1103
|
+
stitch_rmsd_thresh=float(search_cfg["stitch_rmsd_thresh"]),
|
|
1104
|
+
bridge_rmsd_thresh=float(search_cfg["bridge_rmsd_thresh"]),
|
|
1105
|
+
shared_calc=shared_calc,
|
|
1106
|
+
gs_cfg=gs_bridge_cfg,
|
|
1107
|
+
stopt_cfg=stopt_cfg,
|
|
1108
|
+
out_dir=out_dir,
|
|
1109
|
+
tag=tag0,
|
|
1110
|
+
ref_pdb_path=ref_pdb_path,
|
|
1111
|
+
bond_cfg=bond_cfg,
|
|
1112
|
+
segment_builder=_segment_builder,
|
|
1113
|
+
segments_out=seg_reports,
|
|
1114
|
+
bridge_pair_index=pair_index,
|
|
1115
|
+
mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
_tag_images(stitched_imgs, pair_index=pair_index)
|
|
1119
|
+
|
|
1120
|
+
return CombinedPath(images=stitched_imgs, energies=stitched_E, segments=seg_reports)
|
|
1121
|
+
|
|
1122
|
+
|
|
1123
|
+
# -----------------------------------------------
|
|
1124
|
+
# CLI
|
|
1125
|
+
# -----------------------------------------------
|
|
1126
|
+
|
|
1127
|
+
@click.command(
|
|
1128
|
+
help="Multistep MEP search via recursive GSM segmentation.",
|
|
1129
|
+
context_settings={
|
|
1130
|
+
"help_option_names": ["-h", "--help"],
|
|
1131
|
+
"ignore_unknown_options": True,
|
|
1132
|
+
"allow_extra_args": True,
|
|
1133
|
+
},
|
|
1134
|
+
)
|
|
1135
|
+
@click.option(
|
|
1136
|
+
"-i", "--input",
|
|
1137
|
+
"input_paths",
|
|
1138
|
+
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
|
1139
|
+
multiple=True, # allow: -i A -i B -i C or -i A B C
|
|
1140
|
+
required=True,
|
|
1141
|
+
help=("Two or more structures in reaction order. "
|
|
1142
|
+
"Either repeat '-i' (e.g., '-i A -i B -i C') or use a single '-i' "
|
|
1143
|
+
"followed by multiple space-separated paths (e.g., '-i A B C').")
|
|
1144
|
+
)
|
|
1145
|
+
@click.option(
|
|
1146
|
+
"--parm",
|
|
1147
|
+
"real_parm7",
|
|
1148
|
+
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
|
1149
|
+
required=True,
|
|
1150
|
+
help="Amber parm7 topology covering the full enzyme complex.",
|
|
1151
|
+
)
|
|
1152
|
+
@click.option(
|
|
1153
|
+
"--model-pdb",
|
|
1154
|
+
"model_pdb",
|
|
1155
|
+
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
|
1156
|
+
required=False,
|
|
1157
|
+
help="PDB describing atoms that belong to the ML (high-level) region. "
|
|
1158
|
+
"Optional when --detect-layer is enabled.",
|
|
1159
|
+
)
|
|
1160
|
+
@click.option(
|
|
1161
|
+
"--model-indices",
|
|
1162
|
+
"model_indices_str",
|
|
1163
|
+
type=str,
|
|
1164
|
+
default=None,
|
|
1165
|
+
show_default=False,
|
|
1166
|
+
help="Comma-separated atom indices for the ML region (ranges allowed like 1-5). "
|
|
1167
|
+
"Used when --model-pdb is omitted.",
|
|
1168
|
+
)
|
|
1169
|
+
@click.option(
|
|
1170
|
+
"--model-indices-one-based/--model-indices-zero-based",
|
|
1171
|
+
"model_indices_one_based",
|
|
1172
|
+
default=True,
|
|
1173
|
+
show_default=True,
|
|
1174
|
+
help="Interpret --model-indices as 1-based (default) or 0-based.",
|
|
1175
|
+
)
|
|
1176
|
+
@click.option(
|
|
1177
|
+
"--detect-layer/--no-detect-layer",
|
|
1178
|
+
"detect_layer",
|
|
1179
|
+
default=True,
|
|
1180
|
+
show_default=True,
|
|
1181
|
+
help="Detect ML/MM layers from input PDB B-factors (B=0/10/20). "
|
|
1182
|
+
"If disabled, you must provide --model-pdb or --model-indices.",
|
|
1183
|
+
)
|
|
1184
|
+
@click.option(
|
|
1185
|
+
"-q",
|
|
1186
|
+
"--charge",
|
|
1187
|
+
type=int,
|
|
1188
|
+
required=False,
|
|
1189
|
+
help="Total system charge. Required unless --ligand-charge is provided.",
|
|
1190
|
+
)
|
|
1191
|
+
@click.option("-l", "--ligand-charge", type=str, default=None, show_default=False,
|
|
1192
|
+
help="Total charge or per-resname mapping (e.g., GPP:-3,SAM:1) used to derive "
|
|
1193
|
+
"charge when -q is omitted (requires PDB input or --ref-pdb).")
|
|
1194
|
+
@click.option(
|
|
1195
|
+
"-m",
|
|
1196
|
+
"--multiplicity",
|
|
1197
|
+
"spin",
|
|
1198
|
+
type=int,
|
|
1199
|
+
default=None,
|
|
1200
|
+
show_default=False,
|
|
1201
|
+
help="Spin multiplicity (2S+1). Defaults to 1 when omitted.",
|
|
1202
|
+
)
|
|
1203
|
+
@click.option(
|
|
1204
|
+
"--mep-mode",
|
|
1205
|
+
"mep_mode",
|
|
1206
|
+
type=click.Choice(["gsm", "dmf"], case_sensitive=False),
|
|
1207
|
+
default="gsm",
|
|
1208
|
+
show_default=True,
|
|
1209
|
+
help="MEP method: gsm (GrowingString) or dmf (Direct Max Flux).",
|
|
1210
|
+
)
|
|
1211
|
+
@click.option(
|
|
1212
|
+
"--refine-mode",
|
|
1213
|
+
type=click.Choice(["peak", "minima"], case_sensitive=False),
|
|
1214
|
+
default=None,
|
|
1215
|
+
show_default=True,
|
|
1216
|
+
help=(
|
|
1217
|
+
"Refinement seed around the highest-energy image: "
|
|
1218
|
+
"'peak' uses HEI±1, 'minima' uses nearest local minima. "
|
|
1219
|
+
"Defaults to peak for gsm and minima for dmf."
|
|
1220
|
+
),
|
|
1221
|
+
)
|
|
1222
|
+
@click.option(
|
|
1223
|
+
"--freeze-atoms",
|
|
1224
|
+
"freeze_atoms_text",
|
|
1225
|
+
type=str,
|
|
1226
|
+
default=None,
|
|
1227
|
+
show_default=False,
|
|
1228
|
+
help="Comma-separated 1-based atom indices to freeze (e.g., '1,3,5').",
|
|
1229
|
+
)
|
|
1230
|
+
@click.option(
|
|
1231
|
+
"--hess-cutoff",
|
|
1232
|
+
"hess_cutoff",
|
|
1233
|
+
type=float,
|
|
1234
|
+
default=None,
|
|
1235
|
+
show_default=False,
|
|
1236
|
+
help="Distance cutoff (Å) from ML region for MM atoms to include in Hessian calculation. "
|
|
1237
|
+
"Applied to movable MM atoms and can be combined with --detect-layer.",
|
|
1238
|
+
)
|
|
1239
|
+
@click.option(
|
|
1240
|
+
"--movable-cutoff",
|
|
1241
|
+
"movable_cutoff",
|
|
1242
|
+
type=float,
|
|
1243
|
+
default=None,
|
|
1244
|
+
show_default=False,
|
|
1245
|
+
help="Distance cutoff (Å) from ML region for movable MM atoms. MM atoms beyond this are frozen. "
|
|
1246
|
+
"Providing --movable-cutoff disables --detect-layer.",
|
|
1247
|
+
)
|
|
1248
|
+
@click.option("--max-nodes", type=int, default=10, show_default=True,
|
|
1249
|
+
help=("Number of internal nodes (string has max_nodes+2 images including endpoints). "
|
|
1250
|
+
"Used for *segment* GSM unless overridden by YAML search.max_nodes_segment."))
|
|
1251
|
+
@click.option("--max-cycles", type=int, default=300, show_default=True, help="Maximum GSM optimization cycles.")
|
|
1252
|
+
@click.option(
|
|
1253
|
+
"--climb/--no-climb",
|
|
1254
|
+
default=True,
|
|
1255
|
+
show_default=True,
|
|
1256
|
+
help="Enable transition-state search after path growth.",
|
|
1257
|
+
)
|
|
1258
|
+
@click.option(
|
|
1259
|
+
"--dump/--no-dump",
|
|
1260
|
+
default=False,
|
|
1261
|
+
show_default=True,
|
|
1262
|
+
help="Dump GSM/single-optimization trajectories during the run.",
|
|
1263
|
+
)
|
|
1264
|
+
@click.option(
|
|
1265
|
+
"--opt-mode",
|
|
1266
|
+
"opt_mode",
|
|
1267
|
+
type=click.Choice(["grad", "hess"], case_sensitive=False),
|
|
1268
|
+
default="grad",
|
|
1269
|
+
show_default=True,
|
|
1270
|
+
help="Single-structure optimizer: grad (=LBFGS) or hess (=RFO).",
|
|
1271
|
+
)
|
|
1272
|
+
@click.option("-o", "--out-dir", "out_dir", type=str, default="./result_path_search/", show_default=True, help="Output directory.")
|
|
1273
|
+
@click.option(
|
|
1274
|
+
"--thresh",
|
|
1275
|
+
type=click.Choice(["gau_loose", "gau", "gau_tight", "gau_vtight", "baker", "never"], case_sensitive=False),
|
|
1276
|
+
default=None,
|
|
1277
|
+
help="Convergence preset for GSM/StringOptimizer and single LBFGS runs.",
|
|
1278
|
+
)
|
|
1279
|
+
@click.option(
|
|
1280
|
+
"--config",
|
|
1281
|
+
"config_yaml",
|
|
1282
|
+
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
|
1283
|
+
default=None,
|
|
1284
|
+
help="Base YAML configuration file applied before explicit CLI options.",
|
|
1285
|
+
)
|
|
1286
|
+
@click.option(
|
|
1287
|
+
"--show-config/--no-show-config",
|
|
1288
|
+
"show_config",
|
|
1289
|
+
default=False,
|
|
1290
|
+
show_default=True,
|
|
1291
|
+
help="Print resolved configuration and continue execution.",
|
|
1292
|
+
)
|
|
1293
|
+
@click.option(
|
|
1294
|
+
"--dry-run/--no-dry-run",
|
|
1295
|
+
"dry_run",
|
|
1296
|
+
default=False,
|
|
1297
|
+
show_default=True,
|
|
1298
|
+
help="Validate options and print the execution plan without running path search.",
|
|
1299
|
+
)
|
|
1300
|
+
@click.option(
|
|
1301
|
+
"--preopt/--no-preopt",
|
|
1302
|
+
"pre_opt",
|
|
1303
|
+
default=True,
|
|
1304
|
+
show_default=True,
|
|
1305
|
+
help="If False, skip initial single-structure optimizations of inputs."
|
|
1306
|
+
)
|
|
1307
|
+
# Input alignment switch (default True)
|
|
1308
|
+
@click.option(
|
|
1309
|
+
"--align/--no-align",
|
|
1310
|
+
"align",
|
|
1311
|
+
default=True,
|
|
1312
|
+
show_default=True,
|
|
1313
|
+
help=("After pre-optimization, align all inputs to the *first* input and match freeze_atoms "
|
|
1314
|
+
"using the align_freeze_atoms API.")
|
|
1315
|
+
)
|
|
1316
|
+
# Full template PDBs for XYZ→PDB conversion and topology reference
|
|
1317
|
+
@click.option(
|
|
1318
|
+
"--ref-pdb",
|
|
1319
|
+
"ref_pdb_paths",
|
|
1320
|
+
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
|
1321
|
+
multiple=True,
|
|
1322
|
+
default=None,
|
|
1323
|
+
help=("Full-size template PDBs in the same reaction order as --input. "
|
|
1324
|
+
"Required when using XYZ inputs to provide topology and B-factor information.")
|
|
1325
|
+
)
|
|
1326
|
+
@click.option(
|
|
1327
|
+
"--convert-files/--no-convert-files",
|
|
1328
|
+
"convert_files",
|
|
1329
|
+
default=True,
|
|
1330
|
+
show_default=True,
|
|
1331
|
+
help="Convert XYZ/TRJ outputs into PDB companions based on the input format.",
|
|
1332
|
+
)
|
|
1333
|
+
@click.option(
|
|
1334
|
+
"-b", "--backend",
|
|
1335
|
+
type=click.Choice(["uma", "orb", "mace", "aimnet2"], case_sensitive=False),
|
|
1336
|
+
default=None,
|
|
1337
|
+
show_default=False,
|
|
1338
|
+
help="ML backend for the ONIOM high-level region (default: uma).",
|
|
1339
|
+
)
|
|
1340
|
+
@click.option(
|
|
1341
|
+
"--embedcharge/--no-embedcharge",
|
|
1342
|
+
"embedcharge",
|
|
1343
|
+
default=False,
|
|
1344
|
+
show_default=True,
|
|
1345
|
+
help="Enable xTB point-charge embedding correction for MM→ML environmental effects.",
|
|
1346
|
+
)
|
|
1347
|
+
@click.option(
|
|
1348
|
+
"--embedcharge-cutoff",
|
|
1349
|
+
"embedcharge_cutoff",
|
|
1350
|
+
type=float,
|
|
1351
|
+
default=None,
|
|
1352
|
+
show_default=False,
|
|
1353
|
+
help="Distance cutoff (Å) from ML region for MM point charges in xTB embedding. "
|
|
1354
|
+
"Default: 12.0 Å when --embedcharge is enabled.",
|
|
1355
|
+
)
|
|
1356
|
+
@click.pass_context
|
|
1357
|
+
def cli(
|
|
1358
|
+
ctx: click.Context,
|
|
1359
|
+
input_paths: Sequence[Path],
|
|
1360
|
+
real_parm7: Path,
|
|
1361
|
+
model_pdb: Optional[Path],
|
|
1362
|
+
model_indices_str: Optional[str],
|
|
1363
|
+
model_indices_one_based: bool,
|
|
1364
|
+
detect_layer: bool,
|
|
1365
|
+
charge: Optional[int],
|
|
1366
|
+
ligand_charge: Optional[str],
|
|
1367
|
+
spin: Optional[int],
|
|
1368
|
+
mep_mode: str,
|
|
1369
|
+
refine_mode: Optional[str],
|
|
1370
|
+
freeze_atoms_text: Optional[str],
|
|
1371
|
+
hess_cutoff: Optional[float],
|
|
1372
|
+
movable_cutoff: Optional[float],
|
|
1373
|
+
max_nodes: int,
|
|
1374
|
+
max_cycles: int,
|
|
1375
|
+
climb: bool,
|
|
1376
|
+
dump: bool,
|
|
1377
|
+
opt_mode: str,
|
|
1378
|
+
out_dir: str,
|
|
1379
|
+
thresh: Optional[str],
|
|
1380
|
+
config_yaml: Optional[Path],
|
|
1381
|
+
show_config: bool,
|
|
1382
|
+
dry_run: bool,
|
|
1383
|
+
pre_opt: bool,
|
|
1384
|
+
align: bool,
|
|
1385
|
+
ref_pdb_paths: Optional[Sequence[Path]],
|
|
1386
|
+
convert_files: bool,
|
|
1387
|
+
backend: Optional[str],
|
|
1388
|
+
embedcharge: bool,
|
|
1389
|
+
embedcharge_cutoff: Optional[float],
|
|
1390
|
+
) -> None:
|
|
1391
|
+
set_convert_file_enabled(convert_files)
|
|
1392
|
+
prepared_inputs: List[PreparedInputStructure] = []
|
|
1393
|
+
# --- Robustly accept both styles for -i/--input and --ref-pdb ---
|
|
1394
|
+
argv_all = sys.argv[1:] # drop program name
|
|
1395
|
+
i_vals = collect_single_option_values(argv_all, ("-i", "--input"), label="-i/--input")
|
|
1396
|
+
if i_vals:
|
|
1397
|
+
i_parsed = validate_existing_files(
|
|
1398
|
+
i_vals,
|
|
1399
|
+
option_name="-i/--input",
|
|
1400
|
+
hint="When using '-i', list only existing file paths (multiple paths may follow a single '-i').",
|
|
1401
|
+
)
|
|
1402
|
+
input_paths = tuple(i_parsed)
|
|
1403
|
+
|
|
1404
|
+
ref_vals = collect_single_option_values(argv_all, ("--ref-pdb",), label="--ref-pdb")
|
|
1405
|
+
if ref_vals:
|
|
1406
|
+
ref_parsed = validate_existing_files(
|
|
1407
|
+
ref_vals,
|
|
1408
|
+
option_name="--ref-pdb",
|
|
1409
|
+
hint="When using '--ref-pdb', multiple files may follow a single option.",
|
|
1410
|
+
)
|
|
1411
|
+
ref_pdb_paths = tuple(ref_parsed)
|
|
1412
|
+
# --- end of robust parsing fix ---
|
|
1413
|
+
|
|
1414
|
+
_is_param_explicit = make_is_param_explicit(ctx)
|
|
1415
|
+
|
|
1416
|
+
config_yaml, override_yaml, used_legacy_yaml = resolve_yaml_sources(
|
|
1417
|
+
config_yaml=config_yaml,
|
|
1418
|
+
override_yaml=None,
|
|
1419
|
+
args_yaml_legacy=None,
|
|
1420
|
+
)
|
|
1421
|
+
merged_yaml_cfg, _, _ = load_merged_yaml_cfg(
|
|
1422
|
+
config_yaml=config_yaml,
|
|
1423
|
+
override_yaml=None,
|
|
1424
|
+
)
|
|
1425
|
+
|
|
1426
|
+
time_start = time.perf_counter() # start timing
|
|
1427
|
+
command_str = "mlmm path-search " + " ".join(sys.argv[1:])
|
|
1428
|
+
try:
|
|
1429
|
+
# --------------------------
|
|
1430
|
+
# 0) Input validation (multi-structure)
|
|
1431
|
+
# --------------------------
|
|
1432
|
+
if len(input_paths) < 2:
|
|
1433
|
+
raise click.BadParameter("Provide at least two structures for --input in reaction order (reactant [intermediates ...] product).")
|
|
1434
|
+
|
|
1435
|
+
p_list = [Path(p) for p in input_paths]
|
|
1436
|
+
ref_list = list(ref_pdb_paths) if ref_pdb_paths else []
|
|
1437
|
+
prepared_inputs = []
|
|
1438
|
+
for i, p in enumerate(p_list):
|
|
1439
|
+
pi = prepare_input_structure(p)
|
|
1440
|
+
if p.suffix.lower() == ".xyz":
|
|
1441
|
+
if i < len(ref_list):
|
|
1442
|
+
apply_ref_pdb_override(pi, ref_list[i])
|
|
1443
|
+
else:
|
|
1444
|
+
raise click.BadParameter(
|
|
1445
|
+
f"XYZ input '{p}' requires a corresponding --ref-pdb for topology/B-factor info."
|
|
1446
|
+
)
|
|
1447
|
+
elif p.suffix.lower() != ".pdb":
|
|
1448
|
+
raise click.BadParameter(
|
|
1449
|
+
f"'{p}': unsupported format. Use .pdb or .xyz (with --ref-pdb)."
|
|
1450
|
+
)
|
|
1451
|
+
prepared_inputs.append(pi)
|
|
1452
|
+
# --------------------------
|
|
1453
|
+
# 1) Resolve settings (defaults < config < CLI(explicit) < override)
|
|
1454
|
+
# --------------------------
|
|
1455
|
+
config_layer_cfg = load_yaml_dict(config_yaml)
|
|
1456
|
+
override_layer_cfg = load_yaml_dict(override_yaml)
|
|
1457
|
+
|
|
1458
|
+
mep_mode_kind = mep_mode.lower().strip()
|
|
1459
|
+
refine_mode_kind = refine_mode.strip().lower() if refine_mode else None
|
|
1460
|
+
|
|
1461
|
+
geom_cfg = dict(GEOM_KW)
|
|
1462
|
+
calc_cfg = dict(CALC_KW)
|
|
1463
|
+
gs_cfg = dict(GS_KW)
|
|
1464
|
+
stopt_cfg = dict(STOPT_KW)
|
|
1465
|
+
lbfgs_cfg = dict(LBFGS_KW)
|
|
1466
|
+
bond_cfg = dict(BOND_KW)
|
|
1467
|
+
search_cfg = dict(SEARCH_KW)
|
|
1468
|
+
dmf_cfg = dict(DMF_KW)
|
|
1469
|
+
|
|
1470
|
+
apply_yaml_overrides(
|
|
1471
|
+
config_layer_cfg,
|
|
1472
|
+
[
|
|
1473
|
+
(geom_cfg, (("geom",),)),
|
|
1474
|
+
(calc_cfg, (("calc",), ("mlmm",))),
|
|
1475
|
+
(gs_cfg, (("gs",),)),
|
|
1476
|
+
(stopt_cfg, (("stopt",), ("opt",))),
|
|
1477
|
+
(lbfgs_cfg, (("stopt", "lbfgs"), ("lbfgs",))),
|
|
1478
|
+
(bond_cfg, (("bond",),)),
|
|
1479
|
+
(search_cfg, (("search",),)),
|
|
1480
|
+
(dmf_cfg, (("dmf",),)),
|
|
1481
|
+
],
|
|
1482
|
+
)
|
|
1483
|
+
|
|
1484
|
+
# CLI explicit overrides (after config YAML, before override YAML)
|
|
1485
|
+
if backend is not None:
|
|
1486
|
+
calc_cfg["backend"] = str(backend).lower()
|
|
1487
|
+
if _is_param_explicit("embedcharge"):
|
|
1488
|
+
calc_cfg["embedcharge"] = bool(embedcharge)
|
|
1489
|
+
if _is_param_explicit("embedcharge_cutoff"):
|
|
1490
|
+
calc_cfg["embedcharge_cutoff"] = embedcharge_cutoff
|
|
1491
|
+
|
|
1492
|
+
try:
|
|
1493
|
+
geom_freeze = _normalize_geom_freeze(geom_cfg.get("freeze_atoms"))
|
|
1494
|
+
except click.BadParameter as e:
|
|
1495
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
1496
|
+
sys.exit(1)
|
|
1497
|
+
geom_cfg["freeze_atoms"] = geom_freeze
|
|
1498
|
+
|
|
1499
|
+
try:
|
|
1500
|
+
cli_freeze = _parse_freeze_atoms(freeze_atoms_text)
|
|
1501
|
+
except click.BadParameter as e:
|
|
1502
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
1503
|
+
sys.exit(1)
|
|
1504
|
+
|
|
1505
|
+
model_indices: Optional[List[int]] = None
|
|
1506
|
+
if model_indices_str:
|
|
1507
|
+
try:
|
|
1508
|
+
model_indices = parse_indices_string(model_indices_str, one_based=model_indices_one_based)
|
|
1509
|
+
except click.BadParameter as e:
|
|
1510
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
1511
|
+
sys.exit(1)
|
|
1512
|
+
if cli_freeze:
|
|
1513
|
+
merge_freeze_atom_indices(geom_cfg, cli_freeze)
|
|
1514
|
+
|
|
1515
|
+
freeze_atoms_final = list(geom_cfg.get("freeze_atoms") or [])
|
|
1516
|
+
calc_cfg["freeze_atoms"] = freeze_atoms_final
|
|
1517
|
+
|
|
1518
|
+
resolved_charge = charge
|
|
1519
|
+
resolved_spin = spin
|
|
1520
|
+
for prepared in prepared_inputs:
|
|
1521
|
+
resolved_charge, resolved_spin = resolve_charge_spin_or_raise(
|
|
1522
|
+
prepared,
|
|
1523
|
+
resolved_charge,
|
|
1524
|
+
resolved_spin,
|
|
1525
|
+
ligand_charge=ligand_charge,
|
|
1526
|
+
prefix="[path-search]",
|
|
1527
|
+
)
|
|
1528
|
+
charge_value = calc_cfg.get("model_charge", resolved_charge)
|
|
1529
|
+
if charge_value is None:
|
|
1530
|
+
charge_value = resolved_charge
|
|
1531
|
+
calc_cfg["model_charge"] = int(charge_value)
|
|
1532
|
+
if _is_param_explicit("charge"):
|
|
1533
|
+
calc_cfg["model_charge"] = int(resolved_charge)
|
|
1534
|
+
|
|
1535
|
+
spin_value = calc_cfg.get("model_mult", resolved_spin)
|
|
1536
|
+
if spin_value is None:
|
|
1537
|
+
spin_value = resolved_spin
|
|
1538
|
+
calc_cfg["model_mult"] = int(spin_value)
|
|
1539
|
+
if _is_param_explicit("spin"):
|
|
1540
|
+
calc_cfg["model_mult"] = int(resolved_spin)
|
|
1541
|
+
|
|
1542
|
+
first_input = p_list[0]
|
|
1543
|
+
# input_pdb must be a PDB (parmed requirement); use --ref-pdb when input is XYZ
|
|
1544
|
+
if first_input.suffix.lower() != ".pdb" and ref_list:
|
|
1545
|
+
calc_cfg["input_pdb"] = str(Path(ref_list[0]).resolve())
|
|
1546
|
+
else:
|
|
1547
|
+
calc_cfg["input_pdb"] = str(first_input)
|
|
1548
|
+
calc_cfg["real_parm7"] = str(real_parm7)
|
|
1549
|
+
|
|
1550
|
+
detect_layer_effective = bool(calc_cfg.get("use_bfactor_layers", detect_layer))
|
|
1551
|
+
if _is_param_explicit("detect_layer"):
|
|
1552
|
+
detect_layer_effective = bool(detect_layer)
|
|
1553
|
+
|
|
1554
|
+
if _is_param_explicit("max_nodes"):
|
|
1555
|
+
gs_cfg["max_nodes"] = int(max_nodes)
|
|
1556
|
+
search_cfg["max_nodes_segment"] = int(max_nodes)
|
|
1557
|
+
if _is_param_explicit("max_cycles"):
|
|
1558
|
+
stopt_cfg["max_cycles"] = int(max_cycles)
|
|
1559
|
+
stopt_cfg["stop_in_when_full"] = int(max_cycles)
|
|
1560
|
+
dmf_cfg["max_cycles"] = int(max_cycles)
|
|
1561
|
+
if _is_param_explicit("climb"):
|
|
1562
|
+
gs_cfg["climb"] = bool(climb)
|
|
1563
|
+
gs_cfg["climb_lanczos"] = bool(climb)
|
|
1564
|
+
if _is_param_explicit("dump"):
|
|
1565
|
+
stopt_cfg["dump"] = bool(dump)
|
|
1566
|
+
lbfgs_cfg["dump"] = bool(dump)
|
|
1567
|
+
if _is_param_explicit("out_dir"):
|
|
1568
|
+
stopt_cfg["out_dir"] = out_dir
|
|
1569
|
+
lbfgs_cfg["out_dir"] = out_dir
|
|
1570
|
+
if _is_param_explicit("thresh") and thresh is not None:
|
|
1571
|
+
stopt_cfg["thresh"] = str(thresh)
|
|
1572
|
+
lbfgs_cfg["thresh"] = str(thresh)
|
|
1573
|
+
if _is_param_explicit("hess_cutoff") and hess_cutoff is not None:
|
|
1574
|
+
calc_cfg["hess_cutoff"] = float(hess_cutoff)
|
|
1575
|
+
if _is_param_explicit("movable_cutoff") and movable_cutoff is not None:
|
|
1576
|
+
calc_cfg["movable_cutoff"] = float(movable_cutoff)
|
|
1577
|
+
detect_layer_effective = False
|
|
1578
|
+
if _is_param_explicit("refine_mode"):
|
|
1579
|
+
search_cfg["refine_mode"] = refine_mode_kind
|
|
1580
|
+
|
|
1581
|
+
apply_yaml_overrides(
|
|
1582
|
+
override_layer_cfg,
|
|
1583
|
+
[
|
|
1584
|
+
(geom_cfg, (("geom",),)),
|
|
1585
|
+
(calc_cfg, (("calc",), ("mlmm",))),
|
|
1586
|
+
(gs_cfg, (("gs",),)),
|
|
1587
|
+
(stopt_cfg, (("stopt",), ("opt",))),
|
|
1588
|
+
(lbfgs_cfg, (("stopt", "lbfgs"), ("lbfgs",))),
|
|
1589
|
+
(bond_cfg, (("bond",),)),
|
|
1590
|
+
(search_cfg, (("search",),)),
|
|
1591
|
+
(dmf_cfg, (("dmf",),)),
|
|
1592
|
+
],
|
|
1593
|
+
)
|
|
1594
|
+
|
|
1595
|
+
refine_mode_kind = search_cfg.get("refine_mode")
|
|
1596
|
+
if refine_mode_kind is None:
|
|
1597
|
+
refine_mode_kind = "peak" if mep_mode_kind == "gsm" else "minima"
|
|
1598
|
+
else:
|
|
1599
|
+
refine_mode_kind = str(refine_mode_kind).strip().lower()
|
|
1600
|
+
if refine_mode_kind not in {"peak", "minima"}:
|
|
1601
|
+
raise click.BadParameter(f"Unknown --refine-mode '{refine_mode_kind}'.")
|
|
1602
|
+
search_cfg["refine_mode"] = refine_mode_kind
|
|
1603
|
+
|
|
1604
|
+
out_dir_path = Path(stopt_cfg.get("out_dir", out_dir)).resolve()
|
|
1605
|
+
detect_layer_effective = bool(calc_cfg.get("use_bfactor_layers", detect_layer_effective))
|
|
1606
|
+
|
|
1607
|
+
model_pdb_effective: Optional[Path] = None
|
|
1608
|
+
if _is_param_explicit("model_pdb") and model_pdb is not None:
|
|
1609
|
+
model_pdb_effective = Path(model_pdb)
|
|
1610
|
+
else:
|
|
1611
|
+
model_pdb_cfg = calc_cfg.get("model_pdb")
|
|
1612
|
+
if isinstance(model_pdb_cfg, (str, Path)) and str(model_pdb_cfg).strip():
|
|
1613
|
+
model_pdb_effective = Path(model_pdb_cfg)
|
|
1614
|
+
|
|
1615
|
+
hess_cutoff_effective = calc_cfg.get("hess_cutoff")
|
|
1616
|
+
movable_cutoff_effective = calc_cfg.get("movable_cutoff")
|
|
1617
|
+
if movable_cutoff_effective is not None:
|
|
1618
|
+
if detect_layer_effective:
|
|
1619
|
+
click.echo("[layer] movable_cutoff is set; disabling detect-layer.", err=True)
|
|
1620
|
+
detect_layer_effective = False
|
|
1621
|
+
|
|
1622
|
+
# For layer detection, prefer --ref-pdb (which carries B-factor layers)
|
|
1623
|
+
# over the first input (which may be XYZ).
|
|
1624
|
+
if ref_list and ref_list[0]:
|
|
1625
|
+
layer_source_pdb = Path(ref_list[0]).resolve()
|
|
1626
|
+
else:
|
|
1627
|
+
layer_source_pdb = first_input
|
|
1628
|
+
if detect_layer_effective and layer_source_pdb.suffix.lower() != ".pdb":
|
|
1629
|
+
click.echo("ERROR: --detect-layer requires a PDB input (or --ref-pdb).", err=True)
|
|
1630
|
+
sys.exit(1)
|
|
1631
|
+
|
|
1632
|
+
if dry_run:
|
|
1633
|
+
layer_info_preview: Optional[Dict[str, List[int]]] = None
|
|
1634
|
+
model_region_source = "bfactor"
|
|
1635
|
+
|
|
1636
|
+
if detect_layer_effective:
|
|
1637
|
+
try:
|
|
1638
|
+
bfactors = read_bfactors_from_pdb(layer_source_pdb)
|
|
1639
|
+
if not bfactors:
|
|
1640
|
+
raise ValueError(f"No ATOM/HETATM records found in {layer_source_pdb}.")
|
|
1641
|
+
if not has_valid_layer_bfactors(bfactors):
|
|
1642
|
+
raise ValueError(
|
|
1643
|
+
"Invalid or missing layer B-factors (expected ~0/10/20). "
|
|
1644
|
+
"Provide --no-detect-layer with --model-pdb/--model-indices."
|
|
1645
|
+
)
|
|
1646
|
+
layer_info_preview = parse_layer_indices_from_bfactors(bfactors)
|
|
1647
|
+
if not layer_info_preview.get("ml_indices"):
|
|
1648
|
+
raise ValueError("No ML atoms detected from B-factors (value ~0).")
|
|
1649
|
+
except Exception as e:
|
|
1650
|
+
if model_pdb_effective is None and not model_indices:
|
|
1651
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
1652
|
+
sys.exit(1)
|
|
1653
|
+
click.echo(f"[layer] WARNING: {e} Falling back to explicit ML region.", err=True)
|
|
1654
|
+
detect_layer_effective = False
|
|
1655
|
+
|
|
1656
|
+
if not detect_layer_effective:
|
|
1657
|
+
if model_pdb_effective is not None:
|
|
1658
|
+
model_region_source = "model_pdb"
|
|
1659
|
+
elif model_indices:
|
|
1660
|
+
model_region_source = "model_indices"
|
|
1661
|
+
if layer_source_pdb.suffix.lower() != ".pdb":
|
|
1662
|
+
click.echo("ERROR: --model-indices requires a PDB input.", err=True)
|
|
1663
|
+
sys.exit(1)
|
|
1664
|
+
n_atoms = 0
|
|
1665
|
+
with layer_source_pdb.open("r", encoding="utf-8", errors="ignore") as fh:
|
|
1666
|
+
for line in fh:
|
|
1667
|
+
if line.startswith(("ATOM ", "HETATM")):
|
|
1668
|
+
n_atoms += 1
|
|
1669
|
+
bad_idx = [i for i in model_indices if i < 0 or i >= n_atoms]
|
|
1670
|
+
if bad_idx:
|
|
1671
|
+
click.echo(
|
|
1672
|
+
f"ERROR: model index out of range: {bad_idx[0]} (valid: 0 <= idx < {n_atoms})",
|
|
1673
|
+
err=True,
|
|
1674
|
+
)
|
|
1675
|
+
sys.exit(1)
|
|
1676
|
+
else:
|
|
1677
|
+
click.echo("ERROR: Provide --model-pdb or --model-indices when --no-detect-layer.", err=True)
|
|
1678
|
+
sys.exit(1)
|
|
1679
|
+
|
|
1680
|
+
if show_config:
|
|
1681
|
+
click.echo(
|
|
1682
|
+
pretty_block(
|
|
1683
|
+
"yaml_layers",
|
|
1684
|
+
{
|
|
1685
|
+
"config": None if config_yaml is None else str(config_yaml),
|
|
1686
|
+
"override": None if override_yaml is None else str(override_yaml),
|
|
1687
|
+
"merged_keys": sorted(merged_yaml_cfg.keys()),
|
|
1688
|
+
},
|
|
1689
|
+
)
|
|
1690
|
+
)
|
|
1691
|
+
|
|
1692
|
+
dry_payload: Dict[str, Any] = {
|
|
1693
|
+
"input_count": len(p_list),
|
|
1694
|
+
"input_first": str(p_list[0]) if p_list else None,
|
|
1695
|
+
"input_last": str(p_list[-1]) if p_list else None,
|
|
1696
|
+
"output_dir": str(out_dir_path),
|
|
1697
|
+
"mep_mode": mep_mode_kind,
|
|
1698
|
+
"refine_mode": refine_mode_kind,
|
|
1699
|
+
"opt_mode": str(opt_mode),
|
|
1700
|
+
"detect_layer": bool(detect_layer_effective),
|
|
1701
|
+
"model_region_source": model_region_source,
|
|
1702
|
+
"model_indices_count": 0 if not model_indices else len(model_indices),
|
|
1703
|
+
"pre_opt": bool(pre_opt),
|
|
1704
|
+
"align": bool(align),
|
|
1705
|
+
"max_depth": int(search_cfg.get("max_depth", SEARCH_KW["max_depth"])),
|
|
1706
|
+
"max_nodes_segment": int(search_cfg.get("max_nodes_segment", gs_cfg.get("max_nodes", 0))),
|
|
1707
|
+
"will_run_path_search": True,
|
|
1708
|
+
"will_write_summary": True,
|
|
1709
|
+
"backend": calc_cfg.get("backend", "uma"),
|
|
1710
|
+
"embedcharge": bool(calc_cfg.get("embedcharge", False)),
|
|
1711
|
+
}
|
|
1712
|
+
if layer_info_preview is not None:
|
|
1713
|
+
dry_payload["layer_counts"] = {
|
|
1714
|
+
"ml": len(layer_info_preview.get("ml_indices", [])),
|
|
1715
|
+
"movable_mm": len(layer_info_preview.get("movable_mm_indices", [])),
|
|
1716
|
+
"frozen": len(layer_info_preview.get("frozen_indices", [])),
|
|
1717
|
+
"unassigned": len(layer_info_preview.get("unassigned_indices", [])),
|
|
1718
|
+
}
|
|
1719
|
+
|
|
1720
|
+
click.echo(pretty_block("dry_run_plan", dry_payload))
|
|
1721
|
+
click.echo("[dry-run] Validation complete. Path search execution was skipped.")
|
|
1722
|
+
return
|
|
1723
|
+
|
|
1724
|
+
model_pdb_path: Optional[Path] = None
|
|
1725
|
+
layer_info: Optional[Dict[str, List[int]]] = None
|
|
1726
|
+
|
|
1727
|
+
if detect_layer_effective:
|
|
1728
|
+
try:
|
|
1729
|
+
model_pdb_path, layer_info = build_model_pdb_from_bfactors(layer_source_pdb, out_dir_path)
|
|
1730
|
+
calc_cfg["use_bfactor_layers"] = True
|
|
1731
|
+
click.echo(
|
|
1732
|
+
f"[layer] Detected B-factor layers: ML={len(layer_info.get('ml_indices', []))}, "
|
|
1733
|
+
f"MovableMM={len(layer_info.get('movable_mm_indices', []))}, "
|
|
1734
|
+
f"FrozenMM={len(layer_info.get('frozen_indices', []))}"
|
|
1735
|
+
)
|
|
1736
|
+
except Exception as e:
|
|
1737
|
+
if model_pdb_effective is None and not model_indices:
|
|
1738
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
1739
|
+
sys.exit(1)
|
|
1740
|
+
click.echo(f"[layer] WARNING: {e} Falling back to explicit ML region.", err=True)
|
|
1741
|
+
detect_layer_effective = False
|
|
1742
|
+
|
|
1743
|
+
if not detect_layer_effective:
|
|
1744
|
+
if model_pdb_effective is None and not model_indices:
|
|
1745
|
+
click.echo("ERROR: Provide --model-pdb or --model-indices when --no-detect-layer.", err=True)
|
|
1746
|
+
sys.exit(1)
|
|
1747
|
+
if model_pdb_effective is not None:
|
|
1748
|
+
model_pdb_path = Path(model_pdb_effective)
|
|
1749
|
+
else:
|
|
1750
|
+
if layer_source_pdb.suffix.lower() != ".pdb":
|
|
1751
|
+
click.echo("ERROR: --model-indices requires a PDB input.", err=True)
|
|
1752
|
+
sys.exit(1)
|
|
1753
|
+
try:
|
|
1754
|
+
model_pdb_path = build_model_pdb_from_indices(layer_source_pdb, out_dir_path, model_indices or [])
|
|
1755
|
+
except Exception as e:
|
|
1756
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
1757
|
+
sys.exit(1)
|
|
1758
|
+
calc_cfg["use_bfactor_layers"] = False
|
|
1759
|
+
|
|
1760
|
+
if model_pdb_path is None:
|
|
1761
|
+
click.echo("ERROR: Failed to resolve model PDB for the ML region.", err=True)
|
|
1762
|
+
sys.exit(1)
|
|
1763
|
+
|
|
1764
|
+
calc_cfg["model_pdb"] = str(model_pdb_path)
|
|
1765
|
+
freeze_atoms_final = apply_layer_freeze_constraints(
|
|
1766
|
+
geom_cfg,
|
|
1767
|
+
calc_cfg,
|
|
1768
|
+
layer_info,
|
|
1769
|
+
echo_fn=click.echo,
|
|
1770
|
+
)
|
|
1771
|
+
|
|
1772
|
+
# Distance-based overrides for Hessian-target and movable MM selection.
|
|
1773
|
+
if hess_cutoff_effective is not None:
|
|
1774
|
+
calc_cfg["hess_cutoff"] = float(hess_cutoff_effective)
|
|
1775
|
+
if movable_cutoff_effective is not None:
|
|
1776
|
+
calc_cfg["movable_cutoff"] = float(movable_cutoff_effective)
|
|
1777
|
+
calc_cfg["use_bfactor_layers"] = False
|
|
1778
|
+
|
|
1779
|
+
for key in ("input_pdb", "real_parm7", "model_pdb", "mm_fd_dir"):
|
|
1780
|
+
val = calc_cfg.get(key)
|
|
1781
|
+
if isinstance(val, (str, Path)):
|
|
1782
|
+
calc_cfg[key] = str(Path(val).expanduser().resolve())
|
|
1783
|
+
|
|
1784
|
+
stopt_cfg["stop_in_when_full"] = int(stopt_cfg.get("max_cycles", STOPT_KW["max_cycles"]))
|
|
1785
|
+
out_dir_path = Path(stopt_cfg.get("out_dir", out_dir)).resolve()
|
|
1786
|
+
echo_geom = format_freeze_atoms_for_echo(geom_cfg, key="freeze_atoms")
|
|
1787
|
+
echo_calc = format_freeze_atoms_for_echo(filter_calc_for_echo(calc_cfg), key="freeze_atoms")
|
|
1788
|
+
echo_gs = strip_inherited_keys(gs_cfg, GS_KW, mode="same")
|
|
1789
|
+
echo_stopt = strip_inherited_keys({**stopt_cfg, "out_dir": str(out_dir_path)}, STOPT_KW, mode="same")
|
|
1790
|
+
echo_lbfgs = strip_inherited_keys(lbfgs_cfg, LBFGS_KW, mode="same")
|
|
1791
|
+
echo_bond = strip_inherited_keys(bond_cfg, BOND_KW, mode="same")
|
|
1792
|
+
echo_search = strip_inherited_keys(search_cfg, SEARCH_KW, mode="same")
|
|
1793
|
+
|
|
1794
|
+
click.echo(pretty_block("geom", echo_geom))
|
|
1795
|
+
click.echo(pretty_block("calc", echo_calc))
|
|
1796
|
+
click.echo(pretty_block("gs", echo_gs))
|
|
1797
|
+
click.echo(pretty_block("stopt", echo_stopt))
|
|
1798
|
+
click.echo(pretty_block("lbfgs", echo_lbfgs))
|
|
1799
|
+
click.echo(pretty_block("bond", echo_bond))
|
|
1800
|
+
click.echo(pretty_block("search", echo_search))
|
|
1801
|
+
# Echo pre-optimization and alignment flags
|
|
1802
|
+
click.echo(
|
|
1803
|
+
pretty_block(
|
|
1804
|
+
"run_flags",
|
|
1805
|
+
{
|
|
1806
|
+
"pre_opt": bool(pre_opt),
|
|
1807
|
+
"align": bool(align),
|
|
1808
|
+
"mep_mode": mep_mode_kind,
|
|
1809
|
+
"refine_mode": refine_mode_kind,
|
|
1810
|
+
"opt_mode": str(opt_mode),
|
|
1811
|
+
},
|
|
1812
|
+
)
|
|
1813
|
+
)
|
|
1814
|
+
|
|
1815
|
+
if show_config:
|
|
1816
|
+
click.echo(
|
|
1817
|
+
pretty_block(
|
|
1818
|
+
"yaml_layers",
|
|
1819
|
+
{
|
|
1820
|
+
"config": None if config_yaml is None else str(config_yaml),
|
|
1821
|
+
"override": None if override_yaml is None else str(override_yaml),
|
|
1822
|
+
"merged_keys": sorted(merged_yaml_cfg.keys()),
|
|
1823
|
+
},
|
|
1824
|
+
)
|
|
1825
|
+
)
|
|
1826
|
+
|
|
1827
|
+
if int(stopt_cfg.get("max_cycles", 0)) <= 0:
|
|
1828
|
+
click.echo("[INFO] max_cycles <= 0: skipping path search.")
|
|
1829
|
+
return
|
|
1830
|
+
|
|
1831
|
+
# --------------------------
|
|
1832
|
+
# 2) Prepare inputs
|
|
1833
|
+
# --------------------------
|
|
1834
|
+
out_dir_path.mkdir(parents=True, exist_ok=True)
|
|
1835
|
+
|
|
1836
|
+
geoms = _load_structures(
|
|
1837
|
+
inputs=prepared_inputs,
|
|
1838
|
+
coord_type=geom_cfg.get("coord_type", "cart"),
|
|
1839
|
+
base_freeze=geom_cfg.get("freeze_atoms", []),
|
|
1840
|
+
)
|
|
1841
|
+
|
|
1842
|
+
shared_calc = mlmm(**calc_cfg)
|
|
1843
|
+
for g in geoms:
|
|
1844
|
+
g.set_calculator(shared_calc)
|
|
1845
|
+
|
|
1846
|
+
# Reference PDB for output conversion: prefer --ref-pdb, fall back to input PDBs
|
|
1847
|
+
ref_pdb_for_segments: Optional[Path] = None
|
|
1848
|
+
if ref_list:
|
|
1849
|
+
ref_pdb_for_segments = Path(ref_list[0]).resolve()
|
|
1850
|
+
else:
|
|
1851
|
+
for p in p_list:
|
|
1852
|
+
if p.suffix.lower() == ".pdb":
|
|
1853
|
+
ref_pdb_for_segments = p.resolve()
|
|
1854
|
+
break
|
|
1855
|
+
|
|
1856
|
+
if pre_opt:
|
|
1857
|
+
new_geoms: List[Any] = []
|
|
1858
|
+
for i, g in enumerate(geoms):
|
|
1859
|
+
tag = f"init{i:02d}"
|
|
1860
|
+
g_opt = _optimize_single(g, shared_calc, lbfgs_cfg, out_dir_path, tag=tag, ref_pdb_path=ref_pdb_for_segments)
|
|
1861
|
+
new_geoms.append(g_opt)
|
|
1862
|
+
geoms = new_geoms
|
|
1863
|
+
else:
|
|
1864
|
+
click.echo("[init] Skipping endpoint pre-optimization as requested by --no-preopt.")
|
|
1865
|
+
|
|
1866
|
+
# Align all inputs to the first structure, guided by freeze constraints, when requested
|
|
1867
|
+
align_thresh = str(stopt_cfg.get("thresh", "gau"))
|
|
1868
|
+
if align:
|
|
1869
|
+
try:
|
|
1870
|
+
click.echo("\n=== Aligning all inputs to the first structure (freeze-guided scan + relaxation) ===\n")
|
|
1871
|
+
_ = align_and_refine_sequence_inplace(
|
|
1872
|
+
geoms,
|
|
1873
|
+
thresh=align_thresh,
|
|
1874
|
+
shared_calc=shared_calc,
|
|
1875
|
+
out_dir=out_dir_path / "align_refine",
|
|
1876
|
+
verbose=True,
|
|
1877
|
+
)
|
|
1878
|
+
click.echo("[align] Completed input alignment.")
|
|
1879
|
+
except Exception as e:
|
|
1880
|
+
click.echo(f"[align] WARNING: Alignment failed; continuing without alignment: {e}", err=True)
|
|
1881
|
+
else:
|
|
1882
|
+
click.echo("[align] Skipping input alignment as requested by --no-align.")
|
|
1883
|
+
|
|
1884
|
+
# --------------------------
|
|
1885
|
+
# 3) Run recursive search for each adjacent pair and stitch
|
|
1886
|
+
# --------------------------
|
|
1887
|
+
click.echo("\n=== Multistep MEP search (multi-structure) started ===\n")
|
|
1888
|
+
seg_counter = [0]
|
|
1889
|
+
|
|
1890
|
+
bridge_max_nodes = int(search_cfg.get("max_nodes_bridge", 5))
|
|
1891
|
+
gs_bridge_cfg = {**gs_cfg, "max_nodes": bridge_max_nodes, "climb": False, "climb_lanczos": False}
|
|
1892
|
+
|
|
1893
|
+
combined_imgs: List[Any] = []
|
|
1894
|
+
combined_Es: List[float] = []
|
|
1895
|
+
seg_reports_all: List[SegmentReport] = []
|
|
1896
|
+
|
|
1897
|
+
def _segment_builder_for_pairs(tail_g, head_g, _tag: str) -> CombinedPath:
|
|
1898
|
+
sub = _build_multistep_path(
|
|
1899
|
+
tail_g, head_g,
|
|
1900
|
+
shared_calc,
|
|
1901
|
+
geom_cfg, gs_cfg, stopt_cfg,
|
|
1902
|
+
lbfgs_cfg,
|
|
1903
|
+
bond_cfg, search_cfg, refine_mode_kind,
|
|
1904
|
+
out_dir=out_dir_path,
|
|
1905
|
+
ref_pdb_path=ref_pdb_for_segments,
|
|
1906
|
+
depth=0,
|
|
1907
|
+
seg_counter=seg_counter,
|
|
1908
|
+
branch_tag="B",
|
|
1909
|
+
pair_index=None,
|
|
1910
|
+
mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
|
|
1911
|
+
kink_seq_count=_trailing_kink_count(seg_reports_all),
|
|
1912
|
+
)
|
|
1913
|
+
return sub
|
|
1914
|
+
|
|
1915
|
+
for i in range(len(geoms) - 1):
|
|
1916
|
+
gA, gB = geoms[i], geoms[i + 1]
|
|
1917
|
+
pair_tag = f"pair_{i:02d}"
|
|
1918
|
+
click.echo(f"\n--- Processing pair {i:02d}: image {i} → {i+1} ---")
|
|
1919
|
+
pair_path = _build_multistep_path(
|
|
1920
|
+
gA, gB,
|
|
1921
|
+
shared_calc,
|
|
1922
|
+
geom_cfg, gs_cfg, stopt_cfg,
|
|
1923
|
+
lbfgs_cfg,
|
|
1924
|
+
bond_cfg, search_cfg, refine_mode_kind,
|
|
1925
|
+
out_dir=out_dir_path,
|
|
1926
|
+
ref_pdb_path=ref_pdb_for_segments,
|
|
1927
|
+
depth=0,
|
|
1928
|
+
seg_counter=seg_counter,
|
|
1929
|
+
branch_tag=pair_tag,
|
|
1930
|
+
pair_index=i,
|
|
1931
|
+
mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
|
|
1932
|
+
)
|
|
1933
|
+
|
|
1934
|
+
if i == 0:
|
|
1935
|
+
combined_imgs = list(pair_path.images)
|
|
1936
|
+
combined_Es = list(pair_path.energies)
|
|
1937
|
+
seg_reports_all.extend(pair_path.segments)
|
|
1938
|
+
else:
|
|
1939
|
+
parts = [(combined_imgs, combined_Es), (pair_path.images, pair_path.energies)]
|
|
1940
|
+
combined_imgs, combined_Es = _stitch_paths(
|
|
1941
|
+
parts=parts,
|
|
1942
|
+
stitch_rmsd_thresh=float(search_cfg["stitch_rmsd_thresh"]),
|
|
1943
|
+
bridge_rmsd_thresh=float(search_cfg["bridge_rmsd_thresh"]),
|
|
1944
|
+
shared_calc=shared_calc,
|
|
1945
|
+
gs_cfg=gs_bridge_cfg,
|
|
1946
|
+
stopt_cfg=stopt_cfg,
|
|
1947
|
+
out_dir=out_dir_path,
|
|
1948
|
+
tag=pair_tag,
|
|
1949
|
+
ref_pdb_path=ref_pdb_for_segments,
|
|
1950
|
+
bond_cfg=bond_cfg,
|
|
1951
|
+
segment_builder=_segment_builder_for_pairs,
|
|
1952
|
+
segments_out=seg_reports_all,
|
|
1953
|
+
bridge_pair_index=i,
|
|
1954
|
+
mep_mode_kind=mep_mode_kind, calc_cfg=calc_cfg, dmf_cfg=dmf_cfg,
|
|
1955
|
+
)
|
|
1956
|
+
seg_reports_all.extend(pair_path.segments)
|
|
1957
|
+
|
|
1958
|
+
click.echo("\n=== Multistep MEP search (multi-structure) finished ===\n")
|
|
1959
|
+
|
|
1960
|
+
combined_all = CombinedPath(images=combined_imgs, energies=combined_Es, segments=seg_reports_all)
|
|
1961
|
+
|
|
1962
|
+
# --------------------------
|
|
1963
|
+
# 4) Outputs
|
|
1964
|
+
# --------------------------
|
|
1965
|
+
for idx, srep in enumerate(combined_all.segments, 1):
|
|
1966
|
+
srep.seg_index = idx
|
|
1967
|
+
tag_to_index = {s.tag: int(s.seg_index) for s in combined_all.segments}
|
|
1968
|
+
for im in combined_all.images:
|
|
1969
|
+
tag = getattr(im, "mep_seg_tag", None)
|
|
1970
|
+
if tag and tag in tag_to_index:
|
|
1971
|
+
try:
|
|
1972
|
+
setattr(im, "mep_seg_index", int(tag_to_index[tag]))
|
|
1973
|
+
except Exception:
|
|
1974
|
+
logger.debug("Failed to set mep_seg_index on image", exc_info=True)
|
|
1975
|
+
|
|
1976
|
+
# Always write mep_trj.xyz for downstream compatibility; convert to PDB when possible.
|
|
1977
|
+
pdb_input = ref_pdb_for_segments is not None
|
|
1978
|
+
final_trj = out_dir_path / "mep_trj.xyz"
|
|
1979
|
+
_write_xyz_trj_with_energy(combined_all.images, combined_all.energies, final_trj)
|
|
1980
|
+
click.echo(f"[write] Wrote '{final_trj}'.")
|
|
1981
|
+
try:
|
|
1982
|
+
run_trj2fig(final_trj, [out_dir_path / "mep_plot.png"], unit="kcal", reference="init", reverse_x=False)
|
|
1983
|
+
click.echo(f"[plot] Saved energy plot → '{out_dir_path / 'mep_plot.png'}'")
|
|
1984
|
+
except Exception as e:
|
|
1985
|
+
click.echo(f"[plot] WARNING: Failed to plot final energy: {e}", err=True)
|
|
1986
|
+
|
|
1987
|
+
if pdb_input:
|
|
1988
|
+
try:
|
|
1989
|
+
final_pdb = out_dir_path / "mep.pdb"
|
|
1990
|
+
convert_xyz_to_pdb(final_trj, ref_pdb_for_segments, final_pdb)
|
|
1991
|
+
click.echo(f"[convert] Wrote '{final_pdb}'.")
|
|
1992
|
+
except Exception as e:
|
|
1993
|
+
click.echo(f"[convert] WARNING: Failed to convert final MEP to PDB: {e}", err=True)
|
|
1994
|
+
|
|
1995
|
+
# ---- Pocket-only per-segment trajectories & HEIs ----
|
|
1996
|
+
try:
|
|
1997
|
+
# Map frames → segment indices
|
|
1998
|
+
frame_seg_indices: List[int] = [int(getattr(im, "mep_seg_index", 0) or 0) for im in combined_all.images]
|
|
1999
|
+
seg_to_frames: Dict[int, List[int]] = {}
|
|
2000
|
+
for ii, sidx in enumerate(frame_seg_indices):
|
|
2001
|
+
if sidx <= 0:
|
|
2002
|
+
continue
|
|
2003
|
+
seg_to_frames.setdefault(int(sidx), []).append(ii)
|
|
2004
|
+
|
|
2005
|
+
for s in combined_all.segments:
|
|
2006
|
+
seg_idx = int(s.seg_index)
|
|
2007
|
+
idxs = seg_to_frames.get(seg_idx, [])
|
|
2008
|
+
if not idxs:
|
|
2009
|
+
continue
|
|
2010
|
+
|
|
2011
|
+
# (A) Only for bond-change segments: pocket-only per-segment path
|
|
2012
|
+
if s.kind != "bridge" and s.summary and s.summary.strip() != "(no covalent changes detected)":
|
|
2013
|
+
seg_imgs = [combined_all.images[j] for j in idxs]
|
|
2014
|
+
seg_Es = [combined_all.energies[j] for j in idxs]
|
|
2015
|
+
seg_trj = out_dir_path / f"mep_seg_{seg_idx:02d}_trj.xyz"
|
|
2016
|
+
_write_xyz_trj_with_energy(seg_imgs, seg_Es, seg_trj)
|
|
2017
|
+
click.echo(f"[write] Wrote per-segment pocket trajectory → '{seg_trj}'")
|
|
2018
|
+
if ref_pdb_for_segments is not None:
|
|
2019
|
+
_maybe_convert_to_pdb(seg_trj, ref_pdb_for_segments, out_path=out_dir_path / f"mep_seg_{seg_idx:02d}.pdb")
|
|
2020
|
+
|
|
2021
|
+
# (B) HEI pocket files only for bond-change segments
|
|
2022
|
+
if s.kind != "bridge" and s.summary and s.summary.strip() != "(no covalent changes detected)":
|
|
2023
|
+
energies_seg = [combined_all.energies[j] for j in idxs]
|
|
2024
|
+
imax_rel = int(np.argmax(np.array(energies_seg, dtype=float)))
|
|
2025
|
+
imax_abs = idxs[imax_rel]
|
|
2026
|
+
hei_img = combined_all.images[imax_abs]
|
|
2027
|
+
hei_E = [combined_all.energies[imax_abs]]
|
|
2028
|
+
hei_trj = out_dir_path / f"hei_seg_{seg_idx:02d}.xyz"
|
|
2029
|
+
_write_xyz_trj_with_energy([hei_img], hei_E, hei_trj)
|
|
2030
|
+
click.echo(f"[write] Wrote segment HEI (pocket) → '{hei_trj}'")
|
|
2031
|
+
if ref_pdb_for_segments is not None:
|
|
2032
|
+
_maybe_convert_to_pdb(hei_trj, ref_pdb_for_segments, out_path=out_dir_path / f"hei_seg_{seg_idx:02d}.pdb")
|
|
2033
|
+
except Exception as e:
|
|
2034
|
+
click.echo(f"[write] WARNING: Failed to emit per-segment pocket outputs: {e}", err=True)
|
|
2035
|
+
# ---- END ----
|
|
2036
|
+
|
|
2037
|
+
summary = {
|
|
2038
|
+
"out_dir": str(out_dir_path),
|
|
2039
|
+
"n_images": len(combined_all.images),
|
|
2040
|
+
"n_segments": len(combined_all.segments),
|
|
2041
|
+
"segments": [
|
|
2042
|
+
{
|
|
2043
|
+
"index": int(s.seg_index),
|
|
2044
|
+
"tag": s.tag,
|
|
2045
|
+
"kind": s.kind,
|
|
2046
|
+
"barrier_kcal": float(s.barrier_kcal),
|
|
2047
|
+
"delta_kcal": float(s.delta_kcal),
|
|
2048
|
+
"bond_changes": (s.summary if (s.kind != "bridge") else "")
|
|
2049
|
+
} for s in combined_all.segments
|
|
2050
|
+
],
|
|
2051
|
+
}
|
|
2052
|
+
|
|
2053
|
+
# --------------------------
|
|
2054
|
+
# 5) Console summary
|
|
2055
|
+
# --------------------------
|
|
2056
|
+
try:
|
|
2057
|
+
overall_changed, overall_summary = _has_bond_change(combined_all.images[0], combined_all.images[-1], bond_cfg)
|
|
2058
|
+
except Exception:
|
|
2059
|
+
overall_changed, overall_summary = False, ""
|
|
2060
|
+
|
|
2061
|
+
click.echo("\n=== MEP Summary ===\n")
|
|
2062
|
+
|
|
2063
|
+
click.echo("\n[overall] Covalent-bond changes between first and last image:")
|
|
2064
|
+
if overall_changed and overall_summary.strip():
|
|
2065
|
+
click.echo(textwrap.indent(overall_summary.strip(), prefix=" "))
|
|
2066
|
+
else:
|
|
2067
|
+
click.echo(" (no covalent changes detected)")
|
|
2068
|
+
|
|
2069
|
+
if combined_all.segments:
|
|
2070
|
+
click.echo("\n[segments] Along the final MEP order (ΔE‡, ΔE). Bridges are shown between connected segments:")
|
|
2071
|
+
for i, seg in enumerate(combined_all.segments, 1):
|
|
2072
|
+
kind_label = "BRIDGE" if seg.kind == "bridge" else "SEG"
|
|
2073
|
+
click.echo(f" [{i:02d}] ({kind_label}) {seg.tag} | ΔE‡ = {seg.barrier_kcal:.2f} kcal/mol, ΔE = {seg.delta_kcal:.2f} kcal/mol")
|
|
2074
|
+
if seg.kind != "bridge" and seg.summary.strip():
|
|
2075
|
+
click.echo(textwrap.indent(seg.summary.strip(), prefix=" "))
|
|
2076
|
+
else:
|
|
2077
|
+
click.echo("\n[segments] (no segment reports)")
|
|
2078
|
+
|
|
2079
|
+
# --------------------------
|
|
2080
|
+
# 6) Energy diagram from bond-change segments (state labeling; compressed)
|
|
2081
|
+
# --------------------------
|
|
2082
|
+
diagram_payload: Optional[Dict[str, Any]] = None
|
|
2083
|
+
try:
|
|
2084
|
+
# Map each segment index → list of frame indices
|
|
2085
|
+
frame_seg_indices: List[int] = [int(getattr(im, "mep_seg_index", 0) or 0) for im in combined_all.images]
|
|
2086
|
+
seg_to_frames: Dict[int, List[int]] = {}
|
|
2087
|
+
for ii, sidx in enumerate(frame_seg_indices):
|
|
2088
|
+
if sidx <= 0:
|
|
2089
|
+
continue
|
|
2090
|
+
seg_to_frames.setdefault(int(sidx), []).append(ii)
|
|
2091
|
+
|
|
2092
|
+
# Build TS groups (each bond-change segment starts a group)
|
|
2093
|
+
ts_groups: List[Dict[str, Any]] = []
|
|
2094
|
+
ts_count = 0
|
|
2095
|
+
current: Optional[Dict[str, Any]] = None
|
|
2096
|
+
|
|
2097
|
+
for s in combined_all.segments:
|
|
2098
|
+
idxs = seg_to_frames.get(int(s.seg_index), [])
|
|
2099
|
+
if not idxs:
|
|
2100
|
+
continue
|
|
2101
|
+
|
|
2102
|
+
if s.kind == "seg" and s.summary and s.summary.strip() != "(no covalent changes detected)":
|
|
2103
|
+
# New TS group
|
|
2104
|
+
ts_count += 1
|
|
2105
|
+
imax = max(idxs, key=lambda j: combined_all.energies[j])
|
|
2106
|
+
ts_e = float(combined_all.energies[imax])
|
|
2107
|
+
first_im_e = float(combined_all.energies[idxs[-1]])
|
|
2108
|
+
current = {
|
|
2109
|
+
"ts_label": f"TS{ts_count}",
|
|
2110
|
+
"ts_energy": ts_e,
|
|
2111
|
+
"first_im_energy": first_im_e,
|
|
2112
|
+
"tail_im_energy": first_im_e,
|
|
2113
|
+
"has_extra": False,
|
|
2114
|
+
"index": ts_count,
|
|
2115
|
+
}
|
|
2116
|
+
ts_groups.append(current)
|
|
2117
|
+
else:
|
|
2118
|
+
# Kink/bridge: fold into current group as "extra" and update tail energy
|
|
2119
|
+
if current is not None:
|
|
2120
|
+
current["tail_im_energy"] = float(combined_all.energies[idxs[-1]])
|
|
2121
|
+
current["has_extra"] = True
|
|
2122
|
+
else:
|
|
2123
|
+
# pre-TS region without bond change → ignore
|
|
2124
|
+
pass
|
|
2125
|
+
|
|
2126
|
+
# Clip endpoints to first/last bond-change segment edges
|
|
2127
|
+
start_idx_for_diag = 0
|
|
2128
|
+
end_idx_for_diag = len(combined_all.energies) - 1
|
|
2129
|
+
bc_segments_in_order: List[SegmentReport] = [
|
|
2130
|
+
s for s in combined_all.segments
|
|
2131
|
+
if (s.kind == "seg" and s.summary and s.summary.strip() != "(no covalent changes detected)")
|
|
2132
|
+
]
|
|
2133
|
+
if bc_segments_in_order:
|
|
2134
|
+
first_bc = bc_segments_in_order[0]
|
|
2135
|
+
last_bc = bc_segments_in_order[-1]
|
|
2136
|
+
idxs_first_bc = seg_to_frames.get(int(first_bc.seg_index), [])
|
|
2137
|
+
idxs_last_bc = seg_to_frames.get(int(last_bc.seg_index), [])
|
|
2138
|
+
if idxs_first_bc:
|
|
2139
|
+
start_idx_for_diag = int(idxs_first_bc[0])
|
|
2140
|
+
if idxs_last_bc:
|
|
2141
|
+
end_idx_for_diag = int(idxs_last_bc[-1])
|
|
2142
|
+
|
|
2143
|
+
# Compose compressed labels/energies & human-readable chain
|
|
2144
|
+
labels: List[str] = ["R"]
|
|
2145
|
+
energies_eh: List[float] = [float(combined_all.energies[start_idx_for_diag])]
|
|
2146
|
+
chain_tokens: List[str] = ["R"]
|
|
2147
|
+
|
|
2148
|
+
for i, g in enumerate(ts_groups, start=1):
|
|
2149
|
+
last_group = (i == len(ts_groups))
|
|
2150
|
+
|
|
2151
|
+
# TS
|
|
2152
|
+
labels.append(g["ts_label"])
|
|
2153
|
+
energies_eh.append(g["ts_energy"])
|
|
2154
|
+
chain_tokens.extend(["-->", g["ts_label"]])
|
|
2155
|
+
|
|
2156
|
+
# For the last TS group: compress directly to P (no IMs)
|
|
2157
|
+
if last_group:
|
|
2158
|
+
continue
|
|
2159
|
+
|
|
2160
|
+
# IM1 (always keep)
|
|
2161
|
+
labels.append(f"IM{i}_1")
|
|
2162
|
+
energies_eh.append(g["first_im_energy"])
|
|
2163
|
+
chain_tokens.extend(["-->", f"IM{i}_1"])
|
|
2164
|
+
|
|
2165
|
+
# IM2 (represent all extra kink/bridge before next TS)
|
|
2166
|
+
if g["has_extra"]:
|
|
2167
|
+
labels.append(f"IM{i}_2")
|
|
2168
|
+
energies_eh.append(g["tail_im_energy"])
|
|
2169
|
+
chain_tokens.extend(["-|-->", f"IM{i}_2"])
|
|
2170
|
+
|
|
2171
|
+
# Product
|
|
2172
|
+
labels.append("P")
|
|
2173
|
+
energies_eh.append(float(combined_all.energies[end_idx_for_diag]))
|
|
2174
|
+
chain_tokens.extend(["-->", "P"])
|
|
2175
|
+
|
|
2176
|
+
# Convert to kcal/mol relative to R
|
|
2177
|
+
e0 = energies_eh[0]
|
|
2178
|
+
energies_kcal = [(e - e0) * AU2KCALPERMOL for e in energies_eh]
|
|
2179
|
+
energies_au = list(energies_eh)
|
|
2180
|
+
diagram_payload = {
|
|
2181
|
+
"name": "energy_diagram_MEP",
|
|
2182
|
+
"labels": list(labels),
|
|
2183
|
+
"energies_kcal": energies_kcal,
|
|
2184
|
+
"ylabel": "ΔE (kcal/mol)",
|
|
2185
|
+
"energies_au": energies_au,
|
|
2186
|
+
"image": str(out_dir_path / "energy_diagram_MEP.png"),
|
|
2187
|
+
}
|
|
2188
|
+
|
|
2189
|
+
# Log exact inputs to build_energy_diagram, and the human-readable chain
|
|
2190
|
+
labels_repr = "[" + ", ".join(f'"{lab}"' for lab in labels) + "]"
|
|
2191
|
+
energies_repr = "[" + ", ".join(f"{val:.6f}" for val in energies_kcal) + "]"
|
|
2192
|
+
click.echo(f"[diagram] build_energy_diagram.labels = {labels_repr}")
|
|
2193
|
+
click.echo(f"[diagram] build_energy_diagram.energies_kcal = {energies_repr}")
|
|
2194
|
+
|
|
2195
|
+
fig = build_energy_diagram(
|
|
2196
|
+
energies=energies_kcal,
|
|
2197
|
+
labels=labels,
|
|
2198
|
+
ylabel="ΔE (kcal/mol)",
|
|
2199
|
+
baseline=True,
|
|
2200
|
+
showgrid=False,
|
|
2201
|
+
)
|
|
2202
|
+
|
|
2203
|
+
try:
|
|
2204
|
+
png_path = out_dir_path / "energy_diagram_MEP.png"
|
|
2205
|
+
fig.write_image(str(png_path), scale=2)
|
|
2206
|
+
click.echo(f"[diagram] Wrote energy diagram (PNG) → '{png_path}'")
|
|
2207
|
+
except Exception as e:
|
|
2208
|
+
click.echo(f"[diagram] NOTE: PNG export skipped (install 'kaleido' to enable): {e}", err=True)
|
|
2209
|
+
|
|
2210
|
+
chain_text = " ".join(chain_tokens)
|
|
2211
|
+
click.echo(f"[diagram] State label sequence: {chain_text}")
|
|
2212
|
+
|
|
2213
|
+
except Exception as e:
|
|
2214
|
+
click.echo(f"[diagram] WARNING: Failed to build energy diagram: {e}", err=True)
|
|
2215
|
+
|
|
2216
|
+
# --------------------------
|
|
2217
|
+
# 7) Summary (YAML + log)
|
|
2218
|
+
# --------------------------
|
|
2219
|
+
if diagram_payload is not None:
|
|
2220
|
+
summary["energy_diagrams"] = [diagram_payload]
|
|
2221
|
+
|
|
2222
|
+
with open(out_dir_path / "summary.yaml", "w") as f:
|
|
2223
|
+
yaml.safe_dump(summary, f, sort_keys=False, allow_unicode=True)
|
|
2224
|
+
click.echo(f"[write] Wrote '{out_dir_path / 'summary.yaml'}'.")
|
|
2225
|
+
|
|
2226
|
+
try:
|
|
2227
|
+
freeze_atoms_for_log: List[int] = []
|
|
2228
|
+
try:
|
|
2229
|
+
freeze_atoms_for_log = sorted(
|
|
2230
|
+
{
|
|
2231
|
+
int(i)
|
|
2232
|
+
for g in getattr(combined_all, "images", [])
|
|
2233
|
+
for i in getattr(g, "freeze_atoms", [])
|
|
2234
|
+
}
|
|
2235
|
+
)
|
|
2236
|
+
except Exception:
|
|
2237
|
+
freeze_atoms_for_log = []
|
|
2238
|
+
|
|
2239
|
+
diag_for_log: Dict[str, Any] = diagram_payload or {}
|
|
2240
|
+
mep_info = {
|
|
2241
|
+
"n_images": len(combined_all.images),
|
|
2242
|
+
"n_segments": len(combined_all.segments),
|
|
2243
|
+
"traj_pdb": str(out_dir_path / "mep.pdb") if (out_dir_path / "mep.pdb").exists() else None,
|
|
2244
|
+
"mep_plot": str(out_dir_path / "mep_plot.png") if (out_dir_path / "mep_plot.png").exists() else None,
|
|
2245
|
+
"diagram": diag_for_log,
|
|
2246
|
+
}
|
|
2247
|
+
summary_payload = {
|
|
2248
|
+
"root_out_dir": str(out_dir_path),
|
|
2249
|
+
"path_dir": str(out_dir_path),
|
|
2250
|
+
"path_module_dir": "path_search",
|
|
2251
|
+
"pipeline_mode": "path-search",
|
|
2252
|
+
"refine_path": True,
|
|
2253
|
+
"tsopt": False,
|
|
2254
|
+
"thermo": False,
|
|
2255
|
+
"dft": False,
|
|
2256
|
+
"opt_mode": opt_mode,
|
|
2257
|
+
"mep_mode": "path-search",
|
|
2258
|
+
"uma_model": calc_cfg.get("uma_model"),
|
|
2259
|
+
"command": command_str,
|
|
2260
|
+
"charge": calc_cfg.get("model_charge"),
|
|
2261
|
+
"spin": calc_cfg.get("model_mult"),
|
|
2262
|
+
"freeze_atoms": freeze_atoms_for_log,
|
|
2263
|
+
"mep": mep_info,
|
|
2264
|
+
"segments": summary.get("segments", []),
|
|
2265
|
+
"energy_diagrams": summary.get("energy_diagrams", []),
|
|
2266
|
+
"key_files": {},
|
|
2267
|
+
}
|
|
2268
|
+
write_summary_log(out_dir_path / "summary.log", summary_payload)
|
|
2269
|
+
click.echo(f"[write] Wrote '{out_dir_path / 'summary.log'}'.")
|
|
2270
|
+
except Exception as e:
|
|
2271
|
+
click.echo(f"[write] WARNING: Failed to write summary.log: {e}", err=True)
|
|
2272
|
+
|
|
2273
|
+
# summary.md and key_* outputs are disabled.
|
|
2274
|
+
# --------------------------
|
|
2275
|
+
# 8) Elapsed time
|
|
2276
|
+
# --------------------------
|
|
2277
|
+
click.echo(format_elapsed("[time] Elapsed for Path Search", time_start))
|
|
2278
|
+
|
|
2279
|
+
except ZeroStepLength:
|
|
2280
|
+
click.echo("ERROR: Proposed step length dropped below the minimum allowed (ZeroStepLength).", err=True)
|
|
2281
|
+
sys.exit(2)
|
|
2282
|
+
except OptimizationError as e:
|
|
2283
|
+
click.echo(f"ERROR: Path search failed — {e}", err=True)
|
|
2284
|
+
sys.exit(3)
|
|
2285
|
+
except KeyboardInterrupt:
|
|
2286
|
+
click.echo("\nInterrupted by user.", err=True)
|
|
2287
|
+
sys.exit(130)
|
|
2288
|
+
except Exception as e:
|
|
2289
|
+
tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
2290
|
+
click.echo("Unhandled error during path search:\n" + textwrap.indent(tb, " "), err=True)
|
|
2291
|
+
sys.exit(1)
|
|
2292
|
+
finally:
|
|
2293
|
+
for prepared in prepared_inputs:
|
|
2294
|
+
prepared.cleanup()
|
|
2295
|
+
# Release GPU memory so subsequent pipeline stages don't OOM
|
|
2296
|
+
shared_calc = geoms = None
|
|
2297
|
+
gc.collect() # break cyclic refs inside torch.nn.Module
|
|
2298
|
+
if torch.cuda.is_available():
|
|
2299
|
+
torch.cuda.empty_cache()
|