mlmm-toolkit 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (372) hide show
  1. hessian_ff/__init__.py +50 -0
  2. hessian_ff/analytical_hessian.py +609 -0
  3. hessian_ff/constants.py +46 -0
  4. hessian_ff/forcefield.py +339 -0
  5. hessian_ff/loaders.py +608 -0
  6. hessian_ff/native/Makefile +8 -0
  7. hessian_ff/native/__init__.py +28 -0
  8. hessian_ff/native/analytical_hessian.py +88 -0
  9. hessian_ff/native/analytical_hessian_ext.cpp +258 -0
  10. hessian_ff/native/bonded.py +82 -0
  11. hessian_ff/native/bonded_ext.cpp +640 -0
  12. hessian_ff/native/loader.py +349 -0
  13. hessian_ff/native/nonbonded.py +118 -0
  14. hessian_ff/native/nonbonded_ext.cpp +1150 -0
  15. hessian_ff/prmtop_parmed.py +23 -0
  16. hessian_ff/system.py +107 -0
  17. hessian_ff/terms/__init__.py +14 -0
  18. hessian_ff/terms/angle.py +73 -0
  19. hessian_ff/terms/bond.py +44 -0
  20. hessian_ff/terms/cmap.py +406 -0
  21. hessian_ff/terms/dihedral.py +141 -0
  22. hessian_ff/terms/nonbonded.py +209 -0
  23. hessian_ff/tests/__init__.py +0 -0
  24. hessian_ff/tests/conftest.py +75 -0
  25. hessian_ff/tests/data/small/complex.parm7 +1346 -0
  26. hessian_ff/tests/data/small/complex.pdb +125 -0
  27. hessian_ff/tests/data/small/complex.rst7 +63 -0
  28. hessian_ff/tests/test_coords_input.py +44 -0
  29. hessian_ff/tests/test_energy_force.py +49 -0
  30. hessian_ff/tests/test_hessian.py +137 -0
  31. hessian_ff/tests/test_smoke.py +18 -0
  32. hessian_ff/tests/test_validation.py +40 -0
  33. hessian_ff/workflows.py +889 -0
  34. mlmm/__init__.py +36 -0
  35. mlmm/__main__.py +7 -0
  36. mlmm/_version.py +34 -0
  37. mlmm/add_elem_info.py +374 -0
  38. mlmm/advanced_help.py +91 -0
  39. mlmm/align_freeze_atoms.py +601 -0
  40. mlmm/all.py +3535 -0
  41. mlmm/bond_changes.py +231 -0
  42. mlmm/bool_compat.py +223 -0
  43. mlmm/cli.py +574 -0
  44. mlmm/cli_utils.py +166 -0
  45. mlmm/default_group.py +337 -0
  46. mlmm/defaults.py +467 -0
  47. mlmm/define_layer.py +526 -0
  48. mlmm/dft.py +1041 -0
  49. mlmm/energy_diagram.py +253 -0
  50. mlmm/extract.py +2213 -0
  51. mlmm/fix_altloc.py +464 -0
  52. mlmm/freq.py +1406 -0
  53. mlmm/harmonic_constraints.py +140 -0
  54. mlmm/hessian_cache.py +44 -0
  55. mlmm/hessian_calc.py +174 -0
  56. mlmm/irc.py +638 -0
  57. mlmm/mlmm_calc.py +2262 -0
  58. mlmm/mm_parm.py +945 -0
  59. mlmm/oniom_export.py +1983 -0
  60. mlmm/oniom_import.py +457 -0
  61. mlmm/opt.py +1742 -0
  62. mlmm/path_opt.py +1353 -0
  63. mlmm/path_search.py +2299 -0
  64. mlmm/preflight.py +88 -0
  65. mlmm/py.typed +1 -0
  66. mlmm/pysis_runner.py +45 -0
  67. mlmm/scan.py +1047 -0
  68. mlmm/scan2d.py +1226 -0
  69. mlmm/scan3d.py +1265 -0
  70. mlmm/scan_common.py +184 -0
  71. mlmm/summary_log.py +736 -0
  72. mlmm/trj2fig.py +448 -0
  73. mlmm/tsopt.py +2871 -0
  74. mlmm/utils.py +2309 -0
  75. mlmm/xtb_embedcharge_correction.py +475 -0
  76. mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
  77. mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
  78. mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
  79. mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
  80. mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
  81. mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
  82. pysisyphus/Geometry.py +1667 -0
  83. pysisyphus/LICENSE +674 -0
  84. pysisyphus/TableFormatter.py +63 -0
  85. pysisyphus/TablePrinter.py +74 -0
  86. pysisyphus/__init__.py +12 -0
  87. pysisyphus/calculators/AFIR.py +452 -0
  88. pysisyphus/calculators/AnaPot.py +20 -0
  89. pysisyphus/calculators/AnaPot2.py +48 -0
  90. pysisyphus/calculators/AnaPot3.py +12 -0
  91. pysisyphus/calculators/AnaPot4.py +20 -0
  92. pysisyphus/calculators/AnaPotBase.py +337 -0
  93. pysisyphus/calculators/AnaPotCBM.py +25 -0
  94. pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
  95. pysisyphus/calculators/CFOUR.py +250 -0
  96. pysisyphus/calculators/Calculator.py +844 -0
  97. pysisyphus/calculators/CerjanMiller.py +24 -0
  98. pysisyphus/calculators/Composite.py +123 -0
  99. pysisyphus/calculators/ConicalIntersection.py +171 -0
  100. pysisyphus/calculators/DFTBp.py +430 -0
  101. pysisyphus/calculators/DFTD3.py +66 -0
  102. pysisyphus/calculators/DFTD4.py +84 -0
  103. pysisyphus/calculators/Dalton.py +61 -0
  104. pysisyphus/calculators/Dimer.py +681 -0
  105. pysisyphus/calculators/Dummy.py +20 -0
  106. pysisyphus/calculators/EGO.py +76 -0
  107. pysisyphus/calculators/EnergyMin.py +224 -0
  108. pysisyphus/calculators/ExternalPotential.py +264 -0
  109. pysisyphus/calculators/FakeASE.py +35 -0
  110. pysisyphus/calculators/FourWellAnaPot.py +28 -0
  111. pysisyphus/calculators/FreeEndNEBPot.py +39 -0
  112. pysisyphus/calculators/Gaussian09.py +18 -0
  113. pysisyphus/calculators/Gaussian16.py +726 -0
  114. pysisyphus/calculators/HardSphere.py +159 -0
  115. pysisyphus/calculators/IDPPCalculator.py +49 -0
  116. pysisyphus/calculators/IPIClient.py +133 -0
  117. pysisyphus/calculators/IPIServer.py +234 -0
  118. pysisyphus/calculators/LEPSBase.py +24 -0
  119. pysisyphus/calculators/LEPSExpr.py +139 -0
  120. pysisyphus/calculators/LennardJones.py +80 -0
  121. pysisyphus/calculators/MOPAC.py +219 -0
  122. pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
  123. pysisyphus/calculators/MultiCalc.py +85 -0
  124. pysisyphus/calculators/NFK.py +45 -0
  125. pysisyphus/calculators/OBabel.py +87 -0
  126. pysisyphus/calculators/ONIOMv2.py +1129 -0
  127. pysisyphus/calculators/ORCA.py +893 -0
  128. pysisyphus/calculators/ORCA5.py +6 -0
  129. pysisyphus/calculators/OpenMM.py +88 -0
  130. pysisyphus/calculators/OpenMolcas.py +281 -0
  131. pysisyphus/calculators/OverlapCalculator.py +908 -0
  132. pysisyphus/calculators/Psi4.py +218 -0
  133. pysisyphus/calculators/PyPsi4.py +37 -0
  134. pysisyphus/calculators/PySCF.py +341 -0
  135. pysisyphus/calculators/PyXTB.py +73 -0
  136. pysisyphus/calculators/QCEngine.py +106 -0
  137. pysisyphus/calculators/Rastrigin.py +22 -0
  138. pysisyphus/calculators/Remote.py +76 -0
  139. pysisyphus/calculators/Rosenbrock.py +15 -0
  140. pysisyphus/calculators/SocketCalc.py +97 -0
  141. pysisyphus/calculators/TIP3P.py +111 -0
  142. pysisyphus/calculators/TransTorque.py +161 -0
  143. pysisyphus/calculators/Turbomole.py +965 -0
  144. pysisyphus/calculators/VRIPot.py +37 -0
  145. pysisyphus/calculators/WFOWrapper.py +333 -0
  146. pysisyphus/calculators/WFOWrapper2.py +341 -0
  147. pysisyphus/calculators/XTB.py +418 -0
  148. pysisyphus/calculators/__init__.py +81 -0
  149. pysisyphus/calculators/cosmo_data.py +139 -0
  150. pysisyphus/calculators/parser.py +150 -0
  151. pysisyphus/color.py +19 -0
  152. pysisyphus/config.py +133 -0
  153. pysisyphus/constants.py +65 -0
  154. pysisyphus/cos/AdaptiveNEB.py +230 -0
  155. pysisyphus/cos/ChainOfStates.py +725 -0
  156. pysisyphus/cos/FreeEndNEB.py +25 -0
  157. pysisyphus/cos/FreezingString.py +103 -0
  158. pysisyphus/cos/GrowingChainOfStates.py +71 -0
  159. pysisyphus/cos/GrowingNT.py +309 -0
  160. pysisyphus/cos/GrowingString.py +508 -0
  161. pysisyphus/cos/NEB.py +189 -0
  162. pysisyphus/cos/SimpleZTS.py +64 -0
  163. pysisyphus/cos/__init__.py +22 -0
  164. pysisyphus/cos/stiffness.py +199 -0
  165. pysisyphus/drivers/__init__.py +17 -0
  166. pysisyphus/drivers/afir.py +855 -0
  167. pysisyphus/drivers/barriers.py +271 -0
  168. pysisyphus/drivers/birkholz.py +138 -0
  169. pysisyphus/drivers/cluster.py +318 -0
  170. pysisyphus/drivers/diabatization.py +133 -0
  171. pysisyphus/drivers/merge.py +368 -0
  172. pysisyphus/drivers/merge_mol2.py +322 -0
  173. pysisyphus/drivers/opt.py +375 -0
  174. pysisyphus/drivers/perf.py +91 -0
  175. pysisyphus/drivers/pka.py +52 -0
  176. pysisyphus/drivers/precon_pos_rot.py +669 -0
  177. pysisyphus/drivers/rates.py +480 -0
  178. pysisyphus/drivers/replace.py +219 -0
  179. pysisyphus/drivers/scan.py +212 -0
  180. pysisyphus/drivers/spectrum.py +166 -0
  181. pysisyphus/drivers/thermo.py +31 -0
  182. pysisyphus/dynamics/Gaussian.py +103 -0
  183. pysisyphus/dynamics/__init__.py +20 -0
  184. pysisyphus/dynamics/colvars.py +136 -0
  185. pysisyphus/dynamics/driver.py +297 -0
  186. pysisyphus/dynamics/helpers.py +256 -0
  187. pysisyphus/dynamics/lincs.py +105 -0
  188. pysisyphus/dynamics/mdp.py +364 -0
  189. pysisyphus/dynamics/rattle.py +121 -0
  190. pysisyphus/dynamics/thermostats.py +128 -0
  191. pysisyphus/dynamics/wigner.py +266 -0
  192. pysisyphus/elem_data.py +3473 -0
  193. pysisyphus/exceptions.py +2 -0
  194. pysisyphus/filtertrj.py +69 -0
  195. pysisyphus/helpers.py +623 -0
  196. pysisyphus/helpers_pure.py +649 -0
  197. pysisyphus/init_logging.py +50 -0
  198. pysisyphus/intcoords/Bend.py +69 -0
  199. pysisyphus/intcoords/Bend2.py +25 -0
  200. pysisyphus/intcoords/BondedFragment.py +32 -0
  201. pysisyphus/intcoords/Cartesian.py +41 -0
  202. pysisyphus/intcoords/CartesianCoords.py +140 -0
  203. pysisyphus/intcoords/Coords.py +56 -0
  204. pysisyphus/intcoords/DLC.py +197 -0
  205. pysisyphus/intcoords/DistanceFunction.py +34 -0
  206. pysisyphus/intcoords/DummyImproper.py +70 -0
  207. pysisyphus/intcoords/DummyTorsion.py +72 -0
  208. pysisyphus/intcoords/LinearBend.py +105 -0
  209. pysisyphus/intcoords/LinearDisplacement.py +80 -0
  210. pysisyphus/intcoords/OutOfPlane.py +59 -0
  211. pysisyphus/intcoords/PrimTypes.py +286 -0
  212. pysisyphus/intcoords/Primitive.py +137 -0
  213. pysisyphus/intcoords/RedundantCoords.py +659 -0
  214. pysisyphus/intcoords/RobustTorsion.py +59 -0
  215. pysisyphus/intcoords/Rotation.py +147 -0
  216. pysisyphus/intcoords/Stretch.py +31 -0
  217. pysisyphus/intcoords/Torsion.py +101 -0
  218. pysisyphus/intcoords/Torsion2.py +25 -0
  219. pysisyphus/intcoords/Translation.py +45 -0
  220. pysisyphus/intcoords/__init__.py +61 -0
  221. pysisyphus/intcoords/augment_bonds.py +126 -0
  222. pysisyphus/intcoords/derivatives.py +10512 -0
  223. pysisyphus/intcoords/eval.py +80 -0
  224. pysisyphus/intcoords/exceptions.py +37 -0
  225. pysisyphus/intcoords/findiffs.py +48 -0
  226. pysisyphus/intcoords/generate_derivatives.py +414 -0
  227. pysisyphus/intcoords/helpers.py +235 -0
  228. pysisyphus/intcoords/logging_conf.py +10 -0
  229. pysisyphus/intcoords/mp_derivatives.py +10836 -0
  230. pysisyphus/intcoords/setup.py +962 -0
  231. pysisyphus/intcoords/setup_fast.py +176 -0
  232. pysisyphus/intcoords/update.py +272 -0
  233. pysisyphus/intcoords/valid.py +89 -0
  234. pysisyphus/interpolate/Geodesic.py +93 -0
  235. pysisyphus/interpolate/IDPP.py +55 -0
  236. pysisyphus/interpolate/Interpolator.py +116 -0
  237. pysisyphus/interpolate/LST.py +70 -0
  238. pysisyphus/interpolate/Redund.py +152 -0
  239. pysisyphus/interpolate/__init__.py +9 -0
  240. pysisyphus/interpolate/helpers.py +34 -0
  241. pysisyphus/io/__init__.py +22 -0
  242. pysisyphus/io/aomix.py +178 -0
  243. pysisyphus/io/cjson.py +24 -0
  244. pysisyphus/io/crd.py +101 -0
  245. pysisyphus/io/cube.py +220 -0
  246. pysisyphus/io/fchk.py +184 -0
  247. pysisyphus/io/hdf5.py +49 -0
  248. pysisyphus/io/hessian.py +72 -0
  249. pysisyphus/io/mol2.py +146 -0
  250. pysisyphus/io/molden.py +293 -0
  251. pysisyphus/io/orca.py +189 -0
  252. pysisyphus/io/pdb.py +269 -0
  253. pysisyphus/io/psf.py +79 -0
  254. pysisyphus/io/pubchem.py +31 -0
  255. pysisyphus/io/qcschema.py +34 -0
  256. pysisyphus/io/sdf.py +29 -0
  257. pysisyphus/io/xyz.py +61 -0
  258. pysisyphus/io/zmat.py +175 -0
  259. pysisyphus/irc/DWI.py +108 -0
  260. pysisyphus/irc/DampedVelocityVerlet.py +134 -0
  261. pysisyphus/irc/Euler.py +22 -0
  262. pysisyphus/irc/EulerPC.py +345 -0
  263. pysisyphus/irc/GonzalezSchlegel.py +187 -0
  264. pysisyphus/irc/IMKMod.py +164 -0
  265. pysisyphus/irc/IRC.py +878 -0
  266. pysisyphus/irc/IRCDummy.py +10 -0
  267. pysisyphus/irc/Instanton.py +307 -0
  268. pysisyphus/irc/LQA.py +53 -0
  269. pysisyphus/irc/ModeKill.py +136 -0
  270. pysisyphus/irc/ParamPlot.py +53 -0
  271. pysisyphus/irc/RK4.py +36 -0
  272. pysisyphus/irc/__init__.py +31 -0
  273. pysisyphus/irc/initial_displ.py +219 -0
  274. pysisyphus/linalg.py +411 -0
  275. pysisyphus/line_searches/Backtracking.py +88 -0
  276. pysisyphus/line_searches/HagerZhang.py +184 -0
  277. pysisyphus/line_searches/LineSearch.py +232 -0
  278. pysisyphus/line_searches/StrongWolfe.py +108 -0
  279. pysisyphus/line_searches/__init__.py +9 -0
  280. pysisyphus/line_searches/interpol.py +15 -0
  281. pysisyphus/modefollow/NormalMode.py +40 -0
  282. pysisyphus/modefollow/__init__.py +10 -0
  283. pysisyphus/modefollow/davidson.py +199 -0
  284. pysisyphus/modefollow/lanczos.py +95 -0
  285. pysisyphus/optimizers/BFGS.py +99 -0
  286. pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
  287. pysisyphus/optimizers/ConjugateGradient.py +98 -0
  288. pysisyphus/optimizers/CubicNewton.py +75 -0
  289. pysisyphus/optimizers/FIRE.py +113 -0
  290. pysisyphus/optimizers/HessianOptimizer.py +1176 -0
  291. pysisyphus/optimizers/LBFGS.py +228 -0
  292. pysisyphus/optimizers/LayerOpt.py +411 -0
  293. pysisyphus/optimizers/MicroOptimizer.py +169 -0
  294. pysisyphus/optimizers/NCOptimizer.py +90 -0
  295. pysisyphus/optimizers/Optimizer.py +1084 -0
  296. pysisyphus/optimizers/PreconLBFGS.py +260 -0
  297. pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
  298. pysisyphus/optimizers/QuickMin.py +74 -0
  299. pysisyphus/optimizers/RFOptimizer.py +181 -0
  300. pysisyphus/optimizers/RSA.py +99 -0
  301. pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
  302. pysisyphus/optimizers/SteepestDescent.py +23 -0
  303. pysisyphus/optimizers/StringOptimizer.py +173 -0
  304. pysisyphus/optimizers/__init__.py +41 -0
  305. pysisyphus/optimizers/closures.py +301 -0
  306. pysisyphus/optimizers/cls_map.py +58 -0
  307. pysisyphus/optimizers/exceptions.py +6 -0
  308. pysisyphus/optimizers/gdiis.py +280 -0
  309. pysisyphus/optimizers/guess_hessians.py +311 -0
  310. pysisyphus/optimizers/hessian_updates.py +355 -0
  311. pysisyphus/optimizers/poly_fit.py +285 -0
  312. pysisyphus/optimizers/precon.py +153 -0
  313. pysisyphus/optimizers/restrict_step.py +24 -0
  314. pysisyphus/pack.py +172 -0
  315. pysisyphus/peakdetect.py +948 -0
  316. pysisyphus/plot.py +1031 -0
  317. pysisyphus/run.py +2106 -0
  318. pysisyphus/socket_helper.py +74 -0
  319. pysisyphus/stocastic/FragmentKick.py +132 -0
  320. pysisyphus/stocastic/Kick.py +81 -0
  321. pysisyphus/stocastic/Pipeline.py +303 -0
  322. pysisyphus/stocastic/__init__.py +21 -0
  323. pysisyphus/stocastic/align.py +127 -0
  324. pysisyphus/testing.py +96 -0
  325. pysisyphus/thermo.py +156 -0
  326. pysisyphus/trj.py +824 -0
  327. pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
  328. pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
  329. pysisyphus/tsoptimizers/TRIM.py +59 -0
  330. pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
  331. pysisyphus/tsoptimizers/__init__.py +23 -0
  332. pysisyphus/wavefunction/Basis.py +239 -0
  333. pysisyphus/wavefunction/DIIS.py +76 -0
  334. pysisyphus/wavefunction/__init__.py +25 -0
  335. pysisyphus/wavefunction/build_ext.py +42 -0
  336. pysisyphus/wavefunction/cart2sph.py +190 -0
  337. pysisyphus/wavefunction/diabatization.py +304 -0
  338. pysisyphus/wavefunction/excited_states.py +435 -0
  339. pysisyphus/wavefunction/gen_ints.py +1811 -0
  340. pysisyphus/wavefunction/helpers.py +104 -0
  341. pysisyphus/wavefunction/ints/__init__.py +0 -0
  342. pysisyphus/wavefunction/ints/boys.py +193 -0
  343. pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
  344. pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
  345. pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
  346. pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
  347. pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
  348. pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
  349. pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
  350. pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
  351. pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
  352. pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
  353. pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
  354. pysisyphus/wavefunction/localization.py +458 -0
  355. pysisyphus/wavefunction/multipole.py +159 -0
  356. pysisyphus/wavefunction/normalization.py +36 -0
  357. pysisyphus/wavefunction/pop_analysis.py +134 -0
  358. pysisyphus/wavefunction/shells.py +1171 -0
  359. pysisyphus/wavefunction/wavefunction.py +504 -0
  360. pysisyphus/wrapper/__init__.py +11 -0
  361. pysisyphus/wrapper/exceptions.py +2 -0
  362. pysisyphus/wrapper/jmol.py +120 -0
  363. pysisyphus/wrapper/mwfn.py +169 -0
  364. pysisyphus/wrapper/packmol.py +71 -0
  365. pysisyphus/xyzloader.py +168 -0
  366. pysisyphus/yaml_mods.py +45 -0
  367. thermoanalysis/LICENSE +674 -0
  368. thermoanalysis/QCData.py +244 -0
  369. thermoanalysis/__init__.py +0 -0
  370. thermoanalysis/config.py +3 -0
  371. thermoanalysis/constants.py +20 -0
  372. thermoanalysis/thermo.py +1011 -0
@@ -0,0 +1,311 @@
1
+ # [1] https://www.sciencedirect.com/science/article/pii/000926149500646L
2
+ # Lindh, 1995
3
+ # [2] https://pubs.acs.org/doi/pdf/10.1021/j100203a036
4
+ # Fischer, Almlöf, 1992
5
+ # [3] https://onlinelibrary.wiley.com/doi/full/10.1002/qua.21049
6
+ # Swart, Bickelhaupt, 2006
7
+ # [4] http://dx.doi.org/10.1063/1.4952956
8
+ # Lee-Ping Wang 2016
9
+
10
+ from math import exp
11
+ import itertools as it
12
+ from typing import Literal, Optional
13
+
14
+ import h5py
15
+ import numpy as np
16
+ from numpy.typing import ArrayLike
17
+ from scipy.spatial.distance import pdist, squareform
18
+
19
+ from pysisyphus.Geometry import Geometry
20
+ from pysisyphus.intcoords.PrimTypes import PrimTypes as PT, Bonds, Bends, Dihedrals
21
+ from pysisyphus.intcoords.setup import get_pair_covalent_radii
22
+ from pysisyphus.io.hessian import save_hessian
23
+
24
+ import torch
25
+
26
+ HessInit = Literal[
27
+ "calc", "unit", "fischer", "lindh", "simple", "swart", "xtb", "xtb1", "xtbff"
28
+ ]
29
+
30
+
31
+ # See [4], last sentences of III.
32
+ CART_F = 0.05
33
+ # Default force constants
34
+ DEFAULT_F = {
35
+ PT.BOND: 0.5,
36
+ PT.AUX_BOND: 0.1,
37
+ PT.HYDROGEN_BOND: 0.1,
38
+ PT.INTERFRAG_BOND: 0.1,
39
+ PT.AUX_INTERFRAG_BOND: 0.05,
40
+ PT.BEND: 0.2,
41
+ PT.LINEAR_BEND: 0.1,
42
+ PT.LINEAR_BEND_COMPLEMENT: 0.1,
43
+ PT.PROPER_DIHEDRAL: 0.1,
44
+ PT.IMPROPER_DIHEDRAL: 0.1,
45
+ PT.OUT_OF_PLANE: 0.1,
46
+ PT.LINEAR_DISPLACEMENT: 0.1,
47
+ PT.LINEAR_DISPLACEMENT_COMPLEMENT: 0.1,
48
+ PT.TRANSLATION_X: CART_F,
49
+ PT.TRANSLATION_Y: CART_F,
50
+ PT.TRANSLATION_Z: CART_F,
51
+ PT.ROTATION_A: CART_F,
52
+ PT.ROTATION_B: CART_F,
53
+ PT.ROTATION_C: CART_F,
54
+ PT.CARTESIAN_X: CART_F,
55
+ PT.CARTESIAN_Y: CART_F,
56
+ PT.CARTESIAN_Z: CART_F,
57
+ PT.BONDED_FRAGMENT: CART_F,
58
+ PT.DUMMY_TORSION: 0.1,
59
+ PT.DISTANCE_FUNCTION: 0.1,
60
+ PT.DUMMY_IMPROPER: 0.1,
61
+ }
62
+
63
+
64
+ def simple_guess(geom):
65
+ """Default force constants."""
66
+ h_diag = [DEFAULT_F[type_] for type_, *_ in geom.internal.typed_prims]
67
+ return np.diag(h_diag)
68
+
69
+
70
+ def improved_guess(geom, bond_func, bend_func, dihedral_func):
71
+ H_guess = simple_guess(geom)
72
+ for i, (pt, *indices) in enumerate(geom.internal.typed_prims):
73
+ if pt in Bonds:
74
+ f_func = bond_func
75
+ elif pt in Bends:
76
+ f_func = bend_func
77
+ elif pt in Dihedrals:
78
+ f_func = dihedral_func
79
+ else:
80
+ continue
81
+ try:
82
+ new_f = f_func(indices)
83
+ except ValueError:
84
+ new_f = DEFAULT_F[pt]
85
+ H_guess[i, i] = new_f
86
+ return H_guess
87
+
88
+
89
+ def fischer_guess(geom):
90
+ cdm = pdist(geom.coords3d)
91
+ pair_cov_radii = get_pair_covalent_radii(geom.atoms)
92
+
93
+ # For the dihedral force constants we also have to count the number
94
+ # of bonds formed with the centrals atoms of the dihedral.
95
+ central_atoms = [inds[1:3] for inds in geom.internal.dihedral_atom_indices]
96
+ bond_factor = geom.internal.bond_factor
97
+ bond_mat = squareform(cdm <= (pair_cov_radii * bond_factor))
98
+ tors_atom_bonds = dict()
99
+ for a, b in central_atoms:
100
+ # Substract 2, as we don't want the bond between a and b,
101
+ # but this bond will be in both rows of the bond_mat.
102
+ bond_sum = bond_mat[a].sum() + bond_mat[b].sum() - 2
103
+ tors_atom_bonds[(a, b)] = bond_sum
104
+
105
+ dist_mat = squareform(cdm)
106
+ pair_cov_radii_mat = squareform(pair_cov_radii)
107
+
108
+ def h_bond(indices):
109
+ a, b = indices[:2]
110
+ r_ab = dist_mat[a, b]
111
+ r_ab_cov = pair_cov_radii_mat[a, b]
112
+ return 0.3601 * exp(-1.944 * (r_ab - r_ab_cov))
113
+
114
+ def h_bend(indices):
115
+ b, a, c = indices
116
+ r_ab = dist_mat[a, b]
117
+ r_ac = dist_mat[a, c]
118
+ r_ab_cov = pair_cov_radii_mat[a, b]
119
+ r_ac_cov = pair_cov_radii_mat[a, c]
120
+ return 0.089 + 0.11 / (r_ab_cov * r_ac_cov) ** (-0.42) * exp(
121
+ -0.44 * (r_ab + r_ac - r_ab_cov - r_ac_cov)
122
+ )
123
+
124
+ def h_dihedral(indices):
125
+ c, a, b, d = indices
126
+ r_ab = dist_mat[a, b]
127
+ r_ab_cov = pair_cov_radii_mat[a, b]
128
+ bond_sum = max(tors_atom_bonds[(a, b)], 0)
129
+ return 0.0015 + 14.0 * bond_sum ** 0.57 / (r_ab * r_ab_cov) ** 4.0 * exp(
130
+ -2.85 * (r_ab - r_ab_cov)
131
+ )
132
+
133
+ H = improved_guess(
134
+ geom, bond_func=h_bond, bend_func=h_bend, dihedral_func=h_dihedral
135
+ )
136
+ return H
137
+
138
+
139
+ def lindh_style_guess(geom, ks, rhos):
140
+ """Approximate force constants according to Lindh.[1]
141
+
142
+ Bonds: k_ij = k_r * rho_ij
143
+ Bends: k_ijk = k_b * rho_ij * rho_jk
144
+ Dihedrals: k_ijkl = k_d * rho_ij * rho_jk * rho_kl
145
+ """
146
+
147
+ def k_func(indices):
148
+ rho_product = 1
149
+ inds_len = len(indices)
150
+ for i, ind in enumerate(indices[:-1], 1):
151
+ i1, i2 = ind, indices[i]
152
+ rho_product *= rhos[i1, i2]
153
+ k = ks[inds_len] * rho_product
154
+ return k
155
+
156
+ H = improved_guess(geom, bond_func=k_func, bend_func=k_func, dihedral_func=k_func)
157
+ return H
158
+
159
+
160
+ def get_lindh_alpha(atom1, atom2):
161
+ first_period = "h", "he"
162
+ if (atom1 in first_period) and (atom2 in first_period):
163
+ return 1.0
164
+ elif (atom1 in first_period) or (atom2 in first_period):
165
+ return 0.3949
166
+ else:
167
+ return 0.28
168
+
169
+
170
+ def lindh_guess(geom):
171
+ """Slightly modified Lindh model hessian as described in [1].
172
+
173
+ Instead of using the tabulated r_ref,ij values from [1] we will use the
174
+ 'true' covalent radii as pyberny. The tabulated r_ref,ij value for two
175
+ carbons (2nd period) is 2.87 Bohr. Carbons covalent radius is ~ 1.44 Bohr,
176
+ so 2*1.44 Bohr = 2.88 Bohr which fits nicely with the tabulate value.
177
+ If values for elements > 3rd are requested the alpha values for the 3rd
178
+ period will be (re)used.
179
+ """
180
+
181
+ atoms = [a.lower() for a in geom.atoms]
182
+ alphas = [get_lindh_alpha(a1, a2) for a1, a2 in it.combinations(atoms, 2)]
183
+ pair_cov_radii = get_pair_covalent_radii(geom.atoms)
184
+ cdm = pdist(geom.coords3d)
185
+ rhos = squareform(np.exp(alphas * (pair_cov_radii ** 2 - cdm ** 2)))
186
+
187
+ ks = {
188
+ 2: 0.45, # Stretches/bonds
189
+ 3: 0.15, # Bends/angles
190
+ 4: 0.005, # Torsions/dihedrals
191
+ }
192
+ return lindh_style_guess(geom, ks, rhos)
193
+
194
+
195
+ def swart_guess(geom):
196
+ pair_cov_radii = get_pair_covalent_radii(geom.atoms)
197
+ cdm = pdist(geom.coords3d)
198
+ rhos = squareform(np.exp(-cdm / pair_cov_radii + 1))
199
+ ks = {
200
+ 2: 0.35,
201
+ 3: 0.15,
202
+ 4: 0.005,
203
+ }
204
+ return lindh_style_guess(geom, ks, rhos)
205
+
206
+
207
+ def xtb_hessian(geom, gfn=None):
208
+ from pysisyphus.calculators.XTB import XTB # Lazy import
209
+ calc = geom.calculator
210
+ xtb_kwargs = {"charge": calc.charge, "mult": calc.mult, "pal": calc.pal}
211
+ if gfn is not None:
212
+ xtb_kwargs["gfn"] = gfn
213
+ xtb_calc = XTB(**xtb_kwargs)
214
+ geom_ = geom.copy()
215
+ geom_.set_calculator(xtb_calc)
216
+ return geom_.hessian
217
+
218
+
219
+ def get_guess_hessian(
220
+ geometry: Geometry,
221
+ hessian_init: HessInit,
222
+ int_gradient: Optional[ArrayLike] = None,
223
+ cart_gradient: Optional[ArrayLike] = None,
224
+ h5_fn: Optional[str] = None,
225
+ ):
226
+ """Obtain/calculate (model) Hessian.
227
+
228
+ For hessian_init="calc" the Hessian will be in the coord_type
229
+ of the geometry, otherwise a Hessian in primitive internals will
230
+ be returned.
231
+ """
232
+ model_hessian = hessian_init in ("fischer", "lindh", "simple", "swart")
233
+ target_coord_type = geometry.coord_type
234
+
235
+ # Recreate geometry with internal coordinates if needed
236
+ if model_hessian and (geometry.coord_type == "cart"):
237
+ geometry = geometry.copy(coord_type="redund")
238
+
239
+ hess_funcs = {
240
+ # Calculate true hessian
241
+ "calc": lambda: (geometry.hessian, "calculated exact"),
242
+ # Unit hessian
243
+ "unit": lambda: (np.eye(geometry.coords.size), "unit"),
244
+ # Fischer model hessian
245
+ "fischer": lambda: (fischer_guess(geometry), "Fischer"),
246
+ # Lindh model hessian
247
+ "lindh": lambda: (lindh_guess(geometry), "Lindh"),
248
+ # Simple (0.5, 0.2, 0.1) model hessian
249
+ "simple": lambda: (simple_guess(geometry), "simple"),
250
+ # Swart model hessian
251
+ "swart": lambda: (swart_guess(geometry), "Swart"),
252
+ # XTB hessian using GFN-2
253
+ "xtb": lambda: (xtb_hessian(geometry, gfn=2), "GFN2-XTB"),
254
+ # XTB hessian using GFN-1
255
+ "xtb1": lambda: (xtb_hessian(geometry, gfn=1), "GFN1-XTB"),
256
+ # XTB hessian using GFN-FF
257
+ "xtbff": lambda: (xtb_hessian(geometry, gfn="ff"), "GFN-FF"),
258
+ }
259
+ try:
260
+ H, hess_str = hess_funcs[hessian_init]()
261
+ except KeyError:
262
+ # Only cartesian hessians can be loaded
263
+ if str(hessian_init).endswith(".h5"):
264
+ with h5py.File(hessian_init, "r") as handle:
265
+ cart_hessian = handle["hessian"][:]
266
+ # CFOUR Hessians in the form we need are always named "FCMFINAL"
267
+ elif str(hessian_init).endswith("FCMFINAL"):
268
+ raw_cart_hessian = np.loadtxt(hessian_init, skiprows=1)
269
+ nindices = int(np.sqrt(raw_cart_hessian.size))
270
+ cart_hessian = raw_cart_hessian.reshape((nindices, nindices), order='C')
271
+ else:
272
+ cart_hessian = np.loadtxt(hessian_init)
273
+ geometry.cart_hessian = cart_hessian
274
+ # Use the previously set hessian in whatever coordinate system we
275
+ # actually employ.
276
+ H = geometry.hessian
277
+ hess_str = "saved"
278
+
279
+ # if (h5_fn is not None) and (hessian_init == "calc"):
280
+ # save_hessian(h5_fn, geometry)
281
+
282
+ if model_hessian and target_coord_type == "cart":
283
+ if cart_gradient is not None:
284
+ int_gradient = geometry.internal.transform_forces(cart_gradient)
285
+ H = geometry.internal.backtransform_hessian(H, int_gradient=int_gradient)
286
+
287
+ return H, hess_str
288
+
289
+
290
+ def ts_hessian(hessian, coord_inds, damp=0.25):
291
+ """According to [3]"""
292
+
293
+ inds = list(coord_inds)
294
+
295
+ # Use a copy as diag returns only a read-only view
296
+ diag = np.diag(hessian).copy()
297
+
298
+ # Reverse sign of reaction coordinates and damp them
299
+ diag[inds] = -1 * damp * diag[inds]
300
+ ts_hess = np.diag(diag)
301
+
302
+ # Set off-diagonal elements
303
+ for i, j in it.combinations(inds, 2):
304
+ # fi = force_constants[i]
305
+ # fj = force_constants[j]
306
+ fi = diag[i]
307
+ fj = diag[j]
308
+ f = -((2 * fi * fj) ** 0.5)
309
+ ts_hess[i, j] = f
310
+ ts_hess[j, i] = f
311
+ return ts_hess
@@ -0,0 +1,355 @@
1
+ # [1] https://link.springer.com/article/10.1007/s00214-016-1847-3
2
+ # Birkholz, 2016
3
+ # [2] Geometry optimization in Cartesian coordinates: Constrained optimization
4
+ # Baker, 1992
5
+ # [3] https://epubs.siam.org/doi/pdf/10.1137/S1052623496306450
6
+ # BFGS WITH UPDATE SKIPPING AND VARYING MEMORY
7
+ # Kolda, 1998
8
+ # [4] https://link.springer.com/article/10.1186/1029-242X-2012-241
9
+ # New cautious BFGS algorithm based on modified Armijo-type line search
10
+ # Wan, 2012
11
+ # [5] Numerical optimization, 2nd ed.
12
+ # Nocedal, Wright
13
+ # [6] https://arxiv.org/abs/2006.08877
14
+ # Goldfarb, 2020
15
+ # [7] https://pubs.acs.org/doi/10.1021/acs.jctc.9b00869
16
+ # Hermes, Zádor, 2019
17
+ # [8] https://doi.org/10.1002/(SICI)1096-987X(199802)19:3<349::AID-JCC8>3.0.CO;2-T
18
+ # Bofill, 1998
19
+ # [9] http://dx.doi.org/10.1016/S0166-1280(02)00209-9
20
+ # Bungay, Poirier
21
+
22
+
23
+ import numpy as np
24
+
25
+ from pysisyphus.optimizers.closures import bfgs_multiply
26
+
27
+ import torch
28
+
29
+
30
+ def _outer(a, b):
31
+ """Torch based outer product preserving dtype and device."""
32
+ if isinstance(a, torch.Tensor):
33
+ b = torch.as_tensor(b, dtype=a.dtype, device=a.device)
34
+ return torch.outer(a, b)
35
+ return np.outer(a, b)
36
+
37
+ def _dot(a, b):
38
+ """Torch based dot product preserving dtype and device."""
39
+ if isinstance(a, torch.Tensor):
40
+ b = torch.as_tensor(b, dtype=a.dtype, device=a.device)
41
+ return torch.dot(a, b)
42
+ return np.dot(a, b)
43
+
44
+
45
+ def bfgs_update(H, dx, dg):
46
+ """Standard BFGS Hessian update using torch."""
47
+ if isinstance(H, torch.Tensor):
48
+ dx = torch.as_tensor(dx, dtype=H.dtype, device=H.device)
49
+ dg = torch.as_tensor(dg, dtype=H.dtype, device=H.device)
50
+
51
+ Hdx = H @ dx
52
+ first_term = _outer(dg, dg) / _dot(dg, dx)
53
+ second_term = _outer(Hdx, Hdx) / _dot(dx, Hdx)
54
+ update = first_term - second_term
55
+ if isinstance(H, torch.Tensor):
56
+ return update.to(dtype=H.dtype, device=H.device), "BFGS"
57
+ return update, "BFGS"
58
+
59
+
60
+ def damped_bfgs_update(H, dx, dg):
61
+ # if isinstance(H, torch.Tensor):
62
+ # dx = torch.as_tensor(dx, dtype=H.dtype, device=H.device)
63
+ # dg = torch.as_tensor(dg, dtype=H.dtype, device=H.device)
64
+
65
+ # dxdg = _dot(dx, dg)
66
+ # dxHdx = _dot(dx, H @ dx)
67
+ # theta = 1.0
68
+ # if dxdg < 0.2 * dxHdx:
69
+ # theta = 0.8 * dxHdx / (dxHdx - dxdg)
70
+ # r = theta * dg + (1 - theta) * (H @ dx)
71
+
72
+ # first_term = _outer(r, r) / _dot(r, dx)
73
+ # second_term = (H @ _outer(dx, dx) @ H) / dxHdx
74
+ # return first_term - second_term, "damped BFGS"
75
+
76
+ """See [5]"""
77
+ dxdg = dx.dot(dg)
78
+ dxHdx = dx.dot(H).dot(dx)
79
+ theta = 1
80
+ if dxdg < 0.2 * dxHdx:
81
+ theta = 0.8 * dxHdx / (dxHdx - dxdg)
82
+ r = theta * dg + (1 - theta) * H.dot(dx)
83
+
84
+ first_term = np.outer(r, r) / r.dot(dx)
85
+ second_term = H.dot(np.outer(dx, dx)).dot(H) / dxHdx
86
+ return first_term - second_term, "damped BFGS"
87
+
88
+
89
+ def double_damp(
90
+ s, y, H=None, s_list=None, y_list=None, mu_1=0.2, mu_2=0.2, logger=None
91
+ ):
92
+ """Double damped step 's' and gradient differences 'y'.
93
+
94
+ H is the inverse Hessian!
95
+ See [6]. Potentially updates s and y. y is only
96
+ updated if mu_2 is not None.
97
+
98
+ Parameters
99
+ ----------
100
+ s : np.array, shape (N, ), floats
101
+ Coordiante differences/step.
102
+ y : np.array, shape (N, ), floats
103
+ Gradient differences
104
+ H : np.array, shape (N, N), floats, optional
105
+ Inverse Hessian.
106
+ s_list : list of nd.array, shape (K, N), optional
107
+ List of K previous steps. If no H is supplied and prev_ys is given
108
+ the matrix-vector product Hy will be calculated through the
109
+ two-loop LBFGS-recursion.
110
+ y_list : list of nd.array, shape (K, N), optional
111
+ List of K previous gradient differences. See s_list.
112
+ mu_1 : float, optional
113
+ Parameter for 's' damping.
114
+ mu_2 : float, optional
115
+ Parameter for 'y' damping.
116
+ logger : logging.Logger, optional
117
+ Logger to be used.
118
+
119
+ Returns
120
+ -------
121
+ s : np.array, shape (N, ), floats
122
+ Damped coordiante differences/step.
123
+ y : np.array, shape (N, ), floats
124
+ Damped gradient differences
125
+ """
126
+ # if isinstance(H, torch.Tensor):
127
+ # s = torch.as_tensor(s, dtype=H.dtype, device=H.device)
128
+ # y = torch.as_tensor(y, dtype=H.dtype, device=H.device)
129
+
130
+ # sy = _dot(s, y)
131
+ # Hy = H @ y
132
+ # yHy = _dot(y, Hy)
133
+
134
+ # theta_1 = 1.0
135
+ # if sy < mu_1 * yHy:
136
+ # theta_1 = (1 - mu_1) * yHy / (yHy - sy)
137
+ # s = theta_1 * s + (1 - theta_1) * Hy
138
+
139
+ # # --- 二重 damping ---
140
+ # sy = _dot(s, y)
141
+ # ss = _dot(s, s)
142
+ # theta_2 = 1.0
143
+ # if mu_2 is not None and sy < mu_2 * ss:
144
+ # theta_2 = (1 - mu_2) * ss / (ss - sy)
145
+ # y = theta_2 * y + (1 - theta_2) * s
146
+
147
+ # return s, y
148
+
149
+ sy = s.dot(y)
150
+ # Calculate Hy directly
151
+ if H is not None:
152
+ Hy = H.dot(y)
153
+ # Calculate Hy via BFGS_multiply as in LBFGS
154
+ else:
155
+ Hy = bfgs_multiply(s_list, y_list, y, logger=logger)
156
+ yHy = y.dot(Hy)
157
+
158
+ theta_1 = 1
159
+ damped_s = ""
160
+ if sy < mu_1 * yHy:
161
+ theta_1 = (1 - mu_1) * yHy / (yHy - sy)
162
+ s = theta_1 * s + (1 - theta_1) * Hy
163
+ if theta_1 < 1.0:
164
+ damped_s = ", damped s"
165
+ msg = f"damped BFGS\n\ttheta_1={theta_1:.4f} {damped_s}"
166
+
167
+ # Double damping
168
+ damped_y = ""
169
+ if mu_2 is not None:
170
+ sy = s.dot(y)
171
+ ss = s.dot(s)
172
+ theta_2 = 1
173
+ if sy < mu_2 * ss:
174
+ theta_2 = (1 - mu_2) * ss / (ss - sy)
175
+ y = theta_2 * y + (1 - theta_2) * s
176
+ if theta_2 < 1.0:
177
+ damped_y = ", damped y"
178
+ msg = "double " + msg + f"\n\ttheta_2={theta_2:.4f} {damped_y}"
179
+
180
+ if logger is not None:
181
+ logger.debug(msg.capitalize())
182
+
183
+ return s, y
184
+
185
+
186
+ def sr1_update(z, dx):
187
+ """Symmetric rank-one update."""
188
+ if isinstance(z, torch.Tensor):
189
+ dx = torch.as_tensor(dx, dtype=z.dtype, device=z.device)
190
+ update = _outer(z, z) / _dot(z, dx)
191
+ if isinstance(z, torch.Tensor):
192
+ return update.to(dtype=z.dtype, device=z.device), "SR1"
193
+ return update, "SR1"
194
+
195
+
196
+ def psb_update(z, dx):
197
+ """Powell Symmetric Broyden update."""
198
+ if isinstance(z, torch.Tensor):
199
+ dx = torch.as_tensor(dx, dtype=z.dtype, device=z.device)
200
+ first = (_outer(dx, z) + _outer(z, dx)) / _dot(dx, dx)
201
+ second = _dot(dx, z) * _outer(dx, dx) / (_dot(dx, dx) ** 2)
202
+ update = first - second
203
+ if isinstance(z, torch.Tensor):
204
+ return update.to(dtype=z.dtype, device=z.device), "PSB"
205
+ return update, "PSB"
206
+
207
+
208
+ def flowchart_update(H, dx, dg):
209
+ # if isinstance(H, torch.Tensor):
210
+ # dx = torch.as_tensor(dx, dtype=H.dtype, device=H.device)
211
+ # dg = torch.as_tensor(dg, dtype=H.dtype, device=H.device)
212
+
213
+ # z = dg - H @ dx
214
+ # sr1_quot = _dot(z, dx) / (z.norm() * dx.norm())
215
+ # bfgs_quot = _dot(dg, dx) / (dg.norm() * dx.norm())
216
+
217
+ # if sr1_quot < -0.1:
218
+ # update, key = sr1_update(z, dx)
219
+ # elif bfgs_quot > 0.1:
220
+ # update, key = bfgs_update(H, dx, dg)
221
+ # else:
222
+ # update, key = psb_update(z, dx)
223
+ # return update, key
224
+
225
+ # See [1], Sec. 2, equations 1 to 3
226
+ z = dg - H.dot(dx)
227
+ sr1_quot = z.dot(dx) / (np.linalg.norm(z) * np.linalg.norm(dx))
228
+ bfgs_quot = dg.dot(dx) / (np.linalg.norm(dg) * np.linalg.norm(dx))
229
+ if sr1_quot < -0.1:
230
+ update, key = sr1_update(z, dx)
231
+ elif bfgs_quot > 0.1:
232
+ update, key = bfgs_update(H, dx, dg)
233
+ else:
234
+ update, key = psb_update(z, dx)
235
+ return update, key
236
+
237
+
238
+ def mod_flowchart_update(H, dx, dg):
239
+ # This version seems to work too ... at least for minimizations
240
+ # starting from a geometry near a transition state. Interesing.
241
+ z = dg - H.dot(dx)
242
+ quot = z.dot(dx) / (np.linalg.norm(z) * np.linalg.norm(dx))
243
+ if quot < -0.1:
244
+ update, key = bfgs_update(H, dx, dg)
245
+ elif quot > 0.1:
246
+ update, key = sr1_update(z, dx)
247
+ else:
248
+ update, key = psb_update(z, dx)
249
+ return update, key
250
+
251
+
252
+ def bofill_update(H, dx, dg):
253
+ """Bofill's combination of SR1 and PSB updates."""
254
+ if isinstance(H, torch.Tensor):
255
+ dx = torch.as_tensor(dx, dtype=H.dtype, device=H.device)
256
+ dg = torch.as_tensor(dg, dtype=H.dtype, device=H.device)
257
+
258
+ z = dg - H @ dx
259
+ sr1 = sr1_update(z, dx)[0]
260
+ psb = psb_update(z, dx)[0]
261
+ mix = (_dot(z, dx) ** 2) / (_dot(z, z) * _dot(dx, dx))
262
+ update = mix * sr1 + (1 - mix) * psb
263
+ if isinstance(H, torch.Tensor):
264
+ return update.to(dtype=H.dtype, device=H.device), "Bofill"
265
+ return update, "Bofill"
266
+
267
+
268
+ def ts_bfgs_update(H, dx, dg):
269
+ # if isinstance(H, torch.Tensor):
270
+ # dx = torch.as_tensor(dx, dtype=H.dtype, device=H.device)[:, None]
271
+ # dg = torch.as_tensor(dg, dtype=H.dtype, device=H.device)[:, None]
272
+ # j = dg - H @ dx
273
+ # jdx = (j.T @ dx)
274
+
275
+ # w, v = torch.linalg.eigh(H)
276
+ # Hdx = torch.abs(w) * (v @ (v.T @ dx))
277
+ # M = dg @ dg.T + Hdx @ Hdx.T
278
+ # dxTM = dx.T @ M
279
+ # u = torch.linalg.solve(dxTM @ dx, dxTM).T
280
+ # juT = j @ u.T
281
+ # update = juT + juT.T - jdx * (u @ u.T)
282
+ # return update, "TS-BFGS"
283
+
284
+ """As described in [7]"""
285
+ dx = dx[:, None]
286
+ dg = dg[:, None]
287
+ j = dg - H @ dx
288
+ jdx = j.T @ dx
289
+ # Diagonalize Hessian, to construct positive definite version of it
290
+ w, v = np.linalg.eigh(H)
291
+ Hdx = np.abs(w) * v @ (v.T @ dx)
292
+ M = dg @ dg.T + Hdx @ Hdx.T
293
+ dxTM = dx.T @ M
294
+ u = np.linalg.solve(dxTM @ dx, dxTM).T
295
+ juT = j @ u.T
296
+ ts_bfgs_update = juT + juT.T - jdx * u @ u.T
297
+ return ts_bfgs_update, "TS-BFGS"
298
+
299
+
300
+ def ts_bfgs_update_org(H, dx, dg):
301
+ """Do not use! Implemented as described in the 1998 bofill paper [8].
302
+
303
+ This does not seem to work too well."""
304
+ dx = dx[:, None]
305
+ dg = dg[:, None]
306
+ u = dg
307
+ j = H @ dx
308
+ j = u - j
309
+ # jTdx = float(j.T @ dx)
310
+ jTdx = j.T @ dx
311
+ # dxTdx = float(dx.T @ dx)
312
+ dxTdx = dx.T @ dx
313
+ # jTj = float(j.T @ j)
314
+ jTj = j.T @ j
315
+ phi = jTdx ** 2 / dxTdx / jTj
316
+ u = phi * dg * dg.T @ dx
317
+ w, v = np.linalg.eigh(H)
318
+ u = u + (1 - phi) * np.abs(w) * v @ (v.T @ dx)
319
+ u = u / (u.T @ dx)
320
+ juT = j @ u.T
321
+ ts_bfgs_update = juT + juT.T - jTdx * u @ u.T
322
+ return ts_bfgs_update, "TS-BFGS"
323
+
324
+
325
+ def ts_bfgs_update_revised(H, dx, dg):
326
+ """TS-BFGS update as described in [9].
327
+
328
+ Better than the original formula of Bofill, worse than the implementation
329
+ in [7]. a is caluclated as described in the footnote 1 on page 38. Eq. (8)
330
+ looks suspicious as it contains the inverse of a vector?! As also outlined
331
+ in the paper abs(a) is used (|a| in the paper)."""
332
+
333
+ dx = dx[:, None]
334
+ dg = dg[:, None]
335
+ dgTdg = dg.T @ dg
336
+ dgTdx = dg.T @ dx
337
+ a = (dgTdg - dg.T @ H @ dx) / (dgTdg * dgTdx)
338
+ a = abs(a)
339
+
340
+ # Diagonalize Hessian, to construct positive definite version of it
341
+ w, v = np.linalg.eigh(H)
342
+ H_pos_dx = np.abs(w) * v @ (v.T @ dx)
343
+ # Mixing factor
344
+ j = dg - H @ dx
345
+ jTdx = j.T @ dx
346
+ dxTdx = dx.T @ dx
347
+ jTj = j.T @ j
348
+ phi = jTdx ** 2 / dxTdx / jTj
349
+ u = ((1 - phi) * H_pos_dx + a * phi * dgTdx * dg) / (
350
+ (1 - phi) * dx.T @ H_pos_dx + phi * a * dgTdx ** 2
351
+ )
352
+
353
+ juT = j @ u.T
354
+ ts_bfgs_update = juT + juT.T - jTdx * u @ u.T
355
+ return ts_bfgs_update, "TS-BFGS"