mlmm-toolkit 0.2.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hessian_ff/__init__.py +50 -0
- hessian_ff/analytical_hessian.py +609 -0
- hessian_ff/constants.py +46 -0
- hessian_ff/forcefield.py +339 -0
- hessian_ff/loaders.py +608 -0
- hessian_ff/native/Makefile +8 -0
- hessian_ff/native/__init__.py +28 -0
- hessian_ff/native/analytical_hessian.py +88 -0
- hessian_ff/native/analytical_hessian_ext.cpp +258 -0
- hessian_ff/native/bonded.py +82 -0
- hessian_ff/native/bonded_ext.cpp +640 -0
- hessian_ff/native/loader.py +349 -0
- hessian_ff/native/nonbonded.py +118 -0
- hessian_ff/native/nonbonded_ext.cpp +1150 -0
- hessian_ff/prmtop_parmed.py +23 -0
- hessian_ff/system.py +107 -0
- hessian_ff/terms/__init__.py +14 -0
- hessian_ff/terms/angle.py +73 -0
- hessian_ff/terms/bond.py +44 -0
- hessian_ff/terms/cmap.py +406 -0
- hessian_ff/terms/dihedral.py +141 -0
- hessian_ff/terms/nonbonded.py +209 -0
- hessian_ff/tests/__init__.py +0 -0
- hessian_ff/tests/conftest.py +75 -0
- hessian_ff/tests/data/small/complex.parm7 +1346 -0
- hessian_ff/tests/data/small/complex.pdb +125 -0
- hessian_ff/tests/data/small/complex.rst7 +63 -0
- hessian_ff/tests/test_coords_input.py +44 -0
- hessian_ff/tests/test_energy_force.py +49 -0
- hessian_ff/tests/test_hessian.py +137 -0
- hessian_ff/tests/test_smoke.py +18 -0
- hessian_ff/tests/test_validation.py +40 -0
- hessian_ff/workflows.py +889 -0
- mlmm/__init__.py +36 -0
- mlmm/__main__.py +7 -0
- mlmm/_version.py +34 -0
- mlmm/add_elem_info.py +374 -0
- mlmm/advanced_help.py +91 -0
- mlmm/align_freeze_atoms.py +601 -0
- mlmm/all.py +3535 -0
- mlmm/bond_changes.py +231 -0
- mlmm/bool_compat.py +223 -0
- mlmm/cli.py +574 -0
- mlmm/cli_utils.py +166 -0
- mlmm/default_group.py +337 -0
- mlmm/defaults.py +467 -0
- mlmm/define_layer.py +526 -0
- mlmm/dft.py +1041 -0
- mlmm/energy_diagram.py +253 -0
- mlmm/extract.py +2213 -0
- mlmm/fix_altloc.py +464 -0
- mlmm/freq.py +1406 -0
- mlmm/harmonic_constraints.py +140 -0
- mlmm/hessian_cache.py +44 -0
- mlmm/hessian_calc.py +174 -0
- mlmm/irc.py +638 -0
- mlmm/mlmm_calc.py +2262 -0
- mlmm/mm_parm.py +945 -0
- mlmm/oniom_export.py +1983 -0
- mlmm/oniom_import.py +457 -0
- mlmm/opt.py +1742 -0
- mlmm/path_opt.py +1353 -0
- mlmm/path_search.py +2299 -0
- mlmm/preflight.py +88 -0
- mlmm/py.typed +1 -0
- mlmm/pysis_runner.py +45 -0
- mlmm/scan.py +1047 -0
- mlmm/scan2d.py +1226 -0
- mlmm/scan3d.py +1265 -0
- mlmm/scan_common.py +184 -0
- mlmm/summary_log.py +736 -0
- mlmm/trj2fig.py +448 -0
- mlmm/tsopt.py +2871 -0
- mlmm/utils.py +2309 -0
- mlmm/xtb_embedcharge_correction.py +475 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
- mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
- pysisyphus/Geometry.py +1667 -0
- pysisyphus/LICENSE +674 -0
- pysisyphus/TableFormatter.py +63 -0
- pysisyphus/TablePrinter.py +74 -0
- pysisyphus/__init__.py +12 -0
- pysisyphus/calculators/AFIR.py +452 -0
- pysisyphus/calculators/AnaPot.py +20 -0
- pysisyphus/calculators/AnaPot2.py +48 -0
- pysisyphus/calculators/AnaPot3.py +12 -0
- pysisyphus/calculators/AnaPot4.py +20 -0
- pysisyphus/calculators/AnaPotBase.py +337 -0
- pysisyphus/calculators/AnaPotCBM.py +25 -0
- pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
- pysisyphus/calculators/CFOUR.py +250 -0
- pysisyphus/calculators/Calculator.py +844 -0
- pysisyphus/calculators/CerjanMiller.py +24 -0
- pysisyphus/calculators/Composite.py +123 -0
- pysisyphus/calculators/ConicalIntersection.py +171 -0
- pysisyphus/calculators/DFTBp.py +430 -0
- pysisyphus/calculators/DFTD3.py +66 -0
- pysisyphus/calculators/DFTD4.py +84 -0
- pysisyphus/calculators/Dalton.py +61 -0
- pysisyphus/calculators/Dimer.py +681 -0
- pysisyphus/calculators/Dummy.py +20 -0
- pysisyphus/calculators/EGO.py +76 -0
- pysisyphus/calculators/EnergyMin.py +224 -0
- pysisyphus/calculators/ExternalPotential.py +264 -0
- pysisyphus/calculators/FakeASE.py +35 -0
- pysisyphus/calculators/FourWellAnaPot.py +28 -0
- pysisyphus/calculators/FreeEndNEBPot.py +39 -0
- pysisyphus/calculators/Gaussian09.py +18 -0
- pysisyphus/calculators/Gaussian16.py +726 -0
- pysisyphus/calculators/HardSphere.py +159 -0
- pysisyphus/calculators/IDPPCalculator.py +49 -0
- pysisyphus/calculators/IPIClient.py +133 -0
- pysisyphus/calculators/IPIServer.py +234 -0
- pysisyphus/calculators/LEPSBase.py +24 -0
- pysisyphus/calculators/LEPSExpr.py +139 -0
- pysisyphus/calculators/LennardJones.py +80 -0
- pysisyphus/calculators/MOPAC.py +219 -0
- pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
- pysisyphus/calculators/MultiCalc.py +85 -0
- pysisyphus/calculators/NFK.py +45 -0
- pysisyphus/calculators/OBabel.py +87 -0
- pysisyphus/calculators/ONIOMv2.py +1129 -0
- pysisyphus/calculators/ORCA.py +893 -0
- pysisyphus/calculators/ORCA5.py +6 -0
- pysisyphus/calculators/OpenMM.py +88 -0
- pysisyphus/calculators/OpenMolcas.py +281 -0
- pysisyphus/calculators/OverlapCalculator.py +908 -0
- pysisyphus/calculators/Psi4.py +218 -0
- pysisyphus/calculators/PyPsi4.py +37 -0
- pysisyphus/calculators/PySCF.py +341 -0
- pysisyphus/calculators/PyXTB.py +73 -0
- pysisyphus/calculators/QCEngine.py +106 -0
- pysisyphus/calculators/Rastrigin.py +22 -0
- pysisyphus/calculators/Remote.py +76 -0
- pysisyphus/calculators/Rosenbrock.py +15 -0
- pysisyphus/calculators/SocketCalc.py +97 -0
- pysisyphus/calculators/TIP3P.py +111 -0
- pysisyphus/calculators/TransTorque.py +161 -0
- pysisyphus/calculators/Turbomole.py +965 -0
- pysisyphus/calculators/VRIPot.py +37 -0
- pysisyphus/calculators/WFOWrapper.py +333 -0
- pysisyphus/calculators/WFOWrapper2.py +341 -0
- pysisyphus/calculators/XTB.py +418 -0
- pysisyphus/calculators/__init__.py +81 -0
- pysisyphus/calculators/cosmo_data.py +139 -0
- pysisyphus/calculators/parser.py +150 -0
- pysisyphus/color.py +19 -0
- pysisyphus/config.py +133 -0
- pysisyphus/constants.py +65 -0
- pysisyphus/cos/AdaptiveNEB.py +230 -0
- pysisyphus/cos/ChainOfStates.py +725 -0
- pysisyphus/cos/FreeEndNEB.py +25 -0
- pysisyphus/cos/FreezingString.py +103 -0
- pysisyphus/cos/GrowingChainOfStates.py +71 -0
- pysisyphus/cos/GrowingNT.py +309 -0
- pysisyphus/cos/GrowingString.py +508 -0
- pysisyphus/cos/NEB.py +189 -0
- pysisyphus/cos/SimpleZTS.py +64 -0
- pysisyphus/cos/__init__.py +22 -0
- pysisyphus/cos/stiffness.py +199 -0
- pysisyphus/drivers/__init__.py +17 -0
- pysisyphus/drivers/afir.py +855 -0
- pysisyphus/drivers/barriers.py +271 -0
- pysisyphus/drivers/birkholz.py +138 -0
- pysisyphus/drivers/cluster.py +318 -0
- pysisyphus/drivers/diabatization.py +133 -0
- pysisyphus/drivers/merge.py +368 -0
- pysisyphus/drivers/merge_mol2.py +322 -0
- pysisyphus/drivers/opt.py +375 -0
- pysisyphus/drivers/perf.py +91 -0
- pysisyphus/drivers/pka.py +52 -0
- pysisyphus/drivers/precon_pos_rot.py +669 -0
- pysisyphus/drivers/rates.py +480 -0
- pysisyphus/drivers/replace.py +219 -0
- pysisyphus/drivers/scan.py +212 -0
- pysisyphus/drivers/spectrum.py +166 -0
- pysisyphus/drivers/thermo.py +31 -0
- pysisyphus/dynamics/Gaussian.py +103 -0
- pysisyphus/dynamics/__init__.py +20 -0
- pysisyphus/dynamics/colvars.py +136 -0
- pysisyphus/dynamics/driver.py +297 -0
- pysisyphus/dynamics/helpers.py +256 -0
- pysisyphus/dynamics/lincs.py +105 -0
- pysisyphus/dynamics/mdp.py +364 -0
- pysisyphus/dynamics/rattle.py +121 -0
- pysisyphus/dynamics/thermostats.py +128 -0
- pysisyphus/dynamics/wigner.py +266 -0
- pysisyphus/elem_data.py +3473 -0
- pysisyphus/exceptions.py +2 -0
- pysisyphus/filtertrj.py +69 -0
- pysisyphus/helpers.py +623 -0
- pysisyphus/helpers_pure.py +649 -0
- pysisyphus/init_logging.py +50 -0
- pysisyphus/intcoords/Bend.py +69 -0
- pysisyphus/intcoords/Bend2.py +25 -0
- pysisyphus/intcoords/BondedFragment.py +32 -0
- pysisyphus/intcoords/Cartesian.py +41 -0
- pysisyphus/intcoords/CartesianCoords.py +140 -0
- pysisyphus/intcoords/Coords.py +56 -0
- pysisyphus/intcoords/DLC.py +197 -0
- pysisyphus/intcoords/DistanceFunction.py +34 -0
- pysisyphus/intcoords/DummyImproper.py +70 -0
- pysisyphus/intcoords/DummyTorsion.py +72 -0
- pysisyphus/intcoords/LinearBend.py +105 -0
- pysisyphus/intcoords/LinearDisplacement.py +80 -0
- pysisyphus/intcoords/OutOfPlane.py +59 -0
- pysisyphus/intcoords/PrimTypes.py +286 -0
- pysisyphus/intcoords/Primitive.py +137 -0
- pysisyphus/intcoords/RedundantCoords.py +659 -0
- pysisyphus/intcoords/RobustTorsion.py +59 -0
- pysisyphus/intcoords/Rotation.py +147 -0
- pysisyphus/intcoords/Stretch.py +31 -0
- pysisyphus/intcoords/Torsion.py +101 -0
- pysisyphus/intcoords/Torsion2.py +25 -0
- pysisyphus/intcoords/Translation.py +45 -0
- pysisyphus/intcoords/__init__.py +61 -0
- pysisyphus/intcoords/augment_bonds.py +126 -0
- pysisyphus/intcoords/derivatives.py +10512 -0
- pysisyphus/intcoords/eval.py +80 -0
- pysisyphus/intcoords/exceptions.py +37 -0
- pysisyphus/intcoords/findiffs.py +48 -0
- pysisyphus/intcoords/generate_derivatives.py +414 -0
- pysisyphus/intcoords/helpers.py +235 -0
- pysisyphus/intcoords/logging_conf.py +10 -0
- pysisyphus/intcoords/mp_derivatives.py +10836 -0
- pysisyphus/intcoords/setup.py +962 -0
- pysisyphus/intcoords/setup_fast.py +176 -0
- pysisyphus/intcoords/update.py +272 -0
- pysisyphus/intcoords/valid.py +89 -0
- pysisyphus/interpolate/Geodesic.py +93 -0
- pysisyphus/interpolate/IDPP.py +55 -0
- pysisyphus/interpolate/Interpolator.py +116 -0
- pysisyphus/interpolate/LST.py +70 -0
- pysisyphus/interpolate/Redund.py +152 -0
- pysisyphus/interpolate/__init__.py +9 -0
- pysisyphus/interpolate/helpers.py +34 -0
- pysisyphus/io/__init__.py +22 -0
- pysisyphus/io/aomix.py +178 -0
- pysisyphus/io/cjson.py +24 -0
- pysisyphus/io/crd.py +101 -0
- pysisyphus/io/cube.py +220 -0
- pysisyphus/io/fchk.py +184 -0
- pysisyphus/io/hdf5.py +49 -0
- pysisyphus/io/hessian.py +72 -0
- pysisyphus/io/mol2.py +146 -0
- pysisyphus/io/molden.py +293 -0
- pysisyphus/io/orca.py +189 -0
- pysisyphus/io/pdb.py +269 -0
- pysisyphus/io/psf.py +79 -0
- pysisyphus/io/pubchem.py +31 -0
- pysisyphus/io/qcschema.py +34 -0
- pysisyphus/io/sdf.py +29 -0
- pysisyphus/io/xyz.py +61 -0
- pysisyphus/io/zmat.py +175 -0
- pysisyphus/irc/DWI.py +108 -0
- pysisyphus/irc/DampedVelocityVerlet.py +134 -0
- pysisyphus/irc/Euler.py +22 -0
- pysisyphus/irc/EulerPC.py +345 -0
- pysisyphus/irc/GonzalezSchlegel.py +187 -0
- pysisyphus/irc/IMKMod.py +164 -0
- pysisyphus/irc/IRC.py +878 -0
- pysisyphus/irc/IRCDummy.py +10 -0
- pysisyphus/irc/Instanton.py +307 -0
- pysisyphus/irc/LQA.py +53 -0
- pysisyphus/irc/ModeKill.py +136 -0
- pysisyphus/irc/ParamPlot.py +53 -0
- pysisyphus/irc/RK4.py +36 -0
- pysisyphus/irc/__init__.py +31 -0
- pysisyphus/irc/initial_displ.py +219 -0
- pysisyphus/linalg.py +411 -0
- pysisyphus/line_searches/Backtracking.py +88 -0
- pysisyphus/line_searches/HagerZhang.py +184 -0
- pysisyphus/line_searches/LineSearch.py +232 -0
- pysisyphus/line_searches/StrongWolfe.py +108 -0
- pysisyphus/line_searches/__init__.py +9 -0
- pysisyphus/line_searches/interpol.py +15 -0
- pysisyphus/modefollow/NormalMode.py +40 -0
- pysisyphus/modefollow/__init__.py +10 -0
- pysisyphus/modefollow/davidson.py +199 -0
- pysisyphus/modefollow/lanczos.py +95 -0
- pysisyphus/optimizers/BFGS.py +99 -0
- pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
- pysisyphus/optimizers/ConjugateGradient.py +98 -0
- pysisyphus/optimizers/CubicNewton.py +75 -0
- pysisyphus/optimizers/FIRE.py +113 -0
- pysisyphus/optimizers/HessianOptimizer.py +1176 -0
- pysisyphus/optimizers/LBFGS.py +228 -0
- pysisyphus/optimizers/LayerOpt.py +411 -0
- pysisyphus/optimizers/MicroOptimizer.py +169 -0
- pysisyphus/optimizers/NCOptimizer.py +90 -0
- pysisyphus/optimizers/Optimizer.py +1084 -0
- pysisyphus/optimizers/PreconLBFGS.py +260 -0
- pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
- pysisyphus/optimizers/QuickMin.py +74 -0
- pysisyphus/optimizers/RFOptimizer.py +181 -0
- pysisyphus/optimizers/RSA.py +99 -0
- pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
- pysisyphus/optimizers/SteepestDescent.py +23 -0
- pysisyphus/optimizers/StringOptimizer.py +173 -0
- pysisyphus/optimizers/__init__.py +41 -0
- pysisyphus/optimizers/closures.py +301 -0
- pysisyphus/optimizers/cls_map.py +58 -0
- pysisyphus/optimizers/exceptions.py +6 -0
- pysisyphus/optimizers/gdiis.py +280 -0
- pysisyphus/optimizers/guess_hessians.py +311 -0
- pysisyphus/optimizers/hessian_updates.py +355 -0
- pysisyphus/optimizers/poly_fit.py +285 -0
- pysisyphus/optimizers/precon.py +153 -0
- pysisyphus/optimizers/restrict_step.py +24 -0
- pysisyphus/pack.py +172 -0
- pysisyphus/peakdetect.py +948 -0
- pysisyphus/plot.py +1031 -0
- pysisyphus/run.py +2106 -0
- pysisyphus/socket_helper.py +74 -0
- pysisyphus/stocastic/FragmentKick.py +132 -0
- pysisyphus/stocastic/Kick.py +81 -0
- pysisyphus/stocastic/Pipeline.py +303 -0
- pysisyphus/stocastic/__init__.py +21 -0
- pysisyphus/stocastic/align.py +127 -0
- pysisyphus/testing.py +96 -0
- pysisyphus/thermo.py +156 -0
- pysisyphus/trj.py +824 -0
- pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
- pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
- pysisyphus/tsoptimizers/TRIM.py +59 -0
- pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
- pysisyphus/tsoptimizers/__init__.py +23 -0
- pysisyphus/wavefunction/Basis.py +239 -0
- pysisyphus/wavefunction/DIIS.py +76 -0
- pysisyphus/wavefunction/__init__.py +25 -0
- pysisyphus/wavefunction/build_ext.py +42 -0
- pysisyphus/wavefunction/cart2sph.py +190 -0
- pysisyphus/wavefunction/diabatization.py +304 -0
- pysisyphus/wavefunction/excited_states.py +435 -0
- pysisyphus/wavefunction/gen_ints.py +1811 -0
- pysisyphus/wavefunction/helpers.py +104 -0
- pysisyphus/wavefunction/ints/__init__.py +0 -0
- pysisyphus/wavefunction/ints/boys.py +193 -0
- pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
- pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
- pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
- pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
- pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
- pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
- pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
- pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
- pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
- pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
- pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
- pysisyphus/wavefunction/localization.py +458 -0
- pysisyphus/wavefunction/multipole.py +159 -0
- pysisyphus/wavefunction/normalization.py +36 -0
- pysisyphus/wavefunction/pop_analysis.py +134 -0
- pysisyphus/wavefunction/shells.py +1171 -0
- pysisyphus/wavefunction/wavefunction.py +504 -0
- pysisyphus/wrapper/__init__.py +11 -0
- pysisyphus/wrapper/exceptions.py +2 -0
- pysisyphus/wrapper/jmol.py +120 -0
- pysisyphus/wrapper/mwfn.py +169 -0
- pysisyphus/wrapper/packmol.py +71 -0
- pysisyphus/xyzloader.py +168 -0
- pysisyphus/yaml_mods.py +45 -0
- thermoanalysis/LICENSE +674 -0
- thermoanalysis/QCData.py +244 -0
- thermoanalysis/__init__.py +0 -0
- thermoanalysis/config.py +3 -0
- thermoanalysis/constants.py +20 -0
- thermoanalysis/thermo.py +1011 -0
|
@@ -0,0 +1,1176 @@
|
|
|
1
|
+
from math import sqrt
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Literal, Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy.optimize import root_scalar
|
|
7
|
+
|
|
8
|
+
from pysisyphus.cos.ChainOfStates import ChainOfStates
|
|
9
|
+
from pysisyphus.Geometry import Geometry
|
|
10
|
+
from pysisyphus.helpers_pure import rms
|
|
11
|
+
# from pysisyphus.io.hessian import save_hessian
|
|
12
|
+
from pysisyphus.optimizers.guess_hessians import (
|
|
13
|
+
get_guess_hessian,
|
|
14
|
+
xtb_hessian,
|
|
15
|
+
HessInit,
|
|
16
|
+
)
|
|
17
|
+
from pysisyphus.optimizers.hessian_updates import (
|
|
18
|
+
bfgs_update,
|
|
19
|
+
flowchart_update,
|
|
20
|
+
damped_bfgs_update,
|
|
21
|
+
bofill_update,
|
|
22
|
+
ts_bfgs_update,
|
|
23
|
+
ts_bfgs_update_org,
|
|
24
|
+
ts_bfgs_update_revised,
|
|
25
|
+
)
|
|
26
|
+
from pysisyphus.optimizers.Optimizer import Optimizer
|
|
27
|
+
from pysisyphus.optimizers.exceptions import OptimizationError
|
|
28
|
+
|
|
29
|
+
from pysisyphus.helpers import array2string
|
|
30
|
+
import torch
|
|
31
|
+
|
|
32
|
+
def dummy_hessian_update(H, dx, dg):
|
|
33
|
+
return np.zeros_like(H), "no"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
HESS_UPDATE_FUNCS = {
|
|
37
|
+
"none": dummy_hessian_update,
|
|
38
|
+
None: dummy_hessian_update,
|
|
39
|
+
False: dummy_hessian_update,
|
|
40
|
+
"bfgs": bfgs_update,
|
|
41
|
+
"damped_bfgs": damped_bfgs_update,
|
|
42
|
+
"flowchart": flowchart_update,
|
|
43
|
+
"bofill": bofill_update,
|
|
44
|
+
"ts_bfgs": ts_bfgs_update,
|
|
45
|
+
"ts_bfgs_org": ts_bfgs_update_org,
|
|
46
|
+
"ts_bfgs_rev": ts_bfgs_update_revised,
|
|
47
|
+
}
|
|
48
|
+
HessUpdate = Literal[
|
|
49
|
+
"none",
|
|
50
|
+
None,
|
|
51
|
+
False,
|
|
52
|
+
"bfgs",
|
|
53
|
+
"damped_bfgs",
|
|
54
|
+
"flowchart",
|
|
55
|
+
"bofill",
|
|
56
|
+
"ts_bfgs",
|
|
57
|
+
"ts_bfgs_org",
|
|
58
|
+
"ts_bfgs_rev",
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class HessianOptimizer(Optimizer):
|
|
63
|
+
rfo_dict = {
|
|
64
|
+
"min": (0, "min"),
|
|
65
|
+
"max": (-1, "max"),
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
geometry: Geometry,
|
|
71
|
+
trust_radius: float = 0.5,
|
|
72
|
+
trust_update: bool = True,
|
|
73
|
+
trust_min: float = 0.1,
|
|
74
|
+
trust_max: float = 1,
|
|
75
|
+
max_energy_incr: Optional[float] = None,
|
|
76
|
+
hessian_update: HessUpdate = "bfgs",
|
|
77
|
+
hessian_init: HessInit = "fischer",
|
|
78
|
+
hessian_recalc: Optional[int] = None,
|
|
79
|
+
hessian_recalc_adapt: Optional[float] = None,
|
|
80
|
+
hessian_xtb: bool = False,
|
|
81
|
+
hessian_recalc_reset: bool = False,
|
|
82
|
+
small_eigval_thresh: float = 1e-8,
|
|
83
|
+
line_search: bool = False,
|
|
84
|
+
alpha0: float = 1.0,
|
|
85
|
+
max_micro_cycles: int = 25,
|
|
86
|
+
rfo_overlaps: bool = False,
|
|
87
|
+
**kwargs,
|
|
88
|
+
) -> None:
|
|
89
|
+
"""Baseclass for optimizers utilizing Hessian information.
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
geometry
|
|
94
|
+
Geometry to be optimized.
|
|
95
|
+
trust_radius
|
|
96
|
+
Initial trust radius in whatever unit the optimization is carried out.
|
|
97
|
+
trust_update
|
|
98
|
+
Whether to update the trust radius throughout the optimization.
|
|
99
|
+
trust_min
|
|
100
|
+
Minimum trust radius.
|
|
101
|
+
trust_max
|
|
102
|
+
Maximum trust radius.
|
|
103
|
+
max_energy_incr
|
|
104
|
+
Maximum allowed energy increased after a faulty step. Optimization is
|
|
105
|
+
aborted when the threshold is exceeded.
|
|
106
|
+
hessian_update
|
|
107
|
+
Type of Hessian update. Defaults to BFGS for minimizations and Bofill
|
|
108
|
+
for saddle point searches.
|
|
109
|
+
hessian_init
|
|
110
|
+
Type of initial model Hessian.
|
|
111
|
+
hessian_recalc
|
|
112
|
+
Recalculate exact Hessian every n-th cycle instead of updating it.
|
|
113
|
+
hessian_recalc_adapt
|
|
114
|
+
Use a more flexible scheme to determine Hessian recalculation. Undocumented.
|
|
115
|
+
hessian_xtb
|
|
116
|
+
Recalculate the Hessian at the GFN2-XTB level of theory.
|
|
117
|
+
hessian_recalc_reset
|
|
118
|
+
Whether to skip Hessian recalculation after reset. Undocumented.
|
|
119
|
+
small_eigval_thresh
|
|
120
|
+
Threshold for small eigenvalues. Eigenvectors belonging to eigenvalues
|
|
121
|
+
below this threshold are discardewd.
|
|
122
|
+
line_search
|
|
123
|
+
Whether to carry out a line search. Not implemented by a subclassing
|
|
124
|
+
optimizers.
|
|
125
|
+
alpha0
|
|
126
|
+
Initial alpha for restricted-step (RS) procedure.
|
|
127
|
+
max_micro_cycles
|
|
128
|
+
Maximum number of RS iterations.
|
|
129
|
+
rfo_overlaps
|
|
130
|
+
Enable mode-following in RS procedure.
|
|
131
|
+
|
|
132
|
+
Other Parameters
|
|
133
|
+
----------------
|
|
134
|
+
**kwargs
|
|
135
|
+
Keyword arguments passed to the Optimizer baseclass.
|
|
136
|
+
"""
|
|
137
|
+
super().__init__(geometry, **kwargs)
|
|
138
|
+
|
|
139
|
+
assert not issubclass(
|
|
140
|
+
type(geometry), ChainOfStates
|
|
141
|
+
), "HessianOptimizer can't be used for and ChainOfStates objects!"
|
|
142
|
+
|
|
143
|
+
self.trust_update = bool(trust_update)
|
|
144
|
+
assert trust_min <= trust_max, "trust_min must be <= trust_max!"
|
|
145
|
+
self.trust_min = float(trust_min)
|
|
146
|
+
self.trust_max = float(trust_max)
|
|
147
|
+
self.max_energy_incr = max_energy_incr
|
|
148
|
+
# Constrain initial trust radius if trust_max > trust_radius
|
|
149
|
+
self.trust_radius = min(trust_radius, trust_max)
|
|
150
|
+
self.log(f"Initial trust radius: {self.trust_radius:.6f}")
|
|
151
|
+
self.hessian_update = hessian_update
|
|
152
|
+
self.hessian_update_func = HESS_UPDATE_FUNCS[hessian_update]
|
|
153
|
+
self.hessian_init = hessian_init
|
|
154
|
+
self.hessian_recalc = hessian_recalc
|
|
155
|
+
self.hessian_recalc_adapt = hessian_recalc_adapt
|
|
156
|
+
self.hessian_xtb = hessian_xtb
|
|
157
|
+
self.hessian_recalc_reset = hessian_recalc_reset
|
|
158
|
+
self.small_eigval_thresh = float(small_eigval_thresh)
|
|
159
|
+
self.line_search = bool(line_search)
|
|
160
|
+
# Restricted-step related
|
|
161
|
+
self.alpha0 = alpha0
|
|
162
|
+
self.max_micro_cycles = int(max_micro_cycles)
|
|
163
|
+
assert max_micro_cycles >= 0
|
|
164
|
+
self.rfo_overlaps = rfo_overlaps
|
|
165
|
+
|
|
166
|
+
assert self.small_eigval_thresh > 0.0, "small_eigval_thresh must be > 0.!"
|
|
167
|
+
if not self.restarted:
|
|
168
|
+
self.hessian_recalc_in = None
|
|
169
|
+
self.adapt_norm = None
|
|
170
|
+
self.predicted_energy_changes = list()
|
|
171
|
+
hessian_init_exists = Path(self.hessian_init).exists()
|
|
172
|
+
if (
|
|
173
|
+
# Allow actually calculated Hessians for all coordinate systems
|
|
174
|
+
not hessian_init_exists
|
|
175
|
+
and self.hessian_init not in ("calc", "xtb", "xtb1", "xtbff")
|
|
176
|
+
# But disable model Hessian for Cartesian optimizations
|
|
177
|
+
and self.geometry.coord_type in ("cart", "cartesian", "mwcartesian")
|
|
178
|
+
):
|
|
179
|
+
self.hessian_init = "unit"
|
|
180
|
+
self.log(
|
|
181
|
+
f"Chosen initial (model) Hessian is incompatible with current "
|
|
182
|
+
f"coord_type: {self.geometry.coord_type}!"
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
self._prev_eigvec_min = None
|
|
186
|
+
self._prev_eigvec_max = None
|
|
187
|
+
self._using_active_dofs = False
|
|
188
|
+
self._active_dof_indices = None
|
|
189
|
+
self.cur_H = None
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def using_active_dofs(self):
|
|
193
|
+
return self._using_active_dofs
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def active_dof_indices(self):
|
|
197
|
+
return self._active_dof_indices
|
|
198
|
+
|
|
199
|
+
def _set_active_dofs(self, use_active):
|
|
200
|
+
self._using_active_dofs = use_active
|
|
201
|
+
if not use_active:
|
|
202
|
+
self._active_dof_indices = None
|
|
203
|
+
return
|
|
204
|
+
if getattr(self.geometry, "within_partial_hessian", None) is not None:
|
|
205
|
+
self._active_dof_indices = self.geometry.hess_active_dof_indices
|
|
206
|
+
return
|
|
207
|
+
# Fallback: infer active DOFs from calculator's Hessian-active atoms
|
|
208
|
+
calc = getattr(self.geometry, "calculator", None)
|
|
209
|
+
core = getattr(calc, "core", calc)
|
|
210
|
+
hess_atoms = getattr(core, "hess_active_atoms", None)
|
|
211
|
+
if hess_atoms is not None and len(hess_atoms) > 0:
|
|
212
|
+
act = []
|
|
213
|
+
for a in hess_atoms:
|
|
214
|
+
base = 3 * int(a)
|
|
215
|
+
act.extend([base, base + 1, base + 2])
|
|
216
|
+
self._active_dof_indices = np.asarray(act, dtype=int)
|
|
217
|
+
return
|
|
218
|
+
self._active_dof_indices = self.geometry.active_dof_indices
|
|
219
|
+
|
|
220
|
+
def active_from_full(self, vec):
|
|
221
|
+
if not self.using_active_dofs:
|
|
222
|
+
return vec
|
|
223
|
+
inds = self._active_dof_indices
|
|
224
|
+
if inds is None:
|
|
225
|
+
return vec
|
|
226
|
+
# Sanitize indices (drop negatives / out-of-bounds)
|
|
227
|
+
try:
|
|
228
|
+
inds = np.asarray(inds, dtype=int)
|
|
229
|
+
if len(inds) > 0:
|
|
230
|
+
if np.any(inds < 0):
|
|
231
|
+
inds = inds[inds >= 0]
|
|
232
|
+
if len(inds) > 0:
|
|
233
|
+
max_valid = vec.shape[0] - 1
|
|
234
|
+
if np.any(inds > max_valid):
|
|
235
|
+
inds = inds[inds <= max_valid]
|
|
236
|
+
if len(inds) > 0:
|
|
237
|
+
if np.min(inds) < 0 or np.max(inds) >= vec.shape[0]:
|
|
238
|
+
return vec
|
|
239
|
+
except (ValueError, IndexError, TypeError):
|
|
240
|
+
pass
|
|
241
|
+
# Avoid double-slicing if vector is already in active space
|
|
242
|
+
try:
|
|
243
|
+
if vec.shape[0] == len(inds):
|
|
244
|
+
return vec
|
|
245
|
+
except (ValueError, IndexError, TypeError):
|
|
246
|
+
pass
|
|
247
|
+
try:
|
|
248
|
+
if len(inds) > 0 and vec.shape[0] <= int(np.max(inds)):
|
|
249
|
+
# Indices exceed vector length → already active (compact) vector.
|
|
250
|
+
return vec
|
|
251
|
+
except (ValueError, IndexError, TypeError):
|
|
252
|
+
pass
|
|
253
|
+
if isinstance(vec, torch.Tensor):
|
|
254
|
+
if vec.device.type == "cuda":
|
|
255
|
+
try:
|
|
256
|
+
vec_cpu = vec.detach().cpu().numpy()
|
|
257
|
+
return torch.as_tensor(vec_cpu[inds], dtype=vec.dtype, device=vec.device)
|
|
258
|
+
except (ValueError, IndexError, TypeError):
|
|
259
|
+
return vec
|
|
260
|
+
idx = torch.as_tensor(inds, dtype=torch.long, device=vec.device)
|
|
261
|
+
return vec.index_select(0, idx)
|
|
262
|
+
return vec[inds]
|
|
263
|
+
|
|
264
|
+
def full_from_active(self, vec):
|
|
265
|
+
if not self.using_active_dofs:
|
|
266
|
+
return vec
|
|
267
|
+
inds = self._active_dof_indices
|
|
268
|
+
if inds is None:
|
|
269
|
+
return vec
|
|
270
|
+
if isinstance(vec, torch.Tensor):
|
|
271
|
+
idx = torch.as_tensor(inds, dtype=torch.long, device=vec.device)
|
|
272
|
+
full = torch.zeros(self.geometry.cart_coords.size, dtype=vec.dtype, device=vec.device)
|
|
273
|
+
full.index_copy_(0, idx, vec)
|
|
274
|
+
return full
|
|
275
|
+
full = np.zeros(self.geometry.cart_coords.size, dtype=vec.dtype if hasattr(vec, "dtype") else float)
|
|
276
|
+
full[inds] = vec
|
|
277
|
+
return full
|
|
278
|
+
|
|
279
|
+
def active_hessian(self, hessian):
|
|
280
|
+
if not self.using_active_dofs:
|
|
281
|
+
return hessian
|
|
282
|
+
|
|
283
|
+
if getattr(self.geometry, "within_partial_hessian", None) is not None:
|
|
284
|
+
act_n_dof = int(self.geometry.within_partial_hessian.get("active_n_dof", 0))
|
|
285
|
+
if hessian.shape == (act_n_dof, act_n_dof):
|
|
286
|
+
return hessian
|
|
287
|
+
|
|
288
|
+
inds = self.active_dof_indices
|
|
289
|
+
try:
|
|
290
|
+
inds_arr = np.asarray(inds, dtype=int)
|
|
291
|
+
if hessian.shape[0] == len(inds_arr) and len(inds_arr) > 0:
|
|
292
|
+
if np.max(inds_arr) >= hessian.shape[0]:
|
|
293
|
+
# Likely already in active order; avoid double-slicing.
|
|
294
|
+
return hessian
|
|
295
|
+
except (ValueError, IndexError, TypeError):
|
|
296
|
+
pass
|
|
297
|
+
try:
|
|
298
|
+
inds = np.asarray(inds, dtype=int)
|
|
299
|
+
if len(inds) > 0:
|
|
300
|
+
max_valid = hessian.shape[0] - 1
|
|
301
|
+
inds = inds[(inds >= 0) & (inds <= max_valid)]
|
|
302
|
+
except (ValueError, IndexError, TypeError):
|
|
303
|
+
pass
|
|
304
|
+
if isinstance(hessian, torch.Tensor):
|
|
305
|
+
if hessian.device.type == "cuda":
|
|
306
|
+
try:
|
|
307
|
+
hess_cpu = hessian.detach().cpu().numpy()
|
|
308
|
+
return torch.as_tensor(hess_cpu[np.ix_(inds, inds)], dtype=hessian.dtype, device=hessian.device)
|
|
309
|
+
except (ValueError, IndexError, TypeError):
|
|
310
|
+
return hessian
|
|
311
|
+
idx = torch.as_tensor(inds, device=hessian.device, dtype=torch.int64)
|
|
312
|
+
return hessian.index_select(0, idx).index_select(1, idx)
|
|
313
|
+
return hessian[np.ix_(inds, inds)]
|
|
314
|
+
|
|
315
|
+
def active_list(self, seq):
|
|
316
|
+
if not self.using_active_dofs:
|
|
317
|
+
return seq
|
|
318
|
+
return [self.active_from_full(item) for item in seq]
|
|
319
|
+
|
|
320
|
+
@property
|
|
321
|
+
def prev_eigvec_min(self):
|
|
322
|
+
return self._prev_eigvec_min
|
|
323
|
+
|
|
324
|
+
@prev_eigvec_min.setter
|
|
325
|
+
def prev_eigvec_min(self, prev_eigvec_min):
|
|
326
|
+
if self.rfo_overlaps:
|
|
327
|
+
self._prev_eigvec_min = prev_eigvec_min
|
|
328
|
+
|
|
329
|
+
@property
|
|
330
|
+
def prev_eigvec_max(self):
|
|
331
|
+
return self._prev_eigvec_max
|
|
332
|
+
|
|
333
|
+
@prev_eigvec_max.setter
|
|
334
|
+
def prev_eigvec_max(self, prev_eigvec_max):
|
|
335
|
+
if self.rfo_overlaps:
|
|
336
|
+
self._prev_eigvec_max = prev_eigvec_max
|
|
337
|
+
|
|
338
|
+
def reset(self):
|
|
339
|
+
# Don't recalculate the hessian if we have to reset the optimizer
|
|
340
|
+
hessian_init = self.hessian_init
|
|
341
|
+
if (
|
|
342
|
+
(not self.hessian_recalc_reset)
|
|
343
|
+
and hessian_init == "calc"
|
|
344
|
+
and self.geometry.coord_type != "cart"
|
|
345
|
+
):
|
|
346
|
+
hessian_init = "fischer"
|
|
347
|
+
self.prepare_opt(hessian_init)
|
|
348
|
+
|
|
349
|
+
# def save_hessian(self):
|
|
350
|
+
# # Don't try to save Hessians of analytical potentials
|
|
351
|
+
# if self.geometry.is_analytical_2d:
|
|
352
|
+
# return
|
|
353
|
+
|
|
354
|
+
# h5_fn = self.get_path_for_fn(f"hess_calc_cyc_{self.cur_cycle}.h5")
|
|
355
|
+
# # Save the cartesian hessian, as it is independent of the
|
|
356
|
+
# # actual coordinate system that is used.
|
|
357
|
+
# save_hessian(
|
|
358
|
+
# h5_fn,
|
|
359
|
+
# self.geometry,
|
|
360
|
+
# self.geometry.cart_hessian,
|
|
361
|
+
# self.geometry.energy,
|
|
362
|
+
# self.geometry.calculator.mult,
|
|
363
|
+
# )
|
|
364
|
+
# self.log(f"Wrote calculated cartesian Hessian to '{h5_fn}'")
|
|
365
|
+
|
|
366
|
+
def prepare_opt(self, hessian_init=None):
|
|
367
|
+
if hessian_init is None:
|
|
368
|
+
hessian_init = self.hessian_init
|
|
369
|
+
|
|
370
|
+
self.H, hess_str = get_guess_hessian(self.geometry, hessian_init)
|
|
371
|
+
if self.hessian_init != "calc" and self.geometry.is_analytical_2d:
|
|
372
|
+
assert self.H.shape == (3, 3)
|
|
373
|
+
self.H[2, 2] = 0.0
|
|
374
|
+
|
|
375
|
+
msg = f"Using {hess_str} Hessian"
|
|
376
|
+
if hess_str == "saved":
|
|
377
|
+
msg += f" from '{hessian_init}'"
|
|
378
|
+
self.log(msg)
|
|
379
|
+
|
|
380
|
+
# # Dump to disk if hessian was calculated
|
|
381
|
+
# if self.hessian_init == "calc":
|
|
382
|
+
# self.save_hessian()
|
|
383
|
+
|
|
384
|
+
if (
|
|
385
|
+
hasattr(self.geometry, "coord_type")
|
|
386
|
+
and self.geometry.coord_type == "dlc"
|
|
387
|
+
# Calculated Hessian is already in DLC
|
|
388
|
+
and hessian_init != "calc"
|
|
389
|
+
):
|
|
390
|
+
U = self.geometry.internal.U
|
|
391
|
+
self.H = U.T.dot(self.H).dot(U)
|
|
392
|
+
|
|
393
|
+
if self.hessian_recalc_adapt:
|
|
394
|
+
self.adapt_norm = np.linalg.norm(self.geometry.forces)
|
|
395
|
+
|
|
396
|
+
if self.hessian_recalc:
|
|
397
|
+
# Already substract one, as we don't do a hessian update in
|
|
398
|
+
# the first cycle.
|
|
399
|
+
self.hessian_recalc_in = self.hessian_recalc - 1
|
|
400
|
+
|
|
401
|
+
def _get_opt_restart_info(self):
|
|
402
|
+
opt_restart_info = {
|
|
403
|
+
"adapt_norm": self.adapt_norm,
|
|
404
|
+
"H": self.H.tolist(),
|
|
405
|
+
"hessian_recalc_in": self.hessian_recalc_in,
|
|
406
|
+
"predicted_energy_changes": self.predicted_energy_changes,
|
|
407
|
+
}
|
|
408
|
+
return opt_restart_info
|
|
409
|
+
|
|
410
|
+
def _set_opt_restart_info(self, opt_restart_info):
|
|
411
|
+
self.adapt_norm = opt_restart_info["adapt_norm"]
|
|
412
|
+
self.H = np.array(opt_restart_info["H"])
|
|
413
|
+
self.hessian_recalc_in = opt_restart_info["hessian_recalc_in"]
|
|
414
|
+
self.predicted_energy_changes = opt_restart_info["predicted_energy_changes"]
|
|
415
|
+
|
|
416
|
+
def update_trust_radius(self):
|
|
417
|
+
# The predicted change should be calculated at the end of optimize
|
|
418
|
+
# of the previous cycle.
|
|
419
|
+
assert (
|
|
420
|
+
len(self.predicted_energy_changes) == len(self.forces) - 1
|
|
421
|
+
), "Did you forget to append to self.predicted_energy_changes?"
|
|
422
|
+
self.log("Trust radius update")
|
|
423
|
+
self.log(f"\tCurrent trust radius: {self.trust_radius:.6f}")
|
|
424
|
+
predicted_change = self.predicted_energy_changes[-1]
|
|
425
|
+
actual_change = self.energies[-1] - self.energies[-2]
|
|
426
|
+
# Only report an unexpected increase if we actually predicted a
|
|
427
|
+
# decrease.
|
|
428
|
+
unexpected_increase = (actual_change > 0) and (predicted_change < 0)
|
|
429
|
+
old_trust = self.trust_radius
|
|
430
|
+
if unexpected_increase:
|
|
431
|
+
self.log(f"Energy increased by {actual_change:.6f} au!")
|
|
432
|
+
if self.max_energy_incr and (actual_change > self.max_energy_incr):
|
|
433
|
+
raise OptimizationError("Actual energy change too high!")
|
|
434
|
+
coeff = actual_change / predicted_change
|
|
435
|
+
self.log(f"\tPredicted change: {predicted_change:.4e} au")
|
|
436
|
+
self.log(f"\tActual change: {actual_change:.4e} au")
|
|
437
|
+
self.log(f"\tCoefficient: {coeff:.2%}")
|
|
438
|
+
step = self.steps[-1]
|
|
439
|
+
last_step_norm = np.linalg.norm(step)
|
|
440
|
+
self.set_new_trust_radius(coeff, last_step_norm)
|
|
441
|
+
if unexpected_increase:
|
|
442
|
+
self.table.print(
|
|
443
|
+
f"Unexpected energy increase ({actual_change:.6f} au)! "
|
|
444
|
+
f"Trust radius: old={old_trust:.4}, new={self.trust_radius:.4}"
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
def set_new_trust_radius(self, coeff, last_step_norm):
|
|
448
|
+
# Nocedal, Numerical optimization Chapter 4, Algorithm 4.1
|
|
449
|
+
|
|
450
|
+
# If actual and predicted energy change have different signs
|
|
451
|
+
# coeff will be negative and lead to a decreased trust radius,
|
|
452
|
+
# which is fine.
|
|
453
|
+
if coeff < 0.25:
|
|
454
|
+
self.trust_radius = max(self.trust_radius / 4, self.trust_min)
|
|
455
|
+
self.log("\tDecreasing trust radius.")
|
|
456
|
+
# Only increase trust radius if last step norm was at least 80% of it
|
|
457
|
+
# See [5], Appendix, step size and direction control
|
|
458
|
+
# elif coeff > 0.75 and (last_step_norm >= .8*self.trust_radius):
|
|
459
|
+
#
|
|
460
|
+
# Only increase trust radius if last step norm corresponded approximately
|
|
461
|
+
# to the trust radius.
|
|
462
|
+
elif coeff > 0.75 and abs(self.trust_radius - last_step_norm) <= 1e-3:
|
|
463
|
+
self.trust_radius = min(self.trust_radius * 2, self.trust_max)
|
|
464
|
+
self.log("\tIncreasing trust radius.")
|
|
465
|
+
else:
|
|
466
|
+
self.log(f"\tKeeping current trust radius at {self.trust_radius:.6f}")
|
|
467
|
+
return
|
|
468
|
+
self.log(f"\tUpdated trust radius: {self.trust_radius:.6f}")
|
|
469
|
+
|
|
470
|
+
def update_hessian(self):
|
|
471
|
+
# Compare current forces to reference forces to see if we shall recalc the
|
|
472
|
+
# hessian.
|
|
473
|
+
try:
|
|
474
|
+
cur_norm = np.linalg.norm(self.forces[-1])
|
|
475
|
+
ref_norm = self.adapt_norm / self.hessian_recalc_adapt
|
|
476
|
+
recalc_adapt = cur_norm <= ref_norm
|
|
477
|
+
self.log(
|
|
478
|
+
"Check for adaptive Hessian recalculation: "
|
|
479
|
+
f"{cur_norm:.6f} <= {ref_norm:.6f}, {recalc_adapt}"
|
|
480
|
+
)
|
|
481
|
+
except TypeError:
|
|
482
|
+
recalc_adapt = False
|
|
483
|
+
|
|
484
|
+
try:
|
|
485
|
+
self.hessian_recalc_in = max(self.hessian_recalc_in - 1, 0)
|
|
486
|
+
self.log(f"Recalculation of Hessian in {self.hessian_recalc_in} cycle(s).")
|
|
487
|
+
except TypeError:
|
|
488
|
+
self.hessian_recalc_in = None
|
|
489
|
+
|
|
490
|
+
# Update reference norm if needed
|
|
491
|
+
# TODO: Decide on whether to update the norm when the recalculation is
|
|
492
|
+
# initiated by 'recalc'.
|
|
493
|
+
if recalc_adapt:
|
|
494
|
+
self.adapt_norm = cur_norm
|
|
495
|
+
|
|
496
|
+
recalc = self.hessian_recalc_in == 0
|
|
497
|
+
|
|
498
|
+
if recalc or recalc_adapt:
|
|
499
|
+
# Free old Hessian from GPU before recalculating
|
|
500
|
+
H_old = self.H
|
|
501
|
+
self.H = None
|
|
502
|
+
del H_old
|
|
503
|
+
try:
|
|
504
|
+
import torch
|
|
505
|
+
if torch.cuda.is_available():
|
|
506
|
+
torch.cuda.empty_cache()
|
|
507
|
+
except ImportError:
|
|
508
|
+
pass
|
|
509
|
+
# Use xtb hessian
|
|
510
|
+
self.log("Requested Hessian recalculation.")
|
|
511
|
+
if self.hessian_xtb:
|
|
512
|
+
self.H = xtb_hessian(self.geometry)
|
|
513
|
+
key = "xtb"
|
|
514
|
+
# Calculated hessian at actual level of theory
|
|
515
|
+
else:
|
|
516
|
+
self.H = self.geometry.hessian
|
|
517
|
+
key = "exact"
|
|
518
|
+
# self.save_hessian()
|
|
519
|
+
if self.using_active_dofs:
|
|
520
|
+
# Keep the optimizer Hessian in active-DOF space to avoid
|
|
521
|
+
# shape mismatches during quasi-Newton updates.
|
|
522
|
+
self.H = self.active_hessian(self.H)
|
|
523
|
+
if not (self.cur_cycle == 0):
|
|
524
|
+
self.log(f"Recalculated {key} Hessian in cycle {self.cur_cycle}.")
|
|
525
|
+
# Reset counter. It is also reset when the recalculation was initiated
|
|
526
|
+
# by the adaptive formulation.
|
|
527
|
+
self.hessian_recalc_in = self.hessian_recalc
|
|
528
|
+
# Simple hessian update
|
|
529
|
+
else:
|
|
530
|
+
dx = self.steps[-1]
|
|
531
|
+
dg = -(self.forces[-1] - self.forces[-2])
|
|
532
|
+
H_work = self.H
|
|
533
|
+
if self.using_active_dofs:
|
|
534
|
+
H_work = self.active_hessian(self.H)
|
|
535
|
+
dx = self.active_from_full(dx)
|
|
536
|
+
dg = self.active_from_full(dg)
|
|
537
|
+
curv_cond = dx.dot(dg)
|
|
538
|
+
if curv_cond < 0.0:
|
|
539
|
+
self.log(
|
|
540
|
+
f"Curvature condition (s·y = {curv_cond:.4f} < 0) not satisfied!"
|
|
541
|
+
)
|
|
542
|
+
dH, key = self.hessian_update_func(H_work, dx, dg)
|
|
543
|
+
self.H = H_work + dH
|
|
544
|
+
self.log(f"Did {key} Hessian update.")
|
|
545
|
+
|
|
546
|
+
def solve_rfo(self, rfo_mat, kind="min", prev_eigvec=None):
|
|
547
|
+
# When using the restricted step variant of RFO the RFO matrix
|
|
548
|
+
# may not be symmetric. Thats why we can't use eigh here.
|
|
549
|
+
is_torch = isinstance(rfo_mat, torch.Tensor)
|
|
550
|
+
|
|
551
|
+
if is_torch:
|
|
552
|
+
if not torch.isfinite(rfo_mat).all():
|
|
553
|
+
self.log("RFO matrix contains NaN/inf; sanitizing entries.")
|
|
554
|
+
rfo_mat = torch.nan_to_num(
|
|
555
|
+
rfo_mat, nan=0.0, posinf=1e8, neginf=-1e8
|
|
556
|
+
)
|
|
557
|
+
else:
|
|
558
|
+
if not np.isfinite(rfo_mat).all():
|
|
559
|
+
self.log("RFO matrix contains NaN/inf; sanitizing entries.")
|
|
560
|
+
rfo_mat = np.nan_to_num(rfo_mat, nan=0.0, posinf=1e8, neginf=-1e8)
|
|
561
|
+
|
|
562
|
+
if is_torch:
|
|
563
|
+
is_sym = torch.allclose(rfo_mat, rfo_mat.T)
|
|
564
|
+
else:
|
|
565
|
+
is_sym = np.allclose(rfo_mat, rfo_mat.T)
|
|
566
|
+
|
|
567
|
+
if is_sym:
|
|
568
|
+
try:
|
|
569
|
+
eigenvalues, eigenvectors = (torch.linalg.eigh(rfo_mat) if is_torch else np.linalg.eigh(rfo_mat))
|
|
570
|
+
except (torch._C._LinAlgError, np.linalg.LinAlgError):
|
|
571
|
+
self.log("eigh failed; falling back to eig.")
|
|
572
|
+
eigenvalues, eigenvectors = (torch.linalg.eig(rfo_mat) if is_torch else np.linalg.eig(rfo_mat))
|
|
573
|
+
eigenvalues = eigenvalues.real
|
|
574
|
+
eigenvectors = eigenvectors.real
|
|
575
|
+
else:
|
|
576
|
+
eigenvalues, eigenvectors = (torch.linalg.eig(rfo_mat) if is_torch else np.linalg.eig(rfo_mat))
|
|
577
|
+
eigenvalues = eigenvalues.real
|
|
578
|
+
eigenvectors = eigenvectors.real
|
|
579
|
+
|
|
580
|
+
self.log("\tdiagonalized augmented Hessian")
|
|
581
|
+
|
|
582
|
+
if isinstance(eigenvectors, torch.Tensor):
|
|
583
|
+
sorted_inds = torch.argsort(eigenvalues)
|
|
584
|
+
else:
|
|
585
|
+
sorted_inds = np.argsort(eigenvalues)
|
|
586
|
+
|
|
587
|
+
# Depending on wether we want to minimize (maximize) along
|
|
588
|
+
# the mode(s) in the rfo mat we have to select the smallest
|
|
589
|
+
# (biggest) eigenvalue and corresponding eigenvector.
|
|
590
|
+
first_or_last, verbose = self.rfo_dict[kind]
|
|
591
|
+
# Given sorted eigenvalue-indices (sorted_inds) use the first
|
|
592
|
+
# (smallest eigenvalue) or the last (largest eigenvalue) index.
|
|
593
|
+
if prev_eigvec is None:
|
|
594
|
+
ind = sorted_inds[first_or_last]
|
|
595
|
+
else:
|
|
596
|
+
if isinstance(prev_eigvec, torch.Tensor):
|
|
597
|
+
ovlps = prev_eigvec.matmul(eigenvectors)
|
|
598
|
+
else:
|
|
599
|
+
ovlps = np.array([prev_eigvec.dot(ev) for ev in eigenvectors.T])
|
|
600
|
+
naive_ind = sorted_inds[first_or_last]
|
|
601
|
+
ind = np.abs(ovlps).argmax() if isinstance(ovlps, np.ndarray) else torch.argmax(torch.abs(ovlps)).item()
|
|
602
|
+
self.log(
|
|
603
|
+
f"Overlap: {ind} ({eigenvalues[ind]:.6f}), "
|
|
604
|
+
f"Naive: {naive_ind} ({eigenvalues[naive_ind]:.6f})"
|
|
605
|
+
)
|
|
606
|
+
follow_eigvec = eigenvectors.T[ind]
|
|
607
|
+
if isinstance(follow_eigvec, torch.Tensor):
|
|
608
|
+
step_nu = follow_eigvec.clone()
|
|
609
|
+
else:
|
|
610
|
+
step_nu = follow_eigvec.copy()
|
|
611
|
+
nu = step_nu[-1]
|
|
612
|
+
self.log(f"\tnu_{verbose}={nu:.8e}")
|
|
613
|
+
# Scale eigenvector so that its last element equals 1. The
|
|
614
|
+
# final is step is the scaled eigenvector without the last element.
|
|
615
|
+
step = step_nu[:-1] / nu
|
|
616
|
+
eigval = eigenvalues[ind]
|
|
617
|
+
self.log(f"\teigenvalue_{verbose}={eigval:.8e}")
|
|
618
|
+
return step, eigval, nu, follow_eigvec
|
|
619
|
+
|
|
620
|
+
def solve_rfo_secular(self, eigvals, gradient, alpha=1.0, kind="min",
|
|
621
|
+
prev_eigvec=None, max_iter=50, tol=1e-12):
|
|
622
|
+
"""Solve the RFO eigenvalue problem via the secular equation.
|
|
623
|
+
|
|
624
|
+
The augmented Hessian has arrowhead structure, so its eigenvalue
|
|
625
|
+
problem reduces to: f(mu) = sum g_i^2/(alpha*mu - lam_i) - mu = 0,
|
|
626
|
+
solvable in O(N) instead of O(N^3).
|
|
627
|
+
|
|
628
|
+
Returns (step, eigval, nu, eigvec) on success, None on failure.
|
|
629
|
+
"""
|
|
630
|
+
is_torch = isinstance(eigvals, torch.Tensor)
|
|
631
|
+
|
|
632
|
+
# Convert all inputs to plain numpy/float for root-finding
|
|
633
|
+
_np = lambda x: x.detach().cpu().numpy().copy() if isinstance(x, torch.Tensor) else np.asarray(x, dtype=np.float64)
|
|
634
|
+
g = _np(gradient)
|
|
635
|
+
lam = _np(eigvals)
|
|
636
|
+
alpha = alpha.detach().cpu().item() if isinstance(alpha, torch.Tensor) else float(alpha)
|
|
637
|
+
|
|
638
|
+
n = len(lam)
|
|
639
|
+
g2 = g ** 2
|
|
640
|
+
nz = g2 > 1e-30
|
|
641
|
+
if not nz.any():
|
|
642
|
+
step = np.zeros(n)
|
|
643
|
+
eigvec = np.zeros(n + 1); eigvec[-1] = 1.0
|
|
644
|
+
if is_torch:
|
|
645
|
+
step = torch.zeros(n, device=eigvals.device, dtype=eigvals.dtype)
|
|
646
|
+
eigvec = torch.zeros(n + 1, device=eigvals.device, dtype=eigvals.dtype)
|
|
647
|
+
eigvec[-1] = 1.0
|
|
648
|
+
return step, 0.0, 1.0, eigvec
|
|
649
|
+
|
|
650
|
+
g2_nz, lam_nz = g2[nz], lam[nz]
|
|
651
|
+
|
|
652
|
+
# Guard against alpha ≈ 0 (can arise from trust-radius adaptation)
|
|
653
|
+
if abs(alpha) < 1e-14:
|
|
654
|
+
return None
|
|
655
|
+
|
|
656
|
+
def f_df(mu):
|
|
657
|
+
d = alpha * mu - lam_nz
|
|
658
|
+
return float(np.sum(g2_nz / d) - mu), float(-alpha * np.sum(g2_nz / d**2) - 1.0)
|
|
659
|
+
|
|
660
|
+
_, verbose = self.rfo_dict[kind]
|
|
661
|
+
|
|
662
|
+
# Bracket the root
|
|
663
|
+
if kind == "min":
|
|
664
|
+
pole = float(lam.min() / alpha)
|
|
665
|
+
mu = pole - max(float(np.sqrt(g2_nz.sum())) / alpha, 1.0)
|
|
666
|
+
for _ in range(20):
|
|
667
|
+
if f_df(mu)[0] > 0: break
|
|
668
|
+
mu = pole - 2.0 * (pole - mu)
|
|
669
|
+
else:
|
|
670
|
+
return None
|
|
671
|
+
lo, hi = mu, pole - 1e-15 * max(abs(pole), 1.0)
|
|
672
|
+
elif kind == "max":
|
|
673
|
+
pole = float(lam.max() / alpha)
|
|
674
|
+
mu = pole + max(float(np.sqrt(g2_nz.sum())) / alpha, 1.0)
|
|
675
|
+
for _ in range(20):
|
|
676
|
+
if f_df(mu)[0] < 0: break
|
|
677
|
+
mu = pole + 2.0 * (mu - pole)
|
|
678
|
+
else:
|
|
679
|
+
return None
|
|
680
|
+
lo, hi = pole + 1e-15 * max(abs(pole), 1.0), mu
|
|
681
|
+
else:
|
|
682
|
+
return None
|
|
683
|
+
|
|
684
|
+
# Newton-Raphson with bisection safeguard
|
|
685
|
+
mu_cur = (lo + hi) / 2.0
|
|
686
|
+
for _ in range(max_iter):
|
|
687
|
+
fval, dfval = f_df(mu_cur)
|
|
688
|
+
if abs(fval) < tol:
|
|
689
|
+
break
|
|
690
|
+
mu_new = mu_cur - fval / dfval if abs(dfval) > 1e-30 else (lo + hi) / 2.0
|
|
691
|
+
if mu_new <= lo or mu_new >= hi:
|
|
692
|
+
mu_new = (lo + hi) / 2.0
|
|
693
|
+
f_new = f_df(mu_new)[0]
|
|
694
|
+
if f_new > 0: lo = mu_new
|
|
695
|
+
else: hi = mu_new
|
|
696
|
+
mu_cur = mu_new
|
|
697
|
+
else:
|
|
698
|
+
self.log(f"Secular equation did not converge in {max_iter} iters.")
|
|
699
|
+
return None
|
|
700
|
+
|
|
701
|
+
self.log(f"\teigenvalue_{verbose}={mu_cur:.8e} (secular)")
|
|
702
|
+
self.log(f"\tnu_{verbose}={1.0:.8e}")
|
|
703
|
+
|
|
704
|
+
# Compute step: s_i = g_i / (alpha * mu - lam_i)
|
|
705
|
+
denom = alpha * mu_cur - lam
|
|
706
|
+
step_np = np.where(nz, g / denom, 0.0)
|
|
707
|
+
|
|
708
|
+
# Eigenvector for mode tracking
|
|
709
|
+
eigvec_np = np.append(step_np, 1.0)
|
|
710
|
+
eigvec_np /= np.linalg.norm(eigvec_np)
|
|
711
|
+
|
|
712
|
+
# Mode tracking check
|
|
713
|
+
if prev_eigvec is not None:
|
|
714
|
+
prev_np = _np(prev_eigvec)
|
|
715
|
+
if abs(float(np.dot(prev_np, eigvec_np))) < 0.3:
|
|
716
|
+
self.log("Secular eigvec overlap too low; falling back.")
|
|
717
|
+
return None
|
|
718
|
+
|
|
719
|
+
if is_torch:
|
|
720
|
+
step_np = torch.tensor(step_np, device=eigvals.device, dtype=eigvals.dtype)
|
|
721
|
+
eigvec_np = torch.tensor(eigvec_np, device=eigvals.device, dtype=eigvals.dtype)
|
|
722
|
+
|
|
723
|
+
return step_np, mu_cur, 1.0, eigvec_np
|
|
724
|
+
|
|
725
|
+
def filter_small_eigvals(self, eigvals, eigvecs, mask=False):
|
|
726
|
+
if isinstance(eigvals, torch.Tensor):
|
|
727
|
+
small_inds = torch.abs(eigvals) < self.small_eigval_thresh
|
|
728
|
+
else:
|
|
729
|
+
small_inds = np.abs(eigvals) < self.small_eigval_thresh
|
|
730
|
+
eigvals = eigvals[~small_inds]
|
|
731
|
+
eigvecs = eigvecs[:, ~small_inds]
|
|
732
|
+
small_num = sum(small_inds)
|
|
733
|
+
self.log(
|
|
734
|
+
f"Found {small_num} small eigenvalues in Hessian. Removed "
|
|
735
|
+
"corresponding eigenvalues and eigenvectors."
|
|
736
|
+
)
|
|
737
|
+
# assert small_num <= 6, (
|
|
738
|
+
# "Expected at most 6 small eigenvalues in cartesian hessian "
|
|
739
|
+
# f"but found {small_num}!"
|
|
740
|
+
# )
|
|
741
|
+
if mask:
|
|
742
|
+
return eigvals, eigvecs, small_inds
|
|
743
|
+
else:
|
|
744
|
+
return eigvals, eigvecs
|
|
745
|
+
|
|
746
|
+
def log_negative_eigenvalues(self, eigvals, pre_str=""):
|
|
747
|
+
neg_inds = eigvals < -self.small_eigval_thresh
|
|
748
|
+
neg_eigval_str = array2string(eigvals[neg_inds], precision=6)
|
|
749
|
+
self.log(f"{pre_str}Hessian has {neg_inds.sum()} negative eigenvalue(s).")
|
|
750
|
+
self.log(f"\t{neg_eigval_str}")
|
|
751
|
+
|
|
752
|
+
def housekeeping(self):
|
|
753
|
+
"""Calculate gradient and energy. Update trust radius and hessian
|
|
754
|
+
if needed. Return energy, gradient and hessian for the current cycle."""
|
|
755
|
+
gradient_full = self.geometry.gradient
|
|
756
|
+
energy = self.geometry.energy
|
|
757
|
+
self.energies.append(energy)
|
|
758
|
+
self.log(f" Energy: {energy: >12.6f} au")
|
|
759
|
+
self.log(
|
|
760
|
+
f"norm(grad): {np.linalg.norm(gradient_full): >12.6f} au / bohr (rad)"
|
|
761
|
+
)
|
|
762
|
+
self.log(
|
|
763
|
+
f" rms(grad): {np.sqrt(np.mean(gradient_full**2)): >12.6f} au / bohr (rad)"
|
|
764
|
+
)
|
|
765
|
+
self.forces.append(-gradient_full)
|
|
766
|
+
|
|
767
|
+
can_update = (
|
|
768
|
+
# Allows gradient differences
|
|
769
|
+
len(self.forces) > 1
|
|
770
|
+
and (self.forces[-2].shape == gradient_full.shape)
|
|
771
|
+
and len(self.coords) > 1
|
|
772
|
+
# Coordinates may have been rebuilt. Take care of that.
|
|
773
|
+
and (self.coords[-2].shape == self.coords[1].shape)
|
|
774
|
+
and len(self.energies) > 1
|
|
775
|
+
)
|
|
776
|
+
if can_update:
|
|
777
|
+
if self.trust_update:
|
|
778
|
+
self.update_trust_radius()
|
|
779
|
+
self.update_hessian()
|
|
780
|
+
|
|
781
|
+
# Convert gradient to match H device/dtype AFTER update_hessian(),
|
|
782
|
+
# so that hessian_recalc (which may replace self.H with a new tensor
|
|
783
|
+
# on a different device) is accounted for.
|
|
784
|
+
if isinstance(self.H, torch.Tensor):
|
|
785
|
+
gradient_full = torch.from_numpy(gradient_full).to(
|
|
786
|
+
self.H.device, self.H.dtype
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
H = self.H
|
|
790
|
+
if self.geometry.internal:
|
|
791
|
+
# Shift eigenvalues of orthogonal part to high values, so they
|
|
792
|
+
# don't contribute to the actual step.
|
|
793
|
+
H_proj = self.geometry.internal.project_hessian(self.H)
|
|
794
|
+
# Symmetrize hessian, as the projection may break it?!
|
|
795
|
+
H = (H_proj + H_proj.T) / 2
|
|
796
|
+
|
|
797
|
+
if getattr(self.geometry, "within_partial_hessian", None) is not None:
|
|
798
|
+
use_active = True
|
|
799
|
+
elif (
|
|
800
|
+
H.shape[0] != self.geometry.cart_coords.size
|
|
801
|
+
and self.geometry.coord_type in ("cart", "cartesian", "mwcartesian")
|
|
802
|
+
):
|
|
803
|
+
# Partial Hessian without explicit metadata: still use active slicing.
|
|
804
|
+
# Only applies to Cartesian coordinate types; for internal coordinates
|
|
805
|
+
# (e.g. DLC), the Hessian is naturally smaller than cart_coords.size.
|
|
806
|
+
use_active = True
|
|
807
|
+
else:
|
|
808
|
+
use_active = (
|
|
809
|
+
len(self.geometry.freeze_atoms) > 0
|
|
810
|
+
and self.geometry.coord_type in ("cart", "cartesian", "mwcartesian")
|
|
811
|
+
and H.shape[0] == self.geometry.cart_coords.size
|
|
812
|
+
)
|
|
813
|
+
self._set_active_dofs(use_active)
|
|
814
|
+
|
|
815
|
+
H = self.active_hessian(H)
|
|
816
|
+
if gradient_full.shape[0] == H.shape[0]:
|
|
817
|
+
gradient = gradient_full
|
|
818
|
+
else:
|
|
819
|
+
gradient = self.active_from_full(gradient_full)
|
|
820
|
+
|
|
821
|
+
if isinstance(H, torch.Tensor):
|
|
822
|
+
eigvals, eigvecs = torch.linalg.eigh(H)
|
|
823
|
+
else:
|
|
824
|
+
eigvals, eigvecs = np.linalg.eigh(H)
|
|
825
|
+
# Neglect small eigenvalues
|
|
826
|
+
eigvals, eigvecs = self.filter_small_eigvals(eigvals, eigvecs)
|
|
827
|
+
|
|
828
|
+
resetted = not can_update
|
|
829
|
+
self.cur_H = H
|
|
830
|
+
return energy, gradient, H, eigvals, eigvecs, resetted
|
|
831
|
+
|
|
832
|
+
def get_augmented_hessian(self, eigvals, gradient, alpha=1.0):
|
|
833
|
+
if isinstance(gradient, torch.Tensor):
|
|
834
|
+
dim_ = eigvals.size(0) + 1
|
|
835
|
+
H_aug = torch.zeros((dim_, dim_), device=gradient.device, dtype=gradient.dtype)
|
|
836
|
+
H_aug[: dim_ - 1, : dim_ - 1] = torch.diag(eigvals / alpha)
|
|
837
|
+
else:
|
|
838
|
+
dim_ = eigvals.size + 1
|
|
839
|
+
H_aug = np.zeros((dim_, dim_))
|
|
840
|
+
H_aug[: dim_ - 1, : dim_ - 1] = np.diag(eigvals / alpha)
|
|
841
|
+
H_aug[-1, :-1] = gradient
|
|
842
|
+
H_aug[:-1, -1] = gradient
|
|
843
|
+
|
|
844
|
+
H_aug[:-1, -1] /= alpha
|
|
845
|
+
|
|
846
|
+
return H_aug
|
|
847
|
+
|
|
848
|
+
def get_alpha_step(self, cur_alpha, rfo_eigval, step_norm, eigvals, gradient):
|
|
849
|
+
# Derivative of the squared step w.r.t. alpha
|
|
850
|
+
numer = gradient**2
|
|
851
|
+
denom = (eigvals - rfo_eigval * cur_alpha) ** 3
|
|
852
|
+
if isinstance(gradient, torch.Tensor):
|
|
853
|
+
quot = torch.sum(numer / denom)
|
|
854
|
+
else:
|
|
855
|
+
quot = np.sum(numer / denom)
|
|
856
|
+
self.log(f"quot={quot:.6f}")
|
|
857
|
+
dstep2_dalpha = 2 * rfo_eigval / (1 + step_norm**2 * cur_alpha) * quot
|
|
858
|
+
if isinstance(gradient, torch.Tensor):
|
|
859
|
+
dstep2_valid = bool(
|
|
860
|
+
torch.isfinite(dstep2_dalpha)
|
|
861
|
+
& (torch.abs(dstep2_dalpha) > 1e-12)
|
|
862
|
+
)
|
|
863
|
+
else:
|
|
864
|
+
dstep2_valid = np.isfinite(dstep2_dalpha) and abs(dstep2_dalpha) > 1e-12
|
|
865
|
+
if not dstep2_valid:
|
|
866
|
+
self.log(
|
|
867
|
+
"alpha update skipped due to invalid derivative "
|
|
868
|
+
f"(dstep2_dalpha={dstep2_dalpha})"
|
|
869
|
+
)
|
|
870
|
+
return 0.0
|
|
871
|
+
self.log(f"analytic deriv.={dstep2_dalpha:.6f}")
|
|
872
|
+
# Update alpha
|
|
873
|
+
alpha_step = (
|
|
874
|
+
2 * (self.trust_radius * step_norm - step_norm**2) / dstep2_dalpha
|
|
875
|
+
)
|
|
876
|
+
self.log(f"alpha_step={alpha_step:.4f}")
|
|
877
|
+
min_alpha = 1e-8
|
|
878
|
+
if (cur_alpha + alpha_step) <= min_alpha:
|
|
879
|
+
self.log(
|
|
880
|
+
"alpha update would make alpha non-positive; "
|
|
881
|
+
f"clamping to min_alpha={min_alpha:.1e}"
|
|
882
|
+
)
|
|
883
|
+
alpha_step = min_alpha - cur_alpha
|
|
884
|
+
return alpha_step
|
|
885
|
+
|
|
886
|
+
def get_rs_step(self, eigvals, eigvecs, gradient, name="RS"):
|
|
887
|
+
# Transform gradient to basis of eigenvectors
|
|
888
|
+
if isinstance(eigvecs, torch.Tensor):
|
|
889
|
+
if not isinstance(gradient, torch.Tensor):
|
|
890
|
+
gradient = torch.as_tensor(
|
|
891
|
+
gradient, device=eigvecs.device, dtype=eigvecs.dtype
|
|
892
|
+
)
|
|
893
|
+
elif gradient.device != eigvecs.device:
|
|
894
|
+
gradient = gradient.to(device=eigvecs.device)
|
|
895
|
+
gradient_ = eigvecs.T @ gradient
|
|
896
|
+
else:
|
|
897
|
+
gradient_ = eigvecs.T.dot(gradient)
|
|
898
|
+
|
|
899
|
+
alpha = self.alpha0
|
|
900
|
+
for mu in range(self.max_micro_cycles):
|
|
901
|
+
self.log(f"{name} micro cycle {mu:02d}, alpha={alpha:.6f}")
|
|
902
|
+
# Try secular equation solver first (O(N) vs O(N^3))
|
|
903
|
+
secular_result = self.solve_rfo_secular(
|
|
904
|
+
eigvals, gradient_, alpha, kind="min",
|
|
905
|
+
prev_eigvec=self.prev_eigvec_min,
|
|
906
|
+
)
|
|
907
|
+
if secular_result is not None:
|
|
908
|
+
rfo_step_, eigval_min, nu, self.prev_eigvec_min = secular_result
|
|
909
|
+
else:
|
|
910
|
+
# Fallback to full eigendecomposition
|
|
911
|
+
self.log("Secular solver failed; using full eigendecomposition.")
|
|
912
|
+
H_aug = self.get_augmented_hessian(eigvals, gradient_, alpha)
|
|
913
|
+
rfo_step_, eigval_min, nu, self.prev_eigvec_min = self.solve_rfo(
|
|
914
|
+
H_aug, "min", prev_eigvec=self.prev_eigvec_min
|
|
915
|
+
)
|
|
916
|
+
if isinstance(rfo_step_, torch.Tensor):
|
|
917
|
+
rfo_norm_ = torch.linalg.norm(rfo_step_)
|
|
918
|
+
else:
|
|
919
|
+
rfo_norm_ = np.linalg.norm(rfo_step_)
|
|
920
|
+
self.log(f"norm(rfo step)={rfo_norm_:.6f}")
|
|
921
|
+
|
|
922
|
+
if rfo_norm_ <= 0:
|
|
923
|
+
self.log(
|
|
924
|
+
"RFO step length is zero; falling back to trust-region Newton step."
|
|
925
|
+
)
|
|
926
|
+
step_ = self.get_newton_step_on_trust(
|
|
927
|
+
eigvals, eigvecs, gradient, transform=False
|
|
928
|
+
)
|
|
929
|
+
break
|
|
930
|
+
|
|
931
|
+
if (rfo_norm_ < self.trust_radius) or abs(
|
|
932
|
+
rfo_norm_ - self.trust_radius
|
|
933
|
+
) <= 1e-3:
|
|
934
|
+
step_ = rfo_step_
|
|
935
|
+
break
|
|
936
|
+
|
|
937
|
+
alpha_step = self.get_alpha_step(
|
|
938
|
+
alpha, eigval_min, rfo_norm_, eigvals, gradient_
|
|
939
|
+
)
|
|
940
|
+
alpha += alpha_step
|
|
941
|
+
self.log("")
|
|
942
|
+
# Otherwise, use trust region newton step
|
|
943
|
+
else:
|
|
944
|
+
self.log(
|
|
945
|
+
"RS algorithm did not produce a desired step length in "
|
|
946
|
+
f"{self.max_micro_cycles} micro cycles. Trying RFO with α=1.0."
|
|
947
|
+
)
|
|
948
|
+
secular_result = self.solve_rfo_secular(
|
|
949
|
+
eigvals, gradient_, alpha=1.0, kind="min"
|
|
950
|
+
)
|
|
951
|
+
if secular_result is not None:
|
|
952
|
+
rfo_step_, eigval_min, nu, _ = secular_result
|
|
953
|
+
else:
|
|
954
|
+
H_aug = self.get_augmented_hessian(eigvals, gradient_, alpha=1.0)
|
|
955
|
+
rfo_step_, eigval_min, nu, _ = self.solve_rfo(H_aug, "min")
|
|
956
|
+
if isinstance(rfo_step_, torch.Tensor):
|
|
957
|
+
rfo_norm_ = torch.linalg.norm(rfo_step_)
|
|
958
|
+
else:
|
|
959
|
+
rfo_norm_ = np.linalg.norm(rfo_step_)
|
|
960
|
+
|
|
961
|
+
# This should always be True if the above algorithm failed but we
|
|
962
|
+
# keep this line nonetheless, to make it more obvious.
|
|
963
|
+
if rfo_norm_ > self.trust_radius:
|
|
964
|
+
self.log(
|
|
965
|
+
f"Proposed RFO step with norm {rfo_norm_:.4f} is outside trust "
|
|
966
|
+
f"radius Δ={self.trust_radius:.4f}. "
|
|
967
|
+
)
|
|
968
|
+
step_ = self.get_newton_step_on_trust(
|
|
969
|
+
eigvals, eigvecs, gradient, transform=False
|
|
970
|
+
)
|
|
971
|
+
# Simple, downscaled RFO step
|
|
972
|
+
# step_ = rfo_step_ / rfo_norm_ * self.trust_radius
|
|
973
|
+
else:
|
|
974
|
+
step_ = rfo_step_
|
|
975
|
+
|
|
976
|
+
# Transform step back to original basis
|
|
977
|
+
if isinstance(eigvecs, torch.Tensor):
|
|
978
|
+
if isinstance(step_, torch.Tensor):
|
|
979
|
+
pass
|
|
980
|
+
else:
|
|
981
|
+
step_ = torch.tensor(step_, device=eigvecs.device, dtype=eigvecs.dtype)
|
|
982
|
+
step = eigvecs @ step_
|
|
983
|
+
step = step.cpu().numpy()
|
|
984
|
+
else:
|
|
985
|
+
step = eigvecs.dot(step_)
|
|
986
|
+
min_step_norm = getattr(self, "min_step_norm", 0.0)
|
|
987
|
+
step_norm = np.linalg.norm(step)
|
|
988
|
+
if step_norm <= min_step_norm:
|
|
989
|
+
self.log(
|
|
990
|
+
"RFO step length below minimum threshold; "
|
|
991
|
+
"falling back to trust-region Newton step."
|
|
992
|
+
)
|
|
993
|
+
step = self.get_newton_step_on_trust(eigvals, eigvecs, gradient)
|
|
994
|
+
if not np.isfinite(step).all():
|
|
995
|
+
self.log(
|
|
996
|
+
"RFO step contains NaN/inf; falling back to trust-region Newton step."
|
|
997
|
+
)
|
|
998
|
+
step = self.get_newton_step_on_trust(eigvals, eigvecs, gradient)
|
|
999
|
+
if not np.isfinite(step).all():
|
|
1000
|
+
raise ValueError(
|
|
1001
|
+
"Fallback Newton step still contains NaN/inf; "
|
|
1002
|
+
"aborting to avoid corrupting coordinates."
|
|
1003
|
+
)
|
|
1004
|
+
return step
|
|
1005
|
+
|
|
1006
|
+
@staticmethod
|
|
1007
|
+
def get_shifted_step_trans(eigvals, gradient_trans, shift):
|
|
1008
|
+
return -gradient_trans / (eigvals + shift)
|
|
1009
|
+
|
|
1010
|
+
@staticmethod
|
|
1011
|
+
def get_newton_step(eigvals, eigvecs, gradient):
|
|
1012
|
+
if isinstance(eigvecs, torch.Tensor):
|
|
1013
|
+
eigvals = eigvals.to(eigvecs.device, dtype=eigvecs.dtype)
|
|
1014
|
+
gradient = gradient.to(eigvecs.device, dtype=eigvecs.dtype)
|
|
1015
|
+
return (eigvecs @ (eigvecs.T @ gradient / eigvals)).cpu().numpy()
|
|
1016
|
+
else:
|
|
1017
|
+
return eigvecs.dot(eigvecs.T.dot(gradient) / eigvals)
|
|
1018
|
+
|
|
1019
|
+
def get_newton_step_on_trust(self, eigvals, eigvecs, gradient, transform=True):
|
|
1020
|
+
"""Step on trust-radius.
|
|
1021
|
+
|
|
1022
|
+
See Nocedal 4.3 Iterative solutions of the subproblem
|
|
1023
|
+
"""
|
|
1024
|
+
if isinstance(eigvals, torch.Tensor):
|
|
1025
|
+
eigvals = eigvals.cpu().numpy()
|
|
1026
|
+
|
|
1027
|
+
min_ind = eigvals.argmin()
|
|
1028
|
+
min_eigval = eigvals[min_ind]
|
|
1029
|
+
pos_definite = bool((eigvals > 0.0).all())
|
|
1030
|
+
if isinstance(eigvecs, torch.Tensor):
|
|
1031
|
+
if not isinstance(gradient, torch.Tensor):
|
|
1032
|
+
gradient = torch.tensor(
|
|
1033
|
+
gradient, device=eigvecs.device, dtype=eigvecs.dtype
|
|
1034
|
+
)
|
|
1035
|
+
else:
|
|
1036
|
+
gradient = gradient.to(device=eigvecs.device, dtype=eigvecs.dtype)
|
|
1037
|
+
gradient_trans = eigvecs.T @ gradient
|
|
1038
|
+
gradient_trans = gradient_trans.cpu().numpy()
|
|
1039
|
+
else:
|
|
1040
|
+
gradient_trans = eigvecs.T.dot(gradient)
|
|
1041
|
+
|
|
1042
|
+
# This will be also be True when we come close to a minimizer,
|
|
1043
|
+
# but then the Hessian will also be positive definite and a
|
|
1044
|
+
# simple Newton step will be used.
|
|
1045
|
+
hard_case = abs(gradient_trans[min_ind]) <= 1e-6
|
|
1046
|
+
self.log(f"Smallest eigenvalue: {min_eigval:.6f}")
|
|
1047
|
+
self.log(f"Positive definite Hessian: {pos_definite}")
|
|
1048
|
+
self.log(f"Hard case: {hard_case}")
|
|
1049
|
+
|
|
1050
|
+
def get_step(shift):
|
|
1051
|
+
return -gradient_trans / (eigvals + shift)
|
|
1052
|
+
|
|
1053
|
+
# Unshifted Newton step
|
|
1054
|
+
newton_step_trans = get_step(0.0)
|
|
1055
|
+
newton_norm = np.linalg.norm(newton_step_trans)
|
|
1056
|
+
|
|
1057
|
+
def on_trust_radius_lin(step):
|
|
1058
|
+
return 1 / self.trust_radius - 1 / np.linalg.norm(step)
|
|
1059
|
+
|
|
1060
|
+
def finalize_step(shift):
|
|
1061
|
+
step = get_step(shift)
|
|
1062
|
+
if transform:
|
|
1063
|
+
if isinstance(eigvecs, torch.Tensor):
|
|
1064
|
+
step = torch.tensor(step, device=eigvecs.device, dtype=eigvecs.dtype)
|
|
1065
|
+
return (eigvecs @ step).cpu().numpy()
|
|
1066
|
+
else:
|
|
1067
|
+
return eigvecs.dot(step)
|
|
1068
|
+
return step
|
|
1069
|
+
|
|
1070
|
+
# Simplest case. Positive definite Hessian and predicted step is
|
|
1071
|
+
# already in trust radius.
|
|
1072
|
+
if pos_definite and newton_norm <= self.trust_radius:
|
|
1073
|
+
self.log("Using unshifted Newton step.")
|
|
1074
|
+
if isinstance(eigvecs, torch.Tensor):
|
|
1075
|
+
newton_step_trans = torch.tensor(
|
|
1076
|
+
newton_step_trans, device=eigvecs.device, dtype=eigvecs.dtype
|
|
1077
|
+
)
|
|
1078
|
+
return (eigvecs @ newton_step_trans).cpu().numpy()
|
|
1079
|
+
else:
|
|
1080
|
+
return eigvecs.dot(newton_step_trans)
|
|
1081
|
+
|
|
1082
|
+
# If the Hessian is not positive definite or if the step is too
|
|
1083
|
+
# long we have to determine the shift parameter lambda.
|
|
1084
|
+
rs_kwargs = {
|
|
1085
|
+
"f": lambda shift: on_trust_radius_lin(get_step(shift)),
|
|
1086
|
+
"xtol": 1e-3,
|
|
1087
|
+
# Would otherwise be chosen automatically, but we set it
|
|
1088
|
+
# here explicitly for verbosity.
|
|
1089
|
+
"method": "brentq",
|
|
1090
|
+
}
|
|
1091
|
+
|
|
1092
|
+
def root_search(bracket):
|
|
1093
|
+
rs_kwargs.update(
|
|
1094
|
+
{
|
|
1095
|
+
"bracket": bracket,
|
|
1096
|
+
"x0": bracket[0] + 1e-3,
|
|
1097
|
+
}
|
|
1098
|
+
)
|
|
1099
|
+
res = root_scalar(**rs_kwargs)
|
|
1100
|
+
return res
|
|
1101
|
+
|
|
1102
|
+
BRACKET_END = 1e10
|
|
1103
|
+
if not hard_case:
|
|
1104
|
+
bracket_start = 0.0 if pos_definite else -min_eigval + 1e-2
|
|
1105
|
+
bracket = (bracket_start, BRACKET_END)
|
|
1106
|
+
try:
|
|
1107
|
+
res = root_search(bracket)
|
|
1108
|
+
assert res.converged
|
|
1109
|
+
return finalize_step(res.root)
|
|
1110
|
+
# ValueError may be raised when the function values for the
|
|
1111
|
+
# initial bracket have the same sign. If so, we continue with
|
|
1112
|
+
# treating it as a hard case.
|
|
1113
|
+
except ValueError:
|
|
1114
|
+
pass
|
|
1115
|
+
|
|
1116
|
+
# Now we would try the bracket (-b2, -b1). The resulting step should have
|
|
1117
|
+
# a suitable length, but the (shifted) Hessian would have an incorrect
|
|
1118
|
+
# eigenvalue spectrum (not positive definite). To solve this we use a
|
|
1119
|
+
# different formula to calculate the step.
|
|
1120
|
+
mask = np.ones_like(gradient_trans)
|
|
1121
|
+
mask[min_ind] = 0
|
|
1122
|
+
mask = mask.astype(bool)
|
|
1123
|
+
without_min = gradient_trans[mask] / (eigvals[mask] - min_eigval)
|
|
1124
|
+
tau_sq = self.trust_radius**2 - (without_min**2).sum()
|
|
1125
|
+
if tau_sq >= 0.0:
|
|
1126
|
+
tau = sqrt(tau_sq)
|
|
1127
|
+
step_trans = [tau] + (-without_min).tolist()
|
|
1128
|
+
else:
|
|
1129
|
+
# Hard case. Search in open interval (endpoints not included)
|
|
1130
|
+
# (-min_eigval, inf).
|
|
1131
|
+
bracket = (-min_eigval + 1e-6, BRACKET_END)
|
|
1132
|
+
try:
|
|
1133
|
+
res = root_search(bracket)
|
|
1134
|
+
if res.converged:
|
|
1135
|
+
return finalize_step(res.root)
|
|
1136
|
+
except ValueError:
|
|
1137
|
+
pass
|
|
1138
|
+
# Fallback: clamp tau to 0 so the step excludes the
|
|
1139
|
+
# minimum-eigenvalue component but remains valid.
|
|
1140
|
+
self.log("Hard case fallback: tau clamped to 0.")
|
|
1141
|
+
tau = 0.0
|
|
1142
|
+
step_trans = [tau] + (-without_min).tolist()
|
|
1143
|
+
|
|
1144
|
+
if not transform:
|
|
1145
|
+
return step_trans
|
|
1146
|
+
|
|
1147
|
+
if isinstance(eigvecs, torch.Tensor):
|
|
1148
|
+
step_trans = torch.tensor(step_trans, device=eigvecs.device, dtype=eigvecs.dtype)
|
|
1149
|
+
return (eigvecs @ step_trans).cpu().numpy()
|
|
1150
|
+
else:
|
|
1151
|
+
return eigvecs.dot(step_trans)
|
|
1152
|
+
|
|
1153
|
+
@staticmethod
|
|
1154
|
+
def quadratic_model(gradient, hessian, step):
|
|
1155
|
+
if isinstance(gradient, torch.Tensor):
|
|
1156
|
+
step = torch.tensor(step, device=gradient.device, dtype=gradient.dtype)
|
|
1157
|
+
return (step @ gradient + 0.5 * step @ hessian @ step).cpu().numpy()
|
|
1158
|
+
else:
|
|
1159
|
+
step = np.asarray(step).ravel()
|
|
1160
|
+
return step.dot(gradient) + 0.5 * step.dot(hessian).dot(step)
|
|
1161
|
+
|
|
1162
|
+
@staticmethod
|
|
1163
|
+
def rfo_model(gradient, hessian, step):
|
|
1164
|
+
return HessianOptimizer.quadratic_model(gradient, hessian, step) / (
|
|
1165
|
+
1 + step.dot(step)
|
|
1166
|
+
)
|
|
1167
|
+
|
|
1168
|
+
def get_step_func(self, eigvals, gradient, grad_rms_thresh=1e-2):
|
|
1169
|
+
positive_definite = (eigvals < 0).sum() == 0
|
|
1170
|
+
gradient_small = rms(gradient) < grad_rms_thresh
|
|
1171
|
+
|
|
1172
|
+
if self.adapt_step_func and gradient_small and positive_definite:
|
|
1173
|
+
return self.get_newton_step_on_trust, self.quadratic_model
|
|
1174
|
+
# RFO fallback
|
|
1175
|
+
else:
|
|
1176
|
+
return self.get_rs_step, self.rfo_model
|