mlmm-toolkit 0.2.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hessian_ff/__init__.py +50 -0
- hessian_ff/analytical_hessian.py +609 -0
- hessian_ff/constants.py +46 -0
- hessian_ff/forcefield.py +339 -0
- hessian_ff/loaders.py +608 -0
- hessian_ff/native/Makefile +8 -0
- hessian_ff/native/__init__.py +28 -0
- hessian_ff/native/analytical_hessian.py +88 -0
- hessian_ff/native/analytical_hessian_ext.cpp +258 -0
- hessian_ff/native/bonded.py +82 -0
- hessian_ff/native/bonded_ext.cpp +640 -0
- hessian_ff/native/loader.py +349 -0
- hessian_ff/native/nonbonded.py +118 -0
- hessian_ff/native/nonbonded_ext.cpp +1150 -0
- hessian_ff/prmtop_parmed.py +23 -0
- hessian_ff/system.py +107 -0
- hessian_ff/terms/__init__.py +14 -0
- hessian_ff/terms/angle.py +73 -0
- hessian_ff/terms/bond.py +44 -0
- hessian_ff/terms/cmap.py +406 -0
- hessian_ff/terms/dihedral.py +141 -0
- hessian_ff/terms/nonbonded.py +209 -0
- hessian_ff/tests/__init__.py +0 -0
- hessian_ff/tests/conftest.py +75 -0
- hessian_ff/tests/data/small/complex.parm7 +1346 -0
- hessian_ff/tests/data/small/complex.pdb +125 -0
- hessian_ff/tests/data/small/complex.rst7 +63 -0
- hessian_ff/tests/test_coords_input.py +44 -0
- hessian_ff/tests/test_energy_force.py +49 -0
- hessian_ff/tests/test_hessian.py +137 -0
- hessian_ff/tests/test_smoke.py +18 -0
- hessian_ff/tests/test_validation.py +40 -0
- hessian_ff/workflows.py +889 -0
- mlmm/__init__.py +36 -0
- mlmm/__main__.py +7 -0
- mlmm/_version.py +34 -0
- mlmm/add_elem_info.py +374 -0
- mlmm/advanced_help.py +91 -0
- mlmm/align_freeze_atoms.py +601 -0
- mlmm/all.py +3535 -0
- mlmm/bond_changes.py +231 -0
- mlmm/bool_compat.py +223 -0
- mlmm/cli.py +574 -0
- mlmm/cli_utils.py +166 -0
- mlmm/default_group.py +337 -0
- mlmm/defaults.py +467 -0
- mlmm/define_layer.py +526 -0
- mlmm/dft.py +1041 -0
- mlmm/energy_diagram.py +253 -0
- mlmm/extract.py +2213 -0
- mlmm/fix_altloc.py +464 -0
- mlmm/freq.py +1406 -0
- mlmm/harmonic_constraints.py +140 -0
- mlmm/hessian_cache.py +44 -0
- mlmm/hessian_calc.py +174 -0
- mlmm/irc.py +638 -0
- mlmm/mlmm_calc.py +2262 -0
- mlmm/mm_parm.py +945 -0
- mlmm/oniom_export.py +1983 -0
- mlmm/oniom_import.py +457 -0
- mlmm/opt.py +1742 -0
- mlmm/path_opt.py +1353 -0
- mlmm/path_search.py +2299 -0
- mlmm/preflight.py +88 -0
- mlmm/py.typed +1 -0
- mlmm/pysis_runner.py +45 -0
- mlmm/scan.py +1047 -0
- mlmm/scan2d.py +1226 -0
- mlmm/scan3d.py +1265 -0
- mlmm/scan_common.py +184 -0
- mlmm/summary_log.py +736 -0
- mlmm/trj2fig.py +448 -0
- mlmm/tsopt.py +2871 -0
- mlmm/utils.py +2309 -0
- mlmm/xtb_embedcharge_correction.py +475 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
- pysisyphus/Geometry.py +1667 -0
- pysisyphus/LICENSE +674 -0
- pysisyphus/TableFormatter.py +63 -0
- pysisyphus/TablePrinter.py +74 -0
- pysisyphus/__init__.py +12 -0
- pysisyphus/calculators/AFIR.py +452 -0
- pysisyphus/calculators/AnaPot.py +20 -0
- pysisyphus/calculators/AnaPot2.py +48 -0
- pysisyphus/calculators/AnaPot3.py +12 -0
- pysisyphus/calculators/AnaPot4.py +20 -0
- pysisyphus/calculators/AnaPotBase.py +337 -0
- pysisyphus/calculators/AnaPotCBM.py +25 -0
- pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
- pysisyphus/calculators/CFOUR.py +250 -0
- pysisyphus/calculators/Calculator.py +844 -0
- pysisyphus/calculators/CerjanMiller.py +24 -0
- pysisyphus/calculators/Composite.py +123 -0
- pysisyphus/calculators/ConicalIntersection.py +171 -0
- pysisyphus/calculators/DFTBp.py +430 -0
- pysisyphus/calculators/DFTD3.py +66 -0
- pysisyphus/calculators/DFTD4.py +84 -0
- pysisyphus/calculators/Dalton.py +61 -0
- pysisyphus/calculators/Dimer.py +681 -0
- pysisyphus/calculators/Dummy.py +20 -0
- pysisyphus/calculators/EGO.py +76 -0
- pysisyphus/calculators/EnergyMin.py +224 -0
- pysisyphus/calculators/ExternalPotential.py +264 -0
- pysisyphus/calculators/FakeASE.py +35 -0
- pysisyphus/calculators/FourWellAnaPot.py +28 -0
- pysisyphus/calculators/FreeEndNEBPot.py +39 -0
- pysisyphus/calculators/Gaussian09.py +18 -0
- pysisyphus/calculators/Gaussian16.py +726 -0
- pysisyphus/calculators/HardSphere.py +159 -0
- pysisyphus/calculators/IDPPCalculator.py +49 -0
- pysisyphus/calculators/IPIClient.py +133 -0
- pysisyphus/calculators/IPIServer.py +234 -0
- pysisyphus/calculators/LEPSBase.py +24 -0
- pysisyphus/calculators/LEPSExpr.py +139 -0
- pysisyphus/calculators/LennardJones.py +80 -0
- pysisyphus/calculators/MOPAC.py +219 -0
- pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
- pysisyphus/calculators/MultiCalc.py +85 -0
- pysisyphus/calculators/NFK.py +45 -0
- pysisyphus/calculators/OBabel.py +87 -0
- pysisyphus/calculators/ONIOMv2.py +1129 -0
- pysisyphus/calculators/ORCA.py +893 -0
- pysisyphus/calculators/ORCA5.py +6 -0
- pysisyphus/calculators/OpenMM.py +88 -0
- pysisyphus/calculators/OpenMolcas.py +281 -0
- pysisyphus/calculators/OverlapCalculator.py +908 -0
- pysisyphus/calculators/Psi4.py +218 -0
- pysisyphus/calculators/PyPsi4.py +37 -0
- pysisyphus/calculators/PySCF.py +341 -0
- pysisyphus/calculators/PyXTB.py +73 -0
- pysisyphus/calculators/QCEngine.py +106 -0
- pysisyphus/calculators/Rastrigin.py +22 -0
- pysisyphus/calculators/Remote.py +76 -0
- pysisyphus/calculators/Rosenbrock.py +15 -0
- pysisyphus/calculators/SocketCalc.py +97 -0
- pysisyphus/calculators/TIP3P.py +111 -0
- pysisyphus/calculators/TransTorque.py +161 -0
- pysisyphus/calculators/Turbomole.py +965 -0
- pysisyphus/calculators/VRIPot.py +37 -0
- pysisyphus/calculators/WFOWrapper.py +333 -0
- pysisyphus/calculators/WFOWrapper2.py +341 -0
- pysisyphus/calculators/XTB.py +418 -0
- pysisyphus/calculators/__init__.py +81 -0
- pysisyphus/calculators/cosmo_data.py +139 -0
- pysisyphus/calculators/parser.py +150 -0
- pysisyphus/color.py +19 -0
- pysisyphus/config.py +133 -0
- pysisyphus/constants.py +65 -0
- pysisyphus/cos/AdaptiveNEB.py +230 -0
- pysisyphus/cos/ChainOfStates.py +725 -0
- pysisyphus/cos/FreeEndNEB.py +25 -0
- pysisyphus/cos/FreezingString.py +103 -0
- pysisyphus/cos/GrowingChainOfStates.py +71 -0
- pysisyphus/cos/GrowingNT.py +309 -0
- pysisyphus/cos/GrowingString.py +508 -0
- pysisyphus/cos/NEB.py +189 -0
- pysisyphus/cos/SimpleZTS.py +64 -0
- pysisyphus/cos/__init__.py +22 -0
- pysisyphus/cos/stiffness.py +199 -0
- pysisyphus/drivers/__init__.py +17 -0
- pysisyphus/drivers/afir.py +855 -0
- pysisyphus/drivers/barriers.py +271 -0
- pysisyphus/drivers/birkholz.py +138 -0
- pysisyphus/drivers/cluster.py +318 -0
- pysisyphus/drivers/diabatization.py +133 -0
- pysisyphus/drivers/merge.py +368 -0
- pysisyphus/drivers/merge_mol2.py +322 -0
- pysisyphus/drivers/opt.py +375 -0
- pysisyphus/drivers/perf.py +91 -0
- pysisyphus/drivers/pka.py +52 -0
- pysisyphus/drivers/precon_pos_rot.py +669 -0
- pysisyphus/drivers/rates.py +480 -0
- pysisyphus/drivers/replace.py +219 -0
- pysisyphus/drivers/scan.py +212 -0
- pysisyphus/drivers/spectrum.py +166 -0
- pysisyphus/drivers/thermo.py +31 -0
- pysisyphus/dynamics/Gaussian.py +103 -0
- pysisyphus/dynamics/__init__.py +20 -0
- pysisyphus/dynamics/colvars.py +136 -0
- pysisyphus/dynamics/driver.py +297 -0
- pysisyphus/dynamics/helpers.py +256 -0
- pysisyphus/dynamics/lincs.py +105 -0
- pysisyphus/dynamics/mdp.py +364 -0
- pysisyphus/dynamics/rattle.py +121 -0
- pysisyphus/dynamics/thermostats.py +128 -0
- pysisyphus/dynamics/wigner.py +266 -0
- pysisyphus/elem_data.py +3473 -0
- pysisyphus/exceptions.py +2 -0
- pysisyphus/filtertrj.py +69 -0
- pysisyphus/helpers.py +623 -0
- pysisyphus/helpers_pure.py +649 -0
- pysisyphus/init_logging.py +50 -0
- pysisyphus/intcoords/Bend.py +69 -0
- pysisyphus/intcoords/Bend2.py +25 -0
- pysisyphus/intcoords/BondedFragment.py +32 -0
- pysisyphus/intcoords/Cartesian.py +41 -0
- pysisyphus/intcoords/CartesianCoords.py +140 -0
- pysisyphus/intcoords/Coords.py +56 -0
- pysisyphus/intcoords/DLC.py +197 -0
- pysisyphus/intcoords/DistanceFunction.py +34 -0
- pysisyphus/intcoords/DummyImproper.py +70 -0
- pysisyphus/intcoords/DummyTorsion.py +72 -0
- pysisyphus/intcoords/LinearBend.py +105 -0
- pysisyphus/intcoords/LinearDisplacement.py +80 -0
- pysisyphus/intcoords/OutOfPlane.py +59 -0
- pysisyphus/intcoords/PrimTypes.py +286 -0
- pysisyphus/intcoords/Primitive.py +137 -0
- pysisyphus/intcoords/RedundantCoords.py +659 -0
- pysisyphus/intcoords/RobustTorsion.py +59 -0
- pysisyphus/intcoords/Rotation.py +147 -0
- pysisyphus/intcoords/Stretch.py +31 -0
- pysisyphus/intcoords/Torsion.py +101 -0
- pysisyphus/intcoords/Torsion2.py +25 -0
- pysisyphus/intcoords/Translation.py +45 -0
- pysisyphus/intcoords/__init__.py +61 -0
- pysisyphus/intcoords/augment_bonds.py +126 -0
- pysisyphus/intcoords/derivatives.py +10512 -0
- pysisyphus/intcoords/eval.py +80 -0
- pysisyphus/intcoords/exceptions.py +37 -0
- pysisyphus/intcoords/findiffs.py +48 -0
- pysisyphus/intcoords/generate_derivatives.py +414 -0
- pysisyphus/intcoords/helpers.py +235 -0
- pysisyphus/intcoords/logging_conf.py +10 -0
- pysisyphus/intcoords/mp_derivatives.py +10836 -0
- pysisyphus/intcoords/setup.py +962 -0
- pysisyphus/intcoords/setup_fast.py +176 -0
- pysisyphus/intcoords/update.py +272 -0
- pysisyphus/intcoords/valid.py +89 -0
- pysisyphus/interpolate/Geodesic.py +93 -0
- pysisyphus/interpolate/IDPP.py +55 -0
- pysisyphus/interpolate/Interpolator.py +116 -0
- pysisyphus/interpolate/LST.py +70 -0
- pysisyphus/interpolate/Redund.py +152 -0
- pysisyphus/interpolate/__init__.py +9 -0
- pysisyphus/interpolate/helpers.py +34 -0
- pysisyphus/io/__init__.py +22 -0
- pysisyphus/io/aomix.py +178 -0
- pysisyphus/io/cjson.py +24 -0
- pysisyphus/io/crd.py +101 -0
- pysisyphus/io/cube.py +220 -0
- pysisyphus/io/fchk.py +184 -0
- pysisyphus/io/hdf5.py +49 -0
- pysisyphus/io/hessian.py +72 -0
- pysisyphus/io/mol2.py +146 -0
- pysisyphus/io/molden.py +293 -0
- pysisyphus/io/orca.py +189 -0
- pysisyphus/io/pdb.py +269 -0
- pysisyphus/io/psf.py +79 -0
- pysisyphus/io/pubchem.py +31 -0
- pysisyphus/io/qcschema.py +34 -0
- pysisyphus/io/sdf.py +29 -0
- pysisyphus/io/xyz.py +61 -0
- pysisyphus/io/zmat.py +175 -0
- pysisyphus/irc/DWI.py +108 -0
- pysisyphus/irc/DampedVelocityVerlet.py +134 -0
- pysisyphus/irc/Euler.py +22 -0
- pysisyphus/irc/EulerPC.py +345 -0
- pysisyphus/irc/GonzalezSchlegel.py +187 -0
- pysisyphus/irc/IMKMod.py +164 -0
- pysisyphus/irc/IRC.py +878 -0
- pysisyphus/irc/IRCDummy.py +10 -0
- pysisyphus/irc/Instanton.py +307 -0
- pysisyphus/irc/LQA.py +53 -0
- pysisyphus/irc/ModeKill.py +136 -0
- pysisyphus/irc/ParamPlot.py +53 -0
- pysisyphus/irc/RK4.py +36 -0
- pysisyphus/irc/__init__.py +31 -0
- pysisyphus/irc/initial_displ.py +219 -0
- pysisyphus/linalg.py +411 -0
- pysisyphus/line_searches/Backtracking.py +88 -0
- pysisyphus/line_searches/HagerZhang.py +184 -0
- pysisyphus/line_searches/LineSearch.py +232 -0
- pysisyphus/line_searches/StrongWolfe.py +108 -0
- pysisyphus/line_searches/__init__.py +9 -0
- pysisyphus/line_searches/interpol.py +15 -0
- pysisyphus/modefollow/NormalMode.py +40 -0
- pysisyphus/modefollow/__init__.py +10 -0
- pysisyphus/modefollow/davidson.py +199 -0
- pysisyphus/modefollow/lanczos.py +95 -0
- pysisyphus/optimizers/BFGS.py +99 -0
- pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
- pysisyphus/optimizers/ConjugateGradient.py +98 -0
- pysisyphus/optimizers/CubicNewton.py +75 -0
- pysisyphus/optimizers/FIRE.py +113 -0
- pysisyphus/optimizers/HessianOptimizer.py +1176 -0
- pysisyphus/optimizers/LBFGS.py +228 -0
- pysisyphus/optimizers/LayerOpt.py +411 -0
- pysisyphus/optimizers/MicroOptimizer.py +169 -0
- pysisyphus/optimizers/NCOptimizer.py +90 -0
- pysisyphus/optimizers/Optimizer.py +1084 -0
- pysisyphus/optimizers/PreconLBFGS.py +260 -0
- pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
- pysisyphus/optimizers/QuickMin.py +74 -0
- pysisyphus/optimizers/RFOptimizer.py +181 -0
- pysisyphus/optimizers/RSA.py +99 -0
- pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
- pysisyphus/optimizers/SteepestDescent.py +23 -0
- pysisyphus/optimizers/StringOptimizer.py +173 -0
- pysisyphus/optimizers/__init__.py +41 -0
- pysisyphus/optimizers/closures.py +301 -0
- pysisyphus/optimizers/cls_map.py +58 -0
- pysisyphus/optimizers/exceptions.py +6 -0
- pysisyphus/optimizers/gdiis.py +280 -0
- pysisyphus/optimizers/guess_hessians.py +311 -0
- pysisyphus/optimizers/hessian_updates.py +355 -0
- pysisyphus/optimizers/poly_fit.py +285 -0
- pysisyphus/optimizers/precon.py +153 -0
- pysisyphus/optimizers/restrict_step.py +24 -0
- pysisyphus/pack.py +172 -0
- pysisyphus/peakdetect.py +948 -0
- pysisyphus/plot.py +1031 -0
- pysisyphus/run.py +2106 -0
- pysisyphus/socket_helper.py +74 -0
- pysisyphus/stocastic/FragmentKick.py +132 -0
- pysisyphus/stocastic/Kick.py +81 -0
- pysisyphus/stocastic/Pipeline.py +303 -0
- pysisyphus/stocastic/__init__.py +21 -0
- pysisyphus/stocastic/align.py +127 -0
- pysisyphus/testing.py +96 -0
- pysisyphus/thermo.py +156 -0
- pysisyphus/trj.py +824 -0
- pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
- pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
- pysisyphus/tsoptimizers/TRIM.py +59 -0
- pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
- pysisyphus/tsoptimizers/__init__.py +23 -0
- pysisyphus/wavefunction/Basis.py +239 -0
- pysisyphus/wavefunction/DIIS.py +76 -0
- pysisyphus/wavefunction/__init__.py +25 -0
- pysisyphus/wavefunction/build_ext.py +42 -0
- pysisyphus/wavefunction/cart2sph.py +190 -0
- pysisyphus/wavefunction/diabatization.py +304 -0
- pysisyphus/wavefunction/excited_states.py +435 -0
- pysisyphus/wavefunction/gen_ints.py +1811 -0
- pysisyphus/wavefunction/helpers.py +104 -0
- pysisyphus/wavefunction/ints/__init__.py +0 -0
- pysisyphus/wavefunction/ints/boys.py +193 -0
- pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
- pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
- pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
- pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
- pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
- pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
- pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
- pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
- pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
- pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
- pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
- pysisyphus/wavefunction/localization.py +458 -0
- pysisyphus/wavefunction/multipole.py +159 -0
- pysisyphus/wavefunction/normalization.py +36 -0
- pysisyphus/wavefunction/pop_analysis.py +134 -0
- pysisyphus/wavefunction/shells.py +1171 -0
- pysisyphus/wavefunction/wavefunction.py +504 -0
- pysisyphus/wrapper/__init__.py +11 -0
- pysisyphus/wrapper/exceptions.py +2 -0
- pysisyphus/wrapper/jmol.py +120 -0
- pysisyphus/wrapper/mwfn.py +169 -0
- pysisyphus/wrapper/packmol.py +71 -0
- pysisyphus/xyzloader.py +168 -0
- pysisyphus/yaml_mods.py +45 -0
- thermoanalysis/LICENSE +674 -0
- thermoanalysis/QCData.py +244 -0
- thermoanalysis/__init__.py +0 -0
- thermoanalysis/config.py +3 -0
- thermoanalysis/constants.py +20 -0
- thermoanalysis/thermo.py +1011 -0
mlmm/opt.py
ADDED
|
@@ -0,0 +1,1742 @@
|
|
|
1
|
+
# mlmm/opt.py
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
ML/MM geometry optimization (LBFGS or RFO) with UMA + hessian_ff calculator.
|
|
5
|
+
|
|
6
|
+
Example:
|
|
7
|
+
mlmm opt -i pocket.pdb --parm real.parm7 --model-pdb ml_region.pdb -q 0
|
|
8
|
+
|
|
9
|
+
For detailed documentation, see: docs/opt.md
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Set
|
|
14
|
+
|
|
15
|
+
import ast
|
|
16
|
+
import contextlib
|
|
17
|
+
import gc
|
|
18
|
+
import io
|
|
19
|
+
import logging
|
|
20
|
+
|
|
21
|
+
import sys
|
|
22
|
+
import textwrap
|
|
23
|
+
import traceback
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
import click
|
|
28
|
+
import numpy as np
|
|
29
|
+
import torch
|
|
30
|
+
import time
|
|
31
|
+
|
|
32
|
+
from pysisyphus.helpers import geom_loader
|
|
33
|
+
from pysisyphus.optimizers.LBFGS import LBFGS
|
|
34
|
+
from pysisyphus.optimizers.RFOptimizer import RFOptimizer
|
|
35
|
+
from pysisyphus.optimizers.exceptions import OptimizationError, ZeroStepLength
|
|
36
|
+
from pysisyphus.constants import ANG2BOHR, BOHR2ANG, AU2EV
|
|
37
|
+
from pysisyphus.TablePrinter import TablePrinter
|
|
38
|
+
|
|
39
|
+
from .mlmm_calc import mlmm, mlmm_mm_only
|
|
40
|
+
from .defaults import (
|
|
41
|
+
GEOM_KW_DEFAULT,
|
|
42
|
+
HESSIAN_DIMER_KW,
|
|
43
|
+
MLMM_CALC_KW,
|
|
44
|
+
OPT_BASE_KW,
|
|
45
|
+
LBFGS_KW,
|
|
46
|
+
RFO_KW,
|
|
47
|
+
OPT_MODE_ALIASES,
|
|
48
|
+
MICROITER_KW,
|
|
49
|
+
BFACTOR_ML,
|
|
50
|
+
BFACTOR_MOVABLE_MM,
|
|
51
|
+
BFACTOR_FROZEN,
|
|
52
|
+
)
|
|
53
|
+
from .utils import (
|
|
54
|
+
append_xyz_trajectory as _append_xyz_trajectory,
|
|
55
|
+
convert_xyz_to_pdb,
|
|
56
|
+
set_convert_file_enabled,
|
|
57
|
+
is_convert_file_enabled,
|
|
58
|
+
convert_xyz_like_outputs,
|
|
59
|
+
deep_update,
|
|
60
|
+
load_yaml_dict,
|
|
61
|
+
apply_yaml_overrides,
|
|
62
|
+
pretty_block,
|
|
63
|
+
strip_inherited_keys,
|
|
64
|
+
filter_calc_for_echo,
|
|
65
|
+
format_freeze_atoms_for_echo,
|
|
66
|
+
format_elapsed,
|
|
67
|
+
merge_freeze_atom_indices,
|
|
68
|
+
prepare_input_structure,
|
|
69
|
+
apply_ref_pdb_override,
|
|
70
|
+
resolve_charge_spin_or_raise,
|
|
71
|
+
parse_indices_string,
|
|
72
|
+
build_model_pdb_from_bfactors,
|
|
73
|
+
build_model_pdb_from_indices,
|
|
74
|
+
update_pdb_bfactors_from_layers,
|
|
75
|
+
normalize_choice,
|
|
76
|
+
yaml_section_has_key,
|
|
77
|
+
is_scan_spec_file,
|
|
78
|
+
parse_dist_freeze_list,
|
|
79
|
+
parse_dist_freeze_spec,
|
|
80
|
+
load_pdb_atom_metadata,
|
|
81
|
+
)
|
|
82
|
+
from .cli_utils import resolve_yaml_sources, load_merged_yaml_cfg, make_is_param_explicit
|
|
83
|
+
|
|
84
|
+
EV2AU = 1.0 / AU2EV # eV → Hartree
|
|
85
|
+
H_EVAA_2_AU = EV2AU / (ANG2BOHR * ANG2BOHR) # (eV/Å^2) → (Hartree/Bohr^2)
|
|
86
|
+
|
|
87
|
+
# Flatten-loop constants (sourced from defaults.py)
|
|
88
|
+
OPT_FLATTEN_NEG_FREQ_THRESH_CM = HESSIAN_DIMER_KW["neg_freq_thresh_cm"]
|
|
89
|
+
OPT_FLATTEN_AMP_ANG = HESSIAN_DIMER_KW["flatten_amp_ang"]
|
|
90
|
+
OPT_FLATTEN_MAX_ITER = HESSIAN_DIMER_KW["flatten_max_iter"]
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# -----------------------------------------------
|
|
94
|
+
# Default settings (imported from defaults.py, aliased for compatibility)
|
|
95
|
+
# -----------------------------------------------
|
|
96
|
+
|
|
97
|
+
GEOM_KW: Dict[str, Any] = dict(GEOM_KW_DEFAULT)
|
|
98
|
+
CALC_KW: Dict[str, Any] = dict(MLMM_CALC_KW)
|
|
99
|
+
|
|
100
|
+
# Note: OPT_BASE_KW, LBFGS_KW, RFO_KW are imported from defaults.py
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class HarmonicBiasCalculator:
|
|
104
|
+
"""Wrap a base UMA calculator with harmonic distance restraints."""
|
|
105
|
+
|
|
106
|
+
def __init__(self, base_calc, k: float = 10.0, pairs: Optional[List[Tuple[int, int, float]]] = None):
|
|
107
|
+
self.base = base_calc
|
|
108
|
+
self.k_evAA = float(k)
|
|
109
|
+
self.k_au_bohr2 = self.k_evAA * H_EVAA_2_AU
|
|
110
|
+
self._pairs: List[Tuple[int, int, float]] = list(pairs or [])
|
|
111
|
+
|
|
112
|
+
def set_pairs(self, pairs: List[Tuple[int, int, float]]) -> None:
|
|
113
|
+
self._pairs = [(int(i), int(j), float(t)) for (i, j, t) in pairs]
|
|
114
|
+
|
|
115
|
+
def _bias_energy_forces_bohr(self, coords_bohr: np.ndarray) -> Tuple[float, np.ndarray]:
|
|
116
|
+
coords = np.array(coords_bohr, dtype=float).reshape(-1, 3)
|
|
117
|
+
n = coords.shape[0]
|
|
118
|
+
E_bias = 0.0
|
|
119
|
+
F_bias = np.zeros((n, 3), dtype=float)
|
|
120
|
+
k = self.k_au_bohr2
|
|
121
|
+
for (i, j, target_ang) in self._pairs:
|
|
122
|
+
if not (0 <= i < n and 0 <= j < n):
|
|
123
|
+
continue
|
|
124
|
+
rij_vec = coords[i] - coords[j]
|
|
125
|
+
rij = float(np.linalg.norm(rij_vec))
|
|
126
|
+
if rij < 1e-14:
|
|
127
|
+
continue
|
|
128
|
+
target_bohr = float(target_ang) * ANG2BOHR
|
|
129
|
+
diff_bohr = rij - target_bohr
|
|
130
|
+
E_bias += 0.5 * k * diff_bohr * diff_bohr
|
|
131
|
+
u = rij_vec / max(rij, 1e-14)
|
|
132
|
+
Fi = -k * diff_bohr * u
|
|
133
|
+
F_bias[i] += Fi
|
|
134
|
+
F_bias[j] -= Fi
|
|
135
|
+
return E_bias, F_bias.reshape(-1)
|
|
136
|
+
|
|
137
|
+
def get_forces(self, elem, coords):
|
|
138
|
+
coords_bohr = np.asarray(coords, dtype=float).reshape(-1, 3)
|
|
139
|
+
base = self.base.get_forces(elem, coords_bohr)
|
|
140
|
+
E0 = float(base["energy"])
|
|
141
|
+
F0 = np.asarray(base["forces"], dtype=float).reshape(-1)
|
|
142
|
+
Ebias, Fbias = self._bias_energy_forces_bohr(coords_bohr)
|
|
143
|
+
return {"energy": E0 + Ebias, "forces": F0 + Fbias}
|
|
144
|
+
|
|
145
|
+
def get_energy(self, elem, coords):
|
|
146
|
+
coords_bohr = np.asarray(coords, dtype=float).reshape(-1, 3)
|
|
147
|
+
E0 = float(self.base.get_energy(elem, coords_bohr)["energy"])
|
|
148
|
+
Ebias, _ = self._bias_energy_forces_bohr(coords_bohr)
|
|
149
|
+
return {"energy": E0 + Ebias}
|
|
150
|
+
|
|
151
|
+
def get_energy_and_forces(self, elem, coords):
|
|
152
|
+
res = self.get_forces(elem, coords)
|
|
153
|
+
return res["energy"], res["forces"]
|
|
154
|
+
|
|
155
|
+
def get_energy_and_gradient(self, elem, coords):
|
|
156
|
+
res = self.get_forces(elem, coords)
|
|
157
|
+
return res["energy"], -np.asarray(res["forces"], dtype=float).reshape(-1)
|
|
158
|
+
|
|
159
|
+
def __getattr__(self, name: str):
|
|
160
|
+
return getattr(self.base, name)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _parse_freeze_atoms(arg: Optional[str]) -> List[int]:
|
|
164
|
+
"""Parse comma-separated 1-based indices (e.g., "1,3,5") into 0-based ints."""
|
|
165
|
+
if arg is None:
|
|
166
|
+
return []
|
|
167
|
+
|
|
168
|
+
items = [chunk.strip() for chunk in str(arg).split(",")]
|
|
169
|
+
indices: List[int] = []
|
|
170
|
+
for idx, chunk in enumerate(items, start=1):
|
|
171
|
+
if not chunk:
|
|
172
|
+
continue
|
|
173
|
+
try:
|
|
174
|
+
value = int(chunk)
|
|
175
|
+
except ValueError as exc:
|
|
176
|
+
raise click.BadParameter(
|
|
177
|
+
f"Invalid integer in --freeze-atoms entry #{idx}: '{chunk}'"
|
|
178
|
+
) from exc
|
|
179
|
+
if value <= 0:
|
|
180
|
+
raise click.BadParameter(
|
|
181
|
+
f"--freeze-atoms expects 1-based positive indices; got {value}"
|
|
182
|
+
)
|
|
183
|
+
indices.append(value - 1)
|
|
184
|
+
return sorted(set(indices))
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _normalize_geom_freeze(value: Any) -> List[int]:
|
|
188
|
+
"""Normalize YAML-provided geom.freeze_atoms to a sorted 0-based list."""
|
|
189
|
+
if value is None:
|
|
190
|
+
return []
|
|
191
|
+
if isinstance(value, str):
|
|
192
|
+
tokens = [tok.strip() for tok in value.split(",") if tok.strip()]
|
|
193
|
+
try:
|
|
194
|
+
return sorted({int(tok) for tok in tokens})
|
|
195
|
+
except ValueError as exc:
|
|
196
|
+
raise click.BadParameter(
|
|
197
|
+
"geom.freeze_atoms must contain integers (string form)."
|
|
198
|
+
) from exc
|
|
199
|
+
try:
|
|
200
|
+
return sorted({int(idx) for idx in value})
|
|
201
|
+
except TypeError as exc:
|
|
202
|
+
raise click.BadParameter("geom.freeze_atoms must be iterable of integers.") from exc
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _parse_dist_freeze_args(
|
|
206
|
+
raw_args: Sequence[str],
|
|
207
|
+
one_based: bool,
|
|
208
|
+
atom_meta: Optional[Sequence[Dict[str, Any]]],
|
|
209
|
+
) -> List[Tuple[int, int, Optional[float]]]:
|
|
210
|
+
"""Parse all ``--dist-freeze`` arguments (inline literal or spec file).
|
|
211
|
+
|
|
212
|
+
Accepts the same format as ``--scan-lists``: inline Python literal
|
|
213
|
+
(e.g. ``'[(1,5,1.4)]'``) or a YAML/JSON spec file path. String atom
|
|
214
|
+
specs (e.g. ``'A:SER123:OG'``) are supported when *atom_meta* is
|
|
215
|
+
available. Target distance is optional — omit to freeze at the current
|
|
216
|
+
distance.
|
|
217
|
+
"""
|
|
218
|
+
all_pairs: List[Tuple[int, int, Optional[float]]] = []
|
|
219
|
+
for raw in raw_args:
|
|
220
|
+
if is_scan_spec_file(raw):
|
|
221
|
+
all_pairs.extend(parse_dist_freeze_spec(
|
|
222
|
+
Path(raw),
|
|
223
|
+
one_based_default=one_based,
|
|
224
|
+
atom_meta=atom_meta,
|
|
225
|
+
))
|
|
226
|
+
else:
|
|
227
|
+
all_pairs.extend(parse_dist_freeze_list(
|
|
228
|
+
raw,
|
|
229
|
+
one_based=one_based,
|
|
230
|
+
atom_meta=atom_meta,
|
|
231
|
+
))
|
|
232
|
+
return all_pairs
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _resolve_dist_freeze_targets(
|
|
236
|
+
geometry,
|
|
237
|
+
tuples: List[Tuple[int, int, Optional[float]]],
|
|
238
|
+
) -> List[Tuple[int, int, float]]:
|
|
239
|
+
coords_bohr = np.array(geometry.coords3d, dtype=float).reshape(-1, 3)
|
|
240
|
+
coords_ang = coords_bohr * BOHR2ANG
|
|
241
|
+
n = coords_ang.shape[0]
|
|
242
|
+
resolved: List[Tuple[int, int, float]] = []
|
|
243
|
+
for (i, j, target) in tuples:
|
|
244
|
+
if not (0 <= i < n and 0 <= j < n):
|
|
245
|
+
raise click.BadParameter(
|
|
246
|
+
f"--dist-freeze indices {(i, j)} are out of bounds for the loaded geometry (N={n})."
|
|
247
|
+
)
|
|
248
|
+
if target is None:
|
|
249
|
+
vec = coords_ang[i] - coords_ang[j]
|
|
250
|
+
dist = float(np.linalg.norm(vec))
|
|
251
|
+
else:
|
|
252
|
+
dist = float(target)
|
|
253
|
+
resolved.append((i, j, dist))
|
|
254
|
+
return resolved
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
# -------------------------------
|
|
258
|
+
# PDB helpers for B-factor patch
|
|
259
|
+
# -------------------------------
|
|
260
|
+
|
|
261
|
+
def _pdb_keys_from_line(line: str) -> Tuple[Tuple, Tuple]:
|
|
262
|
+
"""
|
|
263
|
+
Extract robust keys from a PDB ATOM/HETATM record.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
key_full: (chain, resseq, icode, resname, atomname, altloc)
|
|
267
|
+
key_simple: (chain, resseq, icode, atomname)
|
|
268
|
+
"""
|
|
269
|
+
atom_name = line[12:16].strip()
|
|
270
|
+
altloc = line[16:17].strip()
|
|
271
|
+
resname = line[17:20].strip()
|
|
272
|
+
chain = line[21:22].strip()
|
|
273
|
+
resseq_str = line[22:26].strip()
|
|
274
|
+
try:
|
|
275
|
+
resseq = int(resseq_str)
|
|
276
|
+
except ValueError:
|
|
277
|
+
resseq = -10**9 # unlikely sentinel when missing
|
|
278
|
+
icode = line[26:27].strip()
|
|
279
|
+
key_full = (chain, resseq, icode, resname, atom_name, altloc)
|
|
280
|
+
key_simple = (chain, resseq, icode, atom_name)
|
|
281
|
+
return key_full, key_simple
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _collect_ml_atom_keys(model_pdb: Path) -> Tuple[Set[Tuple], Set[Tuple]]:
|
|
285
|
+
"""Collect ML-region atom keys from model_pdb."""
|
|
286
|
+
keys_full: Set[Tuple] = set()
|
|
287
|
+
keys_simple: Set[Tuple] = set()
|
|
288
|
+
try:
|
|
289
|
+
with model_pdb.open("r") as fh:
|
|
290
|
+
for line in fh:
|
|
291
|
+
if line.startswith("ATOM") or line.startswith("HETATM"):
|
|
292
|
+
kf, ks = _pdb_keys_from_line(line)
|
|
293
|
+
keys_full.add(kf)
|
|
294
|
+
keys_simple.add(ks)
|
|
295
|
+
except Exception:
|
|
296
|
+
logger.debug("Failed to collect ML atom keys from %s", model_pdb, exc_info=True)
|
|
297
|
+
return keys_full, keys_simple
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _format_with_bfactor(line: str, b: float) -> str:
|
|
301
|
+
"""Return PDB line with B-factor field (cols 61-66) set to b (6.2f)."""
|
|
302
|
+
if len(line) < 66:
|
|
303
|
+
line = line.rstrip("\n")
|
|
304
|
+
line = line + " " * max(0, 66 - len(line))
|
|
305
|
+
line = line + "\n"
|
|
306
|
+
bf_str = f"{b:6.2f}"
|
|
307
|
+
# Preserve occupancy (cols 55-60), overwrite tempFactor (61-66).
|
|
308
|
+
new_line = line[:60] + bf_str + line[66:]
|
|
309
|
+
return new_line
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def _annotate_b_factors_inplace(
|
|
313
|
+
pdb_path: Path,
|
|
314
|
+
model_pdb: Path,
|
|
315
|
+
freeze_indices_0based: Sequence[int],
|
|
316
|
+
beta_ml: float = 100.0,
|
|
317
|
+
beta_frz: float = 50.0,
|
|
318
|
+
beta_both: float = 150.0,
|
|
319
|
+
) -> None:
|
|
320
|
+
"""
|
|
321
|
+
Overwrite B-factors in-place:
|
|
322
|
+
- ML-region atoms: 100.00
|
|
323
|
+
- frozen atoms: 50.00
|
|
324
|
+
- ML ∩ frozen: 150.00
|
|
325
|
+
Indexing for 'frozen' is 0-based and resets at each MODEL.
|
|
326
|
+
"""
|
|
327
|
+
ml_full, ml_simple = _collect_ml_atom_keys(model_pdb)
|
|
328
|
+
frozen_set = set(int(i) for i in (freeze_indices_0based or []))
|
|
329
|
+
|
|
330
|
+
try:
|
|
331
|
+
lines = pdb_path.read_text().splitlines(keepends=True)
|
|
332
|
+
except Exception:
|
|
333
|
+
logger.debug("Failed to read PDB file for B-factor annotation: %s", pdb_path, exc_info=True)
|
|
334
|
+
return
|
|
335
|
+
|
|
336
|
+
out_lines: List[str] = []
|
|
337
|
+
atom_idx = 0 # resets per MODEL
|
|
338
|
+
|
|
339
|
+
for line in lines:
|
|
340
|
+
rec = line[:6]
|
|
341
|
+
if rec.startswith("MODEL"):
|
|
342
|
+
# reset atom counter for each model
|
|
343
|
+
atom_idx = 0
|
|
344
|
+
out_lines.append(line)
|
|
345
|
+
continue
|
|
346
|
+
if rec.startswith("ATOM ") or rec.startswith("HETATM"):
|
|
347
|
+
kf, ks = _pdb_keys_from_line(line)
|
|
348
|
+
is_ml = (kf in ml_full) or (ks in ml_simple)
|
|
349
|
+
is_frz = (atom_idx in frozen_set)
|
|
350
|
+
if is_ml and is_frz:
|
|
351
|
+
out_lines.append(_format_with_bfactor(line, beta_both))
|
|
352
|
+
elif is_ml:
|
|
353
|
+
out_lines.append(_format_with_bfactor(line, beta_ml))
|
|
354
|
+
elif is_frz:
|
|
355
|
+
out_lines.append(_format_with_bfactor(line, beta_frz))
|
|
356
|
+
else:
|
|
357
|
+
out_lines.append(line)
|
|
358
|
+
atom_idx += 1
|
|
359
|
+
else:
|
|
360
|
+
out_lines.append(line)
|
|
361
|
+
|
|
362
|
+
try:
|
|
363
|
+
pdb_path.write_text("".join(out_lines))
|
|
364
|
+
except Exception:
|
|
365
|
+
logger.debug("Failed to write B-factor annotated PDB: %s", pdb_path, exc_info=True)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def _maybe_convert_outputs_to_pdb(
|
|
369
|
+
input_path: Path,
|
|
370
|
+
out_dir: Path,
|
|
371
|
+
dump: bool,
|
|
372
|
+
get_trj_fn,
|
|
373
|
+
final_xyz_path: Path,
|
|
374
|
+
model_pdb: Path,
|
|
375
|
+
freeze_indices_0based: Sequence[int],
|
|
376
|
+
ml_indices: Optional[List[int]] = None,
|
|
377
|
+
hess_mm_indices: Optional[List[int]] = None,
|
|
378
|
+
movable_mm_indices: Optional[List[int]] = None,
|
|
379
|
+
frozen_layer_indices: Optional[List[int]] = None,
|
|
380
|
+
) -> None:
|
|
381
|
+
"""
|
|
382
|
+
If the input is a PDB, convert outputs (final_geometry.xyz and, if dump, optimization_all_trj.xyz /
|
|
383
|
+
optimization_trj.xyz) to PDB,
|
|
384
|
+
and annotate B-factors for the 3-layer ML/MM system.
|
|
385
|
+
|
|
386
|
+
B-factor encoding (3-layer system):
|
|
387
|
+
ML atoms: 0.0
|
|
388
|
+
Movable MM atoms: 10.0
|
|
389
|
+
Frozen MM atoms: 20.0
|
|
390
|
+
|
|
391
|
+
If layer indices are not provided, falls back to legacy encoding:
|
|
392
|
+
ML atoms: 100.0
|
|
393
|
+
Frozen atoms: 50.0
|
|
394
|
+
ML ∩ frozen: 150.0
|
|
395
|
+
"""
|
|
396
|
+
if not is_convert_file_enabled():
|
|
397
|
+
return
|
|
398
|
+
if input_path.suffix.lower() != ".pdb":
|
|
399
|
+
return
|
|
400
|
+
|
|
401
|
+
# Determine if we should use the layer-based B-factor encoding
|
|
402
|
+
use_layer_bfactors = ml_indices is not None
|
|
403
|
+
|
|
404
|
+
ref_pdb = input_path.resolve()
|
|
405
|
+
# final_geometry.xyz → final_geometry.pdb
|
|
406
|
+
final_pdb = out_dir / "final_geometry.pdb"
|
|
407
|
+
try:
|
|
408
|
+
convert_xyz_to_pdb(final_xyz_path, ref_pdb, final_pdb)
|
|
409
|
+
click.echo(f"[convert] Wrote '{final_pdb}'.")
|
|
410
|
+
|
|
411
|
+
if use_layer_bfactors:
|
|
412
|
+
update_pdb_bfactors_from_layers(
|
|
413
|
+
final_pdb,
|
|
414
|
+
ml_indices=ml_indices or [],
|
|
415
|
+
hess_mm_indices=hess_mm_indices,
|
|
416
|
+
movable_mm_indices=movable_mm_indices,
|
|
417
|
+
frozen_indices=frozen_layer_indices,
|
|
418
|
+
)
|
|
419
|
+
click.echo(
|
|
420
|
+
f"[annot] B-factors set in '{final_pdb}' "
|
|
421
|
+
f"(ML={BFACTOR_ML:.0f}, MovableMM={BFACTOR_MOVABLE_MM:.0f}, "
|
|
422
|
+
f"FrozenMM={BFACTOR_FROZEN:.0f})."
|
|
423
|
+
)
|
|
424
|
+
else:
|
|
425
|
+
# Fall back to legacy encoding
|
|
426
|
+
_annotate_b_factors_inplace(
|
|
427
|
+
final_pdb,
|
|
428
|
+
model_pdb=model_pdb,
|
|
429
|
+
freeze_indices_0based=freeze_indices_0based,
|
|
430
|
+
)
|
|
431
|
+
click.echo(f"[annot] B-factors set in '{final_pdb}' (ML=100, frozen=50, both=150).")
|
|
432
|
+
except Exception as e:
|
|
433
|
+
click.echo(f"[convert] WARNING: Failed to convert final geometry to PDB: {e}", err=True)
|
|
434
|
+
|
|
435
|
+
# optimization_all_trj.xyz / optimization_trj.xyz → PDB (if dump)
|
|
436
|
+
if dump:
|
|
437
|
+
try:
|
|
438
|
+
wrote_any = False
|
|
439
|
+
all_trj_path = get_trj_fn("optimization_all_trj.xyz")
|
|
440
|
+
if all_trj_path.exists():
|
|
441
|
+
all_opt_pdb = out_dir / "optimization_all.pdb"
|
|
442
|
+
convert_xyz_to_pdb(all_trj_path, ref_pdb, all_opt_pdb)
|
|
443
|
+
click.echo(f"[convert] Wrote '{all_opt_pdb}'.")
|
|
444
|
+
wrote_any = True
|
|
445
|
+
|
|
446
|
+
if use_layer_bfactors:
|
|
447
|
+
update_pdb_bfactors_from_layers(
|
|
448
|
+
all_opt_pdb,
|
|
449
|
+
ml_indices=ml_indices or [],
|
|
450
|
+
hess_mm_indices=hess_mm_indices,
|
|
451
|
+
movable_mm_indices=movable_mm_indices,
|
|
452
|
+
frozen_indices=frozen_layer_indices,
|
|
453
|
+
)
|
|
454
|
+
click.echo(
|
|
455
|
+
f"[annot] B-factors set in '{all_opt_pdb}' "
|
|
456
|
+
f"(ML={BFACTOR_ML:.0f}, MovableMM={BFACTOR_MOVABLE_MM:.0f}, "
|
|
457
|
+
f"FrozenMM={BFACTOR_FROZEN:.0f})."
|
|
458
|
+
)
|
|
459
|
+
else:
|
|
460
|
+
_annotate_b_factors_inplace(
|
|
461
|
+
all_opt_pdb,
|
|
462
|
+
model_pdb=model_pdb,
|
|
463
|
+
freeze_indices_0based=freeze_indices_0based,
|
|
464
|
+
)
|
|
465
|
+
click.echo(f"[annot] B-factors set in '{all_opt_pdb}' (ML=100, frozen=50, both=150).")
|
|
466
|
+
|
|
467
|
+
trj_path = get_trj_fn("optimization_trj.xyz")
|
|
468
|
+
if trj_path.exists():
|
|
469
|
+
opt_pdb = out_dir / "optimization.pdb"
|
|
470
|
+
convert_xyz_to_pdb(trj_path, ref_pdb, opt_pdb)
|
|
471
|
+
click.echo(f"[convert] Wrote '{opt_pdb}'.")
|
|
472
|
+
wrote_any = True
|
|
473
|
+
|
|
474
|
+
if use_layer_bfactors:
|
|
475
|
+
update_pdb_bfactors_from_layers(
|
|
476
|
+
opt_pdb,
|
|
477
|
+
ml_indices=ml_indices or [],
|
|
478
|
+
hess_mm_indices=hess_mm_indices,
|
|
479
|
+
movable_mm_indices=movable_mm_indices,
|
|
480
|
+
frozen_indices=frozen_layer_indices,
|
|
481
|
+
)
|
|
482
|
+
click.echo(
|
|
483
|
+
f"[annot] B-factors set in '{opt_pdb}' "
|
|
484
|
+
f"(ML={BFACTOR_ML:.0f}, MovableMM={BFACTOR_MOVABLE_MM:.0f}, "
|
|
485
|
+
f"FrozenMM={BFACTOR_FROZEN:.0f})."
|
|
486
|
+
)
|
|
487
|
+
else:
|
|
488
|
+
_annotate_b_factors_inplace(
|
|
489
|
+
opt_pdb,
|
|
490
|
+
model_pdb=model_pdb,
|
|
491
|
+
freeze_indices_0based=freeze_indices_0based,
|
|
492
|
+
)
|
|
493
|
+
click.echo(f"[annot] B-factors set in '{opt_pdb}' (ML=100, frozen=50, both=150).")
|
|
494
|
+
|
|
495
|
+
if not wrote_any:
|
|
496
|
+
click.echo(
|
|
497
|
+
"[convert] WARNING: neither 'optimization_all_trj.xyz' nor 'optimization_trj.xyz' was found; "
|
|
498
|
+
"skipping trajectory PDB conversion.",
|
|
499
|
+
err=True,
|
|
500
|
+
)
|
|
501
|
+
except Exception as e:
|
|
502
|
+
click.echo(f"[convert] WARNING: Failed to convert optimization trajectory to PDB: {e}", err=True)
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
# -----------------------------------------------
|
|
506
|
+
# Flatten helpers
|
|
507
|
+
# -----------------------------------------------
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def _calc_energy(geom, calc_kwargs: dict, calc=None) -> float:
|
|
511
|
+
"""Compute energy (Hartree) from ML/MM calculator."""
|
|
512
|
+
owns_calc = calc is None
|
|
513
|
+
if owns_calc:
|
|
514
|
+
kw = dict(calc_kwargs or {})
|
|
515
|
+
kw["out_hess_torch"] = False
|
|
516
|
+
calc = mlmm(**kw)
|
|
517
|
+
result = calc.get_energy(geom.atoms, geom.coords)
|
|
518
|
+
energy = float(result.get("energy", 0.0))
|
|
519
|
+
del result
|
|
520
|
+
if owns_calc:
|
|
521
|
+
del calc
|
|
522
|
+
if torch.cuda.is_available():
|
|
523
|
+
torch.cuda.empty_cache()
|
|
524
|
+
return energy
|
|
525
|
+
|
|
526
|
+
|
|
527
|
+
def _flatten_all_imag_modes_for_geom(
|
|
528
|
+
geom,
|
|
529
|
+
masses_amu: np.ndarray,
|
|
530
|
+
calc_kwargs: dict,
|
|
531
|
+
freqs_cm: np.ndarray,
|
|
532
|
+
modes: torch.Tensor,
|
|
533
|
+
neg_freq_thresh_cm: float,
|
|
534
|
+
flatten_amp_ang: float,
|
|
535
|
+
) -> bool:
|
|
536
|
+
"""
|
|
537
|
+
Flatten all imaginary modes for a geometry in a single pass.
|
|
538
|
+
"""
|
|
539
|
+
neg_idx_all = np.where(freqs_cm < -abs(neg_freq_thresh_cm))[0]
|
|
540
|
+
if len(neg_idx_all) == 0:
|
|
541
|
+
return False
|
|
542
|
+
|
|
543
|
+
order = np.argsort(freqs_cm[neg_idx_all]) # most negative first
|
|
544
|
+
targets = [int(x) for x in neg_idx_all[order]]
|
|
545
|
+
mass_scale = np.sqrt(12.011 / masses_amu)[:, None]
|
|
546
|
+
amp_bohr = float(flatten_amp_ang) / BOHR2ANG
|
|
547
|
+
E_ref = _calc_energy(geom, calc_kwargs)
|
|
548
|
+
|
|
549
|
+
m3 = np.repeat(masses_amu, 3).reshape(-1, 3)
|
|
550
|
+
for idx in targets:
|
|
551
|
+
v_mw = modes[idx].detach().cpu().numpy().reshape(-1, 3)
|
|
552
|
+
v_cart = v_mw / np.sqrt(m3)
|
|
553
|
+
v_cart /= np.linalg.norm(v_cart)
|
|
554
|
+
|
|
555
|
+
disp = amp_bohr * mass_scale * v_cart
|
|
556
|
+
ref = geom.cart_coords.reshape(-1, 3)
|
|
557
|
+
|
|
558
|
+
plus = ref + disp
|
|
559
|
+
minus = ref - disp
|
|
560
|
+
|
|
561
|
+
geom.coords = plus.reshape(-1)
|
|
562
|
+
E_plus = _calc_energy(geom, calc_kwargs)
|
|
563
|
+
|
|
564
|
+
geom.coords = minus.reshape(-1)
|
|
565
|
+
E_minus = _calc_energy(geom, calc_kwargs)
|
|
566
|
+
|
|
567
|
+
use_plus = E_plus <= E_minus
|
|
568
|
+
geom.coords = (plus if use_plus else minus).reshape(-1)
|
|
569
|
+
E_keep = E_plus if use_plus else E_minus
|
|
570
|
+
delta_e = E_keep - E_ref
|
|
571
|
+
click.echo(
|
|
572
|
+
f"[Flatten] mode={idx} freq={freqs_cm[idx]:+.2f} cm^-1 "
|
|
573
|
+
f"E_disp={E_keep:.8f} Ha \u0394E={delta_e:+.8f} Ha"
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
if torch.cuda.is_available():
|
|
577
|
+
torch.cuda.empty_cache()
|
|
578
|
+
return True
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
# -----------------------------------------------
|
|
582
|
+
# Microiteration optimizer
|
|
583
|
+
# -----------------------------------------------
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
def _run_microiter_opt(
|
|
587
|
+
geometry,
|
|
588
|
+
calc_cfg: Dict[str, Any],
|
|
589
|
+
rfo_cfg: Dict[str, Any],
|
|
590
|
+
lbfgs_cfg: Dict[str, Any],
|
|
591
|
+
opt_cfg: Dict[str, Any],
|
|
592
|
+
microiter_cfg: Dict[str, Any],
|
|
593
|
+
out_dir_path: Path,
|
|
594
|
+
*,
|
|
595
|
+
dump: bool = False,
|
|
596
|
+
) -> None:
|
|
597
|
+
"""Run macro/micro alternating optimization (Gaussian 16-style microiteration).
|
|
598
|
+
|
|
599
|
+
Macro step: 1 RFO step moving only ML region (full ONIOM force).
|
|
600
|
+
Micro step: LBFGS relaxing MM region with MM-only forces until convergence.
|
|
601
|
+
"""
|
|
602
|
+
from .freq import _collect_layer_atom_sets
|
|
603
|
+
|
|
604
|
+
# Resolve layer atom sets
|
|
605
|
+
layer_sets = _collect_layer_atom_sets(calc_cfg)
|
|
606
|
+
ml_indices = sorted(layer_sets["ml"])
|
|
607
|
+
movable_mm = sorted(layer_sets["movable_mm"] | layer_sets["hess_mm"])
|
|
608
|
+
frozen_mm = sorted(layer_sets["frozen_mm"])
|
|
609
|
+
|
|
610
|
+
if not ml_indices:
|
|
611
|
+
click.echo("[microiter] WARNING: No ML atoms found. Falling back to standard optimization.")
|
|
612
|
+
return None
|
|
613
|
+
|
|
614
|
+
n_atoms = len(geometry.atoms)
|
|
615
|
+
all_indices = list(range(n_atoms))
|
|
616
|
+
mm_indices = sorted(set(all_indices) - set(ml_indices))
|
|
617
|
+
|
|
618
|
+
# Freeze lists: for macro step, freeze all MM; for micro step, freeze ML
|
|
619
|
+
macro_freeze = sorted(set(mm_indices) | set(frozen_mm))
|
|
620
|
+
micro_freeze = sorted(set(ml_indices) | set(frozen_mm))
|
|
621
|
+
|
|
622
|
+
max_cycles = int(opt_cfg.get("max_cycles", 10000))
|
|
623
|
+
thresh = opt_cfg.get("thresh", "gau")
|
|
624
|
+
micro_thresh = microiter_cfg.get("micro_thresh") or thresh
|
|
625
|
+
micro_max_cycles = int(microiter_cfg.get("micro_max_cycles", 10000))
|
|
626
|
+
|
|
627
|
+
click.echo(
|
|
628
|
+
f"[microiter] ML atoms: {len(ml_indices)}, "
|
|
629
|
+
f"Movable MM atoms: {len(movable_mm)}, "
|
|
630
|
+
f"Frozen MM atoms: {len(frozen_mm)}"
|
|
631
|
+
)
|
|
632
|
+
click.echo(f"[microiter] Macro thresh: {thresh}, Micro thresh: {micro_thresh}")
|
|
633
|
+
|
|
634
|
+
# Create ONIOM calculator (shared core for MM-only calc)
|
|
635
|
+
base_calc = mlmm(**calc_cfg)
|
|
636
|
+
mm_calc = mlmm_mm_only(base_calc.core, freeze_atoms=micro_freeze)
|
|
637
|
+
|
|
638
|
+
# Seed initial Hessian for RFO (with macro freeze)
|
|
639
|
+
# Try IRC endpoint cache first; fall back to full Hessian calculation.
|
|
640
|
+
from .hessian_cache import load as _hess_load
|
|
641
|
+
from .freq import (
|
|
642
|
+
_calc_full_hessian_torch as _freq_calc_full_hessian_torch,
|
|
643
|
+
_torch_device as _freq_torch_device,
|
|
644
|
+
)
|
|
645
|
+
hess_device = _freq_torch_device(calc_cfg.get("ml_device", "auto"))
|
|
646
|
+
|
|
647
|
+
# Always create macro calculator (needed for optimization loop below)
|
|
648
|
+
macro_calc_cfg = dict(calc_cfg)
|
|
649
|
+
macro_calc_cfg["freeze_atoms"] = macro_freeze
|
|
650
|
+
macro_calc_cfg["hess_mm_atoms"] = [] # macro step は ML-only Hessian
|
|
651
|
+
macro_calc = mlmm(**macro_calc_cfg)
|
|
652
|
+
|
|
653
|
+
cached = _hess_load("irc_endpoint")
|
|
654
|
+
_cache_used = False
|
|
655
|
+
if cached is not None:
|
|
656
|
+
active_dofs = cached.get("active_dofs")
|
|
657
|
+
h_raw = cached["hessian"]
|
|
658
|
+
if isinstance(h_raw, torch.Tensor):
|
|
659
|
+
h_init = h_raw.clone()
|
|
660
|
+
else:
|
|
661
|
+
h_init = torch.as_tensor(h_raw, dtype=torch.float64)
|
|
662
|
+
|
|
663
|
+
# Macro step freezes MM atoms → only ML DOFs are free.
|
|
664
|
+
# The cached IRC Hessian covers ML+MovableMM DOFs and is generally
|
|
665
|
+
# larger. Extract the ML-only sub-block when active_dofs are known;
|
|
666
|
+
# otherwise fall back to a fresh Hessian calculation.
|
|
667
|
+
n_free = geometry.cart_coords.size - 3 * len(macro_freeze)
|
|
668
|
+
if h_init.shape[0] == n_free:
|
|
669
|
+
# Size already matches (e.g. non-microiter or same freeze set)
|
|
670
|
+
geometry.set_calculator(macro_calc)
|
|
671
|
+
if active_dofs is not None:
|
|
672
|
+
geometry.within_partial_hessian = {
|
|
673
|
+
"active_n_dof": len(active_dofs),
|
|
674
|
+
"full_n_dof": geometry.cart_coords.size,
|
|
675
|
+
"active_dofs": active_dofs,
|
|
676
|
+
"active_atoms": sorted(set(d // 3 for d in active_dofs)),
|
|
677
|
+
}
|
|
678
|
+
geometry.cart_hessian = h_init
|
|
679
|
+
click.echo(f"[microiter] Reusing IRC endpoint Hessian for RFO macro step (shape={h_init.shape[0]}x{h_init.shape[1]}).")
|
|
680
|
+
_cache_used = True
|
|
681
|
+
elif active_dofs is not None:
|
|
682
|
+
# Extract ML-only sub-block from the larger cached Hessian.
|
|
683
|
+
macro_free_atoms = sorted(set(range(geometry.cart_coords.size // 3)) - set(macro_freeze))
|
|
684
|
+
macro_free_dofs = []
|
|
685
|
+
for a in macro_free_atoms:
|
|
686
|
+
macro_free_dofs.extend([3 * a, 3 * a + 1, 3 * a + 2])
|
|
687
|
+
# Map macro free DOFs to indices within the cached active_dofs
|
|
688
|
+
cached_dof_set = set(active_dofs)
|
|
689
|
+
sub_indices = []
|
|
690
|
+
for d in macro_free_dofs:
|
|
691
|
+
if d in cached_dof_set:
|
|
692
|
+
sub_indices.append(active_dofs.index(d))
|
|
693
|
+
if len(sub_indices) == n_free:
|
|
694
|
+
idx = torch.tensor(sub_indices, dtype=torch.long)
|
|
695
|
+
h_sub = h_init[idx][:, idx]
|
|
696
|
+
macro_active_dofs = macro_free_dofs
|
|
697
|
+
geometry.set_calculator(macro_calc)
|
|
698
|
+
geometry.within_partial_hessian = {
|
|
699
|
+
"active_n_dof": len(macro_active_dofs),
|
|
700
|
+
"full_n_dof": geometry.cart_coords.size,
|
|
701
|
+
"active_dofs": macro_active_dofs,
|
|
702
|
+
"active_atoms": macro_free_atoms,
|
|
703
|
+
}
|
|
704
|
+
geometry.cart_hessian = h_sub
|
|
705
|
+
click.echo(
|
|
706
|
+
f"[microiter] Reusing IRC endpoint Hessian (sub-block) for RFO macro step "
|
|
707
|
+
f"(cached {h_init.shape[0]}x{h_init.shape[1]} → extracted {h_sub.shape[0]}x{h_sub.shape[1]})."
|
|
708
|
+
)
|
|
709
|
+
_cache_used = True
|
|
710
|
+
del h_sub
|
|
711
|
+
else:
|
|
712
|
+
click.echo(
|
|
713
|
+
f"[microiter] IRC endpoint Hessian sub-block extraction failed "
|
|
714
|
+
f"(expected {n_free}, got {len(sub_indices)}). Falling back to fresh Hessian."
|
|
715
|
+
)
|
|
716
|
+
else:
|
|
717
|
+
click.echo(
|
|
718
|
+
f"[microiter] IRC endpoint Hessian size mismatch "
|
|
719
|
+
f"(cached={h_init.shape[0]}, needed={n_free}). Falling back to fresh Hessian."
|
|
720
|
+
)
|
|
721
|
+
del h_init
|
|
722
|
+
if not _cache_used:
|
|
723
|
+
click.echo("[microiter] Seeding initial Hessian for RFO macro step.")
|
|
724
|
+
geometry.set_calculator(macro_calc)
|
|
725
|
+
|
|
726
|
+
h_init, _ = _freq_calc_full_hessian_torch(
|
|
727
|
+
geometry, macro_calc_cfg, hess_device, refresh_geom_meta=True,
|
|
728
|
+
)
|
|
729
|
+
geometry.cart_hessian = h_init
|
|
730
|
+
click.echo(f"[microiter] Initial Hessian seeded (shape={h_init.shape[0]}x{h_init.shape[1]}).")
|
|
731
|
+
del h_init
|
|
732
|
+
|
|
733
|
+
optim_all_path = out_dir_path / "optimization_all_trj.xyz"
|
|
734
|
+
macro_trj_path = out_dir_path / "optimization_trj.xyz"
|
|
735
|
+
total_macro_steps = 0
|
|
736
|
+
|
|
737
|
+
# Create persistent RFOptimizer once (LayerOpt pattern).
|
|
738
|
+
# This preserves the BFGS Hessian update chain across macro iterations.
|
|
739
|
+
# NOTE: geometry already has macro_calc set (line above); do NOT call
|
|
740
|
+
# set_calculator() again as it clears the pre-computed cart_hessian.
|
|
741
|
+
geometry.freeze_atoms = macro_freeze
|
|
742
|
+
|
|
743
|
+
rfo_args = dict(rfo_cfg)
|
|
744
|
+
rfo_args["max_cycles"] = max_cycles
|
|
745
|
+
rfo_args["out_dir"] = str(out_dir_path)
|
|
746
|
+
rfo_args["dump"] = False # trajectory dumping handled externally
|
|
747
|
+
rfo_args["thresh"] = thresh
|
|
748
|
+
|
|
749
|
+
macro_optimizer = RFOptimizer(geometry, **rfo_args)
|
|
750
|
+
macro_optimizer.prepare_opt() # initialise Hessian from geometry.cart_hessian
|
|
751
|
+
|
|
752
|
+
# Microiteration progress table (pysisyphus-style with micro_steps column)
|
|
753
|
+
micro_header = "cycle Δ(energy) max(|force|) rms(force) max(|step|) rms(step) micro_steps s/cycle".split()
|
|
754
|
+
micro_col_fmts = "int float float float float float int float_short".split()
|
|
755
|
+
micro_table = TablePrinter(micro_header, micro_col_fmts, width=12)
|
|
756
|
+
micro_table.print_header()
|
|
757
|
+
|
|
758
|
+
for macro_iter in range(max_cycles):
|
|
759
|
+
# ---- Macro step: 1 RFO step with ONIOM forces, MM frozen ----
|
|
760
|
+
geometry.freeze_atoms = macro_freeze
|
|
761
|
+
geometry.set_calculator(macro_calc)
|
|
762
|
+
|
|
763
|
+
# Manually feed state to the persistent optimizer (cf. LayerOpt lines 358-364)
|
|
764
|
+
macro_optimizer.coords.append(geometry.coords.copy())
|
|
765
|
+
macro_optimizer.cart_coords.append(geometry.cart_coords.copy())
|
|
766
|
+
macro_optimizer.cur_cycle = macro_iter
|
|
767
|
+
|
|
768
|
+
t_start = time.time()
|
|
769
|
+
step = macro_optimizer.optimize() # housekeeping() triggers BFGS update
|
|
770
|
+
macro_optimizer.steps.append(step)
|
|
771
|
+
|
|
772
|
+
# Convergence check
|
|
773
|
+
macro_converged, conv_info = macro_optimizer.check_convergence()
|
|
774
|
+
total_macro_steps += 1
|
|
775
|
+
|
|
776
|
+
if dump:
|
|
777
|
+
with open(macro_trj_path, "a") as f:
|
|
778
|
+
f.write(geometry.as_xyz() + "\n")
|
|
779
|
+
_append_xyz_trajectory(optim_all_path, macro_trj_path)
|
|
780
|
+
|
|
781
|
+
if macro_converged:
|
|
782
|
+
# Print final converged row (no micro steps)
|
|
783
|
+
energy_diff = macro_optimizer.energies[-1] - macro_optimizer.energies[-2] if len(macro_optimizer.energies) >= 2 else float("nan")
|
|
784
|
+
marks = [False, *conv_info.get_convergence()[:-1], False, False]
|
|
785
|
+
cycle_time = time.time() - t_start
|
|
786
|
+
micro_table.print_row(
|
|
787
|
+
(macro_iter, energy_diff, macro_optimizer.max_forces[-1], macro_optimizer.rms_forces[-1],
|
|
788
|
+
macro_optimizer.max_steps[-1], macro_optimizer.rms_steps[-1], 0, cycle_time),
|
|
789
|
+
marks=marks,
|
|
790
|
+
)
|
|
791
|
+
click.echo(f"[microiter] Macro convergence reached at iteration {macro_iter + 1}.")
|
|
792
|
+
break
|
|
793
|
+
|
|
794
|
+
# Apply step to geometry
|
|
795
|
+
new_coords = geometry.coords.copy() + step
|
|
796
|
+
geometry.coords = new_coords
|
|
797
|
+
# Record actual step (may differ due to coordinate back-transformation)
|
|
798
|
+
macro_optimizer.steps[-1] = geometry.coords - macro_optimizer.coords[-1]
|
|
799
|
+
|
|
800
|
+
# ---- Micro step: LBFGS with MM-only forces, ML frozen ----
|
|
801
|
+
geometry.freeze_atoms = micro_freeze
|
|
802
|
+
geometry.set_calculator(mm_calc)
|
|
803
|
+
|
|
804
|
+
micro_lbfgs_args = dict(lbfgs_cfg)
|
|
805
|
+
micro_lbfgs_args["max_cycles"] = micro_max_cycles
|
|
806
|
+
micro_lbfgs_args["thresh"] = micro_thresh
|
|
807
|
+
micro_lbfgs_args["out_dir"] = str(out_dir_path)
|
|
808
|
+
micro_lbfgs_args["dump"] = dump
|
|
809
|
+
|
|
810
|
+
micro_opt = LBFGS(geometry, **micro_lbfgs_args)
|
|
811
|
+
with contextlib.redirect_stdout(io.StringIO()):
|
|
812
|
+
micro_opt.run()
|
|
813
|
+
micro_steps = max(int(micro_opt.cur_cycle) + 1, 1)
|
|
814
|
+
|
|
815
|
+
if dump:
|
|
816
|
+
_append_xyz_trajectory(optim_all_path, out_dir_path / "optimization_trj.xyz")
|
|
817
|
+
|
|
818
|
+
del micro_opt
|
|
819
|
+
if torch.cuda.is_available():
|
|
820
|
+
torch.cuda.empty_cache()
|
|
821
|
+
|
|
822
|
+
# Print progress row with micro_steps
|
|
823
|
+
cycle_time = time.time() - t_start
|
|
824
|
+
energy_diff = macro_optimizer.energies[-1] - macro_optimizer.energies[-2] if len(macro_optimizer.energies) >= 2 else float("nan")
|
|
825
|
+
marks = [False, *conv_info.get_convergence()[:-1], False, False]
|
|
826
|
+
if (macro_iter > 1) and (macro_iter % 10 == 0):
|
|
827
|
+
micro_table.print_sep()
|
|
828
|
+
micro_table.print_row(
|
|
829
|
+
(macro_iter, energy_diff, macro_optimizer.max_forces[-1], macro_optimizer.rms_forces[-1],
|
|
830
|
+
macro_optimizer.max_steps[-1], macro_optimizer.rms_steps[-1], micro_steps, cycle_time),
|
|
831
|
+
marks=marks,
|
|
832
|
+
)
|
|
833
|
+
|
|
834
|
+
else:
|
|
835
|
+
click.echo(f"[microiter] Reached max macro iterations ({max_cycles}).")
|
|
836
|
+
|
|
837
|
+
del macro_optimizer
|
|
838
|
+
if torch.cuda.is_available():
|
|
839
|
+
torch.cuda.empty_cache()
|
|
840
|
+
|
|
841
|
+
click.echo(f"[microiter] Total macro steps: {total_macro_steps}")
|
|
842
|
+
# Restore full calculator
|
|
843
|
+
geometry.freeze_atoms = list(set(frozen_mm))
|
|
844
|
+
geometry.set_calculator(base_calc)
|
|
845
|
+
|
|
846
|
+
return geometry
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
# -----------------------------------------------
|
|
850
|
+
# CLI
|
|
851
|
+
# -----------------------------------------------
|
|
852
|
+
|
|
853
|
+
@click.command(
|
|
854
|
+
help="ML/MM geometry optimization with LBFGS (light) or RFO (heavy).",
|
|
855
|
+
context_settings={"help_option_names": ["-h", "--help"]},
|
|
856
|
+
)
|
|
857
|
+
@click.option(
|
|
858
|
+
"-i", "--input",
|
|
859
|
+
"input_path",
|
|
860
|
+
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
|
861
|
+
required=True,
|
|
862
|
+
help="Input structure file (PDB, XYZ). XYZ provides higher coordinate precision. "
|
|
863
|
+
"If XYZ, use --ref-pdb to specify PDB topology for atom ordering and output conversion.",
|
|
864
|
+
)
|
|
865
|
+
@click.option(
|
|
866
|
+
"--ref-pdb",
|
|
867
|
+
"ref_pdb",
|
|
868
|
+
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
|
869
|
+
default=None,
|
|
870
|
+
show_default=False,
|
|
871
|
+
help="Reference PDB topology when input is XYZ. XYZ coordinates are used (higher precision) "
|
|
872
|
+
"while PDB provides atom ordering and residue information for output conversion.",
|
|
873
|
+
)
|
|
874
|
+
@click.option(
|
|
875
|
+
"--parm",
|
|
876
|
+
"real_parm7",
|
|
877
|
+
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
|
878
|
+
required=True,
|
|
879
|
+
help="Amber parm7 topology covering the whole enzyme complex.",
|
|
880
|
+
)
|
|
881
|
+
@click.option(
|
|
882
|
+
"--model-pdb",
|
|
883
|
+
"model_pdb",
|
|
884
|
+
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
|
885
|
+
required=False,
|
|
886
|
+
help="PDB defining atoms that belong to the ML (high-level) region. "
|
|
887
|
+
"Optional when --detect-layer is enabled.",
|
|
888
|
+
)
|
|
889
|
+
@click.option(
|
|
890
|
+
"--model-indices",
|
|
891
|
+
"model_indices_str",
|
|
892
|
+
type=str,
|
|
893
|
+
default=None,
|
|
894
|
+
show_default=False,
|
|
895
|
+
help="Comma-separated atom indices for the ML region (ranges allowed like 1-5). "
|
|
896
|
+
"Used when --model-pdb is omitted.",
|
|
897
|
+
)
|
|
898
|
+
@click.option(
|
|
899
|
+
"--model-indices-one-based/--model-indices-zero-based",
|
|
900
|
+
"model_indices_one_based",
|
|
901
|
+
default=True,
|
|
902
|
+
show_default=True,
|
|
903
|
+
help="Interpret --model-indices as 1-based (default) or 0-based.",
|
|
904
|
+
)
|
|
905
|
+
@click.option(
|
|
906
|
+
"--detect-layer/--no-detect-layer",
|
|
907
|
+
"detect_layer",
|
|
908
|
+
default=True,
|
|
909
|
+
show_default=True,
|
|
910
|
+
help="Detect ML/MM layers from input PDB B-factors (B=0/10/20). "
|
|
911
|
+
"If disabled, you must provide --model-pdb or --model-indices.",
|
|
912
|
+
)
|
|
913
|
+
@click.option("-q", "--charge", type=int, required=False,
|
|
914
|
+
help="ML region charge. Required unless --ligand-charge is provided.")
|
|
915
|
+
@click.option("-l", "--ligand-charge", type=str, default=None, show_default=False,
|
|
916
|
+
help="Total charge or per-resname mapping (e.g., GPP:-3,SAM:1) used to derive "
|
|
917
|
+
"charge when -q is omitted (requires PDB input or --ref-pdb).")
|
|
918
|
+
@click.option(
|
|
919
|
+
"-m",
|
|
920
|
+
"--multiplicity",
|
|
921
|
+
"spin",
|
|
922
|
+
type=int,
|
|
923
|
+
default=None,
|
|
924
|
+
show_default=False,
|
|
925
|
+
help="Spin multiplicity (2S+1) for the ML region. Defaults to 1 when omitted.",
|
|
926
|
+
)
|
|
927
|
+
@click.option(
|
|
928
|
+
"--freeze-atoms",
|
|
929
|
+
"freeze_atoms_text",
|
|
930
|
+
type=str,
|
|
931
|
+
default=None,
|
|
932
|
+
show_default=False,
|
|
933
|
+
help="Comma-separated 1-based atom indices to freeze (e.g., '1,3,5').",
|
|
934
|
+
)
|
|
935
|
+
@click.option(
|
|
936
|
+
"--radius-partial-hessian",
|
|
937
|
+
"--hess-cutoff",
|
|
938
|
+
"radius_partial_hessian",
|
|
939
|
+
type=float,
|
|
940
|
+
default=None,
|
|
941
|
+
show_default=False,
|
|
942
|
+
help="Distance cutoff (Å) from ML region for MM atoms to include in Hessian calculation. "
|
|
943
|
+
"Applied to movable MM atoms and can be combined with --detect-layer. "
|
|
944
|
+
"`--hess-cutoff` is a compatibility alias.",
|
|
945
|
+
)
|
|
946
|
+
@click.option(
|
|
947
|
+
"--radius-freeze",
|
|
948
|
+
"--movable-cutoff",
|
|
949
|
+
"radius_freeze",
|
|
950
|
+
type=float,
|
|
951
|
+
default=None,
|
|
952
|
+
show_default=False,
|
|
953
|
+
help="Distance cutoff (Å) from ML region for movable MM atoms. "
|
|
954
|
+
"MM atoms beyond this are frozen. "
|
|
955
|
+
"Providing --radius-freeze disables --detect-layer and uses distance-based layer assignment. "
|
|
956
|
+
"`--movable-cutoff` is a compatibility alias.",
|
|
957
|
+
)
|
|
958
|
+
@click.option(
|
|
959
|
+
"--dist-freeze",
|
|
960
|
+
"dist_freeze_raw",
|
|
961
|
+
type=str,
|
|
962
|
+
multiple=True,
|
|
963
|
+
default=(),
|
|
964
|
+
show_default=False,
|
|
965
|
+
help="Distance restraints: inline Python literal (e.g. '[(1,5,1.4)]') or a YAML/JSON spec file path. "
|
|
966
|
+
"Same format as --scan-lists: (i,j,target_A) triples. "
|
|
967
|
+
"Target may be omitted to freeze at the current distance: (i,j).",
|
|
968
|
+
)
|
|
969
|
+
@click.option(
|
|
970
|
+
"--one-based/--zero-based",
|
|
971
|
+
"one_based",
|
|
972
|
+
default=True,
|
|
973
|
+
show_default=True,
|
|
974
|
+
help="Interpret --dist-freeze / --scan-lists indices as 1-based (default) or 0-based.",
|
|
975
|
+
)
|
|
976
|
+
@click.option(
|
|
977
|
+
"--bias-k",
|
|
978
|
+
type=float,
|
|
979
|
+
default=300.0,
|
|
980
|
+
show_default=True,
|
|
981
|
+
help="Harmonic restraint strength k [eV/Å^2] for --dist-freeze.",
|
|
982
|
+
)
|
|
983
|
+
@click.option("--max-cycles", type=int, default=10000, show_default=True, help="Maximum number of optimization cycles.")
|
|
984
|
+
@click.option(
|
|
985
|
+
"--dump/--no-dump",
|
|
986
|
+
default=False,
|
|
987
|
+
show_default=True,
|
|
988
|
+
help="Write optimization trajectories ('optimization_trj.xyz' and 'optimization_all_trj.xyz').",
|
|
989
|
+
)
|
|
990
|
+
@click.option("-o", "--out-dir", type=str, default="./result_opt/", show_default=True, help="Output directory.")
|
|
991
|
+
@click.option(
|
|
992
|
+
"--thresh",
|
|
993
|
+
type=click.Choice(["gau_loose", "gau", "gau_tight", "gau_vtight", "baker", "never"], case_sensitive=False),
|
|
994
|
+
default=None,
|
|
995
|
+
help="Convergence preset.",
|
|
996
|
+
)
|
|
997
|
+
@click.option(
|
|
998
|
+
"--opt-mode",
|
|
999
|
+
type=click.Choice(["grad", "hess", "light", "heavy", "lbfgs", "rfo"], case_sensitive=False),
|
|
1000
|
+
default="grad",
|
|
1001
|
+
show_default=True,
|
|
1002
|
+
help="Optimization mode: grad (lbfgs) or hess (rfo). Aliases light/heavy and lbfgs/rfo are accepted.",
|
|
1003
|
+
)
|
|
1004
|
+
@click.option(
|
|
1005
|
+
"--microiter/--no-microiter",
|
|
1006
|
+
"microiter",
|
|
1007
|
+
default=True,
|
|
1008
|
+
show_default=True,
|
|
1009
|
+
help="Enable microiteration: alternate ML 1-step (RFO) and MM relaxation (LBFGS with MM-only forces). "
|
|
1010
|
+
"Only effective in --opt-mode hess (RFO). Ignored in grad mode.",
|
|
1011
|
+
)
|
|
1012
|
+
@click.option(
|
|
1013
|
+
"--flatten/--no-flatten",
|
|
1014
|
+
"flatten",
|
|
1015
|
+
default=False,
|
|
1016
|
+
show_default=True,
|
|
1017
|
+
help="Enable/disable imaginary-mode flatten loop after optimization.",
|
|
1018
|
+
)
|
|
1019
|
+
@click.option(
|
|
1020
|
+
"--config",
|
|
1021
|
+
"config_yaml",
|
|
1022
|
+
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
|
1023
|
+
default=None,
|
|
1024
|
+
help="Base YAML configuration file applied before explicit CLI options.",
|
|
1025
|
+
)
|
|
1026
|
+
@click.option(
|
|
1027
|
+
"--show-config/--no-show-config",
|
|
1028
|
+
"show_config",
|
|
1029
|
+
default=False,
|
|
1030
|
+
show_default=True,
|
|
1031
|
+
help="Print resolved configuration and continue execution.",
|
|
1032
|
+
)
|
|
1033
|
+
@click.option(
|
|
1034
|
+
"--dry-run/--no-dry-run",
|
|
1035
|
+
"dry_run",
|
|
1036
|
+
default=False,
|
|
1037
|
+
show_default=True,
|
|
1038
|
+
help="Validate options and print the execution plan without running optimization.",
|
|
1039
|
+
)
|
|
1040
|
+
@click.option(
|
|
1041
|
+
"--convert-files/--no-convert-files",
|
|
1042
|
+
"convert_files",
|
|
1043
|
+
default=True,
|
|
1044
|
+
show_default=True,
|
|
1045
|
+
help="Convert XYZ/TRJ outputs into PDB companions based on the input format.",
|
|
1046
|
+
)
|
|
1047
|
+
@click.option(
|
|
1048
|
+
"-b", "--backend",
|
|
1049
|
+
type=click.Choice(["uma", "orb", "mace", "aimnet2"], case_sensitive=False),
|
|
1050
|
+
default=None,
|
|
1051
|
+
show_default=False,
|
|
1052
|
+
help="ML backend for the ONIOM high-level region (default: uma).",
|
|
1053
|
+
)
|
|
1054
|
+
@click.option(
|
|
1055
|
+
"--embedcharge/--no-embedcharge",
|
|
1056
|
+
"embedcharge",
|
|
1057
|
+
default=False,
|
|
1058
|
+
show_default=True,
|
|
1059
|
+
help="Enable xTB point-charge embedding correction for MM→ML environmental effects.",
|
|
1060
|
+
)
|
|
1061
|
+
@click.option(
|
|
1062
|
+
"--embedcharge-cutoff",
|
|
1063
|
+
"embedcharge_cutoff",
|
|
1064
|
+
type=float,
|
|
1065
|
+
default=None,
|
|
1066
|
+
show_default=False,
|
|
1067
|
+
help="Distance cutoff (Å) from ML region for MM point charges in xTB embedding. "
|
|
1068
|
+
"Default: 12.0 Å when --embedcharge is enabled.",
|
|
1069
|
+
)
|
|
1070
|
+
@click.pass_context
|
|
1071
|
+
def cli(
|
|
1072
|
+
ctx: click.Context,
|
|
1073
|
+
input_path: Path,
|
|
1074
|
+
ref_pdb: Optional[Path],
|
|
1075
|
+
real_parm7: Path,
|
|
1076
|
+
model_pdb: Optional[Path],
|
|
1077
|
+
model_indices_str: Optional[str],
|
|
1078
|
+
model_indices_one_based: bool,
|
|
1079
|
+
detect_layer: bool,
|
|
1080
|
+
charge: Optional[int],
|
|
1081
|
+
ligand_charge: Optional[str],
|
|
1082
|
+
spin: Optional[int],
|
|
1083
|
+
freeze_atoms_text: Optional[str],
|
|
1084
|
+
radius_partial_hessian: Optional[float],
|
|
1085
|
+
radius_freeze: Optional[float],
|
|
1086
|
+
dist_freeze_raw: Sequence[str],
|
|
1087
|
+
one_based: bool,
|
|
1088
|
+
bias_k: float,
|
|
1089
|
+
max_cycles: int,
|
|
1090
|
+
dump: bool,
|
|
1091
|
+
out_dir: str,
|
|
1092
|
+
thresh: Optional[str],
|
|
1093
|
+
opt_mode: str,
|
|
1094
|
+
microiter: bool,
|
|
1095
|
+
flatten: bool,
|
|
1096
|
+
config_yaml: Optional[Path],
|
|
1097
|
+
show_config: bool,
|
|
1098
|
+
dry_run: bool,
|
|
1099
|
+
convert_files: bool,
|
|
1100
|
+
backend: Optional[str],
|
|
1101
|
+
embedcharge: bool,
|
|
1102
|
+
embedcharge_cutoff: Optional[float],
|
|
1103
|
+
) -> None:
|
|
1104
|
+
set_convert_file_enabled(convert_files)
|
|
1105
|
+
time_start = time.perf_counter()
|
|
1106
|
+
prepared_input = None
|
|
1107
|
+
|
|
1108
|
+
_is_param_explicit = make_is_param_explicit(ctx)
|
|
1109
|
+
|
|
1110
|
+
config_yaml, override_yaml, used_legacy_yaml = resolve_yaml_sources(
|
|
1111
|
+
config_yaml=config_yaml,
|
|
1112
|
+
override_yaml=None,
|
|
1113
|
+
args_yaml_legacy=None,
|
|
1114
|
+
)
|
|
1115
|
+
merged_yaml_cfg, _, _ = load_merged_yaml_cfg(
|
|
1116
|
+
config_yaml=config_yaml,
|
|
1117
|
+
override_yaml=None,
|
|
1118
|
+
)
|
|
1119
|
+
|
|
1120
|
+
# Handle input: PDB directly, or XYZ with --ref-pdb for topology
|
|
1121
|
+
suffix = input_path.suffix.lower()
|
|
1122
|
+
if suffix == ".pdb":
|
|
1123
|
+
# PDB input: use directly
|
|
1124
|
+
prepared_input = prepare_input_structure(input_path)
|
|
1125
|
+
elif suffix == ".xyz":
|
|
1126
|
+
# XYZ input: require --ref-pdb for topology
|
|
1127
|
+
if ref_pdb is None:
|
|
1128
|
+
click.echo("ERROR: XYZ/TRJ input requires --ref-pdb to specify PDB topology.", err=True)
|
|
1129
|
+
sys.exit(1)
|
|
1130
|
+
prepared_input = prepare_input_structure(input_path)
|
|
1131
|
+
apply_ref_pdb_override(prepared_input, ref_pdb)
|
|
1132
|
+
click.echo(f"[input] Using XYZ coordinates from {input_path.name}, PDB topology from {ref_pdb.name}")
|
|
1133
|
+
else:
|
|
1134
|
+
click.echo(f"ERROR: Unsupported input format: {suffix}. Use .pdb or .xyz (with --ref-pdb).", err=True)
|
|
1135
|
+
sys.exit(1)
|
|
1136
|
+
|
|
1137
|
+
geom_input_path = prepared_input.geom_path
|
|
1138
|
+
charge, spin = resolve_charge_spin_or_raise(
|
|
1139
|
+
prepared_input, charge, spin,
|
|
1140
|
+
ligand_charge=ligand_charge, prefix="[opt]",
|
|
1141
|
+
)
|
|
1142
|
+
|
|
1143
|
+
try:
|
|
1144
|
+
freeze_atoms_cli = _parse_freeze_atoms(freeze_atoms_text)
|
|
1145
|
+
except click.BadParameter as e:
|
|
1146
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
1147
|
+
prepared_input.cleanup()
|
|
1148
|
+
sys.exit(1)
|
|
1149
|
+
|
|
1150
|
+
model_indices: Optional[List[int]] = None
|
|
1151
|
+
if model_indices_str:
|
|
1152
|
+
try:
|
|
1153
|
+
model_indices = parse_indices_string(model_indices_str, one_based=model_indices_one_based)
|
|
1154
|
+
except click.BadParameter as e:
|
|
1155
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
1156
|
+
prepared_input.cleanup()
|
|
1157
|
+
sys.exit(1)
|
|
1158
|
+
|
|
1159
|
+
pdb_atom_meta: List[Dict[str, Any]] = []
|
|
1160
|
+
if prepared_input.source_path.suffix.lower() == ".pdb":
|
|
1161
|
+
pdb_atom_meta = load_pdb_atom_metadata(prepared_input.source_path)
|
|
1162
|
+
|
|
1163
|
+
try:
|
|
1164
|
+
dist_freeze = _parse_dist_freeze_args(
|
|
1165
|
+
dist_freeze_raw, one_based=bool(one_based), atom_meta=pdb_atom_meta,
|
|
1166
|
+
)
|
|
1167
|
+
except click.BadParameter as e:
|
|
1168
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
1169
|
+
prepared_input.cleanup()
|
|
1170
|
+
sys.exit(1)
|
|
1171
|
+
|
|
1172
|
+
# Resolve optimizer mode
|
|
1173
|
+
mode_resolved = normalize_choice(
|
|
1174
|
+
opt_mode,
|
|
1175
|
+
param="--opt-mode",
|
|
1176
|
+
alias_groups=OPT_MODE_ALIASES,
|
|
1177
|
+
allowed_hint="grad|hess|lbfgs|rfo",
|
|
1178
|
+
)
|
|
1179
|
+
use_rfo = (mode_resolved == "rfo")
|
|
1180
|
+
|
|
1181
|
+
try:
|
|
1182
|
+
config_layer_cfg = load_yaml_dict(config_yaml)
|
|
1183
|
+
override_layer_cfg = load_yaml_dict(override_yaml)
|
|
1184
|
+
geom_cfg = dict(GEOM_KW)
|
|
1185
|
+
calc_cfg = dict(CALC_KW)
|
|
1186
|
+
opt_cfg = dict(OPT_BASE_KW)
|
|
1187
|
+
lbfgs_cfg = dict(LBFGS_KW)
|
|
1188
|
+
rfo_cfg = dict(RFO_KW)
|
|
1189
|
+
|
|
1190
|
+
apply_yaml_overrides(
|
|
1191
|
+
config_layer_cfg,
|
|
1192
|
+
[
|
|
1193
|
+
(geom_cfg, (("geom",),)),
|
|
1194
|
+
(calc_cfg, (("calc",), ("mlmm",))),
|
|
1195
|
+
(opt_cfg, (("opt",),)),
|
|
1196
|
+
(lbfgs_cfg, (("lbfgs",), ("opt", "lbfgs"))),
|
|
1197
|
+
(rfo_cfg, (("rfo",), ("opt", "rfo"))),
|
|
1198
|
+
],
|
|
1199
|
+
)
|
|
1200
|
+
|
|
1201
|
+
if _is_param_explicit("max_cycles"):
|
|
1202
|
+
opt_cfg["max_cycles"] = int(max_cycles)
|
|
1203
|
+
if _is_param_explicit("dump"):
|
|
1204
|
+
opt_cfg["dump"] = bool(dump)
|
|
1205
|
+
if _is_param_explicit("out_dir"):
|
|
1206
|
+
opt_cfg["out_dir"] = out_dir
|
|
1207
|
+
if _is_param_explicit("thresh") and thresh is not None:
|
|
1208
|
+
opt_cfg["thresh"] = str(thresh)
|
|
1209
|
+
|
|
1210
|
+
if _is_param_explicit("detect_layer"):
|
|
1211
|
+
calc_cfg["use_bfactor_layers"] = bool(detect_layer)
|
|
1212
|
+
if _is_param_explicit("radius_partial_hessian") and radius_partial_hessian is not None:
|
|
1213
|
+
calc_cfg["hess_cutoff"] = float(radius_partial_hessian)
|
|
1214
|
+
if _is_param_explicit("radius_freeze") and radius_freeze is not None:
|
|
1215
|
+
calc_cfg["movable_cutoff"] = float(radius_freeze)
|
|
1216
|
+
calc_cfg["use_bfactor_layers"] = False
|
|
1217
|
+
|
|
1218
|
+
model_charge_value = calc_cfg.get("model_charge", charge)
|
|
1219
|
+
if model_charge_value is None:
|
|
1220
|
+
model_charge_value = charge
|
|
1221
|
+
calc_cfg["model_charge"] = int(model_charge_value)
|
|
1222
|
+
if _is_param_explicit("charge"):
|
|
1223
|
+
calc_cfg["model_charge"] = int(charge)
|
|
1224
|
+
model_mult_value = calc_cfg.get("model_mult", spin)
|
|
1225
|
+
if model_mult_value is None:
|
|
1226
|
+
model_mult_value = spin
|
|
1227
|
+
calc_cfg["model_mult"] = int(model_mult_value)
|
|
1228
|
+
if _is_param_explicit("spin"):
|
|
1229
|
+
calc_cfg["model_mult"] = int(spin)
|
|
1230
|
+
if model_pdb is not None:
|
|
1231
|
+
calc_cfg["model_pdb"] = str(model_pdb)
|
|
1232
|
+
calc_cfg["input_pdb"] = str(prepared_input.source_path)
|
|
1233
|
+
calc_cfg["real_parm7"] = str(real_parm7)
|
|
1234
|
+
if backend is not None:
|
|
1235
|
+
calc_cfg["backend"] = str(backend).lower()
|
|
1236
|
+
if _is_param_explicit("embedcharge"):
|
|
1237
|
+
calc_cfg["embedcharge"] = bool(embedcharge)
|
|
1238
|
+
if _is_param_explicit("embedcharge_cutoff"):
|
|
1239
|
+
calc_cfg["embedcharge_cutoff"] = embedcharge_cutoff
|
|
1240
|
+
|
|
1241
|
+
apply_yaml_overrides(
|
|
1242
|
+
override_layer_cfg,
|
|
1243
|
+
[
|
|
1244
|
+
(geom_cfg, (("geom",),)),
|
|
1245
|
+
(calc_cfg, (("calc",), ("mlmm",))),
|
|
1246
|
+
(opt_cfg, (("opt",),)),
|
|
1247
|
+
(lbfgs_cfg, (("lbfgs",), ("opt", "lbfgs"))),
|
|
1248
|
+
(rfo_cfg, (("rfo",), ("opt", "rfo"))),
|
|
1249
|
+
],
|
|
1250
|
+
)
|
|
1251
|
+
calc_paths = (("calc",), ("mlmm",))
|
|
1252
|
+
partial_explicit = (
|
|
1253
|
+
yaml_section_has_key(config_layer_cfg, calc_paths, "return_partial_hessian")
|
|
1254
|
+
or yaml_section_has_key(override_layer_cfg, calc_paths, "return_partial_hessian")
|
|
1255
|
+
)
|
|
1256
|
+
if not partial_explicit:
|
|
1257
|
+
calc_cfg["return_partial_hessian"] = True
|
|
1258
|
+
|
|
1259
|
+
try:
|
|
1260
|
+
geom_freeze = _normalize_geom_freeze(geom_cfg.get("freeze_atoms"))
|
|
1261
|
+
except click.BadParameter as e:
|
|
1262
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
1263
|
+
prepared_input.cleanup()
|
|
1264
|
+
sys.exit(1)
|
|
1265
|
+
geom_cfg["freeze_atoms"] = geom_freeze
|
|
1266
|
+
if freeze_atoms_cli:
|
|
1267
|
+
merge_freeze_atom_indices(geom_cfg, freeze_atoms_cli)
|
|
1268
|
+
freeze_atoms_final = list(geom_cfg.get("freeze_atoms") or [])
|
|
1269
|
+
calc_cfg["freeze_atoms"] = freeze_atoms_final
|
|
1270
|
+
|
|
1271
|
+
out_dir_path = Path(opt_cfg["out_dir"]).resolve()
|
|
1272
|
+
|
|
1273
|
+
# radius_freeze implies full distance-based layer assignment.
|
|
1274
|
+
# radius_partial_hessian alone can be combined with --detect-layer.
|
|
1275
|
+
detect_layer_enabled = bool(calc_cfg.get("use_bfactor_layers", True))
|
|
1276
|
+
model_pdb_cfg = calc_cfg.get("model_pdb")
|
|
1277
|
+
if radius_freeze is not None:
|
|
1278
|
+
if detect_layer_enabled:
|
|
1279
|
+
click.echo("[layer] --radius-freeze provided; disabling --detect-layer.", err=True)
|
|
1280
|
+
detect_layer_enabled = False
|
|
1281
|
+
calc_cfg["use_bfactor_layers"] = False
|
|
1282
|
+
|
|
1283
|
+
layer_source_pdb = prepared_input.source_path
|
|
1284
|
+
if detect_layer_enabled and layer_source_pdb.suffix.lower() != ".pdb":
|
|
1285
|
+
click.echo("ERROR: --detect-layer requires a PDB input (or --ref-pdb).", err=True)
|
|
1286
|
+
prepared_input.cleanup()
|
|
1287
|
+
sys.exit(1)
|
|
1288
|
+
|
|
1289
|
+
if show_config:
|
|
1290
|
+
click.echo(
|
|
1291
|
+
pretty_block(
|
|
1292
|
+
"yaml_layers",
|
|
1293
|
+
{
|
|
1294
|
+
"config": None if config_yaml is None else str(config_yaml),
|
|
1295
|
+
"override_yaml": None if override_yaml is None else str(override_yaml),
|
|
1296
|
+
"merged_keys": sorted(merged_yaml_cfg.keys()),
|
|
1297
|
+
},
|
|
1298
|
+
)
|
|
1299
|
+
)
|
|
1300
|
+
|
|
1301
|
+
if dry_run:
|
|
1302
|
+
model_region_source = "bfactor"
|
|
1303
|
+
if not detect_layer_enabled:
|
|
1304
|
+
if model_pdb_cfg is not None:
|
|
1305
|
+
model_region_source = "model_pdb"
|
|
1306
|
+
elif model_indices:
|
|
1307
|
+
model_region_source = "model_indices"
|
|
1308
|
+
else:
|
|
1309
|
+
click.echo("ERROR: Provide --model-pdb or --model-indices when --no-detect-layer.", err=True)
|
|
1310
|
+
prepared_input.cleanup()
|
|
1311
|
+
sys.exit(1)
|
|
1312
|
+
if (
|
|
1313
|
+
not detect_layer_enabled
|
|
1314
|
+
and model_pdb_cfg is None
|
|
1315
|
+
and model_indices
|
|
1316
|
+
and layer_source_pdb.suffix.lower() != ".pdb"
|
|
1317
|
+
):
|
|
1318
|
+
click.echo("ERROR: --model-indices requires a PDB input (or --ref-pdb).", err=True)
|
|
1319
|
+
prepared_input.cleanup()
|
|
1320
|
+
sys.exit(1)
|
|
1321
|
+
click.echo(
|
|
1322
|
+
pretty_block(
|
|
1323
|
+
"dry_run_plan",
|
|
1324
|
+
{
|
|
1325
|
+
"input_geometry": str(geom_input_path),
|
|
1326
|
+
"output_dir": str(out_dir_path),
|
|
1327
|
+
"optimizer_mode": "rfo" if use_rfo else "lbfgs",
|
|
1328
|
+
"detect_layer": bool(detect_layer_enabled),
|
|
1329
|
+
"model_region_source": model_region_source,
|
|
1330
|
+
"model_indices_count": 0 if not model_indices else len(model_indices),
|
|
1331
|
+
"will_run_optimization": True,
|
|
1332
|
+
"will_convert_outputs": True,
|
|
1333
|
+
"backend": calc_cfg.get("backend", "uma"),
|
|
1334
|
+
"embedcharge": bool(calc_cfg.get("embedcharge", False)),
|
|
1335
|
+
},
|
|
1336
|
+
)
|
|
1337
|
+
)
|
|
1338
|
+
click.echo("[dry-run] Validation complete. Optimization execution was skipped.")
|
|
1339
|
+
return
|
|
1340
|
+
|
|
1341
|
+
model_pdb_path: Optional[Path] = None
|
|
1342
|
+
layer_info: Optional[Dict[str, List[int]]] = None
|
|
1343
|
+
|
|
1344
|
+
if detect_layer_enabled:
|
|
1345
|
+
try:
|
|
1346
|
+
model_pdb_path, layer_info = build_model_pdb_from_bfactors(layer_source_pdb, out_dir_path)
|
|
1347
|
+
calc_cfg["use_bfactor_layers"] = True
|
|
1348
|
+
click.echo(
|
|
1349
|
+
f"[layer] Detected B-factor layers: ML={len(layer_info.get('ml_indices', []))}, "
|
|
1350
|
+
f"MovableMM={len(layer_info.get('movable_mm_indices', []))}, "
|
|
1351
|
+
f"FrozenMM={len(layer_info.get('frozen_indices', []))}"
|
|
1352
|
+
)
|
|
1353
|
+
except Exception as e:
|
|
1354
|
+
if model_pdb_cfg is None and not model_indices:
|
|
1355
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
1356
|
+
prepared_input.cleanup()
|
|
1357
|
+
sys.exit(1)
|
|
1358
|
+
click.echo(f"[layer] WARNING: {e} Falling back to explicit ML region.", err=True)
|
|
1359
|
+
detect_layer_enabled = False
|
|
1360
|
+
|
|
1361
|
+
if not detect_layer_enabled:
|
|
1362
|
+
if model_pdb_cfg is None and not model_indices:
|
|
1363
|
+
click.echo("ERROR: Provide --model-pdb or --model-indices when --no-detect-layer.", err=True)
|
|
1364
|
+
prepared_input.cleanup()
|
|
1365
|
+
sys.exit(1)
|
|
1366
|
+
if model_pdb_cfg is not None:
|
|
1367
|
+
model_pdb_path = Path(model_pdb_cfg)
|
|
1368
|
+
else:
|
|
1369
|
+
if layer_source_pdb.suffix.lower() != ".pdb":
|
|
1370
|
+
click.echo("ERROR: --model-indices requires a PDB input (or --ref-pdb).", err=True)
|
|
1371
|
+
prepared_input.cleanup()
|
|
1372
|
+
sys.exit(1)
|
|
1373
|
+
try:
|
|
1374
|
+
model_pdb_path = build_model_pdb_from_indices(layer_source_pdb, out_dir_path, model_indices or [])
|
|
1375
|
+
except Exception as e:
|
|
1376
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
1377
|
+
prepared_input.cleanup()
|
|
1378
|
+
sys.exit(1)
|
|
1379
|
+
calc_cfg["use_bfactor_layers"] = False
|
|
1380
|
+
|
|
1381
|
+
if model_pdb_path is None:
|
|
1382
|
+
click.echo("ERROR: Failed to resolve model PDB for the ML region.", err=True)
|
|
1383
|
+
prepared_input.cleanup()
|
|
1384
|
+
sys.exit(1)
|
|
1385
|
+
|
|
1386
|
+
calc_cfg["model_pdb"] = str(model_pdb_path)
|
|
1387
|
+
|
|
1388
|
+
# When layer detection is enabled, also freeze frozen-layer atoms at the
|
|
1389
|
+
# optimizer geometry level (not only inside the calculator).
|
|
1390
|
+
# Otherwise LBFGS may still move those coordinates through coupled
|
|
1391
|
+
# inverse-Hessian updates, even if raw forces are zeroed there.
|
|
1392
|
+
if layer_info is not None:
|
|
1393
|
+
frozen_from_layer = [int(i) for i in layer_info.get("frozen_indices", [])]
|
|
1394
|
+
if frozen_from_layer:
|
|
1395
|
+
before = set(freeze_atoms_final)
|
|
1396
|
+
merged = sorted(before | set(frozen_from_layer))
|
|
1397
|
+
added = len(set(merged) - before)
|
|
1398
|
+
freeze_atoms_final = merged
|
|
1399
|
+
geom_cfg["freeze_atoms"] = freeze_atoms_final
|
|
1400
|
+
calc_cfg["freeze_atoms"] = freeze_atoms_final
|
|
1401
|
+
click.echo(
|
|
1402
|
+
f"[layer] Applied optimizer freeze constraints: "
|
|
1403
|
+
f"total={len(freeze_atoms_final)} (added_from_layer={added})"
|
|
1404
|
+
)
|
|
1405
|
+
|
|
1406
|
+
# Distance-based overrides for Hessian-target and movable MM selection.
|
|
1407
|
+
hess_cutoff_final = calc_cfg.get("hess_cutoff")
|
|
1408
|
+
movable_cutoff_final = calc_cfg.get("movable_cutoff")
|
|
1409
|
+
if hess_cutoff_final is not None or movable_cutoff_final is not None:
|
|
1410
|
+
click.echo(
|
|
1411
|
+
f"[layer] Applied distance cutoffs: "
|
|
1412
|
+
f"hess={hess_cutoff_final} Å, freeze={movable_cutoff_final} Å"
|
|
1413
|
+
)
|
|
1414
|
+
from .freq import _align_three_layer_hessian_targets as _freq_align_three_layer_hessian_targets
|
|
1415
|
+
_freq_align_three_layer_hessian_targets(calc_cfg, echo_fn=click.echo)
|
|
1416
|
+
|
|
1417
|
+
for key in ("input_pdb", "real_parm7", "model_pdb", "mm_fd_dir"):
|
|
1418
|
+
val = calc_cfg.get(key)
|
|
1419
|
+
if val:
|
|
1420
|
+
calc_cfg[key] = str(Path(val).expanduser().resolve())
|
|
1421
|
+
|
|
1422
|
+
mode_str = "RFO (hess)" if use_rfo else "LBFGS (grad)"
|
|
1423
|
+
click.echo(f"\n[mode] Optimizer: {mode_str}\n")
|
|
1424
|
+
click.echo(pretty_block("geom", format_freeze_atoms_for_echo(geom_cfg, key="freeze_atoms")))
|
|
1425
|
+
echo_calc = format_freeze_atoms_for_echo(filter_calc_for_echo(calc_cfg), key="freeze_atoms")
|
|
1426
|
+
click.echo(pretty_block("calc", echo_calc))
|
|
1427
|
+
# Show only non-default opt settings
|
|
1428
|
+
echo_opt = strip_inherited_keys({**opt_cfg, "out_dir": str(out_dir_path)}, OPT_BASE_KW, mode="same")
|
|
1429
|
+
click.echo(pretty_block("opt", echo_opt))
|
|
1430
|
+
# Show only optimizer-specific settings, not inherited from opt_cfg
|
|
1431
|
+
if use_rfo:
|
|
1432
|
+
echo_rfo = strip_inherited_keys(rfo_cfg, opt_cfg)
|
|
1433
|
+
click.echo(pretty_block("rfo", echo_rfo))
|
|
1434
|
+
else:
|
|
1435
|
+
echo_lbfgs = strip_inherited_keys(lbfgs_cfg, opt_cfg)
|
|
1436
|
+
click.echo(pretty_block("lbfgs", echo_lbfgs))
|
|
1437
|
+
if dist_freeze:
|
|
1438
|
+
display_pairs = []
|
|
1439
|
+
for (i, j, target) in dist_freeze:
|
|
1440
|
+
label = (f"{target:.4f}" if target is not None else "<current>")
|
|
1441
|
+
display_pairs.append((int(i) + 1, int(j) + 1, label))
|
|
1442
|
+
click.echo(
|
|
1443
|
+
pretty_block(
|
|
1444
|
+
"dist_freeze (input)",
|
|
1445
|
+
{
|
|
1446
|
+
"k (eV/Å^2)": float(bias_k),
|
|
1447
|
+
"pairs_1based": display_pairs,
|
|
1448
|
+
},
|
|
1449
|
+
)
|
|
1450
|
+
)
|
|
1451
|
+
|
|
1452
|
+
out_dir_path.mkdir(parents=True, exist_ok=True)
|
|
1453
|
+
coord_type = geom_cfg.get("coord_type", "cart")
|
|
1454
|
+
coord_kwargs = dict(geom_cfg)
|
|
1455
|
+
coord_kwargs.pop("coord_type", None)
|
|
1456
|
+
geometry = geom_loader(
|
|
1457
|
+
geom_input_path,
|
|
1458
|
+
coord_type=coord_type,
|
|
1459
|
+
**coord_kwargs,
|
|
1460
|
+
)
|
|
1461
|
+
|
|
1462
|
+
base_calc = mlmm(**calc_cfg)
|
|
1463
|
+
geometry.set_calculator(base_calc)
|
|
1464
|
+
|
|
1465
|
+
resolved_dist_freeze: List[Tuple[int, int, float]] = []
|
|
1466
|
+
if dist_freeze:
|
|
1467
|
+
try:
|
|
1468
|
+
resolved_dist_freeze = _resolve_dist_freeze_targets(geometry, dist_freeze)
|
|
1469
|
+
except click.BadParameter as e:
|
|
1470
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
1471
|
+
sys.exit(1)
|
|
1472
|
+
click.echo(
|
|
1473
|
+
pretty_block(
|
|
1474
|
+
"dist_freeze (active)",
|
|
1475
|
+
{
|
|
1476
|
+
"k (eV/Å^2)": float(bias_k),
|
|
1477
|
+
"pairs_1based": [
|
|
1478
|
+
(int(i) + 1, int(j) + 1, float(f"{t:.4f}"))
|
|
1479
|
+
for (i, j, t) in resolved_dist_freeze
|
|
1480
|
+
],
|
|
1481
|
+
},
|
|
1482
|
+
)
|
|
1483
|
+
)
|
|
1484
|
+
bias_calc = HarmonicBiasCalculator(base_calc, k=float(bias_k))
|
|
1485
|
+
bias_calc.set_pairs(resolved_dist_freeze)
|
|
1486
|
+
geometry.set_calculator(bias_calc)
|
|
1487
|
+
|
|
1488
|
+
# Pass only opt-level values that differ from OPT_BASE defaults, so
|
|
1489
|
+
# optimizer-specific YAML (e.g. rfo.print_every / lbfgs.print_every)
|
|
1490
|
+
# is not overwritten by inherited defaults such as opt.print_every=100.
|
|
1491
|
+
common_kwargs = strip_inherited_keys(dict(opt_cfg), OPT_BASE_KW, mode="same")
|
|
1492
|
+
common_kwargs["out_dir"] = str(out_dir_path)
|
|
1493
|
+
|
|
1494
|
+
def _build_optimizer(run_kind: str):
|
|
1495
|
+
if run_kind == "lbfgs":
|
|
1496
|
+
lbfgs_args = {**lbfgs_cfg, **common_kwargs}
|
|
1497
|
+
return LBFGS(geometry, **lbfgs_args)
|
|
1498
|
+
if run_kind == "rfo":
|
|
1499
|
+
rfo_args = {**rfo_cfg, **common_kwargs}
|
|
1500
|
+
return RFOptimizer(geometry, **rfo_args)
|
|
1501
|
+
raise click.BadParameter(f"Unknown optimizer kind '{run_kind}'.")
|
|
1502
|
+
|
|
1503
|
+
def _seed_rfo_hessian():
|
|
1504
|
+
"""Seed initial Hessian via shared freq backend for RFO."""
|
|
1505
|
+
from .hessian_cache import load as _hess_load
|
|
1506
|
+
cached = _hess_load("irc_endpoint")
|
|
1507
|
+
if cached is not None:
|
|
1508
|
+
click.echo("[opt] Reusing IRC endpoint Hessian for RFO seeding.")
|
|
1509
|
+
active_dofs = cached.get("active_dofs")
|
|
1510
|
+
h_raw = cached["hessian"]
|
|
1511
|
+
if isinstance(h_raw, torch.Tensor):
|
|
1512
|
+
h_init = h_raw.clone()
|
|
1513
|
+
else:
|
|
1514
|
+
h_init = torch.as_tensor(h_raw, dtype=torch.float64)
|
|
1515
|
+
if active_dofs is not None:
|
|
1516
|
+
geometry.within_partial_hessian = {
|
|
1517
|
+
"active_n_dof": len(active_dofs),
|
|
1518
|
+
"full_n_dof": geometry.cart_coords.size,
|
|
1519
|
+
"active_dofs": active_dofs,
|
|
1520
|
+
"active_atoms": sorted(set(d // 3 for d in active_dofs)),
|
|
1521
|
+
}
|
|
1522
|
+
geometry.cart_hessian = h_init
|
|
1523
|
+
click.echo(f"[opt] Initial Hessian seeded (shape={h_init.shape[0]}x{h_init.shape[1]}).")
|
|
1524
|
+
del h_init
|
|
1525
|
+
return
|
|
1526
|
+
click.echo("[opt] Seeding initial Hessian via shared freq backend.")
|
|
1527
|
+
from .freq import (
|
|
1528
|
+
_calc_full_hessian_torch as _freq_calc_full_hessian_torch,
|
|
1529
|
+
_torch_device as _freq_torch_device,
|
|
1530
|
+
)
|
|
1531
|
+
hess_device = _freq_torch_device(calc_cfg.get("ml_device", "auto"))
|
|
1532
|
+
h_init, _ = _freq_calc_full_hessian_torch(
|
|
1533
|
+
geometry,
|
|
1534
|
+
calc_cfg,
|
|
1535
|
+
hess_device,
|
|
1536
|
+
refresh_geom_meta=True,
|
|
1537
|
+
)
|
|
1538
|
+
geometry.cart_hessian = h_init
|
|
1539
|
+
click.echo(f"[opt] Initial Hessian seeded (shape={h_init.shape[0]}x{h_init.shape[1]}).")
|
|
1540
|
+
del h_init
|
|
1541
|
+
|
|
1542
|
+
# Resolve microiteration config from YAML
|
|
1543
|
+
microiter_cfg = dict(MICROITER_KW)
|
|
1544
|
+
apply_yaml_overrides(
|
|
1545
|
+
config_layer_cfg,
|
|
1546
|
+
[(microiter_cfg, (("microiter",),))],
|
|
1547
|
+
)
|
|
1548
|
+
apply_yaml_overrides(
|
|
1549
|
+
override_layer_cfg,
|
|
1550
|
+
[(microiter_cfg, (("microiter",),))],
|
|
1551
|
+
)
|
|
1552
|
+
|
|
1553
|
+
use_microiter = bool(microiter) and use_rfo and not dist_freeze
|
|
1554
|
+
if bool(microiter) and not use_rfo:
|
|
1555
|
+
click.echo("[microiter] --microiter is only effective with --opt-mode hess (RFO). Ignoring.")
|
|
1556
|
+
if bool(microiter) and use_rfo and dist_freeze:
|
|
1557
|
+
click.echo("[microiter] --microiter is not compatible with --dist-freeze. Falling back to standard RFO.")
|
|
1558
|
+
|
|
1559
|
+
if use_microiter:
|
|
1560
|
+
click.echo("\n====== Optimization (RFO + Microiteration) started ======\n")
|
|
1561
|
+
_run_microiter_opt(
|
|
1562
|
+
geometry,
|
|
1563
|
+
calc_cfg,
|
|
1564
|
+
rfo_cfg,
|
|
1565
|
+
lbfgs_cfg,
|
|
1566
|
+
opt_cfg,
|
|
1567
|
+
microiter_cfg,
|
|
1568
|
+
out_dir_path,
|
|
1569
|
+
dump=bool(opt_cfg["dump"]),
|
|
1570
|
+
)
|
|
1571
|
+
click.echo("\n====== Optimization (RFO + Microiteration) finished ======\n")
|
|
1572
|
+
|
|
1573
|
+
# Write final geometry
|
|
1574
|
+
from ase import Atoms as _Atoms
|
|
1575
|
+
from ase.io import write as _write
|
|
1576
|
+
final_xyz_path = out_dir_path / "final_geometry.xyz"
|
|
1577
|
+
final_coords_ang = geometry.coords.reshape(-1, 3) * BOHR2ANG
|
|
1578
|
+
atoms_final = _Atoms(geometry.atoms, positions=final_coords_ang, pbc=False)
|
|
1579
|
+
_write(final_xyz_path, atoms_final)
|
|
1580
|
+
|
|
1581
|
+
else:
|
|
1582
|
+
main_kind = "rfo" if use_rfo else "lbfgs"
|
|
1583
|
+
if use_rfo:
|
|
1584
|
+
_seed_rfo_hessian()
|
|
1585
|
+
|
|
1586
|
+
main_label = "RFO" if use_rfo else "LBFGS"
|
|
1587
|
+
optimizer = _build_optimizer(main_kind)
|
|
1588
|
+
click.echo(f"\n====== Optimization ({main_label}) started ======\n")
|
|
1589
|
+
optimizer.run()
|
|
1590
|
+
click.echo(f"\n====== Optimization ({main_label}) finished ======\n")
|
|
1591
|
+
|
|
1592
|
+
# Get final geometry path
|
|
1593
|
+
final_xyz_path = optimizer.final_fn if isinstance(optimizer.final_fn, Path) else Path(optimizer.final_fn)
|
|
1594
|
+
|
|
1595
|
+
if bool(opt_cfg["dump"]):
|
|
1596
|
+
optim_all_path = out_dir_path / "optimization_all_trj.xyz"
|
|
1597
|
+
if not optim_all_path.exists():
|
|
1598
|
+
trj_path = optimizer.get_path_for_fn("optimization_trj.xyz")
|
|
1599
|
+
_append_xyz_trajectory(optim_all_path, trj_path, reset=True)
|
|
1600
|
+
|
|
1601
|
+
# --------------------------
|
|
1602
|
+
# Flatten loop (all imaginary modes)
|
|
1603
|
+
# --------------------------
|
|
1604
|
+
if flatten:
|
|
1605
|
+
from .freq import (
|
|
1606
|
+
_torch_device,
|
|
1607
|
+
_calc_full_hessian_torch,
|
|
1608
|
+
_frequencies_cm_and_modes,
|
|
1609
|
+
_safe_masses_amu,
|
|
1610
|
+
)
|
|
1611
|
+
|
|
1612
|
+
click.echo("\n====== Optimization (Flatten loop) started ======\n")
|
|
1613
|
+
|
|
1614
|
+
geometry.set_calculator(None)
|
|
1615
|
+
uma_kwargs_for_flatten = dict(calc_cfg)
|
|
1616
|
+
uma_kwargs_for_flatten["out_hess_torch"] = True
|
|
1617
|
+
device = _torch_device(calc_cfg.get("ml_device", "auto"))
|
|
1618
|
+
freeze_idx = list(geom_cfg.get("freeze_atoms", [])) if len(geom_cfg.get("freeze_atoms", [])) > 0 else None
|
|
1619
|
+
masses_amu = _safe_masses_amu(geometry.atomic_numbers)
|
|
1620
|
+
|
|
1621
|
+
def _attach_opt_calc() -> None:
|
|
1622
|
+
geometry.set_calculator(
|
|
1623
|
+
bias_calc if resolved_dist_freeze else base_calc
|
|
1624
|
+
)
|
|
1625
|
+
|
|
1626
|
+
def _calc_freqs_and_modes() -> Tuple[np.ndarray, torch.Tensor]:
|
|
1627
|
+
H, _e = _calc_full_hessian_torch(geometry, uma_kwargs_for_flatten, device)
|
|
1628
|
+
freqs_local, modes_local = _frequencies_cm_and_modes(
|
|
1629
|
+
H,
|
|
1630
|
+
geometry.atomic_numbers,
|
|
1631
|
+
geometry.cart_coords.reshape(-1, 3),
|
|
1632
|
+
device,
|
|
1633
|
+
freeze_idx=freeze_idx,
|
|
1634
|
+
)
|
|
1635
|
+
del H
|
|
1636
|
+
return freqs_local, modes_local
|
|
1637
|
+
|
|
1638
|
+
freqs_cm, modes = _calc_freqs_and_modes()
|
|
1639
|
+
neg_mask = freqs_cm < -abs(OPT_FLATTEN_NEG_FREQ_THRESH_CM)
|
|
1640
|
+
n_imag = int(np.sum(neg_mask))
|
|
1641
|
+
ims = [float(x) for x in freqs_cm if x < -abs(OPT_FLATTEN_NEG_FREQ_THRESH_CM)]
|
|
1642
|
+
click.echo(f"[Imaginary modes] n={n_imag} ({ims})")
|
|
1643
|
+
|
|
1644
|
+
flatten_kind = mode_resolved # reuse same optimizer type
|
|
1645
|
+
for it in range(OPT_FLATTEN_MAX_ITER):
|
|
1646
|
+
if n_imag == 0:
|
|
1647
|
+
break
|
|
1648
|
+
click.echo(f"[flatten] iteration {it + 1}/{OPT_FLATTEN_MAX_ITER}")
|
|
1649
|
+
did_flatten = _flatten_all_imag_modes_for_geom(
|
|
1650
|
+
geometry,
|
|
1651
|
+
masses_amu,
|
|
1652
|
+
uma_kwargs_for_flatten,
|
|
1653
|
+
freqs_cm,
|
|
1654
|
+
modes,
|
|
1655
|
+
OPT_FLATTEN_NEG_FREQ_THRESH_CM,
|
|
1656
|
+
OPT_FLATTEN_AMP_ANG,
|
|
1657
|
+
)
|
|
1658
|
+
if not did_flatten:
|
|
1659
|
+
click.echo("[flatten] No eligible imaginary modes to flatten; stopping.")
|
|
1660
|
+
break
|
|
1661
|
+
|
|
1662
|
+
_attach_opt_calc()
|
|
1663
|
+
opt_restart = _build_optimizer(flatten_kind)
|
|
1664
|
+
restart_label = "LBFGS" if flatten_kind == "lbfgs" else "RFO"
|
|
1665
|
+
click.echo(f"\n====== Optimization ({restart_label}) restarted ======\n")
|
|
1666
|
+
opt_restart.run()
|
|
1667
|
+
click.echo(f"\n====== Optimization ({restart_label}) finished ======\n")
|
|
1668
|
+
|
|
1669
|
+
geometry.set_calculator(None)
|
|
1670
|
+
freqs_cm, modes = _calc_freqs_and_modes()
|
|
1671
|
+
neg_mask = freqs_cm < -abs(OPT_FLATTEN_NEG_FREQ_THRESH_CM)
|
|
1672
|
+
n_imag = int(np.sum(neg_mask))
|
|
1673
|
+
ims = [float(x) for x in freqs_cm if x < -abs(OPT_FLATTEN_NEG_FREQ_THRESH_CM)]
|
|
1674
|
+
click.echo(f"[Imaginary modes] n={n_imag} ({ims})")
|
|
1675
|
+
|
|
1676
|
+
if n_imag > 0:
|
|
1677
|
+
click.echo(
|
|
1678
|
+
f"[flatten] WARNING: Remaining imaginary modes after {OPT_FLATTEN_MAX_ITER} iterations: {n_imag}",
|
|
1679
|
+
err=True,
|
|
1680
|
+
)
|
|
1681
|
+
if torch.cuda.is_available():
|
|
1682
|
+
torch.cuda.empty_cache()
|
|
1683
|
+
click.echo("\n====== Optimization (Flatten loop) finished ======\n")
|
|
1684
|
+
|
|
1685
|
+
# Update final geometry after flatten
|
|
1686
|
+
final_xyz_path = out_dir_path / "final_geometry.xyz"
|
|
1687
|
+
from ase import Atoms as _Atoms
|
|
1688
|
+
from ase.io import write as _write
|
|
1689
|
+
final_coords_ang = geometry.coords.reshape(-1, 3) * BOHR2ANG
|
|
1690
|
+
atoms_final = _Atoms(geometry.atoms, positions=final_coords_ang, pbc=False)
|
|
1691
|
+
_write(final_xyz_path, atoms_final)
|
|
1692
|
+
|
|
1693
|
+
# Extract layer indices from calculator for layer-based B-factor encoding
|
|
1694
|
+
calc_core = base_calc.core if hasattr(base_calc, 'core') else base_calc
|
|
1695
|
+
ml_indices = getattr(calc_core, 'ml_indices', None)
|
|
1696
|
+
hess_mm_indices = getattr(calc_core, 'hess_mm_indices', None)
|
|
1697
|
+
movable_mm_indices = getattr(calc_core, 'movable_mm_indices', None)
|
|
1698
|
+
frozen_layer_indices = getattr(calc_core, 'frozen_layer_indices', None)
|
|
1699
|
+
|
|
1700
|
+
_maybe_convert_outputs_to_pdb(
|
|
1701
|
+
input_path=prepared_input.source_path, # Use PDB topology for conversion
|
|
1702
|
+
out_dir=out_dir_path,
|
|
1703
|
+
dump=bool(opt_cfg["dump"]),
|
|
1704
|
+
get_trj_fn=(lambda fn: out_dir_path / fn) if use_microiter else optimizer.get_path_for_fn,
|
|
1705
|
+
final_xyz_path=final_xyz_path,
|
|
1706
|
+
model_pdb=Path(calc_cfg["model_pdb"]),
|
|
1707
|
+
freeze_indices_0based=freeze_atoms_final,
|
|
1708
|
+
ml_indices=ml_indices,
|
|
1709
|
+
hess_mm_indices=hess_mm_indices,
|
|
1710
|
+
movable_mm_indices=movable_mm_indices,
|
|
1711
|
+
frozen_layer_indices=frozen_layer_indices,
|
|
1712
|
+
)
|
|
1713
|
+
|
|
1714
|
+
# summary.md and key_* outputs are disabled.
|
|
1715
|
+
click.echo(format_elapsed("[time] Elapsed Time for Opt", time_start))
|
|
1716
|
+
|
|
1717
|
+
except ZeroStepLength:
|
|
1718
|
+
click.echo("ERROR: Step length fell below the minimum allowed (ZeroStepLength).", err=True)
|
|
1719
|
+
sys.exit(2)
|
|
1720
|
+
except OptimizationError as e:
|
|
1721
|
+
click.echo(f"ERROR: Optimization failed - {e}", err=True)
|
|
1722
|
+
sys.exit(3)
|
|
1723
|
+
except KeyboardInterrupt:
|
|
1724
|
+
click.echo("\nInterrupted by user.", err=True)
|
|
1725
|
+
sys.exit(130)
|
|
1726
|
+
except Exception as e:
|
|
1727
|
+
tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
1728
|
+
click.echo("Unhandled exception during optimization:\n" + textwrap.indent(tb, " "), err=True)
|
|
1729
|
+
sys.exit(1)
|
|
1730
|
+
finally:
|
|
1731
|
+
if prepared_input is not None:
|
|
1732
|
+
prepared_input.cleanup()
|
|
1733
|
+
# Release GPU memory so subsequent pipeline stages don't OOM
|
|
1734
|
+
base_calc = bias_calc = geometry = optimizer = mm_calc = macro_calc = macro_optimizer = None
|
|
1735
|
+
gc.collect() # break cyclic refs inside torch.nn.Module
|
|
1736
|
+
if torch.cuda.is_available():
|
|
1737
|
+
torch.cuda.empty_cache()
|
|
1738
|
+
|
|
1739
|
+
|
|
1740
|
+
# Allow `python -m mlmm.opt` direct execution
|
|
1741
|
+
if __name__ == "__main__":
|
|
1742
|
+
cli()
|