mlmm-toolkit 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (372) hide show
  1. hessian_ff/__init__.py +50 -0
  2. hessian_ff/analytical_hessian.py +609 -0
  3. hessian_ff/constants.py +46 -0
  4. hessian_ff/forcefield.py +339 -0
  5. hessian_ff/loaders.py +608 -0
  6. hessian_ff/native/Makefile +8 -0
  7. hessian_ff/native/__init__.py +28 -0
  8. hessian_ff/native/analytical_hessian.py +88 -0
  9. hessian_ff/native/analytical_hessian_ext.cpp +258 -0
  10. hessian_ff/native/bonded.py +82 -0
  11. hessian_ff/native/bonded_ext.cpp +640 -0
  12. hessian_ff/native/loader.py +349 -0
  13. hessian_ff/native/nonbonded.py +118 -0
  14. hessian_ff/native/nonbonded_ext.cpp +1150 -0
  15. hessian_ff/prmtop_parmed.py +23 -0
  16. hessian_ff/system.py +107 -0
  17. hessian_ff/terms/__init__.py +14 -0
  18. hessian_ff/terms/angle.py +73 -0
  19. hessian_ff/terms/bond.py +44 -0
  20. hessian_ff/terms/cmap.py +406 -0
  21. hessian_ff/terms/dihedral.py +141 -0
  22. hessian_ff/terms/nonbonded.py +209 -0
  23. hessian_ff/tests/__init__.py +0 -0
  24. hessian_ff/tests/conftest.py +75 -0
  25. hessian_ff/tests/data/small/complex.parm7 +1346 -0
  26. hessian_ff/tests/data/small/complex.pdb +125 -0
  27. hessian_ff/tests/data/small/complex.rst7 +63 -0
  28. hessian_ff/tests/test_coords_input.py +44 -0
  29. hessian_ff/tests/test_energy_force.py +49 -0
  30. hessian_ff/tests/test_hessian.py +137 -0
  31. hessian_ff/tests/test_smoke.py +18 -0
  32. hessian_ff/tests/test_validation.py +40 -0
  33. hessian_ff/workflows.py +889 -0
  34. mlmm/__init__.py +36 -0
  35. mlmm/__main__.py +7 -0
  36. mlmm/_version.py +34 -0
  37. mlmm/add_elem_info.py +374 -0
  38. mlmm/advanced_help.py +91 -0
  39. mlmm/align_freeze_atoms.py +601 -0
  40. mlmm/all.py +3535 -0
  41. mlmm/bond_changes.py +231 -0
  42. mlmm/bool_compat.py +223 -0
  43. mlmm/cli.py +574 -0
  44. mlmm/cli_utils.py +166 -0
  45. mlmm/default_group.py +337 -0
  46. mlmm/defaults.py +467 -0
  47. mlmm/define_layer.py +526 -0
  48. mlmm/dft.py +1041 -0
  49. mlmm/energy_diagram.py +253 -0
  50. mlmm/extract.py +2213 -0
  51. mlmm/fix_altloc.py +464 -0
  52. mlmm/freq.py +1406 -0
  53. mlmm/harmonic_constraints.py +140 -0
  54. mlmm/hessian_cache.py +44 -0
  55. mlmm/hessian_calc.py +174 -0
  56. mlmm/irc.py +638 -0
  57. mlmm/mlmm_calc.py +2262 -0
  58. mlmm/mm_parm.py +945 -0
  59. mlmm/oniom_export.py +1983 -0
  60. mlmm/oniom_import.py +457 -0
  61. mlmm/opt.py +1742 -0
  62. mlmm/path_opt.py +1353 -0
  63. mlmm/path_search.py +2299 -0
  64. mlmm/preflight.py +88 -0
  65. mlmm/py.typed +1 -0
  66. mlmm/pysis_runner.py +45 -0
  67. mlmm/scan.py +1047 -0
  68. mlmm/scan2d.py +1226 -0
  69. mlmm/scan3d.py +1265 -0
  70. mlmm/scan_common.py +184 -0
  71. mlmm/summary_log.py +736 -0
  72. mlmm/trj2fig.py +448 -0
  73. mlmm/tsopt.py +2871 -0
  74. mlmm/utils.py +2309 -0
  75. mlmm/xtb_embedcharge_correction.py +475 -0
  76. mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
  77. mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
  78. mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
  79. mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
  80. mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
  81. mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
  82. pysisyphus/Geometry.py +1667 -0
  83. pysisyphus/LICENSE +674 -0
  84. pysisyphus/TableFormatter.py +63 -0
  85. pysisyphus/TablePrinter.py +74 -0
  86. pysisyphus/__init__.py +12 -0
  87. pysisyphus/calculators/AFIR.py +452 -0
  88. pysisyphus/calculators/AnaPot.py +20 -0
  89. pysisyphus/calculators/AnaPot2.py +48 -0
  90. pysisyphus/calculators/AnaPot3.py +12 -0
  91. pysisyphus/calculators/AnaPot4.py +20 -0
  92. pysisyphus/calculators/AnaPotBase.py +337 -0
  93. pysisyphus/calculators/AnaPotCBM.py +25 -0
  94. pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
  95. pysisyphus/calculators/CFOUR.py +250 -0
  96. pysisyphus/calculators/Calculator.py +844 -0
  97. pysisyphus/calculators/CerjanMiller.py +24 -0
  98. pysisyphus/calculators/Composite.py +123 -0
  99. pysisyphus/calculators/ConicalIntersection.py +171 -0
  100. pysisyphus/calculators/DFTBp.py +430 -0
  101. pysisyphus/calculators/DFTD3.py +66 -0
  102. pysisyphus/calculators/DFTD4.py +84 -0
  103. pysisyphus/calculators/Dalton.py +61 -0
  104. pysisyphus/calculators/Dimer.py +681 -0
  105. pysisyphus/calculators/Dummy.py +20 -0
  106. pysisyphus/calculators/EGO.py +76 -0
  107. pysisyphus/calculators/EnergyMin.py +224 -0
  108. pysisyphus/calculators/ExternalPotential.py +264 -0
  109. pysisyphus/calculators/FakeASE.py +35 -0
  110. pysisyphus/calculators/FourWellAnaPot.py +28 -0
  111. pysisyphus/calculators/FreeEndNEBPot.py +39 -0
  112. pysisyphus/calculators/Gaussian09.py +18 -0
  113. pysisyphus/calculators/Gaussian16.py +726 -0
  114. pysisyphus/calculators/HardSphere.py +159 -0
  115. pysisyphus/calculators/IDPPCalculator.py +49 -0
  116. pysisyphus/calculators/IPIClient.py +133 -0
  117. pysisyphus/calculators/IPIServer.py +234 -0
  118. pysisyphus/calculators/LEPSBase.py +24 -0
  119. pysisyphus/calculators/LEPSExpr.py +139 -0
  120. pysisyphus/calculators/LennardJones.py +80 -0
  121. pysisyphus/calculators/MOPAC.py +219 -0
  122. pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
  123. pysisyphus/calculators/MultiCalc.py +85 -0
  124. pysisyphus/calculators/NFK.py +45 -0
  125. pysisyphus/calculators/OBabel.py +87 -0
  126. pysisyphus/calculators/ONIOMv2.py +1129 -0
  127. pysisyphus/calculators/ORCA.py +893 -0
  128. pysisyphus/calculators/ORCA5.py +6 -0
  129. pysisyphus/calculators/OpenMM.py +88 -0
  130. pysisyphus/calculators/OpenMolcas.py +281 -0
  131. pysisyphus/calculators/OverlapCalculator.py +908 -0
  132. pysisyphus/calculators/Psi4.py +218 -0
  133. pysisyphus/calculators/PyPsi4.py +37 -0
  134. pysisyphus/calculators/PySCF.py +341 -0
  135. pysisyphus/calculators/PyXTB.py +73 -0
  136. pysisyphus/calculators/QCEngine.py +106 -0
  137. pysisyphus/calculators/Rastrigin.py +22 -0
  138. pysisyphus/calculators/Remote.py +76 -0
  139. pysisyphus/calculators/Rosenbrock.py +15 -0
  140. pysisyphus/calculators/SocketCalc.py +97 -0
  141. pysisyphus/calculators/TIP3P.py +111 -0
  142. pysisyphus/calculators/TransTorque.py +161 -0
  143. pysisyphus/calculators/Turbomole.py +965 -0
  144. pysisyphus/calculators/VRIPot.py +37 -0
  145. pysisyphus/calculators/WFOWrapper.py +333 -0
  146. pysisyphus/calculators/WFOWrapper2.py +341 -0
  147. pysisyphus/calculators/XTB.py +418 -0
  148. pysisyphus/calculators/__init__.py +81 -0
  149. pysisyphus/calculators/cosmo_data.py +139 -0
  150. pysisyphus/calculators/parser.py +150 -0
  151. pysisyphus/color.py +19 -0
  152. pysisyphus/config.py +133 -0
  153. pysisyphus/constants.py +65 -0
  154. pysisyphus/cos/AdaptiveNEB.py +230 -0
  155. pysisyphus/cos/ChainOfStates.py +725 -0
  156. pysisyphus/cos/FreeEndNEB.py +25 -0
  157. pysisyphus/cos/FreezingString.py +103 -0
  158. pysisyphus/cos/GrowingChainOfStates.py +71 -0
  159. pysisyphus/cos/GrowingNT.py +309 -0
  160. pysisyphus/cos/GrowingString.py +508 -0
  161. pysisyphus/cos/NEB.py +189 -0
  162. pysisyphus/cos/SimpleZTS.py +64 -0
  163. pysisyphus/cos/__init__.py +22 -0
  164. pysisyphus/cos/stiffness.py +199 -0
  165. pysisyphus/drivers/__init__.py +17 -0
  166. pysisyphus/drivers/afir.py +855 -0
  167. pysisyphus/drivers/barriers.py +271 -0
  168. pysisyphus/drivers/birkholz.py +138 -0
  169. pysisyphus/drivers/cluster.py +318 -0
  170. pysisyphus/drivers/diabatization.py +133 -0
  171. pysisyphus/drivers/merge.py +368 -0
  172. pysisyphus/drivers/merge_mol2.py +322 -0
  173. pysisyphus/drivers/opt.py +375 -0
  174. pysisyphus/drivers/perf.py +91 -0
  175. pysisyphus/drivers/pka.py +52 -0
  176. pysisyphus/drivers/precon_pos_rot.py +669 -0
  177. pysisyphus/drivers/rates.py +480 -0
  178. pysisyphus/drivers/replace.py +219 -0
  179. pysisyphus/drivers/scan.py +212 -0
  180. pysisyphus/drivers/spectrum.py +166 -0
  181. pysisyphus/drivers/thermo.py +31 -0
  182. pysisyphus/dynamics/Gaussian.py +103 -0
  183. pysisyphus/dynamics/__init__.py +20 -0
  184. pysisyphus/dynamics/colvars.py +136 -0
  185. pysisyphus/dynamics/driver.py +297 -0
  186. pysisyphus/dynamics/helpers.py +256 -0
  187. pysisyphus/dynamics/lincs.py +105 -0
  188. pysisyphus/dynamics/mdp.py +364 -0
  189. pysisyphus/dynamics/rattle.py +121 -0
  190. pysisyphus/dynamics/thermostats.py +128 -0
  191. pysisyphus/dynamics/wigner.py +266 -0
  192. pysisyphus/elem_data.py +3473 -0
  193. pysisyphus/exceptions.py +2 -0
  194. pysisyphus/filtertrj.py +69 -0
  195. pysisyphus/helpers.py +623 -0
  196. pysisyphus/helpers_pure.py +649 -0
  197. pysisyphus/init_logging.py +50 -0
  198. pysisyphus/intcoords/Bend.py +69 -0
  199. pysisyphus/intcoords/Bend2.py +25 -0
  200. pysisyphus/intcoords/BondedFragment.py +32 -0
  201. pysisyphus/intcoords/Cartesian.py +41 -0
  202. pysisyphus/intcoords/CartesianCoords.py +140 -0
  203. pysisyphus/intcoords/Coords.py +56 -0
  204. pysisyphus/intcoords/DLC.py +197 -0
  205. pysisyphus/intcoords/DistanceFunction.py +34 -0
  206. pysisyphus/intcoords/DummyImproper.py +70 -0
  207. pysisyphus/intcoords/DummyTorsion.py +72 -0
  208. pysisyphus/intcoords/LinearBend.py +105 -0
  209. pysisyphus/intcoords/LinearDisplacement.py +80 -0
  210. pysisyphus/intcoords/OutOfPlane.py +59 -0
  211. pysisyphus/intcoords/PrimTypes.py +286 -0
  212. pysisyphus/intcoords/Primitive.py +137 -0
  213. pysisyphus/intcoords/RedundantCoords.py +659 -0
  214. pysisyphus/intcoords/RobustTorsion.py +59 -0
  215. pysisyphus/intcoords/Rotation.py +147 -0
  216. pysisyphus/intcoords/Stretch.py +31 -0
  217. pysisyphus/intcoords/Torsion.py +101 -0
  218. pysisyphus/intcoords/Torsion2.py +25 -0
  219. pysisyphus/intcoords/Translation.py +45 -0
  220. pysisyphus/intcoords/__init__.py +61 -0
  221. pysisyphus/intcoords/augment_bonds.py +126 -0
  222. pysisyphus/intcoords/derivatives.py +10512 -0
  223. pysisyphus/intcoords/eval.py +80 -0
  224. pysisyphus/intcoords/exceptions.py +37 -0
  225. pysisyphus/intcoords/findiffs.py +48 -0
  226. pysisyphus/intcoords/generate_derivatives.py +414 -0
  227. pysisyphus/intcoords/helpers.py +235 -0
  228. pysisyphus/intcoords/logging_conf.py +10 -0
  229. pysisyphus/intcoords/mp_derivatives.py +10836 -0
  230. pysisyphus/intcoords/setup.py +962 -0
  231. pysisyphus/intcoords/setup_fast.py +176 -0
  232. pysisyphus/intcoords/update.py +272 -0
  233. pysisyphus/intcoords/valid.py +89 -0
  234. pysisyphus/interpolate/Geodesic.py +93 -0
  235. pysisyphus/interpolate/IDPP.py +55 -0
  236. pysisyphus/interpolate/Interpolator.py +116 -0
  237. pysisyphus/interpolate/LST.py +70 -0
  238. pysisyphus/interpolate/Redund.py +152 -0
  239. pysisyphus/interpolate/__init__.py +9 -0
  240. pysisyphus/interpolate/helpers.py +34 -0
  241. pysisyphus/io/__init__.py +22 -0
  242. pysisyphus/io/aomix.py +178 -0
  243. pysisyphus/io/cjson.py +24 -0
  244. pysisyphus/io/crd.py +101 -0
  245. pysisyphus/io/cube.py +220 -0
  246. pysisyphus/io/fchk.py +184 -0
  247. pysisyphus/io/hdf5.py +49 -0
  248. pysisyphus/io/hessian.py +72 -0
  249. pysisyphus/io/mol2.py +146 -0
  250. pysisyphus/io/molden.py +293 -0
  251. pysisyphus/io/orca.py +189 -0
  252. pysisyphus/io/pdb.py +269 -0
  253. pysisyphus/io/psf.py +79 -0
  254. pysisyphus/io/pubchem.py +31 -0
  255. pysisyphus/io/qcschema.py +34 -0
  256. pysisyphus/io/sdf.py +29 -0
  257. pysisyphus/io/xyz.py +61 -0
  258. pysisyphus/io/zmat.py +175 -0
  259. pysisyphus/irc/DWI.py +108 -0
  260. pysisyphus/irc/DampedVelocityVerlet.py +134 -0
  261. pysisyphus/irc/Euler.py +22 -0
  262. pysisyphus/irc/EulerPC.py +345 -0
  263. pysisyphus/irc/GonzalezSchlegel.py +187 -0
  264. pysisyphus/irc/IMKMod.py +164 -0
  265. pysisyphus/irc/IRC.py +878 -0
  266. pysisyphus/irc/IRCDummy.py +10 -0
  267. pysisyphus/irc/Instanton.py +307 -0
  268. pysisyphus/irc/LQA.py +53 -0
  269. pysisyphus/irc/ModeKill.py +136 -0
  270. pysisyphus/irc/ParamPlot.py +53 -0
  271. pysisyphus/irc/RK4.py +36 -0
  272. pysisyphus/irc/__init__.py +31 -0
  273. pysisyphus/irc/initial_displ.py +219 -0
  274. pysisyphus/linalg.py +411 -0
  275. pysisyphus/line_searches/Backtracking.py +88 -0
  276. pysisyphus/line_searches/HagerZhang.py +184 -0
  277. pysisyphus/line_searches/LineSearch.py +232 -0
  278. pysisyphus/line_searches/StrongWolfe.py +108 -0
  279. pysisyphus/line_searches/__init__.py +9 -0
  280. pysisyphus/line_searches/interpol.py +15 -0
  281. pysisyphus/modefollow/NormalMode.py +40 -0
  282. pysisyphus/modefollow/__init__.py +10 -0
  283. pysisyphus/modefollow/davidson.py +199 -0
  284. pysisyphus/modefollow/lanczos.py +95 -0
  285. pysisyphus/optimizers/BFGS.py +99 -0
  286. pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
  287. pysisyphus/optimizers/ConjugateGradient.py +98 -0
  288. pysisyphus/optimizers/CubicNewton.py +75 -0
  289. pysisyphus/optimizers/FIRE.py +113 -0
  290. pysisyphus/optimizers/HessianOptimizer.py +1176 -0
  291. pysisyphus/optimizers/LBFGS.py +228 -0
  292. pysisyphus/optimizers/LayerOpt.py +411 -0
  293. pysisyphus/optimizers/MicroOptimizer.py +169 -0
  294. pysisyphus/optimizers/NCOptimizer.py +90 -0
  295. pysisyphus/optimizers/Optimizer.py +1084 -0
  296. pysisyphus/optimizers/PreconLBFGS.py +260 -0
  297. pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
  298. pysisyphus/optimizers/QuickMin.py +74 -0
  299. pysisyphus/optimizers/RFOptimizer.py +181 -0
  300. pysisyphus/optimizers/RSA.py +99 -0
  301. pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
  302. pysisyphus/optimizers/SteepestDescent.py +23 -0
  303. pysisyphus/optimizers/StringOptimizer.py +173 -0
  304. pysisyphus/optimizers/__init__.py +41 -0
  305. pysisyphus/optimizers/closures.py +301 -0
  306. pysisyphus/optimizers/cls_map.py +58 -0
  307. pysisyphus/optimizers/exceptions.py +6 -0
  308. pysisyphus/optimizers/gdiis.py +280 -0
  309. pysisyphus/optimizers/guess_hessians.py +311 -0
  310. pysisyphus/optimizers/hessian_updates.py +355 -0
  311. pysisyphus/optimizers/poly_fit.py +285 -0
  312. pysisyphus/optimizers/precon.py +153 -0
  313. pysisyphus/optimizers/restrict_step.py +24 -0
  314. pysisyphus/pack.py +172 -0
  315. pysisyphus/peakdetect.py +948 -0
  316. pysisyphus/plot.py +1031 -0
  317. pysisyphus/run.py +2106 -0
  318. pysisyphus/socket_helper.py +74 -0
  319. pysisyphus/stocastic/FragmentKick.py +132 -0
  320. pysisyphus/stocastic/Kick.py +81 -0
  321. pysisyphus/stocastic/Pipeline.py +303 -0
  322. pysisyphus/stocastic/__init__.py +21 -0
  323. pysisyphus/stocastic/align.py +127 -0
  324. pysisyphus/testing.py +96 -0
  325. pysisyphus/thermo.py +156 -0
  326. pysisyphus/trj.py +824 -0
  327. pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
  328. pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
  329. pysisyphus/tsoptimizers/TRIM.py +59 -0
  330. pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
  331. pysisyphus/tsoptimizers/__init__.py +23 -0
  332. pysisyphus/wavefunction/Basis.py +239 -0
  333. pysisyphus/wavefunction/DIIS.py +76 -0
  334. pysisyphus/wavefunction/__init__.py +25 -0
  335. pysisyphus/wavefunction/build_ext.py +42 -0
  336. pysisyphus/wavefunction/cart2sph.py +190 -0
  337. pysisyphus/wavefunction/diabatization.py +304 -0
  338. pysisyphus/wavefunction/excited_states.py +435 -0
  339. pysisyphus/wavefunction/gen_ints.py +1811 -0
  340. pysisyphus/wavefunction/helpers.py +104 -0
  341. pysisyphus/wavefunction/ints/__init__.py +0 -0
  342. pysisyphus/wavefunction/ints/boys.py +193 -0
  343. pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
  344. pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
  345. pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
  346. pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
  347. pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
  348. pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
  349. pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
  350. pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
  351. pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
  352. pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
  353. pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
  354. pysisyphus/wavefunction/localization.py +458 -0
  355. pysisyphus/wavefunction/multipole.py +159 -0
  356. pysisyphus/wavefunction/normalization.py +36 -0
  357. pysisyphus/wavefunction/pop_analysis.py +134 -0
  358. pysisyphus/wavefunction/shells.py +1171 -0
  359. pysisyphus/wavefunction/wavefunction.py +504 -0
  360. pysisyphus/wrapper/__init__.py +11 -0
  361. pysisyphus/wrapper/exceptions.py +2 -0
  362. pysisyphus/wrapper/jmol.py +120 -0
  363. pysisyphus/wrapper/mwfn.py +169 -0
  364. pysisyphus/wrapper/packmol.py +71 -0
  365. pysisyphus/xyzloader.py +168 -0
  366. pysisyphus/yaml_mods.py +45 -0
  367. thermoanalysis/LICENSE +674 -0
  368. thermoanalysis/QCData.py +244 -0
  369. thermoanalysis/__init__.py +0 -0
  370. thermoanalysis/config.py +3 -0
  371. thermoanalysis/constants.py +20 -0
  372. thermoanalysis/thermo.py +1011 -0
@@ -0,0 +1,1176 @@
1
+ from math import sqrt
2
+ from pathlib import Path
3
+ from typing import Literal, Optional
4
+
5
+ import numpy as np
6
+ from scipy.optimize import root_scalar
7
+
8
+ from pysisyphus.cos.ChainOfStates import ChainOfStates
9
+ from pysisyphus.Geometry import Geometry
10
+ from pysisyphus.helpers_pure import rms
11
+ # from pysisyphus.io.hessian import save_hessian
12
+ from pysisyphus.optimizers.guess_hessians import (
13
+ get_guess_hessian,
14
+ xtb_hessian,
15
+ HessInit,
16
+ )
17
+ from pysisyphus.optimizers.hessian_updates import (
18
+ bfgs_update,
19
+ flowchart_update,
20
+ damped_bfgs_update,
21
+ bofill_update,
22
+ ts_bfgs_update,
23
+ ts_bfgs_update_org,
24
+ ts_bfgs_update_revised,
25
+ )
26
+ from pysisyphus.optimizers.Optimizer import Optimizer
27
+ from pysisyphus.optimizers.exceptions import OptimizationError
28
+
29
+ from pysisyphus.helpers import array2string
30
+ import torch
31
+
32
+ def dummy_hessian_update(H, dx, dg):
33
+ return np.zeros_like(H), "no"
34
+
35
+
36
+ HESS_UPDATE_FUNCS = {
37
+ "none": dummy_hessian_update,
38
+ None: dummy_hessian_update,
39
+ False: dummy_hessian_update,
40
+ "bfgs": bfgs_update,
41
+ "damped_bfgs": damped_bfgs_update,
42
+ "flowchart": flowchart_update,
43
+ "bofill": bofill_update,
44
+ "ts_bfgs": ts_bfgs_update,
45
+ "ts_bfgs_org": ts_bfgs_update_org,
46
+ "ts_bfgs_rev": ts_bfgs_update_revised,
47
+ }
48
+ HessUpdate = Literal[
49
+ "none",
50
+ None,
51
+ False,
52
+ "bfgs",
53
+ "damped_bfgs",
54
+ "flowchart",
55
+ "bofill",
56
+ "ts_bfgs",
57
+ "ts_bfgs_org",
58
+ "ts_bfgs_rev",
59
+ ]
60
+
61
+
62
+ class HessianOptimizer(Optimizer):
63
+ rfo_dict = {
64
+ "min": (0, "min"),
65
+ "max": (-1, "max"),
66
+ }
67
+
68
+ def __init__(
69
+ self,
70
+ geometry: Geometry,
71
+ trust_radius: float = 0.5,
72
+ trust_update: bool = True,
73
+ trust_min: float = 0.1,
74
+ trust_max: float = 1,
75
+ max_energy_incr: Optional[float] = None,
76
+ hessian_update: HessUpdate = "bfgs",
77
+ hessian_init: HessInit = "fischer",
78
+ hessian_recalc: Optional[int] = None,
79
+ hessian_recalc_adapt: Optional[float] = None,
80
+ hessian_xtb: bool = False,
81
+ hessian_recalc_reset: bool = False,
82
+ small_eigval_thresh: float = 1e-8,
83
+ line_search: bool = False,
84
+ alpha0: float = 1.0,
85
+ max_micro_cycles: int = 25,
86
+ rfo_overlaps: bool = False,
87
+ **kwargs,
88
+ ) -> None:
89
+ """Baseclass for optimizers utilizing Hessian information.
90
+
91
+ Parameters
92
+ ----------
93
+ geometry
94
+ Geometry to be optimized.
95
+ trust_radius
96
+ Initial trust radius in whatever unit the optimization is carried out.
97
+ trust_update
98
+ Whether to update the trust radius throughout the optimization.
99
+ trust_min
100
+ Minimum trust radius.
101
+ trust_max
102
+ Maximum trust radius.
103
+ max_energy_incr
104
+ Maximum allowed energy increased after a faulty step. Optimization is
105
+ aborted when the threshold is exceeded.
106
+ hessian_update
107
+ Type of Hessian update. Defaults to BFGS for minimizations and Bofill
108
+ for saddle point searches.
109
+ hessian_init
110
+ Type of initial model Hessian.
111
+ hessian_recalc
112
+ Recalculate exact Hessian every n-th cycle instead of updating it.
113
+ hessian_recalc_adapt
114
+ Use a more flexible scheme to determine Hessian recalculation. Undocumented.
115
+ hessian_xtb
116
+ Recalculate the Hessian at the GFN2-XTB level of theory.
117
+ hessian_recalc_reset
118
+ Whether to skip Hessian recalculation after reset. Undocumented.
119
+ small_eigval_thresh
120
+ Threshold for small eigenvalues. Eigenvectors belonging to eigenvalues
121
+ below this threshold are discardewd.
122
+ line_search
123
+ Whether to carry out a line search. Not implemented by a subclassing
124
+ optimizers.
125
+ alpha0
126
+ Initial alpha for restricted-step (RS) procedure.
127
+ max_micro_cycles
128
+ Maximum number of RS iterations.
129
+ rfo_overlaps
130
+ Enable mode-following in RS procedure.
131
+
132
+ Other Parameters
133
+ ----------------
134
+ **kwargs
135
+ Keyword arguments passed to the Optimizer baseclass.
136
+ """
137
+ super().__init__(geometry, **kwargs)
138
+
139
+ assert not issubclass(
140
+ type(geometry), ChainOfStates
141
+ ), "HessianOptimizer can't be used for and ChainOfStates objects!"
142
+
143
+ self.trust_update = bool(trust_update)
144
+ assert trust_min <= trust_max, "trust_min must be <= trust_max!"
145
+ self.trust_min = float(trust_min)
146
+ self.trust_max = float(trust_max)
147
+ self.max_energy_incr = max_energy_incr
148
+ # Constrain initial trust radius if trust_max > trust_radius
149
+ self.trust_radius = min(trust_radius, trust_max)
150
+ self.log(f"Initial trust radius: {self.trust_radius:.6f}")
151
+ self.hessian_update = hessian_update
152
+ self.hessian_update_func = HESS_UPDATE_FUNCS[hessian_update]
153
+ self.hessian_init = hessian_init
154
+ self.hessian_recalc = hessian_recalc
155
+ self.hessian_recalc_adapt = hessian_recalc_adapt
156
+ self.hessian_xtb = hessian_xtb
157
+ self.hessian_recalc_reset = hessian_recalc_reset
158
+ self.small_eigval_thresh = float(small_eigval_thresh)
159
+ self.line_search = bool(line_search)
160
+ # Restricted-step related
161
+ self.alpha0 = alpha0
162
+ self.max_micro_cycles = int(max_micro_cycles)
163
+ assert max_micro_cycles >= 0
164
+ self.rfo_overlaps = rfo_overlaps
165
+
166
+ assert self.small_eigval_thresh > 0.0, "small_eigval_thresh must be > 0.!"
167
+ if not self.restarted:
168
+ self.hessian_recalc_in = None
169
+ self.adapt_norm = None
170
+ self.predicted_energy_changes = list()
171
+ hessian_init_exists = Path(self.hessian_init).exists()
172
+ if (
173
+ # Allow actually calculated Hessians for all coordinate systems
174
+ not hessian_init_exists
175
+ and self.hessian_init not in ("calc", "xtb", "xtb1", "xtbff")
176
+ # But disable model Hessian for Cartesian optimizations
177
+ and self.geometry.coord_type in ("cart", "cartesian", "mwcartesian")
178
+ ):
179
+ self.hessian_init = "unit"
180
+ self.log(
181
+ f"Chosen initial (model) Hessian is incompatible with current "
182
+ f"coord_type: {self.geometry.coord_type}!"
183
+ )
184
+
185
+ self._prev_eigvec_min = None
186
+ self._prev_eigvec_max = None
187
+ self._using_active_dofs = False
188
+ self._active_dof_indices = None
189
+ self.cur_H = None
190
+
191
+ @property
192
+ def using_active_dofs(self):
193
+ return self._using_active_dofs
194
+
195
+ @property
196
+ def active_dof_indices(self):
197
+ return self._active_dof_indices
198
+
199
+ def _set_active_dofs(self, use_active):
200
+ self._using_active_dofs = use_active
201
+ if not use_active:
202
+ self._active_dof_indices = None
203
+ return
204
+ if getattr(self.geometry, "within_partial_hessian", None) is not None:
205
+ self._active_dof_indices = self.geometry.hess_active_dof_indices
206
+ return
207
+ # Fallback: infer active DOFs from calculator's Hessian-active atoms
208
+ calc = getattr(self.geometry, "calculator", None)
209
+ core = getattr(calc, "core", calc)
210
+ hess_atoms = getattr(core, "hess_active_atoms", None)
211
+ if hess_atoms is not None and len(hess_atoms) > 0:
212
+ act = []
213
+ for a in hess_atoms:
214
+ base = 3 * int(a)
215
+ act.extend([base, base + 1, base + 2])
216
+ self._active_dof_indices = np.asarray(act, dtype=int)
217
+ return
218
+ self._active_dof_indices = self.geometry.active_dof_indices
219
+
220
+ def active_from_full(self, vec):
221
+ if not self.using_active_dofs:
222
+ return vec
223
+ inds = self._active_dof_indices
224
+ if inds is None:
225
+ return vec
226
+ # Sanitize indices (drop negatives / out-of-bounds)
227
+ try:
228
+ inds = np.asarray(inds, dtype=int)
229
+ if len(inds) > 0:
230
+ if np.any(inds < 0):
231
+ inds = inds[inds >= 0]
232
+ if len(inds) > 0:
233
+ max_valid = vec.shape[0] - 1
234
+ if np.any(inds > max_valid):
235
+ inds = inds[inds <= max_valid]
236
+ if len(inds) > 0:
237
+ if np.min(inds) < 0 or np.max(inds) >= vec.shape[0]:
238
+ return vec
239
+ except (ValueError, IndexError, TypeError):
240
+ pass
241
+ # Avoid double-slicing if vector is already in active space
242
+ try:
243
+ if vec.shape[0] == len(inds):
244
+ return vec
245
+ except (ValueError, IndexError, TypeError):
246
+ pass
247
+ try:
248
+ if len(inds) > 0 and vec.shape[0] <= int(np.max(inds)):
249
+ # Indices exceed vector length → already active (compact) vector.
250
+ return vec
251
+ except (ValueError, IndexError, TypeError):
252
+ pass
253
+ if isinstance(vec, torch.Tensor):
254
+ if vec.device.type == "cuda":
255
+ try:
256
+ vec_cpu = vec.detach().cpu().numpy()
257
+ return torch.as_tensor(vec_cpu[inds], dtype=vec.dtype, device=vec.device)
258
+ except (ValueError, IndexError, TypeError):
259
+ return vec
260
+ idx = torch.as_tensor(inds, dtype=torch.long, device=vec.device)
261
+ return vec.index_select(0, idx)
262
+ return vec[inds]
263
+
264
+ def full_from_active(self, vec):
265
+ if not self.using_active_dofs:
266
+ return vec
267
+ inds = self._active_dof_indices
268
+ if inds is None:
269
+ return vec
270
+ if isinstance(vec, torch.Tensor):
271
+ idx = torch.as_tensor(inds, dtype=torch.long, device=vec.device)
272
+ full = torch.zeros(self.geometry.cart_coords.size, dtype=vec.dtype, device=vec.device)
273
+ full.index_copy_(0, idx, vec)
274
+ return full
275
+ full = np.zeros(self.geometry.cart_coords.size, dtype=vec.dtype if hasattr(vec, "dtype") else float)
276
+ full[inds] = vec
277
+ return full
278
+
279
+ def active_hessian(self, hessian):
280
+ if not self.using_active_dofs:
281
+ return hessian
282
+
283
+ if getattr(self.geometry, "within_partial_hessian", None) is not None:
284
+ act_n_dof = int(self.geometry.within_partial_hessian.get("active_n_dof", 0))
285
+ if hessian.shape == (act_n_dof, act_n_dof):
286
+ return hessian
287
+
288
+ inds = self.active_dof_indices
289
+ try:
290
+ inds_arr = np.asarray(inds, dtype=int)
291
+ if hessian.shape[0] == len(inds_arr) and len(inds_arr) > 0:
292
+ if np.max(inds_arr) >= hessian.shape[0]:
293
+ # Likely already in active order; avoid double-slicing.
294
+ return hessian
295
+ except (ValueError, IndexError, TypeError):
296
+ pass
297
+ try:
298
+ inds = np.asarray(inds, dtype=int)
299
+ if len(inds) > 0:
300
+ max_valid = hessian.shape[0] - 1
301
+ inds = inds[(inds >= 0) & (inds <= max_valid)]
302
+ except (ValueError, IndexError, TypeError):
303
+ pass
304
+ if isinstance(hessian, torch.Tensor):
305
+ if hessian.device.type == "cuda":
306
+ try:
307
+ hess_cpu = hessian.detach().cpu().numpy()
308
+ return torch.as_tensor(hess_cpu[np.ix_(inds, inds)], dtype=hessian.dtype, device=hessian.device)
309
+ except (ValueError, IndexError, TypeError):
310
+ return hessian
311
+ idx = torch.as_tensor(inds, device=hessian.device, dtype=torch.int64)
312
+ return hessian.index_select(0, idx).index_select(1, idx)
313
+ return hessian[np.ix_(inds, inds)]
314
+
315
+ def active_list(self, seq):
316
+ if not self.using_active_dofs:
317
+ return seq
318
+ return [self.active_from_full(item) for item in seq]
319
+
320
+ @property
321
+ def prev_eigvec_min(self):
322
+ return self._prev_eigvec_min
323
+
324
+ @prev_eigvec_min.setter
325
+ def prev_eigvec_min(self, prev_eigvec_min):
326
+ if self.rfo_overlaps:
327
+ self._prev_eigvec_min = prev_eigvec_min
328
+
329
+ @property
330
+ def prev_eigvec_max(self):
331
+ return self._prev_eigvec_max
332
+
333
+ @prev_eigvec_max.setter
334
+ def prev_eigvec_max(self, prev_eigvec_max):
335
+ if self.rfo_overlaps:
336
+ self._prev_eigvec_max = prev_eigvec_max
337
+
338
+ def reset(self):
339
+ # Don't recalculate the hessian if we have to reset the optimizer
340
+ hessian_init = self.hessian_init
341
+ if (
342
+ (not self.hessian_recalc_reset)
343
+ and hessian_init == "calc"
344
+ and self.geometry.coord_type != "cart"
345
+ ):
346
+ hessian_init = "fischer"
347
+ self.prepare_opt(hessian_init)
348
+
349
+ # def save_hessian(self):
350
+ # # Don't try to save Hessians of analytical potentials
351
+ # if self.geometry.is_analytical_2d:
352
+ # return
353
+
354
+ # h5_fn = self.get_path_for_fn(f"hess_calc_cyc_{self.cur_cycle}.h5")
355
+ # # Save the cartesian hessian, as it is independent of the
356
+ # # actual coordinate system that is used.
357
+ # save_hessian(
358
+ # h5_fn,
359
+ # self.geometry,
360
+ # self.geometry.cart_hessian,
361
+ # self.geometry.energy,
362
+ # self.geometry.calculator.mult,
363
+ # )
364
+ # self.log(f"Wrote calculated cartesian Hessian to '{h5_fn}'")
365
+
366
+ def prepare_opt(self, hessian_init=None):
367
+ if hessian_init is None:
368
+ hessian_init = self.hessian_init
369
+
370
+ self.H, hess_str = get_guess_hessian(self.geometry, hessian_init)
371
+ if self.hessian_init != "calc" and self.geometry.is_analytical_2d:
372
+ assert self.H.shape == (3, 3)
373
+ self.H[2, 2] = 0.0
374
+
375
+ msg = f"Using {hess_str} Hessian"
376
+ if hess_str == "saved":
377
+ msg += f" from '{hessian_init}'"
378
+ self.log(msg)
379
+
380
+ # # Dump to disk if hessian was calculated
381
+ # if self.hessian_init == "calc":
382
+ # self.save_hessian()
383
+
384
+ if (
385
+ hasattr(self.geometry, "coord_type")
386
+ and self.geometry.coord_type == "dlc"
387
+ # Calculated Hessian is already in DLC
388
+ and hessian_init != "calc"
389
+ ):
390
+ U = self.geometry.internal.U
391
+ self.H = U.T.dot(self.H).dot(U)
392
+
393
+ if self.hessian_recalc_adapt:
394
+ self.adapt_norm = np.linalg.norm(self.geometry.forces)
395
+
396
+ if self.hessian_recalc:
397
+ # Already substract one, as we don't do a hessian update in
398
+ # the first cycle.
399
+ self.hessian_recalc_in = self.hessian_recalc - 1
400
+
401
+ def _get_opt_restart_info(self):
402
+ opt_restart_info = {
403
+ "adapt_norm": self.adapt_norm,
404
+ "H": self.H.tolist(),
405
+ "hessian_recalc_in": self.hessian_recalc_in,
406
+ "predicted_energy_changes": self.predicted_energy_changes,
407
+ }
408
+ return opt_restart_info
409
+
410
+ def _set_opt_restart_info(self, opt_restart_info):
411
+ self.adapt_norm = opt_restart_info["adapt_norm"]
412
+ self.H = np.array(opt_restart_info["H"])
413
+ self.hessian_recalc_in = opt_restart_info["hessian_recalc_in"]
414
+ self.predicted_energy_changes = opt_restart_info["predicted_energy_changes"]
415
+
416
+ def update_trust_radius(self):
417
+ # The predicted change should be calculated at the end of optimize
418
+ # of the previous cycle.
419
+ assert (
420
+ len(self.predicted_energy_changes) == len(self.forces) - 1
421
+ ), "Did you forget to append to self.predicted_energy_changes?"
422
+ self.log("Trust radius update")
423
+ self.log(f"\tCurrent trust radius: {self.trust_radius:.6f}")
424
+ predicted_change = self.predicted_energy_changes[-1]
425
+ actual_change = self.energies[-1] - self.energies[-2]
426
+ # Only report an unexpected increase if we actually predicted a
427
+ # decrease.
428
+ unexpected_increase = (actual_change > 0) and (predicted_change < 0)
429
+ old_trust = self.trust_radius
430
+ if unexpected_increase:
431
+ self.log(f"Energy increased by {actual_change:.6f} au!")
432
+ if self.max_energy_incr and (actual_change > self.max_energy_incr):
433
+ raise OptimizationError("Actual energy change too high!")
434
+ coeff = actual_change / predicted_change
435
+ self.log(f"\tPredicted change: {predicted_change:.4e} au")
436
+ self.log(f"\tActual change: {actual_change:.4e} au")
437
+ self.log(f"\tCoefficient: {coeff:.2%}")
438
+ step = self.steps[-1]
439
+ last_step_norm = np.linalg.norm(step)
440
+ self.set_new_trust_radius(coeff, last_step_norm)
441
+ if unexpected_increase:
442
+ self.table.print(
443
+ f"Unexpected energy increase ({actual_change:.6f} au)! "
444
+ f"Trust radius: old={old_trust:.4}, new={self.trust_radius:.4}"
445
+ )
446
+
447
+ def set_new_trust_radius(self, coeff, last_step_norm):
448
+ # Nocedal, Numerical optimization Chapter 4, Algorithm 4.1
449
+
450
+ # If actual and predicted energy change have different signs
451
+ # coeff will be negative and lead to a decreased trust radius,
452
+ # which is fine.
453
+ if coeff < 0.25:
454
+ self.trust_radius = max(self.trust_radius / 4, self.trust_min)
455
+ self.log("\tDecreasing trust radius.")
456
+ # Only increase trust radius if last step norm was at least 80% of it
457
+ # See [5], Appendix, step size and direction control
458
+ # elif coeff > 0.75 and (last_step_norm >= .8*self.trust_radius):
459
+ #
460
+ # Only increase trust radius if last step norm corresponded approximately
461
+ # to the trust radius.
462
+ elif coeff > 0.75 and abs(self.trust_radius - last_step_norm) <= 1e-3:
463
+ self.trust_radius = min(self.trust_radius * 2, self.trust_max)
464
+ self.log("\tIncreasing trust radius.")
465
+ else:
466
+ self.log(f"\tKeeping current trust radius at {self.trust_radius:.6f}")
467
+ return
468
+ self.log(f"\tUpdated trust radius: {self.trust_radius:.6f}")
469
+
470
+ def update_hessian(self):
471
+ # Compare current forces to reference forces to see if we shall recalc the
472
+ # hessian.
473
+ try:
474
+ cur_norm = np.linalg.norm(self.forces[-1])
475
+ ref_norm = self.adapt_norm / self.hessian_recalc_adapt
476
+ recalc_adapt = cur_norm <= ref_norm
477
+ self.log(
478
+ "Check for adaptive Hessian recalculation: "
479
+ f"{cur_norm:.6f} <= {ref_norm:.6f}, {recalc_adapt}"
480
+ )
481
+ except TypeError:
482
+ recalc_adapt = False
483
+
484
+ try:
485
+ self.hessian_recalc_in = max(self.hessian_recalc_in - 1, 0)
486
+ self.log(f"Recalculation of Hessian in {self.hessian_recalc_in} cycle(s).")
487
+ except TypeError:
488
+ self.hessian_recalc_in = None
489
+
490
+ # Update reference norm if needed
491
+ # TODO: Decide on whether to update the norm when the recalculation is
492
+ # initiated by 'recalc'.
493
+ if recalc_adapt:
494
+ self.adapt_norm = cur_norm
495
+
496
+ recalc = self.hessian_recalc_in == 0
497
+
498
+ if recalc or recalc_adapt:
499
+ # Free old Hessian from GPU before recalculating
500
+ H_old = self.H
501
+ self.H = None
502
+ del H_old
503
+ try:
504
+ import torch
505
+ if torch.cuda.is_available():
506
+ torch.cuda.empty_cache()
507
+ except ImportError:
508
+ pass
509
+ # Use xtb hessian
510
+ self.log("Requested Hessian recalculation.")
511
+ if self.hessian_xtb:
512
+ self.H = xtb_hessian(self.geometry)
513
+ key = "xtb"
514
+ # Calculated hessian at actual level of theory
515
+ else:
516
+ self.H = self.geometry.hessian
517
+ key = "exact"
518
+ # self.save_hessian()
519
+ if self.using_active_dofs:
520
+ # Keep the optimizer Hessian in active-DOF space to avoid
521
+ # shape mismatches during quasi-Newton updates.
522
+ self.H = self.active_hessian(self.H)
523
+ if not (self.cur_cycle == 0):
524
+ self.log(f"Recalculated {key} Hessian in cycle {self.cur_cycle}.")
525
+ # Reset counter. It is also reset when the recalculation was initiated
526
+ # by the adaptive formulation.
527
+ self.hessian_recalc_in = self.hessian_recalc
528
+ # Simple hessian update
529
+ else:
530
+ dx = self.steps[-1]
531
+ dg = -(self.forces[-1] - self.forces[-2])
532
+ H_work = self.H
533
+ if self.using_active_dofs:
534
+ H_work = self.active_hessian(self.H)
535
+ dx = self.active_from_full(dx)
536
+ dg = self.active_from_full(dg)
537
+ curv_cond = dx.dot(dg)
538
+ if curv_cond < 0.0:
539
+ self.log(
540
+ f"Curvature condition (s·y = {curv_cond:.4f} < 0) not satisfied!"
541
+ )
542
+ dH, key = self.hessian_update_func(H_work, dx, dg)
543
+ self.H = H_work + dH
544
+ self.log(f"Did {key} Hessian update.")
545
+
546
+ def solve_rfo(self, rfo_mat, kind="min", prev_eigvec=None):
547
+ # When using the restricted step variant of RFO the RFO matrix
548
+ # may not be symmetric. Thats why we can't use eigh here.
549
+ is_torch = isinstance(rfo_mat, torch.Tensor)
550
+
551
+ if is_torch:
552
+ if not torch.isfinite(rfo_mat).all():
553
+ self.log("RFO matrix contains NaN/inf; sanitizing entries.")
554
+ rfo_mat = torch.nan_to_num(
555
+ rfo_mat, nan=0.0, posinf=1e8, neginf=-1e8
556
+ )
557
+ else:
558
+ if not np.isfinite(rfo_mat).all():
559
+ self.log("RFO matrix contains NaN/inf; sanitizing entries.")
560
+ rfo_mat = np.nan_to_num(rfo_mat, nan=0.0, posinf=1e8, neginf=-1e8)
561
+
562
+ if is_torch:
563
+ is_sym = torch.allclose(rfo_mat, rfo_mat.T)
564
+ else:
565
+ is_sym = np.allclose(rfo_mat, rfo_mat.T)
566
+
567
+ if is_sym:
568
+ try:
569
+ eigenvalues, eigenvectors = (torch.linalg.eigh(rfo_mat) if is_torch else np.linalg.eigh(rfo_mat))
570
+ except (torch._C._LinAlgError, np.linalg.LinAlgError):
571
+ self.log("eigh failed; falling back to eig.")
572
+ eigenvalues, eigenvectors = (torch.linalg.eig(rfo_mat) if is_torch else np.linalg.eig(rfo_mat))
573
+ eigenvalues = eigenvalues.real
574
+ eigenvectors = eigenvectors.real
575
+ else:
576
+ eigenvalues, eigenvectors = (torch.linalg.eig(rfo_mat) if is_torch else np.linalg.eig(rfo_mat))
577
+ eigenvalues = eigenvalues.real
578
+ eigenvectors = eigenvectors.real
579
+
580
+ self.log("\tdiagonalized augmented Hessian")
581
+
582
+ if isinstance(eigenvectors, torch.Tensor):
583
+ sorted_inds = torch.argsort(eigenvalues)
584
+ else:
585
+ sorted_inds = np.argsort(eigenvalues)
586
+
587
+ # Depending on wether we want to minimize (maximize) along
588
+ # the mode(s) in the rfo mat we have to select the smallest
589
+ # (biggest) eigenvalue and corresponding eigenvector.
590
+ first_or_last, verbose = self.rfo_dict[kind]
591
+ # Given sorted eigenvalue-indices (sorted_inds) use the first
592
+ # (smallest eigenvalue) or the last (largest eigenvalue) index.
593
+ if prev_eigvec is None:
594
+ ind = sorted_inds[first_or_last]
595
+ else:
596
+ if isinstance(prev_eigvec, torch.Tensor):
597
+ ovlps = prev_eigvec.matmul(eigenvectors)
598
+ else:
599
+ ovlps = np.array([prev_eigvec.dot(ev) for ev in eigenvectors.T])
600
+ naive_ind = sorted_inds[first_or_last]
601
+ ind = np.abs(ovlps).argmax() if isinstance(ovlps, np.ndarray) else torch.argmax(torch.abs(ovlps)).item()
602
+ self.log(
603
+ f"Overlap: {ind} ({eigenvalues[ind]:.6f}), "
604
+ f"Naive: {naive_ind} ({eigenvalues[naive_ind]:.6f})"
605
+ )
606
+ follow_eigvec = eigenvectors.T[ind]
607
+ if isinstance(follow_eigvec, torch.Tensor):
608
+ step_nu = follow_eigvec.clone()
609
+ else:
610
+ step_nu = follow_eigvec.copy()
611
+ nu = step_nu[-1]
612
+ self.log(f"\tnu_{verbose}={nu:.8e}")
613
+ # Scale eigenvector so that its last element equals 1. The
614
+ # final is step is the scaled eigenvector without the last element.
615
+ step = step_nu[:-1] / nu
616
+ eigval = eigenvalues[ind]
617
+ self.log(f"\teigenvalue_{verbose}={eigval:.8e}")
618
+ return step, eigval, nu, follow_eigvec
619
+
620
+ def solve_rfo_secular(self, eigvals, gradient, alpha=1.0, kind="min",
621
+ prev_eigvec=None, max_iter=50, tol=1e-12):
622
+ """Solve the RFO eigenvalue problem via the secular equation.
623
+
624
+ The augmented Hessian has arrowhead structure, so its eigenvalue
625
+ problem reduces to: f(mu) = sum g_i^2/(alpha*mu - lam_i) - mu = 0,
626
+ solvable in O(N) instead of O(N^3).
627
+
628
+ Returns (step, eigval, nu, eigvec) on success, None on failure.
629
+ """
630
+ is_torch = isinstance(eigvals, torch.Tensor)
631
+
632
+ # Convert all inputs to plain numpy/float for root-finding
633
+ _np = lambda x: x.detach().cpu().numpy().copy() if isinstance(x, torch.Tensor) else np.asarray(x, dtype=np.float64)
634
+ g = _np(gradient)
635
+ lam = _np(eigvals)
636
+ alpha = alpha.detach().cpu().item() if isinstance(alpha, torch.Tensor) else float(alpha)
637
+
638
+ n = len(lam)
639
+ g2 = g ** 2
640
+ nz = g2 > 1e-30
641
+ if not nz.any():
642
+ step = np.zeros(n)
643
+ eigvec = np.zeros(n + 1); eigvec[-1] = 1.0
644
+ if is_torch:
645
+ step = torch.zeros(n, device=eigvals.device, dtype=eigvals.dtype)
646
+ eigvec = torch.zeros(n + 1, device=eigvals.device, dtype=eigvals.dtype)
647
+ eigvec[-1] = 1.0
648
+ return step, 0.0, 1.0, eigvec
649
+
650
+ g2_nz, lam_nz = g2[nz], lam[nz]
651
+
652
+ # Guard against alpha ≈ 0 (can arise from trust-radius adaptation)
653
+ if abs(alpha) < 1e-14:
654
+ return None
655
+
656
+ def f_df(mu):
657
+ d = alpha * mu - lam_nz
658
+ return float(np.sum(g2_nz / d) - mu), float(-alpha * np.sum(g2_nz / d**2) - 1.0)
659
+
660
+ _, verbose = self.rfo_dict[kind]
661
+
662
+ # Bracket the root
663
+ if kind == "min":
664
+ pole = float(lam.min() / alpha)
665
+ mu = pole - max(float(np.sqrt(g2_nz.sum())) / alpha, 1.0)
666
+ for _ in range(20):
667
+ if f_df(mu)[0] > 0: break
668
+ mu = pole - 2.0 * (pole - mu)
669
+ else:
670
+ return None
671
+ lo, hi = mu, pole - 1e-15 * max(abs(pole), 1.0)
672
+ elif kind == "max":
673
+ pole = float(lam.max() / alpha)
674
+ mu = pole + max(float(np.sqrt(g2_nz.sum())) / alpha, 1.0)
675
+ for _ in range(20):
676
+ if f_df(mu)[0] < 0: break
677
+ mu = pole + 2.0 * (mu - pole)
678
+ else:
679
+ return None
680
+ lo, hi = pole + 1e-15 * max(abs(pole), 1.0), mu
681
+ else:
682
+ return None
683
+
684
+ # Newton-Raphson with bisection safeguard
685
+ mu_cur = (lo + hi) / 2.0
686
+ for _ in range(max_iter):
687
+ fval, dfval = f_df(mu_cur)
688
+ if abs(fval) < tol:
689
+ break
690
+ mu_new = mu_cur - fval / dfval if abs(dfval) > 1e-30 else (lo + hi) / 2.0
691
+ if mu_new <= lo or mu_new >= hi:
692
+ mu_new = (lo + hi) / 2.0
693
+ f_new = f_df(mu_new)[0]
694
+ if f_new > 0: lo = mu_new
695
+ else: hi = mu_new
696
+ mu_cur = mu_new
697
+ else:
698
+ self.log(f"Secular equation did not converge in {max_iter} iters.")
699
+ return None
700
+
701
+ self.log(f"\teigenvalue_{verbose}={mu_cur:.8e} (secular)")
702
+ self.log(f"\tnu_{verbose}={1.0:.8e}")
703
+
704
+ # Compute step: s_i = g_i / (alpha * mu - lam_i)
705
+ denom = alpha * mu_cur - lam
706
+ step_np = np.where(nz, g / denom, 0.0)
707
+
708
+ # Eigenvector for mode tracking
709
+ eigvec_np = np.append(step_np, 1.0)
710
+ eigvec_np /= np.linalg.norm(eigvec_np)
711
+
712
+ # Mode tracking check
713
+ if prev_eigvec is not None:
714
+ prev_np = _np(prev_eigvec)
715
+ if abs(float(np.dot(prev_np, eigvec_np))) < 0.3:
716
+ self.log("Secular eigvec overlap too low; falling back.")
717
+ return None
718
+
719
+ if is_torch:
720
+ step_np = torch.tensor(step_np, device=eigvals.device, dtype=eigvals.dtype)
721
+ eigvec_np = torch.tensor(eigvec_np, device=eigvals.device, dtype=eigvals.dtype)
722
+
723
+ return step_np, mu_cur, 1.0, eigvec_np
724
+
725
+ def filter_small_eigvals(self, eigvals, eigvecs, mask=False):
726
+ if isinstance(eigvals, torch.Tensor):
727
+ small_inds = torch.abs(eigvals) < self.small_eigval_thresh
728
+ else:
729
+ small_inds = np.abs(eigvals) < self.small_eigval_thresh
730
+ eigvals = eigvals[~small_inds]
731
+ eigvecs = eigvecs[:, ~small_inds]
732
+ small_num = sum(small_inds)
733
+ self.log(
734
+ f"Found {small_num} small eigenvalues in Hessian. Removed "
735
+ "corresponding eigenvalues and eigenvectors."
736
+ )
737
+ # assert small_num <= 6, (
738
+ # "Expected at most 6 small eigenvalues in cartesian hessian "
739
+ # f"but found {small_num}!"
740
+ # )
741
+ if mask:
742
+ return eigvals, eigvecs, small_inds
743
+ else:
744
+ return eigvals, eigvecs
745
+
746
+ def log_negative_eigenvalues(self, eigvals, pre_str=""):
747
+ neg_inds = eigvals < -self.small_eigval_thresh
748
+ neg_eigval_str = array2string(eigvals[neg_inds], precision=6)
749
+ self.log(f"{pre_str}Hessian has {neg_inds.sum()} negative eigenvalue(s).")
750
+ self.log(f"\t{neg_eigval_str}")
751
+
752
+ def housekeeping(self):
753
+ """Calculate gradient and energy. Update trust radius and hessian
754
+ if needed. Return energy, gradient and hessian for the current cycle."""
755
+ gradient_full = self.geometry.gradient
756
+ energy = self.geometry.energy
757
+ self.energies.append(energy)
758
+ self.log(f" Energy: {energy: >12.6f} au")
759
+ self.log(
760
+ f"norm(grad): {np.linalg.norm(gradient_full): >12.6f} au / bohr (rad)"
761
+ )
762
+ self.log(
763
+ f" rms(grad): {np.sqrt(np.mean(gradient_full**2)): >12.6f} au / bohr (rad)"
764
+ )
765
+ self.forces.append(-gradient_full)
766
+
767
+ can_update = (
768
+ # Allows gradient differences
769
+ len(self.forces) > 1
770
+ and (self.forces[-2].shape == gradient_full.shape)
771
+ and len(self.coords) > 1
772
+ # Coordinates may have been rebuilt. Take care of that.
773
+ and (self.coords[-2].shape == self.coords[1].shape)
774
+ and len(self.energies) > 1
775
+ )
776
+ if can_update:
777
+ if self.trust_update:
778
+ self.update_trust_radius()
779
+ self.update_hessian()
780
+
781
+ # Convert gradient to match H device/dtype AFTER update_hessian(),
782
+ # so that hessian_recalc (which may replace self.H with a new tensor
783
+ # on a different device) is accounted for.
784
+ if isinstance(self.H, torch.Tensor):
785
+ gradient_full = torch.from_numpy(gradient_full).to(
786
+ self.H.device, self.H.dtype
787
+ )
788
+
789
+ H = self.H
790
+ if self.geometry.internal:
791
+ # Shift eigenvalues of orthogonal part to high values, so they
792
+ # don't contribute to the actual step.
793
+ H_proj = self.geometry.internal.project_hessian(self.H)
794
+ # Symmetrize hessian, as the projection may break it?!
795
+ H = (H_proj + H_proj.T) / 2
796
+
797
+ if getattr(self.geometry, "within_partial_hessian", None) is not None:
798
+ use_active = True
799
+ elif (
800
+ H.shape[0] != self.geometry.cart_coords.size
801
+ and self.geometry.coord_type in ("cart", "cartesian", "mwcartesian")
802
+ ):
803
+ # Partial Hessian without explicit metadata: still use active slicing.
804
+ # Only applies to Cartesian coordinate types; for internal coordinates
805
+ # (e.g. DLC), the Hessian is naturally smaller than cart_coords.size.
806
+ use_active = True
807
+ else:
808
+ use_active = (
809
+ len(self.geometry.freeze_atoms) > 0
810
+ and self.geometry.coord_type in ("cart", "cartesian", "mwcartesian")
811
+ and H.shape[0] == self.geometry.cart_coords.size
812
+ )
813
+ self._set_active_dofs(use_active)
814
+
815
+ H = self.active_hessian(H)
816
+ if gradient_full.shape[0] == H.shape[0]:
817
+ gradient = gradient_full
818
+ else:
819
+ gradient = self.active_from_full(gradient_full)
820
+
821
+ if isinstance(H, torch.Tensor):
822
+ eigvals, eigvecs = torch.linalg.eigh(H)
823
+ else:
824
+ eigvals, eigvecs = np.linalg.eigh(H)
825
+ # Neglect small eigenvalues
826
+ eigvals, eigvecs = self.filter_small_eigvals(eigvals, eigvecs)
827
+
828
+ resetted = not can_update
829
+ self.cur_H = H
830
+ return energy, gradient, H, eigvals, eigvecs, resetted
831
+
832
+ def get_augmented_hessian(self, eigvals, gradient, alpha=1.0):
833
+ if isinstance(gradient, torch.Tensor):
834
+ dim_ = eigvals.size(0) + 1
835
+ H_aug = torch.zeros((dim_, dim_), device=gradient.device, dtype=gradient.dtype)
836
+ H_aug[: dim_ - 1, : dim_ - 1] = torch.diag(eigvals / alpha)
837
+ else:
838
+ dim_ = eigvals.size + 1
839
+ H_aug = np.zeros((dim_, dim_))
840
+ H_aug[: dim_ - 1, : dim_ - 1] = np.diag(eigvals / alpha)
841
+ H_aug[-1, :-1] = gradient
842
+ H_aug[:-1, -1] = gradient
843
+
844
+ H_aug[:-1, -1] /= alpha
845
+
846
+ return H_aug
847
+
848
+ def get_alpha_step(self, cur_alpha, rfo_eigval, step_norm, eigvals, gradient):
849
+ # Derivative of the squared step w.r.t. alpha
850
+ numer = gradient**2
851
+ denom = (eigvals - rfo_eigval * cur_alpha) ** 3
852
+ if isinstance(gradient, torch.Tensor):
853
+ quot = torch.sum(numer / denom)
854
+ else:
855
+ quot = np.sum(numer / denom)
856
+ self.log(f"quot={quot:.6f}")
857
+ dstep2_dalpha = 2 * rfo_eigval / (1 + step_norm**2 * cur_alpha) * quot
858
+ if isinstance(gradient, torch.Tensor):
859
+ dstep2_valid = bool(
860
+ torch.isfinite(dstep2_dalpha)
861
+ & (torch.abs(dstep2_dalpha) > 1e-12)
862
+ )
863
+ else:
864
+ dstep2_valid = np.isfinite(dstep2_dalpha) and abs(dstep2_dalpha) > 1e-12
865
+ if not dstep2_valid:
866
+ self.log(
867
+ "alpha update skipped due to invalid derivative "
868
+ f"(dstep2_dalpha={dstep2_dalpha})"
869
+ )
870
+ return 0.0
871
+ self.log(f"analytic deriv.={dstep2_dalpha:.6f}")
872
+ # Update alpha
873
+ alpha_step = (
874
+ 2 * (self.trust_radius * step_norm - step_norm**2) / dstep2_dalpha
875
+ )
876
+ self.log(f"alpha_step={alpha_step:.4f}")
877
+ min_alpha = 1e-8
878
+ if (cur_alpha + alpha_step) <= min_alpha:
879
+ self.log(
880
+ "alpha update would make alpha non-positive; "
881
+ f"clamping to min_alpha={min_alpha:.1e}"
882
+ )
883
+ alpha_step = min_alpha - cur_alpha
884
+ return alpha_step
885
+
886
+ def get_rs_step(self, eigvals, eigvecs, gradient, name="RS"):
887
+ # Transform gradient to basis of eigenvectors
888
+ if isinstance(eigvecs, torch.Tensor):
889
+ if not isinstance(gradient, torch.Tensor):
890
+ gradient = torch.as_tensor(
891
+ gradient, device=eigvecs.device, dtype=eigvecs.dtype
892
+ )
893
+ elif gradient.device != eigvecs.device:
894
+ gradient = gradient.to(device=eigvecs.device)
895
+ gradient_ = eigvecs.T @ gradient
896
+ else:
897
+ gradient_ = eigvecs.T.dot(gradient)
898
+
899
+ alpha = self.alpha0
900
+ for mu in range(self.max_micro_cycles):
901
+ self.log(f"{name} micro cycle {mu:02d}, alpha={alpha:.6f}")
902
+ # Try secular equation solver first (O(N) vs O(N^3))
903
+ secular_result = self.solve_rfo_secular(
904
+ eigvals, gradient_, alpha, kind="min",
905
+ prev_eigvec=self.prev_eigvec_min,
906
+ )
907
+ if secular_result is not None:
908
+ rfo_step_, eigval_min, nu, self.prev_eigvec_min = secular_result
909
+ else:
910
+ # Fallback to full eigendecomposition
911
+ self.log("Secular solver failed; using full eigendecomposition.")
912
+ H_aug = self.get_augmented_hessian(eigvals, gradient_, alpha)
913
+ rfo_step_, eigval_min, nu, self.prev_eigvec_min = self.solve_rfo(
914
+ H_aug, "min", prev_eigvec=self.prev_eigvec_min
915
+ )
916
+ if isinstance(rfo_step_, torch.Tensor):
917
+ rfo_norm_ = torch.linalg.norm(rfo_step_)
918
+ else:
919
+ rfo_norm_ = np.linalg.norm(rfo_step_)
920
+ self.log(f"norm(rfo step)={rfo_norm_:.6f}")
921
+
922
+ if rfo_norm_ <= 0:
923
+ self.log(
924
+ "RFO step length is zero; falling back to trust-region Newton step."
925
+ )
926
+ step_ = self.get_newton_step_on_trust(
927
+ eigvals, eigvecs, gradient, transform=False
928
+ )
929
+ break
930
+
931
+ if (rfo_norm_ < self.trust_radius) or abs(
932
+ rfo_norm_ - self.trust_radius
933
+ ) <= 1e-3:
934
+ step_ = rfo_step_
935
+ break
936
+
937
+ alpha_step = self.get_alpha_step(
938
+ alpha, eigval_min, rfo_norm_, eigvals, gradient_
939
+ )
940
+ alpha += alpha_step
941
+ self.log("")
942
+ # Otherwise, use trust region newton step
943
+ else:
944
+ self.log(
945
+ "RS algorithm did not produce a desired step length in "
946
+ f"{self.max_micro_cycles} micro cycles. Trying RFO with α=1.0."
947
+ )
948
+ secular_result = self.solve_rfo_secular(
949
+ eigvals, gradient_, alpha=1.0, kind="min"
950
+ )
951
+ if secular_result is not None:
952
+ rfo_step_, eigval_min, nu, _ = secular_result
953
+ else:
954
+ H_aug = self.get_augmented_hessian(eigvals, gradient_, alpha=1.0)
955
+ rfo_step_, eigval_min, nu, _ = self.solve_rfo(H_aug, "min")
956
+ if isinstance(rfo_step_, torch.Tensor):
957
+ rfo_norm_ = torch.linalg.norm(rfo_step_)
958
+ else:
959
+ rfo_norm_ = np.linalg.norm(rfo_step_)
960
+
961
+ # This should always be True if the above algorithm failed but we
962
+ # keep this line nonetheless, to make it more obvious.
963
+ if rfo_norm_ > self.trust_radius:
964
+ self.log(
965
+ f"Proposed RFO step with norm {rfo_norm_:.4f} is outside trust "
966
+ f"radius Δ={self.trust_radius:.4f}. "
967
+ )
968
+ step_ = self.get_newton_step_on_trust(
969
+ eigvals, eigvecs, gradient, transform=False
970
+ )
971
+ # Simple, downscaled RFO step
972
+ # step_ = rfo_step_ / rfo_norm_ * self.trust_radius
973
+ else:
974
+ step_ = rfo_step_
975
+
976
+ # Transform step back to original basis
977
+ if isinstance(eigvecs, torch.Tensor):
978
+ if isinstance(step_, torch.Tensor):
979
+ pass
980
+ else:
981
+ step_ = torch.tensor(step_, device=eigvecs.device, dtype=eigvecs.dtype)
982
+ step = eigvecs @ step_
983
+ step = step.cpu().numpy()
984
+ else:
985
+ step = eigvecs.dot(step_)
986
+ min_step_norm = getattr(self, "min_step_norm", 0.0)
987
+ step_norm = np.linalg.norm(step)
988
+ if step_norm <= min_step_norm:
989
+ self.log(
990
+ "RFO step length below minimum threshold; "
991
+ "falling back to trust-region Newton step."
992
+ )
993
+ step = self.get_newton_step_on_trust(eigvals, eigvecs, gradient)
994
+ if not np.isfinite(step).all():
995
+ self.log(
996
+ "RFO step contains NaN/inf; falling back to trust-region Newton step."
997
+ )
998
+ step = self.get_newton_step_on_trust(eigvals, eigvecs, gradient)
999
+ if not np.isfinite(step).all():
1000
+ raise ValueError(
1001
+ "Fallback Newton step still contains NaN/inf; "
1002
+ "aborting to avoid corrupting coordinates."
1003
+ )
1004
+ return step
1005
+
1006
+ @staticmethod
1007
+ def get_shifted_step_trans(eigvals, gradient_trans, shift):
1008
+ return -gradient_trans / (eigvals + shift)
1009
+
1010
+ @staticmethod
1011
+ def get_newton_step(eigvals, eigvecs, gradient):
1012
+ if isinstance(eigvecs, torch.Tensor):
1013
+ eigvals = eigvals.to(eigvecs.device, dtype=eigvecs.dtype)
1014
+ gradient = gradient.to(eigvecs.device, dtype=eigvecs.dtype)
1015
+ return (eigvecs @ (eigvecs.T @ gradient / eigvals)).cpu().numpy()
1016
+ else:
1017
+ return eigvecs.dot(eigvecs.T.dot(gradient) / eigvals)
1018
+
1019
+ def get_newton_step_on_trust(self, eigvals, eigvecs, gradient, transform=True):
1020
+ """Step on trust-radius.
1021
+
1022
+ See Nocedal 4.3 Iterative solutions of the subproblem
1023
+ """
1024
+ if isinstance(eigvals, torch.Tensor):
1025
+ eigvals = eigvals.cpu().numpy()
1026
+
1027
+ min_ind = eigvals.argmin()
1028
+ min_eigval = eigvals[min_ind]
1029
+ pos_definite = bool((eigvals > 0.0).all())
1030
+ if isinstance(eigvecs, torch.Tensor):
1031
+ if not isinstance(gradient, torch.Tensor):
1032
+ gradient = torch.tensor(
1033
+ gradient, device=eigvecs.device, dtype=eigvecs.dtype
1034
+ )
1035
+ else:
1036
+ gradient = gradient.to(device=eigvecs.device, dtype=eigvecs.dtype)
1037
+ gradient_trans = eigvecs.T @ gradient
1038
+ gradient_trans = gradient_trans.cpu().numpy()
1039
+ else:
1040
+ gradient_trans = eigvecs.T.dot(gradient)
1041
+
1042
+ # This will be also be True when we come close to a minimizer,
1043
+ # but then the Hessian will also be positive definite and a
1044
+ # simple Newton step will be used.
1045
+ hard_case = abs(gradient_trans[min_ind]) <= 1e-6
1046
+ self.log(f"Smallest eigenvalue: {min_eigval:.6f}")
1047
+ self.log(f"Positive definite Hessian: {pos_definite}")
1048
+ self.log(f"Hard case: {hard_case}")
1049
+
1050
+ def get_step(shift):
1051
+ return -gradient_trans / (eigvals + shift)
1052
+
1053
+ # Unshifted Newton step
1054
+ newton_step_trans = get_step(0.0)
1055
+ newton_norm = np.linalg.norm(newton_step_trans)
1056
+
1057
+ def on_trust_radius_lin(step):
1058
+ return 1 / self.trust_radius - 1 / np.linalg.norm(step)
1059
+
1060
+ def finalize_step(shift):
1061
+ step = get_step(shift)
1062
+ if transform:
1063
+ if isinstance(eigvecs, torch.Tensor):
1064
+ step = torch.tensor(step, device=eigvecs.device, dtype=eigvecs.dtype)
1065
+ return (eigvecs @ step).cpu().numpy()
1066
+ else:
1067
+ return eigvecs.dot(step)
1068
+ return step
1069
+
1070
+ # Simplest case. Positive definite Hessian and predicted step is
1071
+ # already in trust radius.
1072
+ if pos_definite and newton_norm <= self.trust_radius:
1073
+ self.log("Using unshifted Newton step.")
1074
+ if isinstance(eigvecs, torch.Tensor):
1075
+ newton_step_trans = torch.tensor(
1076
+ newton_step_trans, device=eigvecs.device, dtype=eigvecs.dtype
1077
+ )
1078
+ return (eigvecs @ newton_step_trans).cpu().numpy()
1079
+ else:
1080
+ return eigvecs.dot(newton_step_trans)
1081
+
1082
+ # If the Hessian is not positive definite or if the step is too
1083
+ # long we have to determine the shift parameter lambda.
1084
+ rs_kwargs = {
1085
+ "f": lambda shift: on_trust_radius_lin(get_step(shift)),
1086
+ "xtol": 1e-3,
1087
+ # Would otherwise be chosen automatically, but we set it
1088
+ # here explicitly for verbosity.
1089
+ "method": "brentq",
1090
+ }
1091
+
1092
+ def root_search(bracket):
1093
+ rs_kwargs.update(
1094
+ {
1095
+ "bracket": bracket,
1096
+ "x0": bracket[0] + 1e-3,
1097
+ }
1098
+ )
1099
+ res = root_scalar(**rs_kwargs)
1100
+ return res
1101
+
1102
+ BRACKET_END = 1e10
1103
+ if not hard_case:
1104
+ bracket_start = 0.0 if pos_definite else -min_eigval + 1e-2
1105
+ bracket = (bracket_start, BRACKET_END)
1106
+ try:
1107
+ res = root_search(bracket)
1108
+ assert res.converged
1109
+ return finalize_step(res.root)
1110
+ # ValueError may be raised when the function values for the
1111
+ # initial bracket have the same sign. If so, we continue with
1112
+ # treating it as a hard case.
1113
+ except ValueError:
1114
+ pass
1115
+
1116
+ # Now we would try the bracket (-b2, -b1). The resulting step should have
1117
+ # a suitable length, but the (shifted) Hessian would have an incorrect
1118
+ # eigenvalue spectrum (not positive definite). To solve this we use a
1119
+ # different formula to calculate the step.
1120
+ mask = np.ones_like(gradient_trans)
1121
+ mask[min_ind] = 0
1122
+ mask = mask.astype(bool)
1123
+ without_min = gradient_trans[mask] / (eigvals[mask] - min_eigval)
1124
+ tau_sq = self.trust_radius**2 - (without_min**2).sum()
1125
+ if tau_sq >= 0.0:
1126
+ tau = sqrt(tau_sq)
1127
+ step_trans = [tau] + (-without_min).tolist()
1128
+ else:
1129
+ # Hard case. Search in open interval (endpoints not included)
1130
+ # (-min_eigval, inf).
1131
+ bracket = (-min_eigval + 1e-6, BRACKET_END)
1132
+ try:
1133
+ res = root_search(bracket)
1134
+ if res.converged:
1135
+ return finalize_step(res.root)
1136
+ except ValueError:
1137
+ pass
1138
+ # Fallback: clamp tau to 0 so the step excludes the
1139
+ # minimum-eigenvalue component but remains valid.
1140
+ self.log("Hard case fallback: tau clamped to 0.")
1141
+ tau = 0.0
1142
+ step_trans = [tau] + (-without_min).tolist()
1143
+
1144
+ if not transform:
1145
+ return step_trans
1146
+
1147
+ if isinstance(eigvecs, torch.Tensor):
1148
+ step_trans = torch.tensor(step_trans, device=eigvecs.device, dtype=eigvecs.dtype)
1149
+ return (eigvecs @ step_trans).cpu().numpy()
1150
+ else:
1151
+ return eigvecs.dot(step_trans)
1152
+
1153
+ @staticmethod
1154
+ def quadratic_model(gradient, hessian, step):
1155
+ if isinstance(gradient, torch.Tensor):
1156
+ step = torch.tensor(step, device=gradient.device, dtype=gradient.dtype)
1157
+ return (step @ gradient + 0.5 * step @ hessian @ step).cpu().numpy()
1158
+ else:
1159
+ step = np.asarray(step).ravel()
1160
+ return step.dot(gradient) + 0.5 * step.dot(hessian).dot(step)
1161
+
1162
+ @staticmethod
1163
+ def rfo_model(gradient, hessian, step):
1164
+ return HessianOptimizer.quadratic_model(gradient, hessian, step) / (
1165
+ 1 + step.dot(step)
1166
+ )
1167
+
1168
+ def get_step_func(self, eigvals, gradient, grad_rms_thresh=1e-2):
1169
+ positive_definite = (eigvals < 0).sum() == 0
1170
+ gradient_small = rms(gradient) < grad_rms_thresh
1171
+
1172
+ if self.adapt_step_func and gradient_small and positive_definite:
1173
+ return self.get_newton_step_on_trust, self.quadratic_model
1174
+ # RFO fallback
1175
+ else:
1176
+ return self.get_rs_step, self.rfo_model