mlmm-toolkit 0.2.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hessian_ff/__init__.py +50 -0
- hessian_ff/analytical_hessian.py +609 -0
- hessian_ff/constants.py +46 -0
- hessian_ff/forcefield.py +339 -0
- hessian_ff/loaders.py +608 -0
- hessian_ff/native/Makefile +8 -0
- hessian_ff/native/__init__.py +28 -0
- hessian_ff/native/analytical_hessian.py +88 -0
- hessian_ff/native/analytical_hessian_ext.cpp +258 -0
- hessian_ff/native/bonded.py +82 -0
- hessian_ff/native/bonded_ext.cpp +640 -0
- hessian_ff/native/loader.py +349 -0
- hessian_ff/native/nonbonded.py +118 -0
- hessian_ff/native/nonbonded_ext.cpp +1150 -0
- hessian_ff/prmtop_parmed.py +23 -0
- hessian_ff/system.py +107 -0
- hessian_ff/terms/__init__.py +14 -0
- hessian_ff/terms/angle.py +73 -0
- hessian_ff/terms/bond.py +44 -0
- hessian_ff/terms/cmap.py +406 -0
- hessian_ff/terms/dihedral.py +141 -0
- hessian_ff/terms/nonbonded.py +209 -0
- hessian_ff/tests/__init__.py +0 -0
- hessian_ff/tests/conftest.py +75 -0
- hessian_ff/tests/data/small/complex.parm7 +1346 -0
- hessian_ff/tests/data/small/complex.pdb +125 -0
- hessian_ff/tests/data/small/complex.rst7 +63 -0
- hessian_ff/tests/test_coords_input.py +44 -0
- hessian_ff/tests/test_energy_force.py +49 -0
- hessian_ff/tests/test_hessian.py +137 -0
- hessian_ff/tests/test_smoke.py +18 -0
- hessian_ff/tests/test_validation.py +40 -0
- hessian_ff/workflows.py +889 -0
- mlmm/__init__.py +36 -0
- mlmm/__main__.py +7 -0
- mlmm/_version.py +34 -0
- mlmm/add_elem_info.py +374 -0
- mlmm/advanced_help.py +91 -0
- mlmm/align_freeze_atoms.py +601 -0
- mlmm/all.py +3535 -0
- mlmm/bond_changes.py +231 -0
- mlmm/bool_compat.py +223 -0
- mlmm/cli.py +574 -0
- mlmm/cli_utils.py +166 -0
- mlmm/default_group.py +337 -0
- mlmm/defaults.py +467 -0
- mlmm/define_layer.py +526 -0
- mlmm/dft.py +1041 -0
- mlmm/energy_diagram.py +253 -0
- mlmm/extract.py +2213 -0
- mlmm/fix_altloc.py +464 -0
- mlmm/freq.py +1406 -0
- mlmm/harmonic_constraints.py +140 -0
- mlmm/hessian_cache.py +44 -0
- mlmm/hessian_calc.py +174 -0
- mlmm/irc.py +638 -0
- mlmm/mlmm_calc.py +2262 -0
- mlmm/mm_parm.py +945 -0
- mlmm/oniom_export.py +1983 -0
- mlmm/oniom_import.py +457 -0
- mlmm/opt.py +1742 -0
- mlmm/path_opt.py +1353 -0
- mlmm/path_search.py +2299 -0
- mlmm/preflight.py +88 -0
- mlmm/py.typed +1 -0
- mlmm/pysis_runner.py +45 -0
- mlmm/scan.py +1047 -0
- mlmm/scan2d.py +1226 -0
- mlmm/scan3d.py +1265 -0
- mlmm/scan_common.py +184 -0
- mlmm/summary_log.py +736 -0
- mlmm/trj2fig.py +448 -0
- mlmm/tsopt.py +2871 -0
- mlmm/utils.py +2309 -0
- mlmm/xtb_embedcharge_correction.py +475 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
- pysisyphus/Geometry.py +1667 -0
- pysisyphus/LICENSE +674 -0
- pysisyphus/TableFormatter.py +63 -0
- pysisyphus/TablePrinter.py +74 -0
- pysisyphus/__init__.py +12 -0
- pysisyphus/calculators/AFIR.py +452 -0
- pysisyphus/calculators/AnaPot.py +20 -0
- pysisyphus/calculators/AnaPot2.py +48 -0
- pysisyphus/calculators/AnaPot3.py +12 -0
- pysisyphus/calculators/AnaPot4.py +20 -0
- pysisyphus/calculators/AnaPotBase.py +337 -0
- pysisyphus/calculators/AnaPotCBM.py +25 -0
- pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
- pysisyphus/calculators/CFOUR.py +250 -0
- pysisyphus/calculators/Calculator.py +844 -0
- pysisyphus/calculators/CerjanMiller.py +24 -0
- pysisyphus/calculators/Composite.py +123 -0
- pysisyphus/calculators/ConicalIntersection.py +171 -0
- pysisyphus/calculators/DFTBp.py +430 -0
- pysisyphus/calculators/DFTD3.py +66 -0
- pysisyphus/calculators/DFTD4.py +84 -0
- pysisyphus/calculators/Dalton.py +61 -0
- pysisyphus/calculators/Dimer.py +681 -0
- pysisyphus/calculators/Dummy.py +20 -0
- pysisyphus/calculators/EGO.py +76 -0
- pysisyphus/calculators/EnergyMin.py +224 -0
- pysisyphus/calculators/ExternalPotential.py +264 -0
- pysisyphus/calculators/FakeASE.py +35 -0
- pysisyphus/calculators/FourWellAnaPot.py +28 -0
- pysisyphus/calculators/FreeEndNEBPot.py +39 -0
- pysisyphus/calculators/Gaussian09.py +18 -0
- pysisyphus/calculators/Gaussian16.py +726 -0
- pysisyphus/calculators/HardSphere.py +159 -0
- pysisyphus/calculators/IDPPCalculator.py +49 -0
- pysisyphus/calculators/IPIClient.py +133 -0
- pysisyphus/calculators/IPIServer.py +234 -0
- pysisyphus/calculators/LEPSBase.py +24 -0
- pysisyphus/calculators/LEPSExpr.py +139 -0
- pysisyphus/calculators/LennardJones.py +80 -0
- pysisyphus/calculators/MOPAC.py +219 -0
- pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
- pysisyphus/calculators/MultiCalc.py +85 -0
- pysisyphus/calculators/NFK.py +45 -0
- pysisyphus/calculators/OBabel.py +87 -0
- pysisyphus/calculators/ONIOMv2.py +1129 -0
- pysisyphus/calculators/ORCA.py +893 -0
- pysisyphus/calculators/ORCA5.py +6 -0
- pysisyphus/calculators/OpenMM.py +88 -0
- pysisyphus/calculators/OpenMolcas.py +281 -0
- pysisyphus/calculators/OverlapCalculator.py +908 -0
- pysisyphus/calculators/Psi4.py +218 -0
- pysisyphus/calculators/PyPsi4.py +37 -0
- pysisyphus/calculators/PySCF.py +341 -0
- pysisyphus/calculators/PyXTB.py +73 -0
- pysisyphus/calculators/QCEngine.py +106 -0
- pysisyphus/calculators/Rastrigin.py +22 -0
- pysisyphus/calculators/Remote.py +76 -0
- pysisyphus/calculators/Rosenbrock.py +15 -0
- pysisyphus/calculators/SocketCalc.py +97 -0
- pysisyphus/calculators/TIP3P.py +111 -0
- pysisyphus/calculators/TransTorque.py +161 -0
- pysisyphus/calculators/Turbomole.py +965 -0
- pysisyphus/calculators/VRIPot.py +37 -0
- pysisyphus/calculators/WFOWrapper.py +333 -0
- pysisyphus/calculators/WFOWrapper2.py +341 -0
- pysisyphus/calculators/XTB.py +418 -0
- pysisyphus/calculators/__init__.py +81 -0
- pysisyphus/calculators/cosmo_data.py +139 -0
- pysisyphus/calculators/parser.py +150 -0
- pysisyphus/color.py +19 -0
- pysisyphus/config.py +133 -0
- pysisyphus/constants.py +65 -0
- pysisyphus/cos/AdaptiveNEB.py +230 -0
- pysisyphus/cos/ChainOfStates.py +725 -0
- pysisyphus/cos/FreeEndNEB.py +25 -0
- pysisyphus/cos/FreezingString.py +103 -0
- pysisyphus/cos/GrowingChainOfStates.py +71 -0
- pysisyphus/cos/GrowingNT.py +309 -0
- pysisyphus/cos/GrowingString.py +508 -0
- pysisyphus/cos/NEB.py +189 -0
- pysisyphus/cos/SimpleZTS.py +64 -0
- pysisyphus/cos/__init__.py +22 -0
- pysisyphus/cos/stiffness.py +199 -0
- pysisyphus/drivers/__init__.py +17 -0
- pysisyphus/drivers/afir.py +855 -0
- pysisyphus/drivers/barriers.py +271 -0
- pysisyphus/drivers/birkholz.py +138 -0
- pysisyphus/drivers/cluster.py +318 -0
- pysisyphus/drivers/diabatization.py +133 -0
- pysisyphus/drivers/merge.py +368 -0
- pysisyphus/drivers/merge_mol2.py +322 -0
- pysisyphus/drivers/opt.py +375 -0
- pysisyphus/drivers/perf.py +91 -0
- pysisyphus/drivers/pka.py +52 -0
- pysisyphus/drivers/precon_pos_rot.py +669 -0
- pysisyphus/drivers/rates.py +480 -0
- pysisyphus/drivers/replace.py +219 -0
- pysisyphus/drivers/scan.py +212 -0
- pysisyphus/drivers/spectrum.py +166 -0
- pysisyphus/drivers/thermo.py +31 -0
- pysisyphus/dynamics/Gaussian.py +103 -0
- pysisyphus/dynamics/__init__.py +20 -0
- pysisyphus/dynamics/colvars.py +136 -0
- pysisyphus/dynamics/driver.py +297 -0
- pysisyphus/dynamics/helpers.py +256 -0
- pysisyphus/dynamics/lincs.py +105 -0
- pysisyphus/dynamics/mdp.py +364 -0
- pysisyphus/dynamics/rattle.py +121 -0
- pysisyphus/dynamics/thermostats.py +128 -0
- pysisyphus/dynamics/wigner.py +266 -0
- pysisyphus/elem_data.py +3473 -0
- pysisyphus/exceptions.py +2 -0
- pysisyphus/filtertrj.py +69 -0
- pysisyphus/helpers.py +623 -0
- pysisyphus/helpers_pure.py +649 -0
- pysisyphus/init_logging.py +50 -0
- pysisyphus/intcoords/Bend.py +69 -0
- pysisyphus/intcoords/Bend2.py +25 -0
- pysisyphus/intcoords/BondedFragment.py +32 -0
- pysisyphus/intcoords/Cartesian.py +41 -0
- pysisyphus/intcoords/CartesianCoords.py +140 -0
- pysisyphus/intcoords/Coords.py +56 -0
- pysisyphus/intcoords/DLC.py +197 -0
- pysisyphus/intcoords/DistanceFunction.py +34 -0
- pysisyphus/intcoords/DummyImproper.py +70 -0
- pysisyphus/intcoords/DummyTorsion.py +72 -0
- pysisyphus/intcoords/LinearBend.py +105 -0
- pysisyphus/intcoords/LinearDisplacement.py +80 -0
- pysisyphus/intcoords/OutOfPlane.py +59 -0
- pysisyphus/intcoords/PrimTypes.py +286 -0
- pysisyphus/intcoords/Primitive.py +137 -0
- pysisyphus/intcoords/RedundantCoords.py +659 -0
- pysisyphus/intcoords/RobustTorsion.py +59 -0
- pysisyphus/intcoords/Rotation.py +147 -0
- pysisyphus/intcoords/Stretch.py +31 -0
- pysisyphus/intcoords/Torsion.py +101 -0
- pysisyphus/intcoords/Torsion2.py +25 -0
- pysisyphus/intcoords/Translation.py +45 -0
- pysisyphus/intcoords/__init__.py +61 -0
- pysisyphus/intcoords/augment_bonds.py +126 -0
- pysisyphus/intcoords/derivatives.py +10512 -0
- pysisyphus/intcoords/eval.py +80 -0
- pysisyphus/intcoords/exceptions.py +37 -0
- pysisyphus/intcoords/findiffs.py +48 -0
- pysisyphus/intcoords/generate_derivatives.py +414 -0
- pysisyphus/intcoords/helpers.py +235 -0
- pysisyphus/intcoords/logging_conf.py +10 -0
- pysisyphus/intcoords/mp_derivatives.py +10836 -0
- pysisyphus/intcoords/setup.py +962 -0
- pysisyphus/intcoords/setup_fast.py +176 -0
- pysisyphus/intcoords/update.py +272 -0
- pysisyphus/intcoords/valid.py +89 -0
- pysisyphus/interpolate/Geodesic.py +93 -0
- pysisyphus/interpolate/IDPP.py +55 -0
- pysisyphus/interpolate/Interpolator.py +116 -0
- pysisyphus/interpolate/LST.py +70 -0
- pysisyphus/interpolate/Redund.py +152 -0
- pysisyphus/interpolate/__init__.py +9 -0
- pysisyphus/interpolate/helpers.py +34 -0
- pysisyphus/io/__init__.py +22 -0
- pysisyphus/io/aomix.py +178 -0
- pysisyphus/io/cjson.py +24 -0
- pysisyphus/io/crd.py +101 -0
- pysisyphus/io/cube.py +220 -0
- pysisyphus/io/fchk.py +184 -0
- pysisyphus/io/hdf5.py +49 -0
- pysisyphus/io/hessian.py +72 -0
- pysisyphus/io/mol2.py +146 -0
- pysisyphus/io/molden.py +293 -0
- pysisyphus/io/orca.py +189 -0
- pysisyphus/io/pdb.py +269 -0
- pysisyphus/io/psf.py +79 -0
- pysisyphus/io/pubchem.py +31 -0
- pysisyphus/io/qcschema.py +34 -0
- pysisyphus/io/sdf.py +29 -0
- pysisyphus/io/xyz.py +61 -0
- pysisyphus/io/zmat.py +175 -0
- pysisyphus/irc/DWI.py +108 -0
- pysisyphus/irc/DampedVelocityVerlet.py +134 -0
- pysisyphus/irc/Euler.py +22 -0
- pysisyphus/irc/EulerPC.py +345 -0
- pysisyphus/irc/GonzalezSchlegel.py +187 -0
- pysisyphus/irc/IMKMod.py +164 -0
- pysisyphus/irc/IRC.py +878 -0
- pysisyphus/irc/IRCDummy.py +10 -0
- pysisyphus/irc/Instanton.py +307 -0
- pysisyphus/irc/LQA.py +53 -0
- pysisyphus/irc/ModeKill.py +136 -0
- pysisyphus/irc/ParamPlot.py +53 -0
- pysisyphus/irc/RK4.py +36 -0
- pysisyphus/irc/__init__.py +31 -0
- pysisyphus/irc/initial_displ.py +219 -0
- pysisyphus/linalg.py +411 -0
- pysisyphus/line_searches/Backtracking.py +88 -0
- pysisyphus/line_searches/HagerZhang.py +184 -0
- pysisyphus/line_searches/LineSearch.py +232 -0
- pysisyphus/line_searches/StrongWolfe.py +108 -0
- pysisyphus/line_searches/__init__.py +9 -0
- pysisyphus/line_searches/interpol.py +15 -0
- pysisyphus/modefollow/NormalMode.py +40 -0
- pysisyphus/modefollow/__init__.py +10 -0
- pysisyphus/modefollow/davidson.py +199 -0
- pysisyphus/modefollow/lanczos.py +95 -0
- pysisyphus/optimizers/BFGS.py +99 -0
- pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
- pysisyphus/optimizers/ConjugateGradient.py +98 -0
- pysisyphus/optimizers/CubicNewton.py +75 -0
- pysisyphus/optimizers/FIRE.py +113 -0
- pysisyphus/optimizers/HessianOptimizer.py +1176 -0
- pysisyphus/optimizers/LBFGS.py +228 -0
- pysisyphus/optimizers/LayerOpt.py +411 -0
- pysisyphus/optimizers/MicroOptimizer.py +169 -0
- pysisyphus/optimizers/NCOptimizer.py +90 -0
- pysisyphus/optimizers/Optimizer.py +1084 -0
- pysisyphus/optimizers/PreconLBFGS.py +260 -0
- pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
- pysisyphus/optimizers/QuickMin.py +74 -0
- pysisyphus/optimizers/RFOptimizer.py +181 -0
- pysisyphus/optimizers/RSA.py +99 -0
- pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
- pysisyphus/optimizers/SteepestDescent.py +23 -0
- pysisyphus/optimizers/StringOptimizer.py +173 -0
- pysisyphus/optimizers/__init__.py +41 -0
- pysisyphus/optimizers/closures.py +301 -0
- pysisyphus/optimizers/cls_map.py +58 -0
- pysisyphus/optimizers/exceptions.py +6 -0
- pysisyphus/optimizers/gdiis.py +280 -0
- pysisyphus/optimizers/guess_hessians.py +311 -0
- pysisyphus/optimizers/hessian_updates.py +355 -0
- pysisyphus/optimizers/poly_fit.py +285 -0
- pysisyphus/optimizers/precon.py +153 -0
- pysisyphus/optimizers/restrict_step.py +24 -0
- pysisyphus/pack.py +172 -0
- pysisyphus/peakdetect.py +948 -0
- pysisyphus/plot.py +1031 -0
- pysisyphus/run.py +2106 -0
- pysisyphus/socket_helper.py +74 -0
- pysisyphus/stocastic/FragmentKick.py +132 -0
- pysisyphus/stocastic/Kick.py +81 -0
- pysisyphus/stocastic/Pipeline.py +303 -0
- pysisyphus/stocastic/__init__.py +21 -0
- pysisyphus/stocastic/align.py +127 -0
- pysisyphus/testing.py +96 -0
- pysisyphus/thermo.py +156 -0
- pysisyphus/trj.py +824 -0
- pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
- pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
- pysisyphus/tsoptimizers/TRIM.py +59 -0
- pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
- pysisyphus/tsoptimizers/__init__.py +23 -0
- pysisyphus/wavefunction/Basis.py +239 -0
- pysisyphus/wavefunction/DIIS.py +76 -0
- pysisyphus/wavefunction/__init__.py +25 -0
- pysisyphus/wavefunction/build_ext.py +42 -0
- pysisyphus/wavefunction/cart2sph.py +190 -0
- pysisyphus/wavefunction/diabatization.py +304 -0
- pysisyphus/wavefunction/excited_states.py +435 -0
- pysisyphus/wavefunction/gen_ints.py +1811 -0
- pysisyphus/wavefunction/helpers.py +104 -0
- pysisyphus/wavefunction/ints/__init__.py +0 -0
- pysisyphus/wavefunction/ints/boys.py +193 -0
- pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
- pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
- pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
- pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
- pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
- pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
- pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
- pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
- pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
- pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
- pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
- pysisyphus/wavefunction/localization.py +458 -0
- pysisyphus/wavefunction/multipole.py +159 -0
- pysisyphus/wavefunction/normalization.py +36 -0
- pysisyphus/wavefunction/pop_analysis.py +134 -0
- pysisyphus/wavefunction/shells.py +1171 -0
- pysisyphus/wavefunction/wavefunction.py +504 -0
- pysisyphus/wrapper/__init__.py +11 -0
- pysisyphus/wrapper/exceptions.py +2 -0
- pysisyphus/wrapper/jmol.py +120 -0
- pysisyphus/wrapper/mwfn.py +169 -0
- pysisyphus/wrapper/packmol.py +71 -0
- pysisyphus/xyzloader.py +168 -0
- pysisyphus/yaml_mods.py +45 -0
- thermoanalysis/LICENSE +674 -0
- thermoanalysis/QCData.py +244 -0
- thermoanalysis/__init__.py +0 -0
- thermoanalysis/config.py +3 -0
- thermoanalysis/constants.py +20 -0
- thermoanalysis/thermo.py +1011 -0
mlmm/tsopt.py
ADDED
|
@@ -0,0 +1,2871 @@
|
|
|
1
|
+
# mlmm/tsopt.py
|
|
2
|
+
|
|
3
|
+
"""Partial Hessian guided Dimer / RS-I-RFO transition-state search with ML/MM.
|
|
4
|
+
|
|
5
|
+
Example:
|
|
6
|
+
mlmm tsopt -i ts_guess.pdb --parm real.parm7 --model-pdb ml_region.pdb \
|
|
7
|
+
-q 0 -m 1 --max-cycles 8000
|
|
8
|
+
|
|
9
|
+
For detailed documentation, see: docs/tsopt.md
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import contextlib
|
|
15
|
+
import gc
|
|
16
|
+
import io
|
|
17
|
+
import logging
|
|
18
|
+
import sys
|
|
19
|
+
import textwrap
|
|
20
|
+
from copy import deepcopy
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
from typing import Dict, Any, Optional, Tuple, List
|
|
25
|
+
|
|
26
|
+
import click
|
|
27
|
+
import numpy as np
|
|
28
|
+
import torch
|
|
29
|
+
from ase import Atoms
|
|
30
|
+
from ase.io import write
|
|
31
|
+
from ase.data import atomic_masses
|
|
32
|
+
import ase.units as units
|
|
33
|
+
import time
|
|
34
|
+
|
|
35
|
+
# ---------------- pysisyphus / mlmm imports ----------------
|
|
36
|
+
from pysisyphus.helpers import geom_loader
|
|
37
|
+
from pysisyphus.optimizers.LBFGS import LBFGS
|
|
38
|
+
from pysisyphus.optimizers.exceptions import OptimizationError, ZeroStepLength
|
|
39
|
+
from pysisyphus.constants import BOHR2ANG, ANG2BOHR, AMU2AU, AU2EV
|
|
40
|
+
from pysisyphus.calculators.Dimer import Dimer # Dimer calculator (orientation-projected forces)
|
|
41
|
+
|
|
42
|
+
# RS-I-RFO optimizer for heavy mode
|
|
43
|
+
from pysisyphus.tsoptimizers.RSIRFOptimizer import RSIRFOptimizer
|
|
44
|
+
from pysisyphus.TablePrinter import TablePrinter
|
|
45
|
+
|
|
46
|
+
# local helpers from mlmm
|
|
47
|
+
from .mlmm_calc import mlmm, mlmm_mm_only
|
|
48
|
+
from .defaults import OUT_DIR_TSOPT
|
|
49
|
+
from .defaults import (
|
|
50
|
+
GEOM_KW_DEFAULT,
|
|
51
|
+
MLMM_CALC_KW,
|
|
52
|
+
OPT_BASE_KW,
|
|
53
|
+
LBFGS_KW,
|
|
54
|
+
DIMER_KW,
|
|
55
|
+
HESSIAN_DIMER_KW,
|
|
56
|
+
RSIRFO_KW,
|
|
57
|
+
MICROITER_KW,
|
|
58
|
+
TSOPT_MODE_ALIASES,
|
|
59
|
+
BFACTOR_ML,
|
|
60
|
+
BFACTOR_MOVABLE_MM,
|
|
61
|
+
BFACTOR_FROZEN,
|
|
62
|
+
)
|
|
63
|
+
from .opt import (
|
|
64
|
+
_parse_freeze_atoms as _parse_freeze_atoms_opt,
|
|
65
|
+
_normalize_geom_freeze as _normalize_geom_freeze_opt,
|
|
66
|
+
)
|
|
67
|
+
from .utils import (
|
|
68
|
+
append_xyz_trajectory as _append_xyz_trajectory,
|
|
69
|
+
apply_layer_freeze_constraints,
|
|
70
|
+
convert_xyz_to_pdb,
|
|
71
|
+
set_convert_file_enabled,
|
|
72
|
+
is_convert_file_enabled,
|
|
73
|
+
convert_xyz_like_outputs,
|
|
74
|
+
deep_update,
|
|
75
|
+
load_yaml_dict,
|
|
76
|
+
apply_yaml_overrides,
|
|
77
|
+
pretty_block,
|
|
78
|
+
strip_inherited_keys,
|
|
79
|
+
filter_calc_for_echo,
|
|
80
|
+
format_freeze_atoms_for_echo,
|
|
81
|
+
format_elapsed,
|
|
82
|
+
merge_freeze_atom_indices,
|
|
83
|
+
prepare_input_structure,
|
|
84
|
+
apply_ref_pdb_override,
|
|
85
|
+
resolve_charge_spin_or_raise,
|
|
86
|
+
parse_indices_string,
|
|
87
|
+
build_model_pdb_from_bfactors,
|
|
88
|
+
build_model_pdb_from_indices,
|
|
89
|
+
update_pdb_bfactors_from_layers,
|
|
90
|
+
normalize_choice,
|
|
91
|
+
yaml_section_has_key,
|
|
92
|
+
)
|
|
93
|
+
from .cli_utils import resolve_yaml_sources, load_merged_yaml_cfg, make_is_param_explicit
|
|
94
|
+
from .freq import (
|
|
95
|
+
_calc_full_hessian_torch as _freq_calc_full_hessian_torch,
|
|
96
|
+
_torch_device,
|
|
97
|
+
_build_tr_basis,
|
|
98
|
+
_tr_orthonormal_basis,
|
|
99
|
+
_mass_weighted_hessian,
|
|
100
|
+
_align_three_layer_hessian_targets,
|
|
101
|
+
_resolve_active_atom_indices,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# ===================================================================
|
|
106
|
+
# Mass-weighted projection & vib analysis
|
|
107
|
+
# ===================================================================
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _calc_full_hessian_torch(geom, calc_kwargs: Dict[str, Any], device: torch.device) -> torch.Tensor:
|
|
111
|
+
"""
|
|
112
|
+
Shared Hessian backend from freq.py; keeps tsopt metadata refresh behavior.
|
|
113
|
+
"""
|
|
114
|
+
H, _ = _freq_calc_full_hessian_torch(
|
|
115
|
+
geom,
|
|
116
|
+
calc_kwargs,
|
|
117
|
+
device,
|
|
118
|
+
refresh_geom_meta=True,
|
|
119
|
+
)
|
|
120
|
+
return H
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _calc_energy(geom, calc_kwargs: Dict[str, Any], calc=None) -> float:
|
|
124
|
+
owns_calc = calc is None
|
|
125
|
+
if owns_calc:
|
|
126
|
+
kw = dict(calc_kwargs or {})
|
|
127
|
+
kw["out_hess_torch"] = False
|
|
128
|
+
calc = mlmm(**kw)
|
|
129
|
+
result = calc.get_energy(geom.atoms, geom.coords)
|
|
130
|
+
energy = float(result.get("energy", 0.0))
|
|
131
|
+
del result
|
|
132
|
+
if owns_calc:
|
|
133
|
+
del calc
|
|
134
|
+
_clear_cuda_cache()
|
|
135
|
+
return energy
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _omega2_to_freqs_cm(omega2: torch.Tensor) -> np.ndarray:
|
|
139
|
+
"""Convert eigenvalues (omega^2) to vibrational frequencies in cm^-1."""
|
|
140
|
+
s_new = (units._hbar * 1e10 / np.sqrt(units._e * units._amu) * np.sqrt(AU2EV) / BOHR2ANG)
|
|
141
|
+
hnu = s_new * torch.sqrt(torch.abs(omega2))
|
|
142
|
+
hnu = torch.where(omega2 < 0, -hnu, hnu)
|
|
143
|
+
return (hnu / units.invcm).detach().cpu().numpy()
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _clear_cuda_cache(tensor: Optional[torch.Tensor] = None) -> None:
|
|
147
|
+
"""Clear CUDA cache if available and tensor (if provided) is on CUDA."""
|
|
148
|
+
if torch.cuda.is_available():
|
|
149
|
+
if tensor is None or tensor.is_cuda:
|
|
150
|
+
torch.cuda.empty_cache()
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _mw_projected_hessian_inplace(H_t: torch.Tensor,
|
|
156
|
+
coords_bohr_t: torch.Tensor,
|
|
157
|
+
masses_au_t: torch.Tensor,
|
|
158
|
+
freeze_idx: Optional[List[int]] = None) -> torch.Tensor:
|
|
159
|
+
"""
|
|
160
|
+
Mass-weight H in-place, optionally restrict to active DOF subspace (PHVA) and
|
|
161
|
+
project out TR motions (in that subspace), also in-place.
|
|
162
|
+
Returns the (possibly reduced) Hessian to be diagonalized.
|
|
163
|
+
"""
|
|
164
|
+
dtype, device = H_t.dtype, H_t.device
|
|
165
|
+
with torch.no_grad():
|
|
166
|
+
N = coords_bohr_t.shape[0]
|
|
167
|
+
if freeze_idx:
|
|
168
|
+
frozen = set(int(i) for i in freeze_idx if 0 <= int(i) < N)
|
|
169
|
+
active_idx = [i for i in range(N) if i not in frozen]
|
|
170
|
+
if len(active_idx) == 0:
|
|
171
|
+
raise RuntimeError("All atoms are frozen; no active DOF left for TR projection.")
|
|
172
|
+
# mass-weight first
|
|
173
|
+
H_t = _mass_weighted_hessian(H_t, masses_au_t)
|
|
174
|
+
# take active DOF submatrix
|
|
175
|
+
mask_dof = torch.ones(3 * N, dtype=torch.bool, device=device)
|
|
176
|
+
for i in frozen:
|
|
177
|
+
mask_dof[3 * i:3 * i + 3] = False
|
|
178
|
+
H_t = H_t[mask_dof][:, mask_dof]
|
|
179
|
+
# TR basis and projection in active subspace (in-place)
|
|
180
|
+
coords_act = coords_bohr_t[active_idx, :]
|
|
181
|
+
masses_act = masses_au_t[active_idx]
|
|
182
|
+
Q, _ = _tr_orthonormal_basis(coords_act, masses_act) # (3N_act, r)
|
|
183
|
+
Qt = Q.T
|
|
184
|
+
QtH = Qt @ H_t
|
|
185
|
+
H_t.addmm_(Q, QtH, beta=1.0, alpha=-1.0)
|
|
186
|
+
H_t.addmm_((QtH.T), Qt, beta=1.0, alpha=-1.0)
|
|
187
|
+
QtHQ = QtH @ Q
|
|
188
|
+
H_t.addmm_(Q @ QtHQ, Qt, beta=1.0, alpha=1.0)
|
|
189
|
+
del Q, Qt, QtH, QtHQ, mask_dof, coords_act, masses_act, active_idx, frozen
|
|
190
|
+
else:
|
|
191
|
+
# Full DOF: mass-weight + TR projection in-place
|
|
192
|
+
H_t = _mass_weighted_hessian(H_t, masses_au_t)
|
|
193
|
+
Q, _ = _tr_orthonormal_basis(coords_bohr_t, masses_au_t) # (3N, r)
|
|
194
|
+
Qt = Q.T
|
|
195
|
+
QtH = Qt @ H_t
|
|
196
|
+
H_t.addmm_(Q, QtH, beta=1.0, alpha=-1.0)
|
|
197
|
+
H_t.addmm_(QtH.T, Qt, beta=1.0, alpha=-1.0)
|
|
198
|
+
QtHQ = QtH @ Q
|
|
199
|
+
H_t.addmm_(Q @ QtHQ, Qt, beta=1.0, alpha=1.0)
|
|
200
|
+
del Q, Qt, QtH, QtHQ
|
|
201
|
+
_clear_cuda_cache()
|
|
202
|
+
return H_t
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _mode_direction_by_root(H_t: torch.Tensor,
|
|
206
|
+
coords_bohr_t: torch.Tensor,
|
|
207
|
+
masses_au_t: torch.Tensor,
|
|
208
|
+
root: int = 0,
|
|
209
|
+
freeze_idx: Optional[List[int]] = None) -> np.ndarray:
|
|
210
|
+
"""
|
|
211
|
+
Get the eigenvector (Cartesian space) corresponding to the `root`-th most negative
|
|
212
|
+
eigenvalue (root=0: most negative) of the mass-weighted, TR-projected Hessian.
|
|
213
|
+
PHVA (active-subspace) is applied if freeze_idx is provided: frozen DOFs are zero.
|
|
214
|
+
root==0 prefers torch.lobpcg; fallback to eigh (UPLO='U').
|
|
215
|
+
"""
|
|
216
|
+
with torch.no_grad():
|
|
217
|
+
# In-place: mass weight + (active-subspace) TR projection
|
|
218
|
+
Hmw_proj = _mw_projected_hessian_inplace(H_t, coords_bohr_t, masses_au_t, freeze_idx=freeze_idx)
|
|
219
|
+
|
|
220
|
+
# Explicit symmetrization before eigendecomposition
|
|
221
|
+
_t = Hmw_proj.T.clone()
|
|
222
|
+
Hmw_proj.add_(_t).mul_(0.5)
|
|
223
|
+
del _t
|
|
224
|
+
|
|
225
|
+
# Solve eigenproblem in the (possibly reduced) space
|
|
226
|
+
if int(root) == 0:
|
|
227
|
+
try:
|
|
228
|
+
w, v_mw_sub = torch.lobpcg(Hmw_proj, k=1, largest=False)
|
|
229
|
+
u_mw_sub = v_mw_sub[:, 0]
|
|
230
|
+
except Exception:
|
|
231
|
+
evals_f, evecs_f = torch.linalg.eigh(Hmw_proj, UPLO="U")
|
|
232
|
+
u_mw_sub = evecs_f[:, torch.argmin(evals_f)]
|
|
233
|
+
del evals_f, evecs_f
|
|
234
|
+
else:
|
|
235
|
+
evals, evecs_mw = torch.linalg.eigh(Hmw_proj, UPLO="U") # ascending
|
|
236
|
+
neg = (evals < 0)
|
|
237
|
+
neg_inds = torch.nonzero(neg, as_tuple=False).view(-1)
|
|
238
|
+
if neg_inds.numel() == 0:
|
|
239
|
+
pick = int(torch.argmin(evals).item())
|
|
240
|
+
else:
|
|
241
|
+
k = max(0, min(int(root), neg_inds.numel() - 1))
|
|
242
|
+
pick = int(neg_inds[k].item())
|
|
243
|
+
u_mw_sub = evecs_mw[:, pick]
|
|
244
|
+
del evals, evecs_mw
|
|
245
|
+
|
|
246
|
+
# Embed back to full 3N (frozen DOF as zeros) if we solved in subspace
|
|
247
|
+
N = coords_bohr_t.shape[0]
|
|
248
|
+
if freeze_idx:
|
|
249
|
+
frozen = set(int(i) for i in freeze_idx if 0 <= int(i) < N)
|
|
250
|
+
mask_dof = torch.ones(3 * N, dtype=torch.bool, device=Hmw_proj.device)
|
|
251
|
+
for i in frozen:
|
|
252
|
+
mask_dof[3 * i:3 * i + 3] = False
|
|
253
|
+
u_mw_full = torch.zeros(3 * N, dtype=Hmw_proj.dtype, device=Hmw_proj.device)
|
|
254
|
+
u_mw_full[mask_dof] = u_mw_sub
|
|
255
|
+
u_mw = u_mw_full
|
|
256
|
+
del mask_dof, frozen
|
|
257
|
+
else:
|
|
258
|
+
u_mw = u_mw_sub
|
|
259
|
+
|
|
260
|
+
# Convert mass-weighted → Cartesian & normalize
|
|
261
|
+
masses_amu_t = (masses_au_t / AMU2AU).to(dtype=Hmw_proj.dtype, device=Hmw_proj.device)
|
|
262
|
+
m3 = torch.repeat_interleave(masses_amu_t, 3)
|
|
263
|
+
inv_sqrt_m = torch.sqrt(1.0 / m3)
|
|
264
|
+
v = inv_sqrt_m * u_mw
|
|
265
|
+
v = v / torch.linalg.norm(v)
|
|
266
|
+
mode = v.reshape(-1, 3).detach().cpu().numpy()
|
|
267
|
+
|
|
268
|
+
del masses_amu_t, m3, inv_sqrt_m, v, u_mw, u_mw_sub
|
|
269
|
+
_clear_cuda_cache()
|
|
270
|
+
return mode
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def _calc_gradient(geom, calc_kwargs: Dict[str, Any]) -> np.ndarray:
|
|
274
|
+
"""
|
|
275
|
+
Return true Cartesian gradient (shape 3N,) in Hartree/Bohr.
|
|
276
|
+
"""
|
|
277
|
+
kw = dict(calc_kwargs or {})
|
|
278
|
+
kw["out_hess_torch"] = False
|
|
279
|
+
calc = mlmm(**kw)
|
|
280
|
+
geom.set_calculator(calc)
|
|
281
|
+
g = np.array(geom.gradient, dtype=float).reshape(-1)
|
|
282
|
+
geom.set_calculator(None)
|
|
283
|
+
del calc
|
|
284
|
+
_clear_cuda_cache()
|
|
285
|
+
return g
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def _frequencies_cm_and_modes(H_t: torch.Tensor,
|
|
289
|
+
atomic_numbers: List[int],
|
|
290
|
+
coords_bohr: np.ndarray,
|
|
291
|
+
device: torch.device,
|
|
292
|
+
tol: float = 1e-6,
|
|
293
|
+
freeze_idx: Optional[List[int]] = None) -> Tuple[np.ndarray, torch.Tensor]:
|
|
294
|
+
"""
|
|
295
|
+
In-place PHVA/TR projection (active-subspace if freeze_idx) and diagonalization.
|
|
296
|
+
Returns:
|
|
297
|
+
freqs_cm : (nmode,) numpy (negatives are imaginary)
|
|
298
|
+
modes : (nmode, 3N) torch (mass-weighted eigenvectors embedded to full 3N)
|
|
299
|
+
"""
|
|
300
|
+
with torch.no_grad():
|
|
301
|
+
Z = np.array(atomic_numbers, dtype=int)
|
|
302
|
+
N = int(len(Z))
|
|
303
|
+
masses_amu = np.array([atomic_masses[z] for z in Z]) # amu
|
|
304
|
+
masses_au_t = torch.as_tensor(masses_amu * AMU2AU, dtype=H_t.dtype, device=device)
|
|
305
|
+
coords_bohr_t = torch.as_tensor(coords_bohr.reshape(-1, 3), dtype=H_t.dtype, device=device)
|
|
306
|
+
|
|
307
|
+
# in-place mass-weight + (active-subspace) TR projection
|
|
308
|
+
Hmw = _mw_projected_hessian_inplace(H_t, coords_bohr_t, masses_au_t, freeze_idx=freeze_idx)
|
|
309
|
+
|
|
310
|
+
# Explicit symmetrization before eigendecomposition
|
|
311
|
+
_t = Hmw.T.clone()
|
|
312
|
+
Hmw.add_(_t).mul_(0.5)
|
|
313
|
+
del _t
|
|
314
|
+
omega2, Vsub = torch.linalg.eigh(Hmw, UPLO="U")
|
|
315
|
+
|
|
316
|
+
sel = torch.abs(omega2) > tol
|
|
317
|
+
omega2 = omega2[sel]
|
|
318
|
+
Vsub = Vsub[:, sel] # (3N_act or 3N, nsel)
|
|
319
|
+
|
|
320
|
+
# embed modes to full 3N
|
|
321
|
+
if freeze_idx:
|
|
322
|
+
frozen = set(int(i) for i in freeze_idx if 0 <= int(i) < N)
|
|
323
|
+
mask_dof = torch.ones(3 * N, dtype=torch.bool, device=Hmw.device)
|
|
324
|
+
for i in frozen:
|
|
325
|
+
mask_dof[3 * i:3 * i + 3] = False
|
|
326
|
+
modes = torch.zeros((Vsub.shape[1], 3 * N), dtype=Hmw.dtype, device=Hmw.device)
|
|
327
|
+
modes[:, mask_dof] = Vsub.T
|
|
328
|
+
del mask_dof, frozen
|
|
329
|
+
else:
|
|
330
|
+
modes = Vsub.T # (nsel, 3N)
|
|
331
|
+
|
|
332
|
+
# convert to cm^-1
|
|
333
|
+
freqs_cm = _omega2_to_freqs_cm(omega2)
|
|
334
|
+
|
|
335
|
+
del omega2, Vsub, sel, masses_amu, masses_au_t, coords_bohr_t, Hmw
|
|
336
|
+
_clear_cuda_cache(H_t)
|
|
337
|
+
return freqs_cm, modes
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def _write_mode_trj_and_pdb(geom,
|
|
341
|
+
mode_vec_3N: np.ndarray,
|
|
342
|
+
out_trj: Path,
|
|
343
|
+
out_pdb: Path,
|
|
344
|
+
amplitude_ang: float = 0.25,
|
|
345
|
+
n_frames: int = 20,
|
|
346
|
+
comment: str = "imag mode",
|
|
347
|
+
ref_pdb: Optional[Path] = None) -> None:
|
|
348
|
+
"""
|
|
349
|
+
Write a single imaginary mode animation both as _trj.xyz (XYZ-like) and .pdb.
|
|
350
|
+
|
|
351
|
+
If `ref_pdb` is provided and is a .pdb file, the .pdb is generated by
|
|
352
|
+
converting the _trj.xyz using the input PDB as the template.
|
|
353
|
+
"""
|
|
354
|
+
ref_ang = geom.coords.reshape(-1, 3) * BOHR2ANG
|
|
355
|
+
mode = mode_vec_3N.reshape(-1, 3).copy()
|
|
356
|
+
mode /= np.linalg.norm(mode)
|
|
357
|
+
|
|
358
|
+
# _trj.xyz (XYZ-like concatenation) — always write
|
|
359
|
+
with out_trj.open("w", encoding="utf-8") as f:
|
|
360
|
+
for i in range(n_frames):
|
|
361
|
+
phase = np.sin(2.0 * np.pi * i / n_frames)
|
|
362
|
+
coords = ref_ang + phase * amplitude_ang * mode
|
|
363
|
+
f.write(f"{len(geom.atoms)}\n{comment} frame={i+1}/{n_frames}\n")
|
|
364
|
+
for sym, (x, y, z) in zip(geom.atoms, coords):
|
|
365
|
+
f.write(f"{sym:2s} {x: .8f} {y: .8f} {z: .8f}\n")
|
|
366
|
+
|
|
367
|
+
# .pdb — use ref_pdb template when available
|
|
368
|
+
if ref_pdb is not None and ref_pdb.suffix.lower() == ".pdb" and is_convert_file_enabled():
|
|
369
|
+
try:
|
|
370
|
+
convert_xyz_to_pdb(out_trj, ref_pdb, out_pdb)
|
|
371
|
+
return
|
|
372
|
+
except Exception:
|
|
373
|
+
pass # fall through to ASE fallback
|
|
374
|
+
|
|
375
|
+
# Fallback: MODEL/ENDMDL via ASE (no topology)
|
|
376
|
+
atoms0 = Atoms(geom.atoms, positions=ref_ang, pbc=False)
|
|
377
|
+
for i in range(n_frames):
|
|
378
|
+
phase = np.sin(2.0 * np.pi * i / n_frames)
|
|
379
|
+
ai = atoms0.copy()
|
|
380
|
+
ai.set_positions(ref_ang + phase * amplitude_ang * mode)
|
|
381
|
+
write(out_pdb, ai, append=(i != 0))
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def _write_all_imag_modes(
|
|
385
|
+
geom,
|
|
386
|
+
freqs_cm: np.ndarray,
|
|
387
|
+
modes: torch.Tensor,
|
|
388
|
+
neg_freq_thresh_cm: float,
|
|
389
|
+
vib_dir: Path,
|
|
390
|
+
*,
|
|
391
|
+
ref_pdb: Optional[Path] = None,
|
|
392
|
+
filename_prefix: str = "final_imag_mode",
|
|
393
|
+
amplitude_ang: float = 0.8,
|
|
394
|
+
n_frames: int = 20,
|
|
395
|
+
) -> int:
|
|
396
|
+
"""
|
|
397
|
+
Write all imaginary modes (freq < -|threshold|) to vib_dir.
|
|
398
|
+
|
|
399
|
+
Returns:
|
|
400
|
+
Number of mode trajectories written.
|
|
401
|
+
"""
|
|
402
|
+
neg_idx = np.where(freqs_cm < -abs(neg_freq_thresh_cm))[0]
|
|
403
|
+
if len(neg_idx) == 0:
|
|
404
|
+
return 0
|
|
405
|
+
|
|
406
|
+
masses_amu = np.array([atomic_masses[int(z)] for z in geom.atomic_numbers], dtype=float)
|
|
407
|
+
sqrt_m3 = np.sqrt(np.repeat(masses_amu, 3))
|
|
408
|
+
order = np.argsort(freqs_cm[neg_idx]) # most negative first
|
|
409
|
+
written = 0
|
|
410
|
+
|
|
411
|
+
for rank, rel_i in enumerate(order, start=1):
|
|
412
|
+
mode_idx = int(neg_idx[int(rel_i)])
|
|
413
|
+
freq = float(freqs_cm[mode_idx])
|
|
414
|
+
mode_mw = modes[mode_idx].detach().cpu().numpy().reshape(-1)
|
|
415
|
+
v_cart = mode_mw / sqrt_m3
|
|
416
|
+
norm = float(np.linalg.norm(v_cart))
|
|
417
|
+
if norm <= 0.0:
|
|
418
|
+
del mode_mw, v_cart
|
|
419
|
+
continue
|
|
420
|
+
v_cart = v_cart / norm
|
|
421
|
+
|
|
422
|
+
stem = f"{filename_prefix}_{rank:02d}_mode{mode_idx:04d}_{freq:+.2f}cm-1"
|
|
423
|
+
out_trj = vib_dir / f"{stem}_trj.xyz"
|
|
424
|
+
out_pdb = vib_dir / f"{stem}.pdb"
|
|
425
|
+
_write_mode_trj_and_pdb(
|
|
426
|
+
geom,
|
|
427
|
+
v_cart,
|
|
428
|
+
out_trj,
|
|
429
|
+
out_pdb,
|
|
430
|
+
amplitude_ang=amplitude_ang,
|
|
431
|
+
n_frames=n_frames,
|
|
432
|
+
comment=f"imag#{rank} mode={mode_idx} {freq:+.2f} cm^-1",
|
|
433
|
+
ref_pdb=ref_pdb,
|
|
434
|
+
)
|
|
435
|
+
del mode_mw, v_cart
|
|
436
|
+
written += 1
|
|
437
|
+
|
|
438
|
+
del masses_amu, sqrt_m3, order, neg_idx
|
|
439
|
+
_clear_cuda_cache()
|
|
440
|
+
return written
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
# ===================================================================
|
|
444
|
+
# Active-subspace helpers & Bofill update
|
|
445
|
+
# ===================================================================
|
|
446
|
+
|
|
447
|
+
def _active_indices(N: int, freeze_idx: Optional[List[int]]) -> List[int]:
|
|
448
|
+
if not freeze_idx:
|
|
449
|
+
return list(range(N))
|
|
450
|
+
fz = set(int(i) for i in freeze_idx if 0 <= int(i) < N)
|
|
451
|
+
return [i for i in range(N) if i not in fz]
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def _active_mask_dof(N: int, freeze_idx: Optional[List[int]]) -> np.ndarray:
|
|
455
|
+
mask = np.ones(3 * N, dtype=bool)
|
|
456
|
+
if freeze_idx:
|
|
457
|
+
for i in freeze_idx:
|
|
458
|
+
if 0 <= int(i) < N:
|
|
459
|
+
mask[3 * int(i):3 * int(i) + 3] = False
|
|
460
|
+
return mask
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def _mask_dof_from_active_idx(N: int, active_idx: List[int]) -> np.ndarray:
|
|
464
|
+
mask = np.zeros(3 * N, dtype=bool)
|
|
465
|
+
for i in active_idx:
|
|
466
|
+
j = int(i)
|
|
467
|
+
if 0 <= j < N:
|
|
468
|
+
mask[3 * j:3 * j + 3] = True
|
|
469
|
+
return mask
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def _extract_active_block(H_full: torch.Tensor, mask_dof: np.ndarray) -> torch.Tensor:
|
|
473
|
+
"""
|
|
474
|
+
Return the active-DOF block as a torch.Tensor sharing device/dtype.
|
|
475
|
+
"""
|
|
476
|
+
device = H_full.device
|
|
477
|
+
m = torch.as_tensor(mask_dof, device=device, dtype=torch.bool)
|
|
478
|
+
return H_full[m][:, m].clone()
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def _mw_tr_project_active_inplace(H_act: torch.Tensor,
|
|
482
|
+
coords_act_t: torch.Tensor,
|
|
483
|
+
masses_act_au_t: torch.Tensor) -> torch.Tensor:
|
|
484
|
+
"""
|
|
485
|
+
Mass-weight & project TR in the *active* subspace (in-place).
|
|
486
|
+
"""
|
|
487
|
+
with torch.no_grad():
|
|
488
|
+
# mass-weight
|
|
489
|
+
masses_amu_t = (masses_act_au_t / AMU2AU).to(dtype=H_act.dtype, device=H_act.device)
|
|
490
|
+
m3 = torch.repeat_interleave(masses_amu_t, 3)
|
|
491
|
+
inv_sqrt_m_col = torch.sqrt(1.0 / m3).view(1, -1)
|
|
492
|
+
inv_sqrt_m_row = inv_sqrt_m_col.view(-1, 1)
|
|
493
|
+
H_act.mul_(inv_sqrt_m_row)
|
|
494
|
+
H_act.mul_(inv_sqrt_m_col)
|
|
495
|
+
# TR basis & projection
|
|
496
|
+
Q, _ = _tr_orthonormal_basis(coords_act_t, masses_act_au_t) # (3N_act, r)
|
|
497
|
+
Qt = Q.T
|
|
498
|
+
QtH = Qt @ H_act
|
|
499
|
+
H_act.addmm_(Q, QtH, beta=1.0, alpha=-1.0)
|
|
500
|
+
H_act.addmm_(QtH.T, Qt, beta=1.0, alpha=-1.0)
|
|
501
|
+
QtHQ = QtH @ Q
|
|
502
|
+
H_act.addmm_(Q @ QtHQ, Qt, beta=1.0, alpha=1.0)
|
|
503
|
+
del masses_amu_t, m3, inv_sqrt_m_col, inv_sqrt_m_row, Q, Qt, QtH, QtHQ
|
|
504
|
+
return H_act
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def _frequencies_from_Hact(H_act: torch.Tensor,
|
|
508
|
+
atomic_numbers: List[int],
|
|
509
|
+
coords_bohr: np.ndarray,
|
|
510
|
+
active_idx: List[int],
|
|
511
|
+
device: torch.device,
|
|
512
|
+
tol: float = 1e-6) -> np.ndarray:
|
|
513
|
+
"""
|
|
514
|
+
Frequencies (cm^-1) computed from active-block Hessian with active-space TR projection.
|
|
515
|
+
"""
|
|
516
|
+
with torch.no_grad():
|
|
517
|
+
coords_act = torch.as_tensor(coords_bohr.reshape(-1, 3)[active_idx, :], dtype=H_act.dtype, device=device)
|
|
518
|
+
masses_act_au = torch.as_tensor([atomic_masses[int(z)] * AMU2AU
|
|
519
|
+
for z in np.array(atomic_numbers, int)[active_idx]],
|
|
520
|
+
dtype=H_act.dtype, device=device)
|
|
521
|
+
Hmw = H_act.clone()
|
|
522
|
+
_mw_tr_project_active_inplace(Hmw, coords_act, masses_act_au)
|
|
523
|
+
# Explicit symmetrization before eigendecomposition
|
|
524
|
+
_t = Hmw.T.clone()
|
|
525
|
+
Hmw.add_(_t).mul_(0.5)
|
|
526
|
+
del _t
|
|
527
|
+
omega2 = torch.linalg.eigvalsh(Hmw, UPLO="U")
|
|
528
|
+
sel = torch.abs(omega2) > tol
|
|
529
|
+
omega2 = omega2[sel]
|
|
530
|
+
freqs_cm = _omega2_to_freqs_cm(omega2)
|
|
531
|
+
del coords_act, masses_act_au, Hmw, omega2, sel
|
|
532
|
+
_clear_cuda_cache(H_act)
|
|
533
|
+
return freqs_cm
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def _modes_from_Hact_embedded(H_act: torch.Tensor,
|
|
537
|
+
atomic_numbers: List[int],
|
|
538
|
+
coords_bohr: np.ndarray,
|
|
539
|
+
active_idx: List[int],
|
|
540
|
+
device: torch.device,
|
|
541
|
+
tol: float = 1e-6) -> Tuple[np.ndarray, torch.Tensor]:
|
|
542
|
+
"""
|
|
543
|
+
Diagonalize active-block Hessian with mass-weight/TR in active space and return:
|
|
544
|
+
freqs_cm : (nmode,)
|
|
545
|
+
modes : (nmode, 3N) mass-weighted eigenvectors embedded to full 3N (torch)
|
|
546
|
+
"""
|
|
547
|
+
with torch.no_grad():
|
|
548
|
+
N = len(atomic_numbers)
|
|
549
|
+
coords_act = torch.as_tensor(coords_bohr.reshape(-1, 3)[active_idx, :], dtype=H_act.dtype, device=device)
|
|
550
|
+
masses_act_au = torch.as_tensor([atomic_masses[int(z)] * AMU2AU
|
|
551
|
+
for z in np.array(atomic_numbers, int)[active_idx]],
|
|
552
|
+
dtype=H_act.dtype, device=device)
|
|
553
|
+
Hmw = H_act.clone()
|
|
554
|
+
_mw_tr_project_active_inplace(Hmw, coords_act, masses_act_au)
|
|
555
|
+
# Explicit symmetrization before eigendecomposition
|
|
556
|
+
_t = Hmw.T.clone()
|
|
557
|
+
Hmw.add_(_t).mul_(0.5)
|
|
558
|
+
del _t
|
|
559
|
+
omega2, Vsub = torch.linalg.eigh(Hmw, UPLO="U")
|
|
560
|
+
sel = torch.abs(omega2) > tol
|
|
561
|
+
omega2 = omega2[sel]
|
|
562
|
+
Vsub = Vsub[:, sel] # (3N_act, nsel)
|
|
563
|
+
|
|
564
|
+
# Embed to full 3N (mass-weighted eigenvectors)
|
|
565
|
+
modes_full = torch.zeros((Vsub.shape[1], 3 * N), dtype=Hmw.dtype, device=device)
|
|
566
|
+
mask_dof = _active_mask_dof(N, list(set(range(N)) - set(active_idx))) # give frozen list
|
|
567
|
+
mask_t = torch.as_tensor(mask_dof, dtype=torch.bool, device=device)
|
|
568
|
+
modes_full[:, mask_t] = Vsub.T
|
|
569
|
+
# frequencies
|
|
570
|
+
freqs_cm = _omega2_to_freqs_cm(omega2)
|
|
571
|
+
|
|
572
|
+
del coords_act, masses_act_au, Hmw, omega2, Vsub, mask_t
|
|
573
|
+
_clear_cuda_cache(H_act)
|
|
574
|
+
return freqs_cm, modes_full
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
def _mode_direction_by_root_from_Hact(H_act: torch.Tensor,
|
|
578
|
+
coords_bohr: np.ndarray,
|
|
579
|
+
atomic_numbers: List[int],
|
|
580
|
+
masses_au_t: torch.Tensor,
|
|
581
|
+
active_idx: List[int],
|
|
582
|
+
device: torch.device,
|
|
583
|
+
root: int = 0) -> np.ndarray:
|
|
584
|
+
"""
|
|
585
|
+
TS direction from the *active* Hessian block. Mass-weighting/TR are done in the
|
|
586
|
+
active space. Result is embedded back to full 3N in Cartesian space.
|
|
587
|
+
"""
|
|
588
|
+
with torch.no_grad():
|
|
589
|
+
N = len(atomic_numbers)
|
|
590
|
+
coords_act = torch.as_tensor(coords_bohr.reshape(-1, 3)[active_idx, :], dtype=H_act.dtype, device=device)
|
|
591
|
+
masses_act_au = masses_au_t[active_idx].to(device=device, dtype=H_act.dtype)
|
|
592
|
+
# mass-weight + TR in active space
|
|
593
|
+
Hmw = H_act.clone()
|
|
594
|
+
_mw_tr_project_active_inplace(Hmw, coords_act, masses_act_au)
|
|
595
|
+
# Explicit symmetrization before eigendecomposition
|
|
596
|
+
_t = Hmw.T.clone()
|
|
597
|
+
Hmw.add_(_t).mul_(0.5)
|
|
598
|
+
del _t
|
|
599
|
+
|
|
600
|
+
# eigenvector for requested root
|
|
601
|
+
if int(root) == 0:
|
|
602
|
+
try:
|
|
603
|
+
w, V = torch.lobpcg(Hmw, k=1, largest=False)
|
|
604
|
+
u_mw = V[:, 0]
|
|
605
|
+
except Exception:
|
|
606
|
+
vals, vecs = torch.linalg.eigh(Hmw, UPLO="U")
|
|
607
|
+
u_mw = vecs[:, torch.argmin(vals)]
|
|
608
|
+
del vals, vecs
|
|
609
|
+
else:
|
|
610
|
+
vals, vecs = torch.linalg.eigh(Hmw, UPLO="U")
|
|
611
|
+
neg = (vals < 0)
|
|
612
|
+
neg_inds = torch.nonzero(neg, as_tuple=False).view(-1)
|
|
613
|
+
if neg_inds.numel() == 0:
|
|
614
|
+
pick = int(torch.argmin(vals).item())
|
|
615
|
+
else:
|
|
616
|
+
k = max(0, min(int(root), neg_inds.numel() - 1))
|
|
617
|
+
pick = int(neg_inds[k].item())
|
|
618
|
+
u_mw = vecs[:, pick]
|
|
619
|
+
del vals, vecs
|
|
620
|
+
|
|
621
|
+
# Mass un-weight to Cartesian in the active space, then embed to full 3N
|
|
622
|
+
masses_act_amu = (masses_act_au / AMU2AU).to(dtype=H_act.dtype, device=device)
|
|
623
|
+
m3 = torch.repeat_interleave(masses_act_amu, 3)
|
|
624
|
+
v_cart_act = u_mw / torch.sqrt(m3)
|
|
625
|
+
v_cart_act = v_cart_act / torch.linalg.norm(v_cart_act)
|
|
626
|
+
|
|
627
|
+
full = torch.zeros(3 * N, dtype=H_act.dtype, device=device)
|
|
628
|
+
mask_dof = _active_mask_dof(N, list(set(range(N)) - set(active_idx)))
|
|
629
|
+
mask_t = torch.as_tensor(mask_dof, dtype=torch.bool, device=device)
|
|
630
|
+
full[mask_t] = v_cart_act
|
|
631
|
+
mode = full.reshape(-1, 3).detach().cpu().numpy()
|
|
632
|
+
|
|
633
|
+
del coords_act, masses_act_au, masses_act_amu, m3, v_cart_act, full, mask_t, Hmw, u_mw
|
|
634
|
+
_clear_cuda_cache(H_act)
|
|
635
|
+
return mode
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def _representative_atoms_for_mode(mode: torch.Tensor, flatten_k: int) -> np.ndarray:
|
|
639
|
+
"""
|
|
640
|
+
Return indices of the top-k atoms with largest displacement norm in mode.
|
|
641
|
+
"""
|
|
642
|
+
vec = mode.reshape(-1, 3)
|
|
643
|
+
norms = torch.linalg.norm(vec, dim=1)
|
|
644
|
+
k = min(int(flatten_k), vec.shape[0])
|
|
645
|
+
if k <= 0:
|
|
646
|
+
return np.zeros(0, dtype=int)
|
|
647
|
+
topk = torch.topk(norms, k=k, largest=True)
|
|
648
|
+
return topk.indices.detach().cpu().numpy()
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
def _select_flatten_targets_for_geom(
|
|
652
|
+
freqs_cm: np.ndarray,
|
|
653
|
+
modes: torch.Tensor,
|
|
654
|
+
coords_bohr: np.ndarray,
|
|
655
|
+
neg_freq_thresh_cm: float,
|
|
656
|
+
root: int,
|
|
657
|
+
flatten_sep_cutoff: float,
|
|
658
|
+
flatten_k: int,
|
|
659
|
+
) -> List[int]:
|
|
660
|
+
"""
|
|
661
|
+
Select a subset of imaginary modes to flatten for a geometry.
|
|
662
|
+
"""
|
|
663
|
+
neg_idx_all = np.where(freqs_cm < -abs(neg_freq_thresh_cm))[0]
|
|
664
|
+
if len(neg_idx_all) <= 1:
|
|
665
|
+
return []
|
|
666
|
+
|
|
667
|
+
order = np.argsort(freqs_cm[neg_idx_all])
|
|
668
|
+
sorted_neg = neg_idx_all[order]
|
|
669
|
+
root_clamped = max(0, min(int(root), len(order) - 1))
|
|
670
|
+
primary_idx = sorted_neg[root_clamped]
|
|
671
|
+
candidates = [int(i) for i in sorted_neg if int(i) != int(primary_idx)]
|
|
672
|
+
if not candidates:
|
|
673
|
+
return []
|
|
674
|
+
|
|
675
|
+
coords_ang = torch.as_tensor(
|
|
676
|
+
coords_bohr.reshape(-1, 3) * BOHR2ANG,
|
|
677
|
+
dtype=modes.dtype,
|
|
678
|
+
device=modes.device,
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
targets: List[int] = []
|
|
682
|
+
reps_list: List[np.ndarray] = []
|
|
683
|
+
|
|
684
|
+
for idx in candidates:
|
|
685
|
+
rep = _representative_atoms_for_mode(modes[idx], flatten_k)
|
|
686
|
+
if rep.size == 0:
|
|
687
|
+
continue
|
|
688
|
+
rep_coords = coords_ang[rep]
|
|
689
|
+
if not reps_list:
|
|
690
|
+
targets.append(idx)
|
|
691
|
+
reps_list.append(rep)
|
|
692
|
+
continue
|
|
693
|
+
|
|
694
|
+
accept = True
|
|
695
|
+
for prev_rep in reps_list:
|
|
696
|
+
prev_coords = coords_ang[prev_rep]
|
|
697
|
+
dmat = torch.cdist(rep_coords, prev_coords)
|
|
698
|
+
min_dist = float(torch.min(dmat).item())
|
|
699
|
+
if min_dist < float(flatten_sep_cutoff):
|
|
700
|
+
accept = False
|
|
701
|
+
break
|
|
702
|
+
if accept:
|
|
703
|
+
targets.append(idx)
|
|
704
|
+
reps_list.append(rep)
|
|
705
|
+
|
|
706
|
+
return targets
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
def _flatten_once_with_modes_for_geom(
|
|
710
|
+
geom,
|
|
711
|
+
masses_amu: np.ndarray,
|
|
712
|
+
calc_kwargs: dict,
|
|
713
|
+
freqs_cm: np.ndarray,
|
|
714
|
+
modes: torch.Tensor,
|
|
715
|
+
neg_freq_thresh_cm: float,
|
|
716
|
+
flatten_amp_ang: float,
|
|
717
|
+
flatten_sep_cutoff: float,
|
|
718
|
+
flatten_k: int,
|
|
719
|
+
root: int,
|
|
720
|
+
) -> bool:
|
|
721
|
+
"""
|
|
722
|
+
Flatten extra imaginary modes for a geometry (single pass).
|
|
723
|
+
"""
|
|
724
|
+
neg_idx_all = np.where(freqs_cm < -abs(neg_freq_thresh_cm))[0]
|
|
725
|
+
if len(neg_idx_all) <= 1:
|
|
726
|
+
return False
|
|
727
|
+
|
|
728
|
+
targets = _select_flatten_targets_for_geom(
|
|
729
|
+
freqs_cm,
|
|
730
|
+
modes,
|
|
731
|
+
geom.cart_coords,
|
|
732
|
+
neg_freq_thresh_cm,
|
|
733
|
+
root,
|
|
734
|
+
flatten_sep_cutoff,
|
|
735
|
+
flatten_k,
|
|
736
|
+
)
|
|
737
|
+
if not targets:
|
|
738
|
+
return False
|
|
739
|
+
|
|
740
|
+
mass_scale = np.sqrt(12.011 / masses_amu)[:, None]
|
|
741
|
+
amp_bohr = float(flatten_amp_ang) / BOHR2ANG
|
|
742
|
+
|
|
743
|
+
for idx in targets:
|
|
744
|
+
v_mw = modes[idx].detach().cpu().numpy().reshape(-1, 3)
|
|
745
|
+
m3 = np.repeat(masses_amu, 3).reshape(-1, 3)
|
|
746
|
+
v_cart = v_mw / np.sqrt(m3)
|
|
747
|
+
v_cart /= np.linalg.norm(v_cart)
|
|
748
|
+
|
|
749
|
+
disp = amp_bohr * mass_scale * v_cart
|
|
750
|
+
ref = geom.cart_coords.reshape(-1, 3)
|
|
751
|
+
|
|
752
|
+
plus = ref + disp
|
|
753
|
+
minus = ref - disp
|
|
754
|
+
|
|
755
|
+
geom.coords = plus.reshape(-1)
|
|
756
|
+
E_plus = _calc_energy(geom, calc_kwargs)
|
|
757
|
+
|
|
758
|
+
geom.coords = minus.reshape(-1)
|
|
759
|
+
E_minus = _calc_energy(geom, calc_kwargs)
|
|
760
|
+
|
|
761
|
+
# Move towards lower energy
|
|
762
|
+
if E_plus <= E_minus:
|
|
763
|
+
geom.coords = plus.reshape(-1)
|
|
764
|
+
else:
|
|
765
|
+
geom.coords = minus.reshape(-1)
|
|
766
|
+
|
|
767
|
+
return True
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def _get_active_dof_indices(
|
|
771
|
+
calc_cfg: Dict[str, Any],
|
|
772
|
+
n_atoms: int,
|
|
773
|
+
active_dof_mode: str,
|
|
774
|
+
freeze_atoms_final: List[int],
|
|
775
|
+
) -> Optional[List[int]]:
|
|
776
|
+
del freeze_atoms_final # Kept for backward-compatible signature.
|
|
777
|
+
active_indices, _ = _resolve_active_atom_indices(calc_cfg, n_atoms, active_dof_mode)
|
|
778
|
+
if active_indices is None:
|
|
779
|
+
return None
|
|
780
|
+
return sorted(active_indices)
|
|
781
|
+
|
|
782
|
+
|
|
783
|
+
def _bofill_update_active(H_act: torch.Tensor,
|
|
784
|
+
delta_act: np.ndarray,
|
|
785
|
+
g_new_act: np.ndarray,
|
|
786
|
+
g_old_act: np.ndarray,
|
|
787
|
+
eps: float = 1e-12) -> torch.Tensor:
|
|
788
|
+
"""
|
|
789
|
+
Memory-efficient Bofill update on the *active* Cartesian Hessian block.
|
|
790
|
+
Apply symmetric rank-1/2 updates directly **in place** using only the **upper triangle**
|
|
791
|
+
index set (and mirror to the lower) to avoid allocating large NxN temporaries.
|
|
792
|
+
Explicit symmetrization is applied at eigendecomposition sites.
|
|
793
|
+
"""
|
|
794
|
+
device = H_act.device
|
|
795
|
+
dtype = H_act.dtype
|
|
796
|
+
|
|
797
|
+
# as torch vectors
|
|
798
|
+
d = torch.as_tensor(delta_act, dtype=dtype, device=device).reshape(-1)
|
|
799
|
+
g0 = torch.as_tensor(g_old_act, dtype=dtype, device=device).reshape(-1)
|
|
800
|
+
g1 = torch.as_tensor(g_new_act, dtype=dtype, device=device).reshape(-1)
|
|
801
|
+
y = g1 - g0
|
|
802
|
+
|
|
803
|
+
# Use current symmetric H_act for matvec (no extra allocation)
|
|
804
|
+
Hd = H_act @ d
|
|
805
|
+
xi = y - Hd
|
|
806
|
+
|
|
807
|
+
d_dot_xi = torch.dot(d, xi)
|
|
808
|
+
d_norm2 = torch.dot(d, d)
|
|
809
|
+
xi_norm2 = torch.dot(xi, xi)
|
|
810
|
+
|
|
811
|
+
# guards
|
|
812
|
+
if torch.abs(d_dot_xi) > eps:
|
|
813
|
+
denom_ms = d_dot_xi
|
|
814
|
+
else:
|
|
815
|
+
sign = torch.sign(d_dot_xi)
|
|
816
|
+
denom_ms = (sign if sign != 0 else torch.tensor(1.0, device=device)) * eps
|
|
817
|
+
denom_psb_d4 = d_norm2 * d_norm2 if d_norm2 > eps else eps
|
|
818
|
+
denom_psb_d2 = d_norm2 if d_norm2 > eps else eps
|
|
819
|
+
denom_phi = d_norm2 * xi_norm2 if (d_norm2 > eps and xi_norm2 > eps) else (1.0)
|
|
820
|
+
|
|
821
|
+
phi = 1.0 - (d_dot_xi * d_dot_xi) / denom_phi
|
|
822
|
+
phi = torch.clamp(phi, 0.0, 1.0)
|
|
823
|
+
|
|
824
|
+
# coefficients for rank updates
|
|
825
|
+
alpha = (1.0 - phi) / denom_ms # for xi xi^T
|
|
826
|
+
beta = - phi * (d_dot_xi / denom_psb_d4) # for d d^T
|
|
827
|
+
gamma = phi / denom_psb_d2 # for d xi^T + xi d^T
|
|
828
|
+
|
|
829
|
+
n = H_act.shape[0]
|
|
830
|
+
iu0, iu1 = torch.triu_indices(n, n, device=device)
|
|
831
|
+
is_diag = (iu0 == iu1)
|
|
832
|
+
off = ~is_diag
|
|
833
|
+
|
|
834
|
+
# Diagonal contributions (i == j): alpha*xi_i^2 + beta*d_i^2 + 2*gamma*d_i*xi_i
|
|
835
|
+
if is_diag.any():
|
|
836
|
+
idx = iu0[is_diag]
|
|
837
|
+
H_act[idx, idx].add_(alpha * xi[idx] * xi[idx]
|
|
838
|
+
+ beta * d[idx] * d[idx]
|
|
839
|
+
+ 2.0 * gamma * d[idx] * xi[idx])
|
|
840
|
+
|
|
841
|
+
# Off-diagonal (i < j): symmetric update
|
|
842
|
+
if off.any():
|
|
843
|
+
i = iu0[off]; j = iu1[off]
|
|
844
|
+
inc = (alpha * xi[i] * xi[j]
|
|
845
|
+
+ beta * d[i] * d[j]
|
|
846
|
+
+ gamma * (d[i] * xi[j] + xi[i] * d[j]))
|
|
847
|
+
H_act[i, j].add_(inc)
|
|
848
|
+
H_act[j, i].add_(inc)
|
|
849
|
+
|
|
850
|
+
return H_act
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
# ===================================================================
|
|
854
|
+
# HessianDimer (extended)
|
|
855
|
+
# ===================================================================
|
|
856
|
+
|
|
857
|
+
class HessianDimer:
|
|
858
|
+
"""
|
|
859
|
+
Dimer-based TS search with periodic Hessian updates.
|
|
860
|
+
|
|
861
|
+
Extensions in this implementation:
|
|
862
|
+
- `root` parameter: choose which imaginary mode to follow (0 = most negative).
|
|
863
|
+
- Pass-through kwargs: `dimer_kwargs` and `lbfgs_kwargs` to tune internals.
|
|
864
|
+
- Hard cap on total LBFGS steps across segments: `max_total_cycles`.
|
|
865
|
+
- PHVA (active DOF subspace) + TR projection for mode picking,
|
|
866
|
+
respecting ``freeze_atoms``, with in-place operations. When ``root == 0`` the
|
|
867
|
+
implementation prefers LOBPCG.
|
|
868
|
+
- The flatten loop uses a *Bofill*-updated active Hessian block, so the
|
|
869
|
+
expensive exact Hessian is evaluated only once before the flatten loop and
|
|
870
|
+
once at the end for the final frequency analysis.
|
|
871
|
+
- UMA calculator kwargs accept ``freeze_atoms`` and ``hessian_calc_mode`` and
|
|
872
|
+
default to ``return_partial_hessian=True`` (active-block Hessian when frozen).
|
|
873
|
+
"""
|
|
874
|
+
|
|
875
|
+
def __init__(self,
|
|
876
|
+
fn: str,
|
|
877
|
+
out_dir: str = "./result_dimer",
|
|
878
|
+
thresh_loose: str = "gau_loose",
|
|
879
|
+
thresh: str = "baker",
|
|
880
|
+
update_interval_hessian: int = 500,
|
|
881
|
+
neg_freq_thresh_cm: float = 5.0,
|
|
882
|
+
flatten_amp_ang: float = 0.10,
|
|
883
|
+
flatten_max_iter: int = 20,
|
|
884
|
+
mem: int = 100000,
|
|
885
|
+
use_lobpcg: bool = True, # kept for backward compat (not used when root!=0)
|
|
886
|
+
calc_kwargs: Optional[dict] = None,
|
|
887
|
+
device: str = "auto",
|
|
888
|
+
dump: bool = False,
|
|
889
|
+
#
|
|
890
|
+
# New:
|
|
891
|
+
root: int = 0,
|
|
892
|
+
dimer_kwargs: Optional[Dict[str, Any]] = None,
|
|
893
|
+
lbfgs_kwargs: Optional[Dict[str, Any]] = None,
|
|
894
|
+
max_total_cycles: int = 10000,
|
|
895
|
+
#
|
|
896
|
+
# Pass geom kwargs so freeze-atoms and YAML geometry overrides apply on the light path (fix #1)
|
|
897
|
+
geom_kwargs: Optional[Dict[str, Any]] = None,
|
|
898
|
+
# New: Use partial Hessian for imaginary mode detection in flatten loop
|
|
899
|
+
partial_hessian_flatten: bool = True,
|
|
900
|
+
# Spatial separation for flatten mode selection (from pdb2reaction)
|
|
901
|
+
flatten_sep_cutoff: float = 0.0,
|
|
902
|
+
flatten_k: int = 10,
|
|
903
|
+
flatten_loop_bofill: bool = False,
|
|
904
|
+
ml_only_hessian_dimer: bool = False,
|
|
905
|
+
source_path: Optional[Path] = None,
|
|
906
|
+
) -> None:
|
|
907
|
+
|
|
908
|
+
self.fn = fn
|
|
909
|
+
self.source_path = Path(source_path) if source_path is not None else None
|
|
910
|
+
self.out_dir = Path(out_dir); self.out_dir.mkdir(parents=True, exist_ok=True)
|
|
911
|
+
self.vib_dir = self.out_dir / "vib"; self.vib_dir.mkdir(parents=True, exist_ok=True)
|
|
912
|
+
|
|
913
|
+
self.thresh_loose = thresh_loose
|
|
914
|
+
self.thresh = thresh
|
|
915
|
+
self.update_interval_hessian = int(update_interval_hessian)
|
|
916
|
+
self.neg_freq_thresh_cm = float(neg_freq_thresh_cm)
|
|
917
|
+
self.flatten_amp_ang = float(flatten_amp_ang)
|
|
918
|
+
self.flatten_max_iter = int(flatten_max_iter)
|
|
919
|
+
self.mem = int(mem)
|
|
920
|
+
self.use_lobpcg = bool(use_lobpcg) # used only when root==0 shortcut
|
|
921
|
+
self.root = int(root)
|
|
922
|
+
self.dimer_kwargs = dict(dimer_kwargs or {})
|
|
923
|
+
self.lbfgs_kwargs = dict(lbfgs_kwargs or {})
|
|
924
|
+
self.max_total_cycles = int(max_total_cycles)
|
|
925
|
+
self.partial_hessian_flatten = bool(partial_hessian_flatten)
|
|
926
|
+
# Spatial separation for flatten mode selection
|
|
927
|
+
self.flatten_sep_cutoff = float(flatten_sep_cutoff)
|
|
928
|
+
self.flatten_k = int(flatten_k)
|
|
929
|
+
self.flatten_loop_bofill = bool(flatten_loop_bofill)
|
|
930
|
+
self.ml_only_hessian_dimer = bool(ml_only_hessian_dimer)
|
|
931
|
+
|
|
932
|
+
# Track total cycles globally across ALL loops/segments (fix #2)
|
|
933
|
+
self._cycles_spent = 0
|
|
934
|
+
|
|
935
|
+
# Hessian caching for 0-step convergence (avoid redundant recalculation)
|
|
936
|
+
self._raw_hessian_cache_cpu: Optional[torch.Tensor] = None
|
|
937
|
+
self._raw_hessian_coords_cpu: Optional[np.ndarray] = None
|
|
938
|
+
self._last_active_idx: Optional[List[int]] = None
|
|
939
|
+
self._last_active_mask_dof: Optional[np.ndarray] = None
|
|
940
|
+
|
|
941
|
+
# ML/MM calculator settings
|
|
942
|
+
self.calc_kwargs = dict(calc_kwargs or {})
|
|
943
|
+
self.calc_kwargs.setdefault("out_hess_torch", False)
|
|
944
|
+
|
|
945
|
+
# Geometry & masses (use provided geom kwargs so freeze_atoms etc. apply)
|
|
946
|
+
gkw = dict(geom_kwargs or {})
|
|
947
|
+
coord_type = gkw.pop("coord_type", "cart")
|
|
948
|
+
freeze_geom = list(gkw.get("freeze_atoms", [])) if "freeze_atoms" in gkw else []
|
|
949
|
+
freeze_calc_raw = self.calc_kwargs.get("freeze_atoms") or []
|
|
950
|
+
try:
|
|
951
|
+
freeze_calc = [int(i) for i in freeze_calc_raw]
|
|
952
|
+
except TypeError:
|
|
953
|
+
freeze_calc = [int(freeze_calc_raw)]
|
|
954
|
+
merged_freeze = sorted({int(i) for i in (freeze_geom + freeze_calc)})
|
|
955
|
+
if merged_freeze:
|
|
956
|
+
gkw["freeze_atoms"] = merged_freeze
|
|
957
|
+
elif "freeze_atoms" in gkw:
|
|
958
|
+
gkw["freeze_atoms"] = []
|
|
959
|
+
self.calc_kwargs["freeze_atoms"] = merged_freeze
|
|
960
|
+
|
|
961
|
+
self.calc_kwargs_partial = dict(self.calc_kwargs)
|
|
962
|
+
self.calc_kwargs_partial["mm_fd"] = False
|
|
963
|
+
self.calc_kwargs_partial["return_partial_hessian"] = False
|
|
964
|
+
self.calc_kwargs_partial["out_hess_torch"] = True
|
|
965
|
+
self.calc_kwargs_full = dict(self.calc_kwargs)
|
|
966
|
+
self.calc_kwargs_full.setdefault("mm_fd", True)
|
|
967
|
+
self.calc_kwargs_full["return_partial_hessian"] = False
|
|
968
|
+
self.calc_kwargs_full["out_hess_torch"] = True
|
|
969
|
+
# ML-only Hessian kwargs: skip MM Hessian entirely, use ML partial Hessian only
|
|
970
|
+
self.calc_kwargs_ml_only = dict(self.calc_kwargs)
|
|
971
|
+
self.calc_kwargs_ml_only["mm_fd"] = False
|
|
972
|
+
self.calc_kwargs_ml_only["return_partial_hessian"] = True
|
|
973
|
+
self.calc_kwargs_ml_only["out_hess_torch"] = True
|
|
974
|
+
self.calc_kwargs_ml_only["hess_cutoff"] = 0.0 # ML atoms only in Hessian
|
|
975
|
+
self.geom = geom_loader(fn, coord_type=coord_type, **gkw)
|
|
976
|
+
# If partial Hessian is requested (explicitly or via B-factor layers),
|
|
977
|
+
# avoid full 3N Hessian allocations in light TS dimer runs.
|
|
978
|
+
if self.calc_kwargs.get("return_partial_hessian"):
|
|
979
|
+
self.calc_kwargs_partial["return_partial_hessian"] = True
|
|
980
|
+
self.calc_kwargs_full["return_partial_hessian"] = True
|
|
981
|
+
elif self.partial_hessian_flatten and self.calc_kwargs.get("use_bfactor_layers"):
|
|
982
|
+
self.calc_kwargs_partial["return_partial_hessian"] = True
|
|
983
|
+
self.calc_kwargs_full["return_partial_hessian"] = True
|
|
984
|
+
self.masses_amu = np.array([atomic_masses[z] for z in self.geom.atomic_numbers])
|
|
985
|
+
self.masses_au_t = torch.as_tensor(self.masses_amu * AMU2AU, dtype=torch.float32)
|
|
986
|
+
|
|
987
|
+
# --- Preserve freeze list (for PHVA) ---
|
|
988
|
+
self.freeze_atoms: List[int] = list(gkw.get("freeze_atoms", [])) if "freeze_atoms" in gkw else []
|
|
989
|
+
|
|
990
|
+
# Device
|
|
991
|
+
self.device = _torch_device(device)
|
|
992
|
+
self.masses_au_t = self.masses_au_t.to(self.device)
|
|
993
|
+
|
|
994
|
+
# temp file for Dimer orientation (N_raw)
|
|
995
|
+
self.mode_path = self.out_dir / ".dimer_mode.dat"
|
|
996
|
+
|
|
997
|
+
self.dump = bool(dump)
|
|
998
|
+
self.optim_all_path = self.out_dir / "optimization_all_trj.xyz"
|
|
999
|
+
|
|
1000
|
+
# ----- One dimer segment for up to n_steps; returns (steps_done, converged) -----
|
|
1001
|
+
def _dimer_segment(self, threshold: str, n_steps: int) -> Tuple[int, bool]:
|
|
1002
|
+
# Dimer calculator using current mode as initial N
|
|
1003
|
+
calc_sp = mlmm(**self.calc_kwargs)
|
|
1004
|
+
|
|
1005
|
+
# Merge user dimer kwargs (but enforce N_raw & write_orientations)
|
|
1006
|
+
dimer_kwargs = dict(self.dimer_kwargs)
|
|
1007
|
+
dimer_kwargs.update({
|
|
1008
|
+
"calculator": calc_sp,
|
|
1009
|
+
"N_raw": str(self.mode_path),
|
|
1010
|
+
"write_orientations": False, # runner override to reduce IO
|
|
1011
|
+
"seed": 0, # runner override for determinism
|
|
1012
|
+
"mem": self.mem, # accepted by Calculator base through **kwargs
|
|
1013
|
+
})
|
|
1014
|
+
dimer = Dimer(**dimer_kwargs)
|
|
1015
|
+
|
|
1016
|
+
self.geom.set_calculator(dimer)
|
|
1017
|
+
|
|
1018
|
+
# LBFGS kwargs: enforce thresh/max_cycles/out_dir/dump; allow others
|
|
1019
|
+
lbfgs_kwargs = dict(self.lbfgs_kwargs)
|
|
1020
|
+
lbfgs_kwargs.update({
|
|
1021
|
+
"max_cycles": n_steps,
|
|
1022
|
+
"thresh": threshold,
|
|
1023
|
+
"out_dir": str(self.out_dir),
|
|
1024
|
+
"dump": self.dump,
|
|
1025
|
+
})
|
|
1026
|
+
opt = LBFGS(self.geom, **lbfgs_kwargs)
|
|
1027
|
+
opt.run()
|
|
1028
|
+
# pysisyphus uses 0-indexed cur_cycle; keep budget accounting strict by clamping
|
|
1029
|
+
# to the requested segment step count.
|
|
1030
|
+
steps = min(max(int(opt.cur_cycle) + 1, 1), int(n_steps))
|
|
1031
|
+
converged = opt.is_converged
|
|
1032
|
+
self.geom.set_calculator(None)
|
|
1033
|
+
|
|
1034
|
+
# Free dimer/optimizer GPU resources before next Hessian computation
|
|
1035
|
+
del calc_sp, dimer, opt
|
|
1036
|
+
if torch.cuda.is_available():
|
|
1037
|
+
torch.cuda.empty_cache()
|
|
1038
|
+
|
|
1039
|
+
# Append to concatenated trajectory if dump enabled
|
|
1040
|
+
if self.dump:
|
|
1041
|
+
_append_xyz_trajectory(self.optim_all_path, self.out_dir / "optimization_trj.xyz")
|
|
1042
|
+
return steps, converged
|
|
1043
|
+
|
|
1044
|
+
# ----- Hessian caching for 0-step convergence -----
|
|
1045
|
+
def _cache_raw_hessian_cpu(self, H: torch.Tensor) -> None:
|
|
1046
|
+
"""Cache the raw Hessian on CPU for the current geometry."""
|
|
1047
|
+
self._raw_hessian_cache_cpu = H.detach().cpu().clone()
|
|
1048
|
+
self._raw_hessian_coords_cpu = self.geom.cart_coords.copy()
|
|
1049
|
+
|
|
1050
|
+
def _reuse_cached_hessian(self) -> Optional[torch.Tensor]:
|
|
1051
|
+
"""If the cached geometry matches current, return cached Hessian on device."""
|
|
1052
|
+
if self._raw_hessian_cache_cpu is None or self._raw_hessian_coords_cpu is None:
|
|
1053
|
+
return None
|
|
1054
|
+
if not np.array_equal(self.geom.cart_coords, self._raw_hessian_coords_cpu):
|
|
1055
|
+
return None
|
|
1056
|
+
H_dev = self._raw_hessian_cache_cpu.to(self.device)
|
|
1057
|
+
if self.device.type == "cpu":
|
|
1058
|
+
H_dev = H_dev.clone()
|
|
1059
|
+
return H_dev
|
|
1060
|
+
|
|
1061
|
+
def _calc_full_hessian_cached(
|
|
1062
|
+
self, calc_kwargs: Dict[str, Any], allow_reuse: bool
|
|
1063
|
+
) -> torch.Tensor:
|
|
1064
|
+
"""Compute Hessian, caching on CPU. Reuse if allow_reuse and geometry unchanged."""
|
|
1065
|
+
if allow_reuse:
|
|
1066
|
+
cached = self._reuse_cached_hessian()
|
|
1067
|
+
if cached is not None:
|
|
1068
|
+
click.echo("[tsopt] Reusing cached raw Hessian (0-step convergence).")
|
|
1069
|
+
return cached
|
|
1070
|
+
H = _calc_full_hessian_torch(self.geom, calc_kwargs, self.device)
|
|
1071
|
+
self._cache_raw_hessian_cpu(H)
|
|
1072
|
+
return H
|
|
1073
|
+
|
|
1074
|
+
def _resolve_hessian_active_subspace(self, H_t: torch.Tensor, N: int) -> Tuple[List[int], np.ndarray]:
|
|
1075
|
+
"""
|
|
1076
|
+
Resolve active atoms/DOFs for a Hessian tensor.
|
|
1077
|
+
For partial Hessians, prefer geometry metadata populated by the calculator.
|
|
1078
|
+
"""
|
|
1079
|
+
h_dim = int(H_t.size(0))
|
|
1080
|
+
full_dim = 3 * int(N)
|
|
1081
|
+
freeze = self.freeze_atoms if len(self.freeze_atoms) > 0 else []
|
|
1082
|
+
|
|
1083
|
+
if h_dim == full_dim:
|
|
1084
|
+
active_idx = _active_indices(N, freeze)
|
|
1085
|
+
mask_dof = _active_mask_dof(N, freeze)
|
|
1086
|
+
self._last_active_idx = list(active_idx)
|
|
1087
|
+
self._last_active_mask_dof = mask_dof.copy()
|
|
1088
|
+
return active_idx, mask_dof
|
|
1089
|
+
|
|
1090
|
+
def _norm_atoms(vals: Optional[Any]) -> np.ndarray:
|
|
1091
|
+
if vals is None:
|
|
1092
|
+
return np.zeros(0, dtype=int)
|
|
1093
|
+
arr = np.asarray(vals, dtype=int).reshape(-1)
|
|
1094
|
+
return arr[(arr >= 0) & (arr < N)]
|
|
1095
|
+
|
|
1096
|
+
def _norm_dofs(vals: Optional[Any]) -> np.ndarray:
|
|
1097
|
+
if vals is None:
|
|
1098
|
+
return np.zeros(0, dtype=int)
|
|
1099
|
+
arr = np.asarray(vals, dtype=int).reshape(-1)
|
|
1100
|
+
return arr[(arr >= 0) & (arr < full_dim)]
|
|
1101
|
+
|
|
1102
|
+
def _stable_unique(vals: np.ndarray) -> np.ndarray:
|
|
1103
|
+
seen = set()
|
|
1104
|
+
out: List[int] = []
|
|
1105
|
+
for v in vals.tolist():
|
|
1106
|
+
iv = int(v)
|
|
1107
|
+
if iv not in seen:
|
|
1108
|
+
seen.add(iv)
|
|
1109
|
+
out.append(iv)
|
|
1110
|
+
return np.asarray(out, dtype=int)
|
|
1111
|
+
|
|
1112
|
+
candidates: List[Tuple[str, np.ndarray, np.ndarray]] = []
|
|
1113
|
+
try:
|
|
1114
|
+
candidates.append((
|
|
1115
|
+
"geom.hess_active_*",
|
|
1116
|
+
_norm_atoms(self.geom.hess_active_atom_indices),
|
|
1117
|
+
_norm_dofs(self.geom.hess_active_dof_indices),
|
|
1118
|
+
))
|
|
1119
|
+
except Exception:
|
|
1120
|
+
logger.debug("Failed to read hess_active_* indices", exc_info=True)
|
|
1121
|
+
|
|
1122
|
+
within = getattr(self.geom, "within_partial_hessian", None)
|
|
1123
|
+
if isinstance(within, dict):
|
|
1124
|
+
candidates.append((
|
|
1125
|
+
"geom.within_partial_hessian",
|
|
1126
|
+
_norm_atoms(within.get("active_atoms")),
|
|
1127
|
+
_norm_dofs(within.get("active_dofs")),
|
|
1128
|
+
))
|
|
1129
|
+
|
|
1130
|
+
candidates.append((
|
|
1131
|
+
"geom._hess_active_*_last",
|
|
1132
|
+
_norm_atoms(getattr(self.geom, "_hess_active_atoms_last", None)),
|
|
1133
|
+
_norm_dofs(getattr(self.geom, "_hess_active_dofs_last", None)),
|
|
1134
|
+
))
|
|
1135
|
+
|
|
1136
|
+
if self._last_active_idx is not None or self._last_active_mask_dof is not None:
|
|
1137
|
+
cached_atoms = _norm_atoms(self._last_active_idx)
|
|
1138
|
+
cached_dofs = np.flatnonzero(self._last_active_mask_dof).astype(int) \
|
|
1139
|
+
if self._last_active_mask_dof is not None else np.zeros(0, dtype=int)
|
|
1140
|
+
candidates.append(("cached_active_subspace", cached_atoms, _norm_dofs(cached_dofs)))
|
|
1141
|
+
|
|
1142
|
+
fallback_atoms = _norm_atoms(_active_indices(N, freeze))
|
|
1143
|
+
fallback_dofs = np.flatnonzero(_active_mask_dof(N, freeze)).astype(int)
|
|
1144
|
+
candidates.append(("freeze_based", fallback_atoms, _norm_dofs(fallback_dofs)))
|
|
1145
|
+
|
|
1146
|
+
for _, atoms_arr, dofs_arr in candidates:
|
|
1147
|
+
if dofs_arr.size > 0:
|
|
1148
|
+
mask_dof = np.zeros(full_dim, dtype=bool)
|
|
1149
|
+
mask_dof[dofs_arr] = True
|
|
1150
|
+
elif atoms_arr.size > 0:
|
|
1151
|
+
mask_dof = _mask_dof_from_active_idx(N, atoms_arr.tolist())
|
|
1152
|
+
else:
|
|
1153
|
+
continue
|
|
1154
|
+
|
|
1155
|
+
if int(mask_dof.sum()) != h_dim:
|
|
1156
|
+
continue
|
|
1157
|
+
|
|
1158
|
+
if dofs_arr.size > 0:
|
|
1159
|
+
atoms_arr = _stable_unique((dofs_arr // 3).astype(int))
|
|
1160
|
+
elif atoms_arr.size == 0:
|
|
1161
|
+
atoms_arr = _stable_unique((np.flatnonzero(mask_dof) // 3).astype(int))
|
|
1162
|
+
active_idx = [int(i) for i in atoms_arr.tolist()]
|
|
1163
|
+
self._last_active_idx = list(active_idx)
|
|
1164
|
+
self._last_active_mask_dof = mask_dof.copy()
|
|
1165
|
+
return active_idx, mask_dof
|
|
1166
|
+
|
|
1167
|
+
raise RuntimeError(
|
|
1168
|
+
f"Failed to resolve active subspace for partial Hessian: "
|
|
1169
|
+
f"H_dim={h_dim}, full_dim={full_dim}, freeze_active_dof={int(fallback_dofs.size)}"
|
|
1170
|
+
)
|
|
1171
|
+
|
|
1172
|
+
# ----- Loop dimer segments, updating mode from Hessian every interval -----
|
|
1173
|
+
def _dimer_loop(self, threshold: str) -> Tuple[int, bool]:
|
|
1174
|
+
"""
|
|
1175
|
+
Run multiple LBFGS segments separated by periodic Hessian-based mode updates.
|
|
1176
|
+
Consumes from a *global* cycle budget self.max_total_cycles.
|
|
1177
|
+
|
|
1178
|
+
Returns:
|
|
1179
|
+
(steps_in_this_call, zero_step_converged)
|
|
1180
|
+
where `zero_step_converged` is True iff the loop terminated by convergence
|
|
1181
|
+
without changing the geometry (i.e., 0-step convergence).
|
|
1182
|
+
"""
|
|
1183
|
+
steps_in_this_call = 0
|
|
1184
|
+
zero_step_converged = False
|
|
1185
|
+
while True:
|
|
1186
|
+
remaining_global = max(0, self.max_total_cycles - self._cycles_spent)
|
|
1187
|
+
if remaining_global == 0:
|
|
1188
|
+
break
|
|
1189
|
+
steps_this = min(self.update_interval_hessian, remaining_global)
|
|
1190
|
+
coords_before = self.geom.cart_coords.copy()
|
|
1191
|
+
steps, ok = self._dimer_segment(threshold, steps_this)
|
|
1192
|
+
self._cycles_spent += steps
|
|
1193
|
+
steps_in_this_call += steps
|
|
1194
|
+
if ok:
|
|
1195
|
+
# Check if geometry unchanged (0-step convergence)
|
|
1196
|
+
if np.array_equal(self.geom.cart_coords, coords_before):
|
|
1197
|
+
zero_step_converged = True
|
|
1198
|
+
break
|
|
1199
|
+
# If budget exhausted after this segment, stop before doing a Hessian update
|
|
1200
|
+
if (self.max_total_cycles - self._cycles_spent) <= 0:
|
|
1201
|
+
break
|
|
1202
|
+
# Update mode from Hessian (respect freeze atoms via PHVA)
|
|
1203
|
+
# Ensure VRAM is fully released after dimer segment before heavy Hessian computation
|
|
1204
|
+
if torch.cuda.is_available():
|
|
1205
|
+
torch.cuda.empty_cache()
|
|
1206
|
+
# Choose ML-only or full active-DOF Hessian for mode direction
|
|
1207
|
+
hess_kw = self.calc_kwargs_ml_only if self.ml_only_hessian_dimer else self.calc_kwargs_partial
|
|
1208
|
+
H_t = _calc_full_hessian_torch(self.geom, hess_kw, self.device)
|
|
1209
|
+
N = len(self.geom.atomic_numbers)
|
|
1210
|
+
coords_bohr_t = torch.as_tensor(self.geom.coords.reshape(-1, 3),
|
|
1211
|
+
dtype=H_t.dtype, device=H_t.device)
|
|
1212
|
+
# full vs active-block Hessian
|
|
1213
|
+
if H_t.size(0) == 3 * N:
|
|
1214
|
+
mode_xyz = _mode_direction_by_root(
|
|
1215
|
+
H_t, coords_bohr_t, self.masses_au_t,
|
|
1216
|
+
root=self.root, freeze_idx=self.freeze_atoms if len(self.freeze_atoms) > 0 else None
|
|
1217
|
+
)
|
|
1218
|
+
else:
|
|
1219
|
+
# partial (active) Hessian returned by UMA
|
|
1220
|
+
active_idx, _ = self._resolve_hessian_active_subspace(H_t, N)
|
|
1221
|
+
mode_xyz = _mode_direction_by_root_from_Hact(
|
|
1222
|
+
H_t, self.geom.coords.reshape(-1, 3), self.geom.atomic_numbers,
|
|
1223
|
+
self.masses_au_t, active_idx, self.device, root=self.root
|
|
1224
|
+
)
|
|
1225
|
+
np.savetxt(self.mode_path, mode_xyz, fmt="%.12f")
|
|
1226
|
+
del H_t, coords_bohr_t, mode_xyz
|
|
1227
|
+
_clear_cuda_cache()
|
|
1228
|
+
return steps_in_this_call, zero_step_converged
|
|
1229
|
+
|
|
1230
|
+
# ----- Flatten (resolve multiple imaginary modes) -----
|
|
1231
|
+
def _flatten_once(self) -> bool:
|
|
1232
|
+
"""
|
|
1233
|
+
Legacy: exact-Hessian-based flattening (kept for reference / fallback).
|
|
1234
|
+
"""
|
|
1235
|
+
H_t = _calc_full_hessian_torch(self.geom, self.calc_kwargs_full, self.device)
|
|
1236
|
+
freqs_cm, modes = _frequencies_cm_and_modes(
|
|
1237
|
+
H_t, self.geom.atomic_numbers, self.geom.coords.reshape(-1, 3), self.device,
|
|
1238
|
+
freeze_idx=self.freeze_atoms if len(self.freeze_atoms) > 0 else None
|
|
1239
|
+
)
|
|
1240
|
+
del H_t
|
|
1241
|
+
neg_idx_all = np.where(freqs_cm < -abs(self.neg_freq_thresh_cm))[0]
|
|
1242
|
+
if len(neg_idx_all) <= 1:
|
|
1243
|
+
del modes
|
|
1244
|
+
return False
|
|
1245
|
+
|
|
1246
|
+
# Identify the "primary" imaginary by root among negative modes
|
|
1247
|
+
order = np.argsort(freqs_cm[neg_idx_all]) # ascending (more negative first)
|
|
1248
|
+
root_clamped = max(0, min(self.root, len(order) - 1))
|
|
1249
|
+
primary_idx = neg_idx_all[order[root_clamped]]
|
|
1250
|
+
|
|
1251
|
+
targets = [i for i in neg_idx_all if i != primary_idx]
|
|
1252
|
+
if not targets:
|
|
1253
|
+
del modes
|
|
1254
|
+
return False
|
|
1255
|
+
|
|
1256
|
+
# Reference structure and energy
|
|
1257
|
+
ref = self.geom.coords.reshape(-1, 3).copy()
|
|
1258
|
+
_ = _calc_energy(self.geom, self.calc_kwargs) # E_ref (unused, but keeps semantics)
|
|
1259
|
+
|
|
1260
|
+
# mass scaling so that carbon ~ amplitude
|
|
1261
|
+
mass_scale = np.sqrt(12.011 / self.masses_amu)[:, None]
|
|
1262
|
+
amp_bohr = self.flatten_amp_ang / BOHR2ANG
|
|
1263
|
+
|
|
1264
|
+
disp_total = np.zeros_like(ref)
|
|
1265
|
+
for idx in targets:
|
|
1266
|
+
v_mw = modes[idx].detach().cpu().numpy().reshape(-1, 3) # mass-weighted eigenvector embedded to 3N
|
|
1267
|
+
# Convert to Cartesian step direction already done downstream in writer,
|
|
1268
|
+
# but for flattening we only need a normalized direction in Cartesian:
|
|
1269
|
+
# use masses to unweight:
|
|
1270
|
+
m3 = np.repeat(self.masses_amu, 3).reshape(-1, 3)
|
|
1271
|
+
v_cart = v_mw / np.sqrt(m3)
|
|
1272
|
+
v_cart /= np.linalg.norm(v_cart)
|
|
1273
|
+
disp0 = amp_bohr * mass_scale * v_cart
|
|
1274
|
+
|
|
1275
|
+
self.geom.coords = (ref + disp0).reshape(-1)
|
|
1276
|
+
E_plus = _calc_energy(self.geom, self.calc_kwargs)
|
|
1277
|
+
self.geom.coords = (ref - disp0).reshape(-1)
|
|
1278
|
+
E_minus = _calc_energy(self.geom, self.calc_kwargs)
|
|
1279
|
+
self.geom.coords = ref.reshape(-1)
|
|
1280
|
+
|
|
1281
|
+
disp_total += (disp0 if E_plus <= E_minus else -disp0)
|
|
1282
|
+
|
|
1283
|
+
del modes
|
|
1284
|
+
_clear_cuda_cache()
|
|
1285
|
+
|
|
1286
|
+
self.geom.coords = (ref + disp_total).reshape(-1)
|
|
1287
|
+
return True
|
|
1288
|
+
|
|
1289
|
+
def _flatten_once_with_modes(self, freqs_cm: np.ndarray, modes: torch.Tensor) -> bool:
|
|
1290
|
+
"""
|
|
1291
|
+
Flatten using precomputed (approximate) modes (mass-weighted, embedded).
|
|
1292
|
+
|
|
1293
|
+
Uses spatial separation (if flatten_sep_cutoff > 0) to select only modes
|
|
1294
|
+
whose representative atoms are well-separated from each other. This avoids
|
|
1295
|
+
applying conflicting displacements to nearby regions. Modes are applied
|
|
1296
|
+
sequentially, updating the reference position after each mode.
|
|
1297
|
+
"""
|
|
1298
|
+
neg_idx_all = np.where(freqs_cm < -abs(self.neg_freq_thresh_cm))[0]
|
|
1299
|
+
if len(neg_idx_all) <= 1:
|
|
1300
|
+
return False
|
|
1301
|
+
|
|
1302
|
+
# Use spatial separation if cutoff > 0, otherwise select all non-primary modes
|
|
1303
|
+
if self.flatten_sep_cutoff > 0:
|
|
1304
|
+
targets = _select_flatten_targets_for_geom(
|
|
1305
|
+
freqs_cm,
|
|
1306
|
+
modes,
|
|
1307
|
+
self.geom.cart_coords,
|
|
1308
|
+
self.neg_freq_thresh_cm,
|
|
1309
|
+
self.root,
|
|
1310
|
+
self.flatten_sep_cutoff,
|
|
1311
|
+
self.flatten_k,
|
|
1312
|
+
)
|
|
1313
|
+
else:
|
|
1314
|
+
# Legacy behavior: select all imaginary modes except primary
|
|
1315
|
+
order = np.argsort(freqs_cm[neg_idx_all])
|
|
1316
|
+
root_clamped = max(0, min(self.root, len(order) - 1))
|
|
1317
|
+
primary_idx = neg_idx_all[order[root_clamped]]
|
|
1318
|
+
targets = [i for i in neg_idx_all if i != primary_idx]
|
|
1319
|
+
|
|
1320
|
+
if not targets:
|
|
1321
|
+
return False
|
|
1322
|
+
|
|
1323
|
+
# Mass scaling (carbon moves exactly flatten_amp_ang Å)
|
|
1324
|
+
mass_scale = np.sqrt(12.011 / self.masses_amu)[:, None]
|
|
1325
|
+
amp_bohr = self.flatten_amp_ang / BOHR2ANG
|
|
1326
|
+
|
|
1327
|
+
# Get reference energy
|
|
1328
|
+
E_ref = _calc_energy(self.geom, self.calc_kwargs)
|
|
1329
|
+
|
|
1330
|
+
# Apply modes sequentially (like pdb2reaction)
|
|
1331
|
+
for idx in targets:
|
|
1332
|
+
v_mw = modes[idx].detach().cpu().numpy().reshape(-1, 3)
|
|
1333
|
+
m3 = np.repeat(self.masses_amu, 3).reshape(-1, 3)
|
|
1334
|
+
v_cart = v_mw / np.sqrt(m3)
|
|
1335
|
+
v_cart /= np.linalg.norm(v_cart)
|
|
1336
|
+
|
|
1337
|
+
disp = amp_bohr * mass_scale * v_cart
|
|
1338
|
+
ref = self.geom.coords.reshape(-1, 3)
|
|
1339
|
+
|
|
1340
|
+
plus = ref + disp
|
|
1341
|
+
minus = ref - disp
|
|
1342
|
+
|
|
1343
|
+
self.geom.coords = plus.reshape(-1)
|
|
1344
|
+
E_plus = _calc_energy(self.geom, self.calc_kwargs)
|
|
1345
|
+
|
|
1346
|
+
self.geom.coords = minus.reshape(-1)
|
|
1347
|
+
E_minus = _calc_energy(self.geom, self.calc_kwargs)
|
|
1348
|
+
|
|
1349
|
+
# Keep lower-energy side and continue from there
|
|
1350
|
+
use_plus = E_plus <= E_minus
|
|
1351
|
+
self.geom.coords = (plus if use_plus else minus).reshape(-1)
|
|
1352
|
+
E_keep = E_plus if use_plus else E_minus
|
|
1353
|
+
delta_e = E_keep - E_ref
|
|
1354
|
+
click.echo(
|
|
1355
|
+
f"[Flatten] mode={idx} freq={freqs_cm[idx]:+.2f} cm^-1 "
|
|
1356
|
+
f"E_disp={E_keep:.8f} Ha ΔE={delta_e:+.8f} Ha"
|
|
1357
|
+
)
|
|
1358
|
+
|
|
1359
|
+
_clear_cuda_cache()
|
|
1360
|
+
return True
|
|
1361
|
+
|
|
1362
|
+
# ----- Run full procedure -----
|
|
1363
|
+
def run(self) -> None:
|
|
1364
|
+
if self.dump and self.optim_all_path.exists():
|
|
1365
|
+
self.optim_all_path.unlink()
|
|
1366
|
+
|
|
1367
|
+
N = len(self.geom.atomic_numbers)
|
|
1368
|
+
H_final_reuse_cpu: Optional[torch.Tensor] = None
|
|
1369
|
+
H_final_reuse_coords: Optional[np.ndarray] = None
|
|
1370
|
+
|
|
1371
|
+
# (1) Initial Hessian → pick direction by `root`
|
|
1372
|
+
hess_kw_init = self.calc_kwargs_ml_only if self.ml_only_hessian_dimer else self.calc_kwargs_partial
|
|
1373
|
+
if self.ml_only_hessian_dimer:
|
|
1374
|
+
click.echo("[tsopt] Using ML-only Hessian for dimer orientation.")
|
|
1375
|
+
H_t = _calc_full_hessian_torch(self.geom, hess_kw_init, self.device)
|
|
1376
|
+
coords_bohr_t = torch.as_tensor(self.geom.coords.reshape(-1, 3),
|
|
1377
|
+
dtype=H_t.dtype, device=H_t.device)
|
|
1378
|
+
active_idx, mask_dof = self._resolve_hessian_active_subspace(H_t, N)
|
|
1379
|
+
if H_t.size(0) != 3 * N:
|
|
1380
|
+
click.echo(
|
|
1381
|
+
f"[tsopt] H_act={int(H_t.size(0))} active_atoms={len(active_idx)} "
|
|
1382
|
+
f"active_dofs={int(mask_dof.sum())} within={self.geom.within_partial_hessian is not None}"
|
|
1383
|
+
)
|
|
1384
|
+
|
|
1385
|
+
if H_t.size(0) == 3 * N:
|
|
1386
|
+
# Skip heavy TR-projection residual check to conserve VRAM.
|
|
1387
|
+
click.echo("[tsopt] TR-projection residual check skipped to conserve VRAM.")
|
|
1388
|
+
mode_xyz = _mode_direction_by_root(
|
|
1389
|
+
H_t, coords_bohr_t, self.masses_au_t,
|
|
1390
|
+
root=self.root, freeze_idx=self.freeze_atoms if len(self.freeze_atoms) > 0 else None
|
|
1391
|
+
)
|
|
1392
|
+
else:
|
|
1393
|
+
click.echo("[tsopt] Using active-block Hessian from UMA (partial Hessian). Skip full-space TR check.")
|
|
1394
|
+
mode_xyz = _mode_direction_by_root_from_Hact(
|
|
1395
|
+
H_t, self.geom.coords.reshape(-1, 3), self.geom.atomic_numbers,
|
|
1396
|
+
self.masses_au_t, active_idx, self.device, root=self.root
|
|
1397
|
+
)
|
|
1398
|
+
np.savetxt(self.mode_path, mode_xyz, fmt="%.12f")
|
|
1399
|
+
del mode_xyz, coords_bohr_t, H_t
|
|
1400
|
+
_clear_cuda_cache()
|
|
1401
|
+
|
|
1402
|
+
# (2) Loose loop
|
|
1403
|
+
if self.root!=0:
|
|
1404
|
+
click.echo("[tsopt] root != 0. Use this 'root' in first dimer loop", err=True)
|
|
1405
|
+
click.echo(f"[tsopt] Dimer Loop with initial direction from mode {self.root}...")
|
|
1406
|
+
self.root=0
|
|
1407
|
+
self.thresh_loose = self.thresh
|
|
1408
|
+
else:
|
|
1409
|
+
click.echo("[tsopt] Loose Dimer Loop...")
|
|
1410
|
+
|
|
1411
|
+
_, zero_step_loose = self._dimer_loop(self.thresh_loose)
|
|
1412
|
+
|
|
1413
|
+
zero_step_normal = False
|
|
1414
|
+
if (self.max_total_cycles - self._cycles_spent) > 0:
|
|
1415
|
+
# (3) Update mode & normal loop (reuse Hessian if 0-step converged)
|
|
1416
|
+
H_t = self._calc_full_hessian_cached(self.calc_kwargs_partial, allow_reuse=zero_step_loose)
|
|
1417
|
+
coords_bohr_t = torch.as_tensor(self.geom.coords.reshape(-1, 3),
|
|
1418
|
+
dtype=H_t.dtype, device=H_t.device)
|
|
1419
|
+
if H_t.size(0) == 3 * N:
|
|
1420
|
+
click.echo("[tsopt] TR-projection residual check skipped to conserve VRAM.")
|
|
1421
|
+
mode_xyz = _mode_direction_by_root(
|
|
1422
|
+
H_t, coords_bohr_t, self.masses_au_t,
|
|
1423
|
+
root=self.root, freeze_idx=self.freeze_atoms if len(self.freeze_atoms) > 0 else None
|
|
1424
|
+
)
|
|
1425
|
+
else:
|
|
1426
|
+
click.echo("[tsopt] Using active-block Hessian from UMA (partial Hessian). Skip full-space TR check.")
|
|
1427
|
+
active_idx, mask_dof = self._resolve_hessian_active_subspace(H_t, N)
|
|
1428
|
+
mode_xyz = _mode_direction_by_root_from_Hact(
|
|
1429
|
+
H_t, self.geom.coords.reshape(-1, 3), self.geom.atomic_numbers,
|
|
1430
|
+
self.masses_au_t, active_idx, self.device, root=self.root
|
|
1431
|
+
)
|
|
1432
|
+
np.savetxt(self.mode_path, mode_xyz, fmt="%.12f")
|
|
1433
|
+
del mode_xyz, coords_bohr_t, H_t
|
|
1434
|
+
_clear_cuda_cache()
|
|
1435
|
+
|
|
1436
|
+
click.echo("[tsopt] Normal Dimer Loop...")
|
|
1437
|
+
_, zero_step_normal = self._dimer_loop(self.thresh)
|
|
1438
|
+
else:
|
|
1439
|
+
click.echo("[tsopt] Reached --max-cycles budget after loose loop; skipping normal dimer loop.")
|
|
1440
|
+
|
|
1441
|
+
if self.flatten_max_iter > 0 and (self.max_total_cycles - self._cycles_spent) > 0:
|
|
1442
|
+
# (4) Flatten Loop — *reduced* exact Hessian calls via Bofill updates (active DOF only)
|
|
1443
|
+
click.echo("[tsopt] Flatten Loop with Bofill-updated active Hessian...")
|
|
1444
|
+
|
|
1445
|
+
# (4.1) Evaluate one exact Hessian at the loop start and prepare the active block
|
|
1446
|
+
# (reuse Hessian if 0-step converged)
|
|
1447
|
+
H_any = self._calc_full_hessian_cached(self.calc_kwargs_full, allow_reuse=zero_step_normal)
|
|
1448
|
+
# Keep a CPU copy so we can skip the final Hessian recomputation
|
|
1449
|
+
# when the flatten loop leaves geometry unchanged.
|
|
1450
|
+
H_final_reuse_cpu = H_any.detach().cpu().clone()
|
|
1451
|
+
H_final_reuse_coords = self.geom.cart_coords.copy()
|
|
1452
|
+
if H_any.size(0) == 3 * N:
|
|
1453
|
+
# full → extract active
|
|
1454
|
+
H_act = _extract_active_block(H_any, mask_dof) # torch (3N_act,3N_act)
|
|
1455
|
+
else:
|
|
1456
|
+
# UMA already returned active-block Hessian
|
|
1457
|
+
active_idx, mask_dof = self._resolve_hessian_active_subspace(H_any, N)
|
|
1458
|
+
H_act = H_any
|
|
1459
|
+
del H_any
|
|
1460
|
+
_clear_cuda_cache()
|
|
1461
|
+
|
|
1462
|
+
# Gradient & coordinates snapshot for quasi-Newton updates
|
|
1463
|
+
x_prev = self.geom.coords.copy().reshape(-1) # (3N,)
|
|
1464
|
+
g_prev = _calc_gradient(self.geom, self.calc_kwargs).reshape(-1) # (3N,)
|
|
1465
|
+
|
|
1466
|
+
# Flatten iterations with *approximate* Hessian updates
|
|
1467
|
+
for it in range(self.flatten_max_iter):
|
|
1468
|
+
if (self.max_total_cycles - self._cycles_spent) <= 0:
|
|
1469
|
+
break
|
|
1470
|
+
|
|
1471
|
+
# (a) Estimate current imaginary modes using the *active* Hessian
|
|
1472
|
+
freqs_est = _frequencies_from_Hact(H_act, self.geom.atomic_numbers,
|
|
1473
|
+
self.geom.coords.reshape(-1, 3), active_idx, self.device)
|
|
1474
|
+
n_imag = int(np.sum(freqs_est < -abs(self.neg_freq_thresh_cm)))
|
|
1475
|
+
click.echo(f"[tsopt] n≈{n_imag} (approx imag: {[float(x) for x in freqs_est if x < -abs(self.neg_freq_thresh_cm)]})")
|
|
1476
|
+
if n_imag <= 1:
|
|
1477
|
+
break
|
|
1478
|
+
|
|
1479
|
+
# (b) Get approximate modes for flattening (embedded, mass-weighted)
|
|
1480
|
+
freqs_cm_approx, modes_embedded = _modes_from_Hact_embedded(
|
|
1481
|
+
H_act, self.geom.atomic_numbers, self.geom.coords.reshape(-1, 3), active_idx, self.device
|
|
1482
|
+
)
|
|
1483
|
+
|
|
1484
|
+
# (c) Do flatten step using the approximate modes
|
|
1485
|
+
x_before_flat = self.geom.coords.copy().reshape(-1)
|
|
1486
|
+
did_flatten = self._flatten_once_with_modes(freqs_cm_approx, modes_embedded)
|
|
1487
|
+
# Free GPU tensors from mode computation immediately after use
|
|
1488
|
+
del freqs_cm_approx, modes_embedded
|
|
1489
|
+
if torch.cuda.is_available():
|
|
1490
|
+
torch.cuda.empty_cache()
|
|
1491
|
+
if not did_flatten:
|
|
1492
|
+
break
|
|
1493
|
+
x_after_flat = self.geom.coords.copy().reshape(-1)
|
|
1494
|
+
|
|
1495
|
+
# (d) Bofill update using UMA gradients across the flatten displacement
|
|
1496
|
+
g_after_flat = _calc_gradient(self.geom, self.calc_kwargs).reshape(-1)
|
|
1497
|
+
delta_flat_full = x_after_flat - x_before_flat
|
|
1498
|
+
delta_flat_act = delta_flat_full[mask_dof]
|
|
1499
|
+
g_old_act = g_prev[mask_dof]
|
|
1500
|
+
g_new_act = g_after_flat[mask_dof]
|
|
1501
|
+
H_act = _bofill_update_active(H_act, delta_flat_act, g_new_act, g_old_act)
|
|
1502
|
+
|
|
1503
|
+
# (e) Refresh dimer direction from updated active Hessian
|
|
1504
|
+
mode_xyz = _mode_direction_by_root_from_Hact(
|
|
1505
|
+
H_act, self.geom.coords.reshape(-1, 3), self.geom.atomic_numbers,
|
|
1506
|
+
self.masses_au_t, active_idx, self.device, root=self.root
|
|
1507
|
+
)
|
|
1508
|
+
np.savetxt(self.mode_path, mode_xyz, fmt="%.12f")
|
|
1509
|
+
del mode_xyz
|
|
1510
|
+
|
|
1511
|
+
# (f) Re-optimize with Dimer (consumes global cycle budget)
|
|
1512
|
+
# Clear VRAM before dimer loop to ensure space for Hessian recomputation
|
|
1513
|
+
if torch.cuda.is_available():
|
|
1514
|
+
torch.cuda.empty_cache()
|
|
1515
|
+
_, zero_step_flat = self._dimer_loop(self.thresh)
|
|
1516
|
+
|
|
1517
|
+
# (g) Bofill update again across the optimization displacement
|
|
1518
|
+
x_after_opt = self.geom.coords.copy().reshape(-1)
|
|
1519
|
+
g_after_opt = _calc_gradient(self.geom, self.calc_kwargs).reshape(-1)
|
|
1520
|
+
delta_opt_full = x_after_opt - x_after_flat
|
|
1521
|
+
delta_opt_act = delta_opt_full[mask_dof]
|
|
1522
|
+
g_old_act2 = g_after_flat[mask_dof]
|
|
1523
|
+
g_new_act2 = g_after_opt[mask_dof]
|
|
1524
|
+
H_act = _bofill_update_active(H_act, delta_opt_act, g_new_act2, g_old_act2)
|
|
1525
|
+
|
|
1526
|
+
# (h) Prepare for next iteration
|
|
1527
|
+
x_prev = x_after_opt
|
|
1528
|
+
g_prev = g_after_opt
|
|
1529
|
+
elif self.flatten_max_iter > 0:
|
|
1530
|
+
click.echo("[tsopt] Reached --max-cycles budget; skipping flatten loop.")
|
|
1531
|
+
|
|
1532
|
+
# (5) Final outputs
|
|
1533
|
+
final_xyz = self.out_dir / "final_geometry.xyz"
|
|
1534
|
+
atoms_final = Atoms(self.geom.atoms, positions=(self.geom.coords.reshape(-1, 3) * BOHR2ANG), pbc=False)
|
|
1535
|
+
write(final_xyz, atoms_final)
|
|
1536
|
+
|
|
1537
|
+
# Final Hessian → imaginary mode animation
|
|
1538
|
+
reuse_final_hessian = (
|
|
1539
|
+
H_final_reuse_cpu is not None
|
|
1540
|
+
and H_final_reuse_coords is not None
|
|
1541
|
+
and np.array_equal(self.geom.cart_coords, H_final_reuse_coords)
|
|
1542
|
+
)
|
|
1543
|
+
if reuse_final_hessian:
|
|
1544
|
+
click.echo("[tsopt] Reusing flatten-start Hessian for final frequency analysis (geometry unchanged).")
|
|
1545
|
+
H_t = H_final_reuse_cpu.to(self.device)
|
|
1546
|
+
else:
|
|
1547
|
+
H_t = _calc_full_hessian_torch(self.geom, self.calc_kwargs_full, self.device)
|
|
1548
|
+
if H_t.size(0) == 3 * N:
|
|
1549
|
+
freqs_cm, modes = _frequencies_cm_and_modes(
|
|
1550
|
+
H_t, self.geom.atomic_numbers, self.geom.coords.reshape(-1, 3), self.device,
|
|
1551
|
+
freeze_idx=self.freeze_atoms if len(self.freeze_atoms) > 0 else None
|
|
1552
|
+
)
|
|
1553
|
+
else:
|
|
1554
|
+
active_idx_final, _ = self._resolve_hessian_active_subspace(H_t, N)
|
|
1555
|
+
freqs_cm, modes = _modes_from_Hact_embedded(
|
|
1556
|
+
H_t, self.geom.atomic_numbers, self.geom.coords.reshape(-1, 3),
|
|
1557
|
+
active_idx_final, self.device
|
|
1558
|
+
)
|
|
1559
|
+
|
|
1560
|
+
del H_t
|
|
1561
|
+
del H_final_reuse_cpu, H_final_reuse_coords
|
|
1562
|
+
_ref_pdb_light = (
|
|
1563
|
+
self.source_path
|
|
1564
|
+
if self.source_path is not None and self.source_path.suffix.lower() == ".pdb"
|
|
1565
|
+
else None
|
|
1566
|
+
)
|
|
1567
|
+
n_written = _write_all_imag_modes(
|
|
1568
|
+
self.geom,
|
|
1569
|
+
freqs_cm,
|
|
1570
|
+
modes,
|
|
1571
|
+
self.neg_freq_thresh_cm,
|
|
1572
|
+
self.vib_dir,
|
|
1573
|
+
ref_pdb=_ref_pdb_light,
|
|
1574
|
+
)
|
|
1575
|
+
if n_written == 0:
|
|
1576
|
+
click.echo(
|
|
1577
|
+
"[tsopt] No imaginary mode found at the end (nu_min = %.2f cm^-1)." % (float(freqs_cm.min()),),
|
|
1578
|
+
err=True,
|
|
1579
|
+
)
|
|
1580
|
+
else:
|
|
1581
|
+
click.echo(f"[tsopt] Wrote {n_written} final imaginary mode(s).")
|
|
1582
|
+
del modes, freqs_cm
|
|
1583
|
+
|
|
1584
|
+
_clear_cuda_cache()
|
|
1585
|
+
click.echo(f"[tsopt] Saved final geometry → {final_xyz}")
|
|
1586
|
+
click.echo(f"[tsopt] Mode files → {self.vib_dir}")
|
|
1587
|
+
|
|
1588
|
+
|
|
1589
|
+
# ===================================================================
|
|
1590
|
+
# Microiteration loop for RS-I-RFO heavy mode
|
|
1591
|
+
# ===================================================================
|
|
1592
|
+
|
|
1593
|
+
|
|
1594
|
+
def _run_microiter_tsopt(
|
|
1595
|
+
geometry,
|
|
1596
|
+
calc_cfg: Dict[str, Any],
|
|
1597
|
+
rsirfo_cfg: Dict[str, Any],
|
|
1598
|
+
lbfgs_cfg: Dict[str, Any],
|
|
1599
|
+
opt_cfg: Dict[str, Any],
|
|
1600
|
+
microiter_cfg: Dict[str, Any],
|
|
1601
|
+
out_dir_path: Path,
|
|
1602
|
+
*,
|
|
1603
|
+
dump: bool = False,
|
|
1604
|
+
thresh: Optional[str] = None,
|
|
1605
|
+
) -> None:
|
|
1606
|
+
"""Run macro/micro alternating TS optimization (Gaussian 16-style microiteration).
|
|
1607
|
+
|
|
1608
|
+
Macro step: 1 RS-I-RFO step moving only ML region (full ONIOM force).
|
|
1609
|
+
Micro step: LBFGS relaxing MM region with MM-only forces until convergence.
|
|
1610
|
+
"""
|
|
1611
|
+
from .freq import _collect_layer_atom_sets
|
|
1612
|
+
|
|
1613
|
+
# Resolve layer atom sets
|
|
1614
|
+
layer_sets = _collect_layer_atom_sets(calc_cfg)
|
|
1615
|
+
ml_indices = sorted(layer_sets["ml"])
|
|
1616
|
+
movable_mm = sorted(layer_sets["movable_mm"] | layer_sets["hess_mm"])
|
|
1617
|
+
frozen_mm = sorted(layer_sets["frozen_mm"])
|
|
1618
|
+
|
|
1619
|
+
if not ml_indices:
|
|
1620
|
+
click.echo("[microiter] WARNING: No ML atoms found. Falling back to standard RS-I-RFO.")
|
|
1621
|
+
return None
|
|
1622
|
+
|
|
1623
|
+
n_atoms = len(geometry.atoms)
|
|
1624
|
+
all_indices = list(range(n_atoms))
|
|
1625
|
+
mm_indices = sorted(set(all_indices) - set(ml_indices))
|
|
1626
|
+
|
|
1627
|
+
# Freeze lists: for macro step, freeze all MM; for micro step, freeze ML
|
|
1628
|
+
macro_freeze = sorted(set(mm_indices) | set(frozen_mm))
|
|
1629
|
+
micro_freeze = sorted(set(ml_indices) | set(frozen_mm))
|
|
1630
|
+
|
|
1631
|
+
max_cycles = int(opt_cfg.get("max_cycles", 10000))
|
|
1632
|
+
macro_thresh = thresh if thresh is not None else rsirfo_cfg.get("thresh", "baker")
|
|
1633
|
+
micro_thresh = microiter_cfg.get("micro_thresh") or macro_thresh
|
|
1634
|
+
micro_max_cycles = int(microiter_cfg.get("micro_max_cycles", 10000))
|
|
1635
|
+
|
|
1636
|
+
click.echo(
|
|
1637
|
+
f"[microiter] ML atoms: {len(ml_indices)}, "
|
|
1638
|
+
f"Movable MM atoms: {len(movable_mm)}, "
|
|
1639
|
+
f"Frozen MM atoms: {len(frozen_mm)}"
|
|
1640
|
+
)
|
|
1641
|
+
click.echo(f"[microiter] Macro thresh: {macro_thresh}, Micro thresh: {micro_thresh}")
|
|
1642
|
+
|
|
1643
|
+
# Create ONIOM calculator (shared core for MM-only calc)
|
|
1644
|
+
macro_calc_cfg = dict(calc_cfg)
|
|
1645
|
+
macro_calc_cfg["freeze_atoms"] = macro_freeze
|
|
1646
|
+
macro_calc_cfg["hess_mm_atoms"] = [] # macro step は ML-only Hessian
|
|
1647
|
+
macro_calc = mlmm(**macro_calc_cfg)
|
|
1648
|
+
mm_calc = mlmm_mm_only(macro_calc.core, freeze_atoms=micro_freeze)
|
|
1649
|
+
|
|
1650
|
+
# Seed initial Hessian for RS-I-RFO (with macro freeze)
|
|
1651
|
+
# Try TS Hessian cache first; fall back to full Hessian calculation.
|
|
1652
|
+
from .hessian_cache import load as _hess_load_ts
|
|
1653
|
+
hess_device = _torch_device(calc_cfg.get("ml_device", "auto"))
|
|
1654
|
+
|
|
1655
|
+
cached_ts = _hess_load_ts("ts")
|
|
1656
|
+
if cached_ts is not None:
|
|
1657
|
+
click.echo("[microiter] Reusing cached TS Hessian for RS-I-RFO macro step.")
|
|
1658
|
+
active_dofs = cached_ts.get("active_dofs")
|
|
1659
|
+
h_raw = cached_ts["hessian"]
|
|
1660
|
+
if isinstance(h_raw, torch.Tensor):
|
|
1661
|
+
h_init = h_raw.clone()
|
|
1662
|
+
else:
|
|
1663
|
+
h_init = torch.as_tensor(h_raw, dtype=torch.float64)
|
|
1664
|
+
geometry.freeze_atoms = macro_freeze
|
|
1665
|
+
geometry.set_calculator(macro_calc)
|
|
1666
|
+
if active_dofs is not None:
|
|
1667
|
+
geometry.within_partial_hessian = {
|
|
1668
|
+
"active_n_dof": len(active_dofs),
|
|
1669
|
+
"full_n_dof": geometry.cart_coords.size,
|
|
1670
|
+
"active_dofs": active_dofs,
|
|
1671
|
+
"active_atoms": sorted(set(d // 3 for d in active_dofs)),
|
|
1672
|
+
}
|
|
1673
|
+
geometry.cart_hessian = h_init
|
|
1674
|
+
click.echo(f"[microiter] Initial Hessian seeded from cache (shape={h_init.shape[0]}x{h_init.shape[1]}).")
|
|
1675
|
+
del h_init
|
|
1676
|
+
else:
|
|
1677
|
+
click.echo("[microiter] Seeding initial Hessian for RS-I-RFO macro step.")
|
|
1678
|
+
|
|
1679
|
+
geometry.freeze_atoms = macro_freeze
|
|
1680
|
+
geometry.set_calculator(macro_calc)
|
|
1681
|
+
|
|
1682
|
+
h_init = _calc_full_hessian_torch(geometry, macro_calc_cfg, hess_device)
|
|
1683
|
+
geometry.cart_hessian = h_init
|
|
1684
|
+
click.echo(f"[microiter] Initial Hessian seeded (shape={h_init.shape[0]}x{h_init.shape[1]}).")
|
|
1685
|
+
del h_init
|
|
1686
|
+
|
|
1687
|
+
optim_all_path = out_dir_path / "optimization_all_trj.xyz"
|
|
1688
|
+
macro_trj_path = out_dir_path / "optimization_trj.xyz"
|
|
1689
|
+
total_macro_steps = 0
|
|
1690
|
+
|
|
1691
|
+
# Create persistent RSIRFOptimizer once (LayerOpt pattern).
|
|
1692
|
+
# This preserves the BFGS Hessian update chain across macro iterations.
|
|
1693
|
+
# NOTE: geometry already has macro_calc set (line above); do NOT call
|
|
1694
|
+
# set_calculator() again as it clears the pre-computed cart_hessian.
|
|
1695
|
+
geometry.freeze_atoms = macro_freeze
|
|
1696
|
+
|
|
1697
|
+
rsirfo_args = dict(rsirfo_cfg)
|
|
1698
|
+
rsirfo_args["max_cycles"] = max_cycles
|
|
1699
|
+
rsirfo_args["out_dir"] = str(out_dir_path)
|
|
1700
|
+
rsirfo_args["dump"] = False # trajectory dumping handled externally
|
|
1701
|
+
if macro_thresh is not None:
|
|
1702
|
+
rsirfo_args["thresh"] = str(macro_thresh)
|
|
1703
|
+
# RSIRFOptimizer does not accept RFOptimizer-specific DIIS knobs; strip them.
|
|
1704
|
+
for _diis_kw in ("gediis", "gdiis", "gdiis_thresh", "gediis_thresh", "gdiis_test_direction", "adapt_step_func"):
|
|
1705
|
+
rsirfo_args.pop(_diis_kw, None)
|
|
1706
|
+
|
|
1707
|
+
macro_optimizer = RSIRFOptimizer(geometry, **rsirfo_args)
|
|
1708
|
+
macro_optimizer.prepare_opt() # initialise Hessian from geometry.cart_hessian
|
|
1709
|
+
|
|
1710
|
+
# Microiteration progress table (pysisyphus-style with micro_steps column)
|
|
1711
|
+
micro_header = "cycle Δ(energy) max(|force|) rms(force) max(|step|) rms(step) micro_steps s/cycle".split()
|
|
1712
|
+
micro_col_fmts = "int float float float float float int float_short".split()
|
|
1713
|
+
micro_table = TablePrinter(micro_header, micro_col_fmts, width=12)
|
|
1714
|
+
micro_table.print_header()
|
|
1715
|
+
|
|
1716
|
+
for macro_iter in range(max_cycles):
|
|
1717
|
+
# ---- Macro step: 1 RS-I-RFO step with ONIOM forces, MM frozen ----
|
|
1718
|
+
geometry.freeze_atoms = macro_freeze
|
|
1719
|
+
geometry.set_calculator(macro_calc)
|
|
1720
|
+
|
|
1721
|
+
# Manually feed state to the persistent optimizer (cf. LayerOpt lines 358-364)
|
|
1722
|
+
macro_optimizer.coords.append(geometry.coords.copy())
|
|
1723
|
+
macro_optimizer.cart_coords.append(geometry.cart_coords.copy())
|
|
1724
|
+
macro_optimizer.cur_cycle = macro_iter
|
|
1725
|
+
|
|
1726
|
+
t_start = time.time()
|
|
1727
|
+
step = macro_optimizer.optimize() # housekeeping() triggers BFGS update
|
|
1728
|
+
macro_optimizer.steps.append(step)
|
|
1729
|
+
|
|
1730
|
+
# Convergence check
|
|
1731
|
+
macro_converged, conv_info = macro_optimizer.check_convergence()
|
|
1732
|
+
total_macro_steps += 1
|
|
1733
|
+
|
|
1734
|
+
if dump:
|
|
1735
|
+
with open(macro_trj_path, "a") as f:
|
|
1736
|
+
f.write(geometry.as_xyz() + "\n")
|
|
1737
|
+
_append_xyz_trajectory(optim_all_path, macro_trj_path)
|
|
1738
|
+
|
|
1739
|
+
if macro_converged:
|
|
1740
|
+
# Print final converged row (no micro steps)
|
|
1741
|
+
energy_diff = macro_optimizer.energies[-1] - macro_optimizer.energies[-2] if len(macro_optimizer.energies) >= 2 else float("nan")
|
|
1742
|
+
marks = [False, *conv_info.get_convergence()[:-1], False, False]
|
|
1743
|
+
cycle_time = time.time() - t_start
|
|
1744
|
+
micro_table.print_row(
|
|
1745
|
+
(macro_iter, energy_diff, macro_optimizer.max_forces[-1], macro_optimizer.rms_forces[-1],
|
|
1746
|
+
macro_optimizer.max_steps[-1], macro_optimizer.rms_steps[-1], 0, cycle_time),
|
|
1747
|
+
marks=marks,
|
|
1748
|
+
)
|
|
1749
|
+
click.echo(f"[microiter] Macro convergence reached at iteration {macro_iter + 1}.")
|
|
1750
|
+
break
|
|
1751
|
+
|
|
1752
|
+
# Apply step to geometry
|
|
1753
|
+
new_coords = geometry.coords.copy() + step
|
|
1754
|
+
geometry.coords = new_coords
|
|
1755
|
+
# Record actual step (may differ due to coordinate back-transformation)
|
|
1756
|
+
macro_optimizer.steps[-1] = geometry.coords - macro_optimizer.coords[-1]
|
|
1757
|
+
|
|
1758
|
+
# ---- Micro step: LBFGS with MM-only forces, ML frozen ----
|
|
1759
|
+
geometry.freeze_atoms = micro_freeze
|
|
1760
|
+
geometry.set_calculator(mm_calc)
|
|
1761
|
+
|
|
1762
|
+
micro_lbfgs_args = dict(lbfgs_cfg)
|
|
1763
|
+
micro_lbfgs_args["max_cycles"] = micro_max_cycles
|
|
1764
|
+
micro_lbfgs_args["thresh"] = micro_thresh
|
|
1765
|
+
micro_lbfgs_args["out_dir"] = str(out_dir_path)
|
|
1766
|
+
micro_lbfgs_args["dump"] = dump
|
|
1767
|
+
|
|
1768
|
+
micro_opt = LBFGS(geometry, **micro_lbfgs_args)
|
|
1769
|
+
with contextlib.redirect_stdout(io.StringIO()):
|
|
1770
|
+
micro_opt.run()
|
|
1771
|
+
micro_steps = max(int(micro_opt.cur_cycle) + 1, 1)
|
|
1772
|
+
|
|
1773
|
+
if dump:
|
|
1774
|
+
_append_xyz_trajectory(optim_all_path, out_dir_path / "optimization_trj.xyz")
|
|
1775
|
+
|
|
1776
|
+
del micro_opt
|
|
1777
|
+
_clear_cuda_cache()
|
|
1778
|
+
|
|
1779
|
+
# Print progress row with micro_steps
|
|
1780
|
+
cycle_time = time.time() - t_start
|
|
1781
|
+
energy_diff = macro_optimizer.energies[-1] - macro_optimizer.energies[-2] if len(macro_optimizer.energies) >= 2 else float("nan")
|
|
1782
|
+
marks = [False, *conv_info.get_convergence()[:-1], False, False]
|
|
1783
|
+
if (macro_iter > 1) and (macro_iter % 10 == 0):
|
|
1784
|
+
micro_table.print_sep()
|
|
1785
|
+
micro_table.print_row(
|
|
1786
|
+
(macro_iter, energy_diff, macro_optimizer.max_forces[-1], macro_optimizer.rms_forces[-1],
|
|
1787
|
+
macro_optimizer.max_steps[-1], macro_optimizer.rms_steps[-1], micro_steps, cycle_time),
|
|
1788
|
+
marks=marks,
|
|
1789
|
+
)
|
|
1790
|
+
|
|
1791
|
+
else:
|
|
1792
|
+
click.echo(f"[microiter] Reached max macro iterations ({max_cycles}).")
|
|
1793
|
+
|
|
1794
|
+
del macro_optimizer
|
|
1795
|
+
_clear_cuda_cache()
|
|
1796
|
+
|
|
1797
|
+
click.echo(f"[microiter] Total macro steps: {total_macro_steps}")
|
|
1798
|
+
# Restore full calculator with only frozen MM frozen
|
|
1799
|
+
geometry.freeze_atoms = list(set(frozen_mm))
|
|
1800
|
+
base_calc = mlmm(**calc_cfg)
|
|
1801
|
+
geometry.set_calculator(base_calc)
|
|
1802
|
+
|
|
1803
|
+
return geometry
|
|
1804
|
+
|
|
1805
|
+
|
|
1806
|
+
# ===================================================================
|
|
1807
|
+
# Defaults for CLI
|
|
1808
|
+
# ===================================================================
|
|
1809
|
+
|
|
1810
|
+
# Configuration defaults (imported from defaults.py)
|
|
1811
|
+
GEOM_KW: Dict[str, Any] = deepcopy(GEOM_KW_DEFAULT)
|
|
1812
|
+
CALC_KW: Dict[str, Any] = deepcopy(MLMM_CALC_KW)
|
|
1813
|
+
|
|
1814
|
+
# HessianDimer defaults - combine imported DIMER_KW and HESSIAN_DIMER_KW
|
|
1815
|
+
hessian_dimer_KW = {
|
|
1816
|
+
**HESSIAN_DIMER_KW,
|
|
1817
|
+
"dimer": {**DIMER_KW},
|
|
1818
|
+
"lbfgs": {**LBFGS_KW},
|
|
1819
|
+
}
|
|
1820
|
+
|
|
1821
|
+
# ===================================================================
|
|
1822
|
+
# CLI
|
|
1823
|
+
# ===================================================================
|
|
1824
|
+
|
|
1825
|
+
@click.command(
|
|
1826
|
+
help="TS optimization: grad (Dimer) or hess (RS-I-RFO) for the ML/MM calculator.",
|
|
1827
|
+
context_settings={"help_option_names": ["-h", "--help"]},
|
|
1828
|
+
)
|
|
1829
|
+
@click.option(
|
|
1830
|
+
"-i", "--input",
|
|
1831
|
+
"input_path",
|
|
1832
|
+
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
|
1833
|
+
required=True,
|
|
1834
|
+
help="Starting geometry (PDB or XYZ). XYZ provides higher coordinate precision. "
|
|
1835
|
+
"If XYZ, use --ref-pdb to specify PDB topology for atom ordering and output conversion.",
|
|
1836
|
+
)
|
|
1837
|
+
@click.option(
|
|
1838
|
+
"--ref-pdb",
|
|
1839
|
+
"ref_pdb",
|
|
1840
|
+
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
|
1841
|
+
default=None,
|
|
1842
|
+
show_default=False,
|
|
1843
|
+
help="Reference PDB topology when input is XYZ. XYZ coordinates are used (higher precision) "
|
|
1844
|
+
"while PDB provides atom ordering and residue information for output conversion.",
|
|
1845
|
+
)
|
|
1846
|
+
@click.option(
|
|
1847
|
+
"--parm",
|
|
1848
|
+
"real_parm7",
|
|
1849
|
+
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
|
1850
|
+
required=True,
|
|
1851
|
+
help="Amber parm7 topology for the whole enzyme (MM region).",
|
|
1852
|
+
)
|
|
1853
|
+
@click.option(
|
|
1854
|
+
"--model-pdb",
|
|
1855
|
+
"model_pdb",
|
|
1856
|
+
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
|
1857
|
+
required=False,
|
|
1858
|
+
help="PDB containing the ML-region atoms. Optional when --detect-layer is enabled.",
|
|
1859
|
+
)
|
|
1860
|
+
@click.option(
|
|
1861
|
+
"--model-indices",
|
|
1862
|
+
"model_indices_str",
|
|
1863
|
+
type=str,
|
|
1864
|
+
default=None,
|
|
1865
|
+
show_default=False,
|
|
1866
|
+
help="Comma-separated atom indices for the ML region (ranges allowed like 1-5). "
|
|
1867
|
+
"Used when --model-pdb is omitted.",
|
|
1868
|
+
)
|
|
1869
|
+
@click.option(
|
|
1870
|
+
"--model-indices-one-based/--model-indices-zero-based",
|
|
1871
|
+
"model_indices_one_based",
|
|
1872
|
+
default=True,
|
|
1873
|
+
show_default=True,
|
|
1874
|
+
help="Interpret --model-indices as 1-based (default) or 0-based.",
|
|
1875
|
+
)
|
|
1876
|
+
@click.option(
|
|
1877
|
+
"--detect-layer/--no-detect-layer",
|
|
1878
|
+
"detect_layer",
|
|
1879
|
+
default=True,
|
|
1880
|
+
show_default=True,
|
|
1881
|
+
help="Detect ML/MM layers from input PDB B-factors (B=0/10/20). "
|
|
1882
|
+
"If disabled, you must provide --model-pdb or --model-indices.",
|
|
1883
|
+
)
|
|
1884
|
+
@click.option(
|
|
1885
|
+
"-q",
|
|
1886
|
+
"--charge",
|
|
1887
|
+
type=int,
|
|
1888
|
+
required=False,
|
|
1889
|
+
help="Total charge of the ML region. Required unless --ligand-charge is provided.",
|
|
1890
|
+
)
|
|
1891
|
+
@click.option("-l", "--ligand-charge", type=str, default=None, show_default=False,
|
|
1892
|
+
help="Total charge or per-resname mapping (e.g., GPP:-3,SAM:1) used to derive "
|
|
1893
|
+
"charge when -q is omitted (requires PDB input or --ref-pdb).")
|
|
1894
|
+
@click.option(
|
|
1895
|
+
"-m",
|
|
1896
|
+
"--multiplicity",
|
|
1897
|
+
"spin",
|
|
1898
|
+
type=int,
|
|
1899
|
+
default=None,
|
|
1900
|
+
show_default=False,
|
|
1901
|
+
help="Spin multiplicity (2S+1) for the ML region.",
|
|
1902
|
+
)
|
|
1903
|
+
@click.option(
|
|
1904
|
+
"--freeze-atoms",
|
|
1905
|
+
"freeze_atoms_text",
|
|
1906
|
+
type=str,
|
|
1907
|
+
default=None,
|
|
1908
|
+
show_default=False,
|
|
1909
|
+
help="Comma-separated 1-based indices to freeze (e.g., '1,3,5').",
|
|
1910
|
+
)
|
|
1911
|
+
@click.option(
|
|
1912
|
+
"--radius-hessian",
|
|
1913
|
+
"--hess-cutoff",
|
|
1914
|
+
"hess_cutoff",
|
|
1915
|
+
type=float,
|
|
1916
|
+
default=0.0,
|
|
1917
|
+
show_default=True,
|
|
1918
|
+
help="Distance cutoff (Å) from ML region for MM atoms to include in Hessian calculation. "
|
|
1919
|
+
"Applied to movable MM atoms. Default 0.0 means ML-only partial Hessian.",
|
|
1920
|
+
)
|
|
1921
|
+
@click.option(
|
|
1922
|
+
"--movable-cutoff",
|
|
1923
|
+
"movable_cutoff",
|
|
1924
|
+
type=float,
|
|
1925
|
+
default=None,
|
|
1926
|
+
show_default=False,
|
|
1927
|
+
help="Distance cutoff (Å) from ML region for movable MM atoms. "
|
|
1928
|
+
"MM atoms beyond this are frozen. "
|
|
1929
|
+
"Providing --movable-cutoff disables --detect-layer.",
|
|
1930
|
+
)
|
|
1931
|
+
@click.option(
|
|
1932
|
+
"--hessian-calc-mode",
|
|
1933
|
+
type=click.Choice(["Analytical", "FiniteDifference"], case_sensitive=False),
|
|
1934
|
+
default=None,
|
|
1935
|
+
help="How the ML backend builds the Hessian (Analytical or FiniteDifference); "
|
|
1936
|
+
"overrides calc.hessian_calc_mode from YAML. "
|
|
1937
|
+
"Default: 'FiniteDifference'. Use 'Analytical' when VRAM is sufficient.",
|
|
1938
|
+
)
|
|
1939
|
+
@click.option("--max-cycles", type=int, default=10000, show_default=True, help="Maximum total optimization cycles.")
|
|
1940
|
+
@click.option(
|
|
1941
|
+
"--dump/--no-dump",
|
|
1942
|
+
default=False,
|
|
1943
|
+
show_default=True,
|
|
1944
|
+
help="Write concatenated trajectory 'optimization_all_trj.xyz'.",
|
|
1945
|
+
)
|
|
1946
|
+
@click.option("-o", "--out-dir", type=str, default=OUT_DIR_TSOPT, show_default=True, help="Output directory.")
|
|
1947
|
+
@click.option(
|
|
1948
|
+
"--thresh",
|
|
1949
|
+
type=click.Choice(["gau_loose", "gau", "gau_tight", "gau_vtight", "baker", "never"], case_sensitive=False),
|
|
1950
|
+
default=None,
|
|
1951
|
+
help="Convergence preset.",
|
|
1952
|
+
)
|
|
1953
|
+
@click.option(
|
|
1954
|
+
"--opt-mode",
|
|
1955
|
+
type=click.Choice(["grad", "hess", "light", "heavy", "dimer", "rsirfo"], case_sensitive=False),
|
|
1956
|
+
default="hess",
|
|
1957
|
+
show_default=True,
|
|
1958
|
+
help="grad (dimer) or hess (rsirfo). Aliases light/heavy and dimer/rsirfo are accepted.",
|
|
1959
|
+
)
|
|
1960
|
+
@click.option(
|
|
1961
|
+
"--microiter/--no-microiter",
|
|
1962
|
+
"microiter",
|
|
1963
|
+
default=True,
|
|
1964
|
+
show_default=True,
|
|
1965
|
+
help="Enable microiteration: alternate ML 1-step (RS-I-RFO) and MM relaxation (LBFGS with MM-only forces). "
|
|
1966
|
+
"Only effective in --opt-mode hess. Ignored in grad mode.",
|
|
1967
|
+
)
|
|
1968
|
+
@click.option(
|
|
1969
|
+
"--partial-hessian-flatten/--full-hessian-flatten",
|
|
1970
|
+
"partial_hessian_flatten",
|
|
1971
|
+
default=True,
|
|
1972
|
+
show_default=True,
|
|
1973
|
+
help="Use partial Hessian (ML region only) for imaginary mode detection in flatten loop.",
|
|
1974
|
+
)
|
|
1975
|
+
@click.option(
|
|
1976
|
+
"--flatten/--no-flatten",
|
|
1977
|
+
"flatten",
|
|
1978
|
+
default=None,
|
|
1979
|
+
show_default=False,
|
|
1980
|
+
help="Enable/disable extra imaginary-mode flattening loop. "
|
|
1981
|
+
"--flatten uses the default flatten_max_iter (50); --no-flatten forces it to 0. "
|
|
1982
|
+
"When not provided, the value is determined by the YAML config or defaults.",
|
|
1983
|
+
)
|
|
1984
|
+
@click.option(
|
|
1985
|
+
"--ml-only-hessian-dimer/--no-ml-only-hessian-dimer",
|
|
1986
|
+
"ml_only_hessian_dimer",
|
|
1987
|
+
default=False,
|
|
1988
|
+
show_default=True,
|
|
1989
|
+
help="Use ML-region-only Hessian (no MM Hessian contribution) for dimer orientation "
|
|
1990
|
+
"in grad mode. Faster but less accurate for mode direction.",
|
|
1991
|
+
)
|
|
1992
|
+
@click.option(
|
|
1993
|
+
"--active-dof-mode",
|
|
1994
|
+
type=click.Choice(["all", "ml-only", "partial", "unfrozen"], case_sensitive=False),
|
|
1995
|
+
default="partial",
|
|
1996
|
+
show_default=True,
|
|
1997
|
+
help="Active DOF selection for final frequency analysis: "
|
|
1998
|
+
"all (all atoms), ml-only (ML only), partial (ML + MovableMM, default), "
|
|
1999
|
+
"unfrozen (all except frozen layer).",
|
|
2000
|
+
)
|
|
2001
|
+
@click.option(
|
|
2002
|
+
"--config",
|
|
2003
|
+
"config_yaml",
|
|
2004
|
+
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
|
2005
|
+
default=None,
|
|
2006
|
+
help="Base YAML configuration file applied before explicit CLI options.",
|
|
2007
|
+
)
|
|
2008
|
+
@click.option(
|
|
2009
|
+
"--show-config/--no-show-config",
|
|
2010
|
+
"show_config",
|
|
2011
|
+
default=False,
|
|
2012
|
+
show_default=True,
|
|
2013
|
+
help="Print resolved configuration and continue execution.",
|
|
2014
|
+
)
|
|
2015
|
+
@click.option(
|
|
2016
|
+
"--dry-run/--no-dry-run",
|
|
2017
|
+
"dry_run",
|
|
2018
|
+
default=False,
|
|
2019
|
+
show_default=True,
|
|
2020
|
+
help="Validate options and print the execution plan without running TS optimization.",
|
|
2021
|
+
)
|
|
2022
|
+
@click.option(
|
|
2023
|
+
"--convert-files/--no-convert-files",
|
|
2024
|
+
"convert_files",
|
|
2025
|
+
default=True,
|
|
2026
|
+
show_default=True,
|
|
2027
|
+
help="Convert XYZ/TRJ outputs into PDB companions based on the input format.",
|
|
2028
|
+
)
|
|
2029
|
+
@click.option(
|
|
2030
|
+
"-b", "--backend",
|
|
2031
|
+
type=click.Choice(["uma", "orb", "mace", "aimnet2"], case_sensitive=False),
|
|
2032
|
+
default=None,
|
|
2033
|
+
show_default=False,
|
|
2034
|
+
help="ML backend for the ONIOM high-level region (default: uma).",
|
|
2035
|
+
)
|
|
2036
|
+
@click.option(
|
|
2037
|
+
"--embedcharge/--no-embedcharge",
|
|
2038
|
+
"embedcharge",
|
|
2039
|
+
default=False,
|
|
2040
|
+
show_default=True,
|
|
2041
|
+
help="Enable xTB point-charge embedding correction for MM→ML environmental effects.",
|
|
2042
|
+
)
|
|
2043
|
+
@click.option(
|
|
2044
|
+
"--embedcharge-cutoff",
|
|
2045
|
+
"embedcharge_cutoff",
|
|
2046
|
+
type=float,
|
|
2047
|
+
default=None,
|
|
2048
|
+
show_default=False,
|
|
2049
|
+
help="Distance cutoff (Å) from ML region for MM point charges in xTB embedding. "
|
|
2050
|
+
"Default: 12.0 Å when --embedcharge is enabled.",
|
|
2051
|
+
)
|
|
2052
|
+
@click.pass_context
|
|
2053
|
+
def cli(
|
|
2054
|
+
ctx: click.Context,
|
|
2055
|
+
input_path: Path,
|
|
2056
|
+
ref_pdb: Optional[Path],
|
|
2057
|
+
real_parm7: Path,
|
|
2058
|
+
model_pdb: Optional[Path],
|
|
2059
|
+
model_indices_str: Optional[str],
|
|
2060
|
+
model_indices_one_based: bool,
|
|
2061
|
+
detect_layer: bool,
|
|
2062
|
+
charge: Optional[int],
|
|
2063
|
+
ligand_charge: Optional[str],
|
|
2064
|
+
spin: Optional[int],
|
|
2065
|
+
freeze_atoms_text: Optional[str],
|
|
2066
|
+
hess_cutoff: Optional[float],
|
|
2067
|
+
movable_cutoff: Optional[float],
|
|
2068
|
+
hessian_calc_mode: Optional[str],
|
|
2069
|
+
max_cycles: int,
|
|
2070
|
+
dump: bool,
|
|
2071
|
+
out_dir: str,
|
|
2072
|
+
thresh: Optional[str],
|
|
2073
|
+
opt_mode: str,
|
|
2074
|
+
microiter: bool,
|
|
2075
|
+
partial_hessian_flatten: bool,
|
|
2076
|
+
flatten: Optional[bool],
|
|
2077
|
+
ml_only_hessian_dimer: bool,
|
|
2078
|
+
active_dof_mode: str,
|
|
2079
|
+
config_yaml: Optional[Path],
|
|
2080
|
+
show_config: bool,
|
|
2081
|
+
dry_run: bool,
|
|
2082
|
+
convert_files: bool,
|
|
2083
|
+
backend: Optional[str],
|
|
2084
|
+
embedcharge: bool,
|
|
2085
|
+
embedcharge_cutoff: Optional[float],
|
|
2086
|
+
) -> None:
|
|
2087
|
+
set_convert_file_enabled(convert_files)
|
|
2088
|
+
_is_param_explicit = make_is_param_explicit(ctx)
|
|
2089
|
+
|
|
2090
|
+
config_yaml, override_yaml, used_legacy_yaml = resolve_yaml_sources(
|
|
2091
|
+
config_yaml=config_yaml,
|
|
2092
|
+
override_yaml=None,
|
|
2093
|
+
args_yaml_legacy=None,
|
|
2094
|
+
)
|
|
2095
|
+
merged_yaml_cfg, _, _ = load_merged_yaml_cfg(
|
|
2096
|
+
config_yaml=config_yaml,
|
|
2097
|
+
override_yaml=None,
|
|
2098
|
+
)
|
|
2099
|
+
|
|
2100
|
+
# Handle input: PDB directly, or XYZ with --ref-pdb for topology
|
|
2101
|
+
suffix = input_path.suffix.lower()
|
|
2102
|
+
if suffix == ".pdb":
|
|
2103
|
+
# PDB input: use directly
|
|
2104
|
+
prepared_input = prepare_input_structure(input_path)
|
|
2105
|
+
elif suffix == ".xyz":
|
|
2106
|
+
# XYZ input: require --ref-pdb for topology
|
|
2107
|
+
if ref_pdb is None:
|
|
2108
|
+
click.echo("ERROR: XYZ/TRJ input requires --ref-pdb to specify PDB topology.", err=True)
|
|
2109
|
+
sys.exit(1)
|
|
2110
|
+
prepared_input = prepare_input_structure(input_path)
|
|
2111
|
+
apply_ref_pdb_override(prepared_input, ref_pdb)
|
|
2112
|
+
click.echo(f"[input] Using XYZ coordinates from {input_path.name}, PDB topology from {ref_pdb.name}")
|
|
2113
|
+
else:
|
|
2114
|
+
click.echo(f"ERROR: Unsupported input format: {suffix}. Use .pdb or .xyz (with --ref-pdb).", err=True)
|
|
2115
|
+
sys.exit(1)
|
|
2116
|
+
|
|
2117
|
+
geom_input_path = prepared_input.geom_path
|
|
2118
|
+
source_path = prepared_input.source_path
|
|
2119
|
+
charge, spin = resolve_charge_spin_or_raise(
|
|
2120
|
+
prepared_input, charge, spin,
|
|
2121
|
+
ligand_charge=ligand_charge, prefix="[tsopt]",
|
|
2122
|
+
)
|
|
2123
|
+
|
|
2124
|
+
try:
|
|
2125
|
+
freeze_atoms_cli = _parse_freeze_atoms_opt(freeze_atoms_text)
|
|
2126
|
+
except click.BadParameter as e:
|
|
2127
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
2128
|
+
prepared_input.cleanup()
|
|
2129
|
+
sys.exit(1)
|
|
2130
|
+
|
|
2131
|
+
model_indices: Optional[List[int]] = None
|
|
2132
|
+
if model_indices_str:
|
|
2133
|
+
try:
|
|
2134
|
+
model_indices = parse_indices_string(model_indices_str, one_based=model_indices_one_based)
|
|
2135
|
+
except click.BadParameter as e:
|
|
2136
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
2137
|
+
prepared_input.cleanup()
|
|
2138
|
+
sys.exit(1)
|
|
2139
|
+
|
|
2140
|
+
time_start = time.perf_counter()
|
|
2141
|
+
|
|
2142
|
+
# Resolve optimizer mode (default is now hess/RS-I-RFO)
|
|
2143
|
+
mode_resolved = normalize_choice(
|
|
2144
|
+
opt_mode,
|
|
2145
|
+
param="--opt-mode",
|
|
2146
|
+
alias_groups=TSOPT_MODE_ALIASES,
|
|
2147
|
+
allowed_hint="grad|hess|dimer|rsirfo",
|
|
2148
|
+
)
|
|
2149
|
+
use_heavy = (mode_resolved == "rsirfo")
|
|
2150
|
+
|
|
2151
|
+
config_layer_cfg = load_yaml_dict(config_yaml)
|
|
2152
|
+
override_layer_cfg = load_yaml_dict(override_yaml)
|
|
2153
|
+
geom_cfg: Dict[str, Any] = deepcopy(GEOM_KW)
|
|
2154
|
+
calc_cfg: Dict[str, Any] = deepcopy(CALC_KW)
|
|
2155
|
+
opt_cfg: Dict[str, Any] = dict(OPT_BASE_KW)
|
|
2156
|
+
lbfgs_cfg: Dict[str, Any] = dict(LBFGS_KW)
|
|
2157
|
+
simple_cfg: Dict[str, Any] = dict(hessian_dimer_KW)
|
|
2158
|
+
rsirfo_cfg: Dict[str, Any] = dict(RSIRFO_KW)
|
|
2159
|
+
|
|
2160
|
+
apply_yaml_overrides(
|
|
2161
|
+
config_layer_cfg,
|
|
2162
|
+
[
|
|
2163
|
+
(geom_cfg, (("geom",),)),
|
|
2164
|
+
(calc_cfg, (("calc",), ("mlmm",))),
|
|
2165
|
+
(opt_cfg, (("opt",),)),
|
|
2166
|
+
(simple_cfg, (("hessian_dimer",),)),
|
|
2167
|
+
(rsirfo_cfg, (("rsirfo",),)),
|
|
2168
|
+
],
|
|
2169
|
+
)
|
|
2170
|
+
if _is_param_explicit("hessian_calc_mode") and hessian_calc_mode is not None:
|
|
2171
|
+
calc_cfg["hessian_calc_mode"] = str(hessian_calc_mode)
|
|
2172
|
+
if _is_param_explicit("max_cycles"):
|
|
2173
|
+
opt_cfg["max_cycles"] = int(max_cycles)
|
|
2174
|
+
if _is_param_explicit("dump"):
|
|
2175
|
+
opt_cfg["dump"] = bool(dump)
|
|
2176
|
+
if _is_param_explicit("out_dir"):
|
|
2177
|
+
opt_cfg["out_dir"] = out_dir
|
|
2178
|
+
if _is_param_explicit("thresh") and thresh is not None:
|
|
2179
|
+
opt_cfg["thresh"] = str(thresh)
|
|
2180
|
+
simple_cfg["thresh"] = str(thresh)
|
|
2181
|
+
rsirfo_cfg["thresh"] = str(thresh)
|
|
2182
|
+
# Handle --flatten/--no-flatten CLI toggle
|
|
2183
|
+
if flatten is not None:
|
|
2184
|
+
if flatten:
|
|
2185
|
+
# Use default from HESSIAN_DIMER_KW if not already set
|
|
2186
|
+
simple_cfg.setdefault("flatten_max_iter", HESSIAN_DIMER_KW["flatten_max_iter"])
|
|
2187
|
+
else:
|
|
2188
|
+
simple_cfg["flatten_max_iter"] = 0
|
|
2189
|
+
if _is_param_explicit("detect_layer"):
|
|
2190
|
+
calc_cfg["use_bfactor_layers"] = bool(detect_layer)
|
|
2191
|
+
if _is_param_explicit("hess_cutoff") and hess_cutoff is not None:
|
|
2192
|
+
calc_cfg["hess_cutoff"] = float(hess_cutoff)
|
|
2193
|
+
if _is_param_explicit("movable_cutoff") and movable_cutoff is not None:
|
|
2194
|
+
calc_cfg["movable_cutoff"] = float(movable_cutoff)
|
|
2195
|
+
calc_cfg["use_bfactor_layers"] = False
|
|
2196
|
+
|
|
2197
|
+
model_charge_value = calc_cfg.get("model_charge", charge)
|
|
2198
|
+
if model_charge_value is None:
|
|
2199
|
+
model_charge_value = charge
|
|
2200
|
+
calc_cfg["model_charge"] = int(model_charge_value)
|
|
2201
|
+
if _is_param_explicit("charge"):
|
|
2202
|
+
calc_cfg["model_charge"] = int(charge)
|
|
2203
|
+
|
|
2204
|
+
model_mult_value = calc_cfg.get("model_mult", spin)
|
|
2205
|
+
if model_mult_value is None:
|
|
2206
|
+
model_mult_value = spin
|
|
2207
|
+
calc_cfg["model_mult"] = int(model_mult_value)
|
|
2208
|
+
if _is_param_explicit("spin"):
|
|
2209
|
+
calc_cfg["model_mult"] = int(spin)
|
|
2210
|
+
|
|
2211
|
+
if model_pdb is not None:
|
|
2212
|
+
calc_cfg["model_pdb"] = str(model_pdb)
|
|
2213
|
+
calc_cfg["input_pdb"] = str(source_path)
|
|
2214
|
+
calc_cfg["real_parm7"] = str(real_parm7)
|
|
2215
|
+
|
|
2216
|
+
if backend is not None:
|
|
2217
|
+
calc_cfg["backend"] = str(backend).lower()
|
|
2218
|
+
if _is_param_explicit("embedcharge"):
|
|
2219
|
+
calc_cfg["embedcharge"] = bool(embedcharge)
|
|
2220
|
+
if _is_param_explicit("embedcharge_cutoff"):
|
|
2221
|
+
calc_cfg["embedcharge_cutoff"] = embedcharge_cutoff
|
|
2222
|
+
|
|
2223
|
+
apply_yaml_overrides(
|
|
2224
|
+
override_layer_cfg,
|
|
2225
|
+
[
|
|
2226
|
+
(geom_cfg, (("geom",),)),
|
|
2227
|
+
(calc_cfg, (("calc",), ("mlmm",))),
|
|
2228
|
+
(opt_cfg, (("opt",),)),
|
|
2229
|
+
(simple_cfg, (("hessian_dimer",),)),
|
|
2230
|
+
(rsirfo_cfg, (("rsirfo",),)),
|
|
2231
|
+
],
|
|
2232
|
+
)
|
|
2233
|
+
calc_paths = (("calc",), ("mlmm",))
|
|
2234
|
+
partial_explicit = (
|
|
2235
|
+
yaml_section_has_key(config_layer_cfg, calc_paths, "return_partial_hessian")
|
|
2236
|
+
or yaml_section_has_key(override_layer_cfg, calc_paths, "return_partial_hessian")
|
|
2237
|
+
)
|
|
2238
|
+
if not partial_explicit:
|
|
2239
|
+
calc_cfg["return_partial_hessian"] = True
|
|
2240
|
+
|
|
2241
|
+
# Resolve microiteration config from YAML
|
|
2242
|
+
microiter_cfg = dict(MICROITER_KW)
|
|
2243
|
+
apply_yaml_overrides(
|
|
2244
|
+
config_layer_cfg,
|
|
2245
|
+
[(microiter_cfg, (("microiter",),))],
|
|
2246
|
+
)
|
|
2247
|
+
apply_yaml_overrides(
|
|
2248
|
+
override_layer_cfg,
|
|
2249
|
+
[(microiter_cfg, (("microiter",),))],
|
|
2250
|
+
)
|
|
2251
|
+
|
|
2252
|
+
use_microiter = bool(microiter) and use_heavy
|
|
2253
|
+
if bool(microiter) and not use_heavy:
|
|
2254
|
+
click.echo("[microiter] --microiter is only effective with --opt-mode hess (RS-I-RFO). Ignoring.")
|
|
2255
|
+
|
|
2256
|
+
try:
|
|
2257
|
+
geom_freeze = _normalize_geom_freeze_opt(geom_cfg.get("freeze_atoms"))
|
|
2258
|
+
except click.BadParameter as e:
|
|
2259
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
2260
|
+
prepared_input.cleanup()
|
|
2261
|
+
sys.exit(1)
|
|
2262
|
+
geom_cfg["freeze_atoms"] = geom_freeze
|
|
2263
|
+
if freeze_atoms_cli:
|
|
2264
|
+
merge_freeze_atom_indices(geom_cfg, freeze_atoms_cli)
|
|
2265
|
+
freeze_atoms_final = list(geom_cfg.get("freeze_atoms") or [])
|
|
2266
|
+
calc_cfg["freeze_atoms"] = freeze_atoms_final
|
|
2267
|
+
|
|
2268
|
+
# Propagate opt.print_every only when it is explicitly different from the
|
|
2269
|
+
# base default. This avoids clobbering optimizer-specific YAML settings
|
|
2270
|
+
# (e.g. hessian_dimer.lbfgs.print_every / rsirfo.print_every) with the
|
|
2271
|
+
# inherited OPT_BASE default value.
|
|
2272
|
+
try:
|
|
2273
|
+
pe_opt = int(opt_cfg.get("print_every", OPT_BASE_KW.get("print_every", 100)))
|
|
2274
|
+
pe_base = int(OPT_BASE_KW.get("print_every", 100))
|
|
2275
|
+
if pe_opt >= 1 and pe_opt != pe_base:
|
|
2276
|
+
simple_cfg.setdefault("lbfgs", {})
|
|
2277
|
+
simple_cfg["lbfgs"]["print_every"] = pe_opt
|
|
2278
|
+
rsirfo_cfg["print_every"] = pe_opt
|
|
2279
|
+
except Exception:
|
|
2280
|
+
logger.debug("Failed to configure print_every", exc_info=True)
|
|
2281
|
+
|
|
2282
|
+
out_dir_path = Path(opt_cfg["out_dir"]).resolve()
|
|
2283
|
+
|
|
2284
|
+
# movable_cutoff implies full distance-based layer assignment.
|
|
2285
|
+
# hess_cutoff alone is allowed with detect-layer and is applied on movable MM atoms.
|
|
2286
|
+
detect_layer_enabled = bool(calc_cfg.get("use_bfactor_layers", True))
|
|
2287
|
+
model_pdb_cfg = calc_cfg.get("model_pdb")
|
|
2288
|
+
if calc_cfg.get("movable_cutoff") is not None:
|
|
2289
|
+
if detect_layer_enabled:
|
|
2290
|
+
click.echo("[layer] movable_cutoff is set; disabling --detect-layer.", err=True)
|
|
2291
|
+
detect_layer_enabled = False
|
|
2292
|
+
calc_cfg["use_bfactor_layers"] = False
|
|
2293
|
+
|
|
2294
|
+
layer_source_pdb = source_path
|
|
2295
|
+
if detect_layer_enabled and layer_source_pdb.suffix.lower() != ".pdb":
|
|
2296
|
+
click.echo("ERROR: --detect-layer requires a PDB input (or --ref-pdb).", err=True)
|
|
2297
|
+
prepared_input.cleanup()
|
|
2298
|
+
sys.exit(1)
|
|
2299
|
+
|
|
2300
|
+
if show_config:
|
|
2301
|
+
click.echo(
|
|
2302
|
+
pretty_block(
|
|
2303
|
+
"yaml_layers",
|
|
2304
|
+
{
|
|
2305
|
+
"config": None if config_yaml is None else str(config_yaml),
|
|
2306
|
+
"override_yaml": None if override_yaml is None else str(override_yaml),
|
|
2307
|
+
"merged_keys": sorted(merged_yaml_cfg.keys()),
|
|
2308
|
+
},
|
|
2309
|
+
)
|
|
2310
|
+
)
|
|
2311
|
+
|
|
2312
|
+
if dry_run:
|
|
2313
|
+
model_region_source = "bfactor"
|
|
2314
|
+
if not detect_layer_enabled:
|
|
2315
|
+
if model_pdb_cfg is not None:
|
|
2316
|
+
model_region_source = "model_pdb"
|
|
2317
|
+
elif model_indices:
|
|
2318
|
+
model_region_source = "model_indices"
|
|
2319
|
+
else:
|
|
2320
|
+
click.echo("ERROR: Provide --model-pdb or --model-indices when --no-detect-layer.", err=True)
|
|
2321
|
+
prepared_input.cleanup()
|
|
2322
|
+
sys.exit(1)
|
|
2323
|
+
if (
|
|
2324
|
+
not detect_layer_enabled
|
|
2325
|
+
and model_pdb_cfg is None
|
|
2326
|
+
and model_indices
|
|
2327
|
+
and layer_source_pdb.suffix.lower() != ".pdb"
|
|
2328
|
+
):
|
|
2329
|
+
click.echo("ERROR: --model-indices requires a PDB input (or --ref-pdb).", err=True)
|
|
2330
|
+
prepared_input.cleanup()
|
|
2331
|
+
sys.exit(1)
|
|
2332
|
+
click.echo(
|
|
2333
|
+
pretty_block(
|
|
2334
|
+
"dry_run_plan",
|
|
2335
|
+
{
|
|
2336
|
+
"input_geometry": str(geom_input_path),
|
|
2337
|
+
"output_dir": str(out_dir_path),
|
|
2338
|
+
"optimizer_mode": ("hess-rsirfo" if use_heavy else "grad-dimer"),
|
|
2339
|
+
"detect_layer": bool(detect_layer_enabled),
|
|
2340
|
+
"model_region_source": model_region_source,
|
|
2341
|
+
"model_indices_count": 0 if not model_indices else len(model_indices),
|
|
2342
|
+
"hessian_calc_mode": calc_cfg.get("hessian_calc_mode"),
|
|
2343
|
+
"partial_hessian_flatten": bool(partial_hessian_flatten),
|
|
2344
|
+
"active_dof_mode": str(active_dof_mode),
|
|
2345
|
+
"will_run_tsopt": True,
|
|
2346
|
+
"will_write_summary": True,
|
|
2347
|
+
"backend": calc_cfg.get("backend", "uma"),
|
|
2348
|
+
"embedcharge": bool(calc_cfg.get("embedcharge", False)),
|
|
2349
|
+
},
|
|
2350
|
+
)
|
|
2351
|
+
)
|
|
2352
|
+
click.echo("[dry-run] Validation complete. TS optimization execution was skipped.")
|
|
2353
|
+
prepared_input.cleanup()
|
|
2354
|
+
return
|
|
2355
|
+
|
|
2356
|
+
model_pdb_path: Optional[Path] = None
|
|
2357
|
+
layer_info: Optional[Dict[str, List[int]]] = None
|
|
2358
|
+
|
|
2359
|
+
if detect_layer_enabled:
|
|
2360
|
+
try:
|
|
2361
|
+
model_pdb_path, layer_info = build_model_pdb_from_bfactors(layer_source_pdb, out_dir_path)
|
|
2362
|
+
calc_cfg["use_bfactor_layers"] = True
|
|
2363
|
+
click.echo(
|
|
2364
|
+
f"[layer] Detected B-factor layers: ML={len(layer_info.get('ml_indices', []))}, "
|
|
2365
|
+
f"MovableMM={len(layer_info.get('movable_mm_indices', []))}, "
|
|
2366
|
+
f"FrozenMM={len(layer_info.get('frozen_indices', []))}"
|
|
2367
|
+
)
|
|
2368
|
+
except Exception as e:
|
|
2369
|
+
if model_pdb_cfg is None and not model_indices:
|
|
2370
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
2371
|
+
prepared_input.cleanup()
|
|
2372
|
+
sys.exit(1)
|
|
2373
|
+
click.echo(f"[layer] WARNING: {e} Falling back to explicit ML region.", err=True)
|
|
2374
|
+
detect_layer_enabled = False
|
|
2375
|
+
|
|
2376
|
+
if not detect_layer_enabled:
|
|
2377
|
+
if model_pdb_cfg is None and not model_indices:
|
|
2378
|
+
click.echo("ERROR: Provide --model-pdb or --model-indices when --no-detect-layer.", err=True)
|
|
2379
|
+
prepared_input.cleanup()
|
|
2380
|
+
sys.exit(1)
|
|
2381
|
+
if model_pdb_cfg is not None:
|
|
2382
|
+
model_pdb_path = Path(model_pdb_cfg)
|
|
2383
|
+
else:
|
|
2384
|
+
if layer_source_pdb.suffix.lower() != ".pdb":
|
|
2385
|
+
click.echo("ERROR: --model-indices requires a PDB input (or --ref-pdb).", err=True)
|
|
2386
|
+
prepared_input.cleanup()
|
|
2387
|
+
sys.exit(1)
|
|
2388
|
+
try:
|
|
2389
|
+
model_pdb_path = build_model_pdb_from_indices(layer_source_pdb, out_dir_path, model_indices or [])
|
|
2390
|
+
except Exception as e:
|
|
2391
|
+
click.echo(f"ERROR: {e}", err=True)
|
|
2392
|
+
prepared_input.cleanup()
|
|
2393
|
+
sys.exit(1)
|
|
2394
|
+
calc_cfg["use_bfactor_layers"] = False
|
|
2395
|
+
|
|
2396
|
+
if model_pdb_path is None:
|
|
2397
|
+
click.echo("ERROR: Failed to resolve model PDB for the ML region.", err=True)
|
|
2398
|
+
prepared_input.cleanup()
|
|
2399
|
+
sys.exit(1)
|
|
2400
|
+
|
|
2401
|
+
calc_cfg["model_pdb"] = str(model_pdb_path)
|
|
2402
|
+
freeze_atoms_final = apply_layer_freeze_constraints(
|
|
2403
|
+
geom_cfg,
|
|
2404
|
+
calc_cfg,
|
|
2405
|
+
layer_info,
|
|
2406
|
+
echo_fn=click.echo,
|
|
2407
|
+
)
|
|
2408
|
+
_align_three_layer_hessian_targets(calc_cfg, echo_fn=click.echo)
|
|
2409
|
+
|
|
2410
|
+
for key in ("input_pdb", "real_parm7", "model_pdb", "mm_fd_dir"):
|
|
2411
|
+
val = calc_cfg.get(key)
|
|
2412
|
+
if val:
|
|
2413
|
+
calc_cfg[key] = str(Path(val).expanduser().resolve())
|
|
2414
|
+
|
|
2415
|
+
# Pretty-print config summary (only non-default values for concise logging)
|
|
2416
|
+
mode_desc = "RS-I-RFO (hess)" if use_heavy else "Dimer (grad)"
|
|
2417
|
+
if use_microiter:
|
|
2418
|
+
mode_desc += " + Microiteration"
|
|
2419
|
+
click.echo(f"\n[mode] TS Optimizer: {mode_desc}\n")
|
|
2420
|
+
click.echo(pretty_block("geom", format_freeze_atoms_for_echo(geom_cfg, key="freeze_atoms")))
|
|
2421
|
+
echo_calc = format_freeze_atoms_for_echo(filter_calc_for_echo(calc_cfg), key="freeze_atoms")
|
|
2422
|
+
click.echo(pretty_block("calc", echo_calc))
|
|
2423
|
+
echo_opt = strip_inherited_keys({**opt_cfg, "out_dir": str(out_dir_path)}, OPT_BASE_KW, mode="same")
|
|
2424
|
+
click.echo(pretty_block("opt", echo_opt))
|
|
2425
|
+
# Show only optimizer-specific settings, not inherited from opt_cfg
|
|
2426
|
+
if use_heavy:
|
|
2427
|
+
echo_rsirfo = strip_inherited_keys(rsirfo_cfg, opt_cfg)
|
|
2428
|
+
click.echo(pretty_block("rsirfo", echo_rsirfo))
|
|
2429
|
+
else:
|
|
2430
|
+
sd_cfg_for_echo: Dict[str, Any] = {}
|
|
2431
|
+
sd_cfg_for_echo["dimer"] = dict(simple_cfg.get("dimer", {}))
|
|
2432
|
+
sd_cfg_for_echo["lbfgs"] = strip_inherited_keys(
|
|
2433
|
+
dict(simple_cfg.get("lbfgs", {})), opt_cfg
|
|
2434
|
+
)
|
|
2435
|
+
click.echo(pretty_block("hessian_dimer", sd_cfg_for_echo))
|
|
2436
|
+
|
|
2437
|
+
# --------------------------
|
|
2438
|
+
# 2) Prepare geometry dir
|
|
2439
|
+
# --------------------------
|
|
2440
|
+
out_dir_path.mkdir(parents=True, exist_ok=True)
|
|
2441
|
+
|
|
2442
|
+
# --------------------------
|
|
2443
|
+
# 3) Run
|
|
2444
|
+
# --------------------------
|
|
2445
|
+
try:
|
|
2446
|
+
if use_heavy:
|
|
2447
|
+
# Heavy mode: RS-I-RFO with full Hessian
|
|
2448
|
+
rsirfo_label = "RS-I-RFO heavy mode"
|
|
2449
|
+
if use_microiter:
|
|
2450
|
+
rsirfo_label += " + Microiteration"
|
|
2451
|
+
optim_all_path = out_dir_path / "optimization_all_trj.xyz"
|
|
2452
|
+
if bool(opt_cfg["dump"]) and optim_all_path.exists():
|
|
2453
|
+
optim_all_path.unlink()
|
|
2454
|
+
|
|
2455
|
+
coord_type = geom_cfg.get("coord_type", "cart")
|
|
2456
|
+
coord_kwargs = dict(geom_cfg)
|
|
2457
|
+
coord_kwargs.pop("coord_type", None)
|
|
2458
|
+
geometry = geom_loader(
|
|
2459
|
+
geom_input_path,
|
|
2460
|
+
coord_type=coord_type,
|
|
2461
|
+
**coord_kwargs,
|
|
2462
|
+
)
|
|
2463
|
+
|
|
2464
|
+
if use_microiter:
|
|
2465
|
+
# --- Microiteration path ---
|
|
2466
|
+
click.echo(f"\n=== TS optimization ({rsirfo_label}) started ===\n")
|
|
2467
|
+
_run_microiter_tsopt(
|
|
2468
|
+
geometry,
|
|
2469
|
+
calc_cfg,
|
|
2470
|
+
rsirfo_cfg,
|
|
2471
|
+
lbfgs_cfg,
|
|
2472
|
+
opt_cfg,
|
|
2473
|
+
microiter_cfg,
|
|
2474
|
+
out_dir_path,
|
|
2475
|
+
dump=bool(opt_cfg["dump"]),
|
|
2476
|
+
thresh=thresh,
|
|
2477
|
+
)
|
|
2478
|
+
click.echo(f"\n=== TS optimization ({rsirfo_label}) finished ===\n")
|
|
2479
|
+
|
|
2480
|
+
# Write final geometry
|
|
2481
|
+
final_xyz = out_dir_path / "final_geometry.xyz"
|
|
2482
|
+
final_xyz.write_text(geometry.as_xyz(), encoding="utf-8")
|
|
2483
|
+
|
|
2484
|
+
# For post-analysis, get hess_active_atoms from a fresh calc
|
|
2485
|
+
_temp_calc = mlmm(**calc_cfg)
|
|
2486
|
+
_temp_core = _temp_calc.core if hasattr(_temp_calc, "core") else _temp_calc
|
|
2487
|
+
hess_active_atoms = list(getattr(_temp_core, "hess_active_atoms", []))
|
|
2488
|
+
del _temp_calc, _temp_core
|
|
2489
|
+
_clear_cuda_cache()
|
|
2490
|
+
_rsirfo_cycles_spent = int(opt_cfg.get("max_cycles", 10000)) # budget consumed
|
|
2491
|
+
else:
|
|
2492
|
+
# --- Standard RS-I-RFO path ---
|
|
2493
|
+
click.echo(f"\n=== TS optimization ({rsirfo_label}) started ===\n")
|
|
2494
|
+
|
|
2495
|
+
base_calc = mlmm(**calc_cfg)
|
|
2496
|
+
geometry.set_calculator(base_calc)
|
|
2497
|
+
|
|
2498
|
+
click.echo("[tsopt] Seeding initial Hessian via shared freq backend.")
|
|
2499
|
+
hess_device = _torch_device(simple_cfg.get("device", calc_cfg.get("ml_device", "auto")))
|
|
2500
|
+
h_init = _calc_full_hessian_torch(geometry, calc_cfg, hess_device)
|
|
2501
|
+
geometry.cart_hessian = h_init
|
|
2502
|
+
click.echo(
|
|
2503
|
+
f"[tsopt] Initial Hessian seeded (shape={h_init.shape[0]}x{h_init.shape[1]})."
|
|
2504
|
+
)
|
|
2505
|
+
del h_init
|
|
2506
|
+
|
|
2507
|
+
rsirfo_args = {**rsirfo_cfg}
|
|
2508
|
+
rsirfo_args["out_dir"] = str(out_dir_path)
|
|
2509
|
+
rsirfo_args["max_cycles"] = int(opt_cfg["max_cycles"])
|
|
2510
|
+
rsirfo_args["dump"] = bool(opt_cfg["dump"])
|
|
2511
|
+
if thresh is not None:
|
|
2512
|
+
rsirfo_args["thresh"] = str(thresh)
|
|
2513
|
+
# RSIRFOptimizer does not accept RFOptimizer-specific DIIS knobs; strip them.
|
|
2514
|
+
for _diis_kw in ("gediis", "gdiis", "gdiis_thresh", "gediis_thresh", "gdiis_test_direction", "adapt_step_func"):
|
|
2515
|
+
rsirfo_args.pop(_diis_kw, None)
|
|
2516
|
+
|
|
2517
|
+
calc_core = base_calc.core if hasattr(base_calc, "core") else base_calc
|
|
2518
|
+
hess_active_atoms = list(getattr(calc_core, "hess_active_atoms", []))
|
|
2519
|
+
optimizer = RSIRFOptimizer(geometry, **rsirfo_args)
|
|
2520
|
+
optimizer.run()
|
|
2521
|
+
if bool(opt_cfg["dump"]):
|
|
2522
|
+
_append_xyz_trajectory(optim_all_path, out_dir_path / "optimization_trj.xyz")
|
|
2523
|
+
|
|
2524
|
+
click.echo(f"\n=== TS optimization ({rsirfo_label}) finished ===\n")
|
|
2525
|
+
|
|
2526
|
+
# --- Post-RSIRFO: count imaginary modes and optional flatten loop ---
|
|
2527
|
+
# Save cycle count before deleting optimizer for budget check.
|
|
2528
|
+
_rsirfo_cycles_spent = getattr(optimizer, "cur_cycle", 0) + 1
|
|
2529
|
+
geometry.set_calculator(None)
|
|
2530
|
+
del optimizer
|
|
2531
|
+
del calc_core
|
|
2532
|
+
del base_calc
|
|
2533
|
+
_clear_cuda_cache()
|
|
2534
|
+
mlmm_kwargs_for_heavy = dict(calc_cfg)
|
|
2535
|
+
mlmm_kwargs_for_heavy["out_hess_torch"] = True
|
|
2536
|
+
device = _torch_device(simple_cfg.get("device", calc_cfg.get("ml_device", "auto")))
|
|
2537
|
+
|
|
2538
|
+
# Determine active atoms for frequency analysis based on --active-dof-mode.
|
|
2539
|
+
active_atoms_freq = _get_active_dof_indices(
|
|
2540
|
+
calc_cfg, len(geometry.atomic_numbers), active_dof_mode, freeze_atoms_final
|
|
2541
|
+
)
|
|
2542
|
+
n_atoms = len(geometry.atomic_numbers)
|
|
2543
|
+
if active_atoms_freq is not None:
|
|
2544
|
+
active_set = set(active_atoms_freq)
|
|
2545
|
+
freeze_atoms_freq = [i for i in range(n_atoms) if i not in active_set]
|
|
2546
|
+
else:
|
|
2547
|
+
freeze_atoms_freq = freeze_atoms_final if freeze_atoms_final else None
|
|
2548
|
+
|
|
2549
|
+
def _calc_freqs_and_modes() -> Tuple[np.ndarray, torch.Tensor]:
|
|
2550
|
+
H, _energy_ha = _freq_calc_full_hessian_torch(
|
|
2551
|
+
geometry, mlmm_kwargs_for_heavy, device, refresh_geom_meta=True,
|
|
2552
|
+
)
|
|
2553
|
+
from .hessian_cache import store as _hess_store
|
|
2554
|
+
# Determine active_dofs for partial Hessian
|
|
2555
|
+
_n_full_dofs = 3 * len(geometry.atomic_numbers)
|
|
2556
|
+
if H.shape[0] < _n_full_dofs:
|
|
2557
|
+
_freeze_atoms = list(calc_cfg.get("freeze_atoms", []))
|
|
2558
|
+
_all_dofs = set(range(_n_full_dofs))
|
|
2559
|
+
_frozen_dofs = set()
|
|
2560
|
+
for _idx in _freeze_atoms:
|
|
2561
|
+
_frozen_dofs.update([3 * _idx, 3 * _idx + 1, 3 * _idx + 2])
|
|
2562
|
+
_active_dofs = sorted(_all_dofs - _frozen_dofs)
|
|
2563
|
+
else:
|
|
2564
|
+
_active_dofs = None
|
|
2565
|
+
_hess_store("ts", H, active_dofs=_active_dofs, meta={"energy_ha": _energy_ha})
|
|
2566
|
+
n_full = len(geometry.atomic_numbers)
|
|
2567
|
+
if H.shape[0] != 3 * n_full:
|
|
2568
|
+
# Partial Hessian: use Hessian-target atoms only and embed modes back.
|
|
2569
|
+
active_atoms = None
|
|
2570
|
+
if getattr(geometry, "within_partial_hessian", None) is not None:
|
|
2571
|
+
active_atoms = geometry.within_partial_hessian.get("active_atoms")
|
|
2572
|
+
if active_atoms is None:
|
|
2573
|
+
active_atoms = hess_active_atoms
|
|
2574
|
+
if active_atoms is None:
|
|
2575
|
+
active_atoms = active_atoms_freq
|
|
2576
|
+
if active_atoms is None:
|
|
2577
|
+
active_atoms = []
|
|
2578
|
+
else:
|
|
2579
|
+
active_atoms = [int(i) for i in np.asarray(active_atoms, dtype=int).reshape(-1).tolist()]
|
|
2580
|
+
if not active_atoms:
|
|
2581
|
+
raise RuntimeError(
|
|
2582
|
+
"No active atoms available for partial Hessian frequency analysis."
|
|
2583
|
+
)
|
|
2584
|
+
coords_act = geometry.cart_coords.reshape(-1, 3)[active_atoms]
|
|
2585
|
+
nums_act = np.asarray(geometry.atomic_numbers)[active_atoms]
|
|
2586
|
+
freqs_local, modes_act = _frequencies_cm_and_modes(
|
|
2587
|
+
H,
|
|
2588
|
+
nums_act,
|
|
2589
|
+
coords_act,
|
|
2590
|
+
device,
|
|
2591
|
+
freeze_idx=None,
|
|
2592
|
+
)
|
|
2593
|
+
# Embed modes into full 3N on CPU to reduce VRAM peak.
|
|
2594
|
+
modes_local = torch.zeros((modes_act.shape[0], 3 * n_full), dtype=modes_act.dtype, device="cpu")
|
|
2595
|
+
mask_dof = torch.as_tensor(_mask_dof_from_active_idx(n_full, active_atoms), dtype=torch.bool)
|
|
2596
|
+
modes_local[:, mask_dof] = modes_act.detach().cpu()
|
|
2597
|
+
del coords_act, nums_act, modes_act, mask_dof
|
|
2598
|
+
else:
|
|
2599
|
+
freqs_local, modes_gpu = _frequencies_cm_and_modes(
|
|
2600
|
+
H,
|
|
2601
|
+
geometry.atomic_numbers,
|
|
2602
|
+
geometry.cart_coords.reshape(-1, 3),
|
|
2603
|
+
device,
|
|
2604
|
+
freeze_idx=freeze_atoms_freq,
|
|
2605
|
+
)
|
|
2606
|
+
modes_local = modes_gpu.detach().cpu()
|
|
2607
|
+
del modes_gpu
|
|
2608
|
+
del H
|
|
2609
|
+
_clear_cuda_cache()
|
|
2610
|
+
return freqs_local, modes_local
|
|
2611
|
+
|
|
2612
|
+
try:
|
|
2613
|
+
freqs_cm, modes = _calc_freqs_and_modes()
|
|
2614
|
+
except Exception as exc:
|
|
2615
|
+
is_oom = isinstance(exc, torch.OutOfMemoryError) or ("cuda out of memory" in str(exc).lower())
|
|
2616
|
+
if is_oom:
|
|
2617
|
+
click.echo(
|
|
2618
|
+
"[tsopt] WARNING: CUDA OOM during final frequency analysis; "
|
|
2619
|
+
"skipping imaginary-mode analysis/flatten loop.",
|
|
2620
|
+
err=True,
|
|
2621
|
+
)
|
|
2622
|
+
_clear_cuda_cache()
|
|
2623
|
+
freqs_cm, modes = None, None
|
|
2624
|
+
else:
|
|
2625
|
+
raise
|
|
2626
|
+
neg_freq_thresh_cm = float(simple_cfg.get("neg_freq_thresh_cm", 5.0))
|
|
2627
|
+
|
|
2628
|
+
if freqs_cm is not None and modes is not None:
|
|
2629
|
+
neg_mask = freqs_cm < -abs(neg_freq_thresh_cm)
|
|
2630
|
+
n_imag = int(np.sum(neg_mask))
|
|
2631
|
+
ims = [float(x) for x in freqs_cm if x < -abs(neg_freq_thresh_cm)]
|
|
2632
|
+
click.echo(f"[Imaginary modes] n={n_imag} ({ims})")
|
|
2633
|
+
|
|
2634
|
+
flatten_max_iter = int(simple_cfg.get("flatten_max_iter", 0))
|
|
2635
|
+
user_max_cycles = int(opt_cfg.get("max_cycles", 10000))
|
|
2636
|
+
budget_remaining = user_max_cycles - _rsirfo_cycles_spent > 0
|
|
2637
|
+
|
|
2638
|
+
if flatten_max_iter > 0 and n_imag > 1 and not budget_remaining:
|
|
2639
|
+
click.echo("[tsopt] Reached --max-cycles budget; skipping flatten loop.")
|
|
2640
|
+
elif flatten_max_iter > 0 and n_imag > 1 and budget_remaining:
|
|
2641
|
+
click.echo("[flatten] Extra imaginary modes detected; starting RS-I-RFO flatten loop.")
|
|
2642
|
+
masses_amu = np.array([atomic_masses[z] for z in geometry.atomic_numbers])
|
|
2643
|
+
main_root = int(simple_cfg.get("root", 0))
|
|
2644
|
+
|
|
2645
|
+
for it in range(flatten_max_iter):
|
|
2646
|
+
click.echo(f"[flatten] RS-I-RFO iteration {it + 1}/{flatten_max_iter}")
|
|
2647
|
+
did_flatten = _flatten_once_with_modes_for_geom(
|
|
2648
|
+
geometry,
|
|
2649
|
+
masses_amu,
|
|
2650
|
+
mlmm_kwargs_for_heavy,
|
|
2651
|
+
freqs_cm,
|
|
2652
|
+
modes,
|
|
2653
|
+
neg_freq_thresh_cm,
|
|
2654
|
+
float(simple_cfg.get("flatten_amp_ang", 0.10)),
|
|
2655
|
+
float(simple_cfg.get("flatten_sep_cutoff", 0.0)),
|
|
2656
|
+
int(simple_cfg.get("flatten_k", 10)),
|
|
2657
|
+
main_root,
|
|
2658
|
+
)
|
|
2659
|
+
if not did_flatten:
|
|
2660
|
+
click.echo("[flatten] No eligible modes to flatten; stopping.")
|
|
2661
|
+
break
|
|
2662
|
+
|
|
2663
|
+
del freqs_cm, modes
|
|
2664
|
+
_clear_cuda_cache()
|
|
2665
|
+
base_calc = mlmm(**calc_cfg)
|
|
2666
|
+
geometry.set_calculator(base_calc)
|
|
2667
|
+
optimizer = RSIRFOptimizer(geometry, **rsirfo_args)
|
|
2668
|
+
click.echo("\n=== TS optimization (RS-I-RFO) restarted ===\n")
|
|
2669
|
+
optimizer.run()
|
|
2670
|
+
click.echo("\n=== TS optimization (RS-I-RFO) finished ===\n")
|
|
2671
|
+
if bool(opt_cfg["dump"]):
|
|
2672
|
+
_append_xyz_trajectory(optim_all_path, out_dir_path / "optimization_trj.xyz")
|
|
2673
|
+
geometry.set_calculator(None)
|
|
2674
|
+
del optimizer, base_calc
|
|
2675
|
+
_clear_cuda_cache()
|
|
2676
|
+
|
|
2677
|
+
try:
|
|
2678
|
+
freqs_cm, modes = _calc_freqs_and_modes()
|
|
2679
|
+
except Exception as exc:
|
|
2680
|
+
is_oom = isinstance(exc, torch.OutOfMemoryError) or ("cuda out of memory" in str(exc).lower())
|
|
2681
|
+
if is_oom:
|
|
2682
|
+
click.echo(
|
|
2683
|
+
"[tsopt] WARNING: CUDA OOM during final frequency analysis; "
|
|
2684
|
+
"stopping flatten loop.",
|
|
2685
|
+
err=True,
|
|
2686
|
+
)
|
|
2687
|
+
_clear_cuda_cache()
|
|
2688
|
+
freqs_cm, modes = None, None
|
|
2689
|
+
break
|
|
2690
|
+
raise
|
|
2691
|
+
|
|
2692
|
+
neg_mask = freqs_cm < -abs(neg_freq_thresh_cm)
|
|
2693
|
+
n_imag = int(np.sum(neg_mask))
|
|
2694
|
+
ims = [float(x) for x in freqs_cm if x < -abs(neg_freq_thresh_cm)]
|
|
2695
|
+
click.echo(f"[Imaginary modes] n={n_imag} ({ims})")
|
|
2696
|
+
if n_imag <= 1:
|
|
2697
|
+
break
|
|
2698
|
+
|
|
2699
|
+
if freqs_cm is not None and modes is not None:
|
|
2700
|
+
# --- Write all final imaginary modes like light mode ---
|
|
2701
|
+
vib_dir = out_dir_path / "vib"
|
|
2702
|
+
vib_dir.mkdir(parents=True, exist_ok=True)
|
|
2703
|
+
_ref_pdb_for_modes = source_path if source_path.suffix.lower() == ".pdb" else None
|
|
2704
|
+
n_written = _write_all_imag_modes(
|
|
2705
|
+
geometry,
|
|
2706
|
+
freqs_cm,
|
|
2707
|
+
modes,
|
|
2708
|
+
neg_freq_thresh_cm,
|
|
2709
|
+
vib_dir,
|
|
2710
|
+
ref_pdb=_ref_pdb_for_modes,
|
|
2711
|
+
)
|
|
2712
|
+
if n_written == 0:
|
|
2713
|
+
click.echo("[INFO] No imaginary mode found at the end for RS-I-RFO.")
|
|
2714
|
+
else:
|
|
2715
|
+
click.echo(f"[DONE] Wrote {n_written} final imaginary mode(s).")
|
|
2716
|
+
click.echo(f"[DONE] Mode files → {vib_dir}")
|
|
2717
|
+
else:
|
|
2718
|
+
click.echo("[INFO] Skipped final imaginary-mode export due to frequency-analysis fallback.")
|
|
2719
|
+
|
|
2720
|
+
if modes is not None:
|
|
2721
|
+
del modes
|
|
2722
|
+
if freqs_cm is not None:
|
|
2723
|
+
del freqs_cm
|
|
2724
|
+
_clear_cuda_cache()
|
|
2725
|
+
|
|
2726
|
+
# Ensure final_geometry.xyz exists (partial-micro path may not write it).
|
|
2727
|
+
final_xyz = out_dir_path / "final_geometry.xyz"
|
|
2728
|
+
if not final_xyz.exists():
|
|
2729
|
+
final_xyz.write_text(geometry.as_xyz(), encoding="utf-8")
|
|
2730
|
+
|
|
2731
|
+
else:
|
|
2732
|
+
# Light mode: Partial Hessian guided Dimer
|
|
2733
|
+
runner = HessianDimer(
|
|
2734
|
+
fn=str(geom_input_path),
|
|
2735
|
+
out_dir=str(out_dir_path),
|
|
2736
|
+
thresh_loose=simple_cfg.get("thresh_loose", "gau_loose"),
|
|
2737
|
+
thresh=simple_cfg.get("thresh", "gau"),
|
|
2738
|
+
update_interval_hessian=int(simple_cfg.get("update_interval_hessian", 500)),
|
|
2739
|
+
neg_freq_thresh_cm=float(simple_cfg.get("neg_freq_thresh_cm", 5.0)),
|
|
2740
|
+
flatten_amp_ang=float(simple_cfg.get("flatten_amp_ang", 0.10)),
|
|
2741
|
+
flatten_max_iter=int(simple_cfg.get("flatten_max_iter", 20)),
|
|
2742
|
+
mem=int(simple_cfg.get("mem", 100000)),
|
|
2743
|
+
use_lobpcg=bool(simple_cfg.get("use_lobpcg", True)),
|
|
2744
|
+
calc_kwargs=dict(calc_cfg),
|
|
2745
|
+
device=str(simple_cfg.get("device", calc_cfg.get("ml_device", "auto"))),
|
|
2746
|
+
dump=bool(opt_cfg["dump"]),
|
|
2747
|
+
root=int(simple_cfg.get("root", 0)),
|
|
2748
|
+
dimer_kwargs=dict(simple_cfg.get("dimer", {})),
|
|
2749
|
+
lbfgs_kwargs=dict(simple_cfg.get("lbfgs", {})),
|
|
2750
|
+
max_total_cycles=int(opt_cfg["max_cycles"]),
|
|
2751
|
+
geom_kwargs=dict(geom_cfg),
|
|
2752
|
+
partial_hessian_flatten=partial_hessian_flatten,
|
|
2753
|
+
flatten_sep_cutoff=float(simple_cfg.get("flatten_sep_cutoff", 0.0)),
|
|
2754
|
+
flatten_k=int(simple_cfg.get("flatten_k", 10)),
|
|
2755
|
+
flatten_loop_bofill=bool(simple_cfg.get("flatten_loop_bofill", False)),
|
|
2756
|
+
ml_only_hessian_dimer=bool(simple_cfg.get("ml_only_hessian_dimer", ml_only_hessian_dimer)),
|
|
2757
|
+
source_path=source_path,
|
|
2758
|
+
)
|
|
2759
|
+
|
|
2760
|
+
click.echo("\n=== TS optimization (Partial Hessian Dimer) started ===\n")
|
|
2761
|
+
runner.run()
|
|
2762
|
+
click.echo("\n=== TS optimization (Partial Hessian Dimer) finished ===\n")
|
|
2763
|
+
|
|
2764
|
+
if is_convert_file_enabled() and source_path.suffix.lower() == ".pdb":
|
|
2765
|
+
ref_pdb = source_path.resolve()
|
|
2766
|
+
final_xyz = out_dir_path / "final_geometry.xyz"
|
|
2767
|
+
final_pdb = out_dir_path / "final_geometry.pdb"
|
|
2768
|
+
|
|
2769
|
+
# Get layer indices for B-factor annotation
|
|
2770
|
+
# For heavy mode, base_calc is available; for light mode, create temporary calc
|
|
2771
|
+
layer_indices = None
|
|
2772
|
+
if use_heavy and 'base_calc' in dir():
|
|
2773
|
+
calc_core = base_calc.core if hasattr(base_calc, 'core') else base_calc
|
|
2774
|
+
layer_indices = {
|
|
2775
|
+
"ml": getattr(calc_core, 'ml_indices', None),
|
|
2776
|
+
"hess_mm": getattr(calc_core, 'hess_mm_indices', None),
|
|
2777
|
+
"movable_mm": getattr(calc_core, 'movable_mm_indices', None),
|
|
2778
|
+
"frozen": getattr(calc_core, 'frozen_layer_indices', None),
|
|
2779
|
+
}
|
|
2780
|
+
else:
|
|
2781
|
+
# For light mode, create a temporary calculator to get layer indices
|
|
2782
|
+
try:
|
|
2783
|
+
temp_calc = mlmm(**calc_cfg)
|
|
2784
|
+
calc_core = temp_calc.core if hasattr(temp_calc, 'core') else temp_calc
|
|
2785
|
+
layer_indices = {
|
|
2786
|
+
"ml": getattr(calc_core, 'ml_indices', None),
|
|
2787
|
+
"hess_mm": getattr(calc_core, 'hess_mm_indices', None),
|
|
2788
|
+
"movable_mm": getattr(calc_core, 'movable_mm_indices', None),
|
|
2789
|
+
"frozen": getattr(calc_core, 'frozen_layer_indices', None),
|
|
2790
|
+
}
|
|
2791
|
+
del temp_calc
|
|
2792
|
+
except Exception:
|
|
2793
|
+
layer_indices = None
|
|
2794
|
+
|
|
2795
|
+
try:
|
|
2796
|
+
convert_xyz_to_pdb(final_xyz, ref_pdb, final_pdb)
|
|
2797
|
+
click.echo(f"[convert] Wrote '{final_pdb}'.")
|
|
2798
|
+
|
|
2799
|
+
# Annotate B-factors with layer-based encoding
|
|
2800
|
+
if layer_indices and layer_indices.get("ml") is not None:
|
|
2801
|
+
update_pdb_bfactors_from_layers(
|
|
2802
|
+
final_pdb,
|
|
2803
|
+
ml_indices=layer_indices["ml"] or [],
|
|
2804
|
+
hess_mm_indices=layer_indices.get("hess_mm"),
|
|
2805
|
+
movable_mm_indices=layer_indices.get("movable_mm"),
|
|
2806
|
+
frozen_indices=layer_indices.get("frozen"),
|
|
2807
|
+
)
|
|
2808
|
+
click.echo(
|
|
2809
|
+
f"[annot] B-factors set in '{final_pdb}' "
|
|
2810
|
+
f"(ML={BFACTOR_ML:.0f}, MovableMM={BFACTOR_MOVABLE_MM:.0f}, "
|
|
2811
|
+
f"FrozenMM={BFACTOR_FROZEN:.0f})."
|
|
2812
|
+
)
|
|
2813
|
+
except Exception as e:
|
|
2814
|
+
click.echo(f"[convert] WARNING: Failed to convert final geometry to PDB: {e}", err=True)
|
|
2815
|
+
|
|
2816
|
+
all_trj = out_dir_path / "optimization_all_trj.xyz"
|
|
2817
|
+
if all_trj.exists():
|
|
2818
|
+
try:
|
|
2819
|
+
opt_pdb = out_dir_path / "optimization_all.pdb"
|
|
2820
|
+
convert_xyz_to_pdb(all_trj, ref_pdb, opt_pdb)
|
|
2821
|
+
click.echo(f"[convert] Wrote '{opt_pdb}'.")
|
|
2822
|
+
|
|
2823
|
+
# Annotate B-factors with layer-based encoding
|
|
2824
|
+
if layer_indices and layer_indices.get("ml") is not None:
|
|
2825
|
+
update_pdb_bfactors_from_layers(
|
|
2826
|
+
opt_pdb,
|
|
2827
|
+
ml_indices=layer_indices["ml"] or [],
|
|
2828
|
+
hess_mm_indices=layer_indices.get("hess_mm"),
|
|
2829
|
+
movable_mm_indices=layer_indices.get("movable_mm"),
|
|
2830
|
+
frozen_indices=layer_indices.get("frozen"),
|
|
2831
|
+
)
|
|
2832
|
+
click.echo(
|
|
2833
|
+
f"[annot] B-factors set in '{opt_pdb}' "
|
|
2834
|
+
f"(ML={BFACTOR_ML:.0f}, MovableMM={BFACTOR_MOVABLE_MM:.0f}, "
|
|
2835
|
+
f"FrozenMM={BFACTOR_FROZEN:.0f})."
|
|
2836
|
+
)
|
|
2837
|
+
except Exception as e:
|
|
2838
|
+
click.echo(f"[convert] WARNING: Failed to convert optimization trajectory to PDB: {e}", err=True)
|
|
2839
|
+
else:
|
|
2840
|
+
final_xyz = out_dir_path / "final_geometry.xyz"
|
|
2841
|
+
|
|
2842
|
+
# summary.md and key_* outputs are disabled.
|
|
2843
|
+
click.echo(format_elapsed("[time] Elapsed Time for TS Opt", time_start))
|
|
2844
|
+
|
|
2845
|
+
except ZeroStepLength:
|
|
2846
|
+
click.echo("ERROR: Proposed step length dropped below the minimum allowed (ZeroStepLength).", err=True)
|
|
2847
|
+
sys.exit(2)
|
|
2848
|
+
except OptimizationError as e:
|
|
2849
|
+
click.echo(f"ERROR: Optimization failed — {e}", err=True)
|
|
2850
|
+
sys.exit(3)
|
|
2851
|
+
except KeyboardInterrupt:
|
|
2852
|
+
click.echo("\nInterrupted by user.", err=True)
|
|
2853
|
+
sys.exit(130)
|
|
2854
|
+
except Exception as e:
|
|
2855
|
+
import traceback
|
|
2856
|
+
tb = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
2857
|
+
click.echo("Unhandled error during optimization:\n" + textwrap.indent(tb, " "), err=True)
|
|
2858
|
+
sys.exit(1)
|
|
2859
|
+
finally:
|
|
2860
|
+
prepared_input.cleanup()
|
|
2861
|
+
# Release GPU memory (model + Hessian) so subsequent stages don't OOM
|
|
2862
|
+
base_calc = geometry = optimizer = last_optimizer = None
|
|
2863
|
+
macro_calc = macro_optimizer = mm_calc = None
|
|
2864
|
+
gc.collect() # break cyclic refs inside torch.nn.Module
|
|
2865
|
+
if torch.cuda.is_available():
|
|
2866
|
+
torch.cuda.empty_cache()
|
|
2867
|
+
|
|
2868
|
+
|
|
2869
|
+
# Allow `python -m mlmm.tsopt` direct execution
|
|
2870
|
+
if __name__ == "__main__":
|
|
2871
|
+
cli()
|