mlmm-toolkit 0.2.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (372) hide show
  1. hessian_ff/__init__.py +50 -0
  2. hessian_ff/analytical_hessian.py +609 -0
  3. hessian_ff/constants.py +46 -0
  4. hessian_ff/forcefield.py +339 -0
  5. hessian_ff/loaders.py +608 -0
  6. hessian_ff/native/Makefile +8 -0
  7. hessian_ff/native/__init__.py +28 -0
  8. hessian_ff/native/analytical_hessian.py +88 -0
  9. hessian_ff/native/analytical_hessian_ext.cpp +258 -0
  10. hessian_ff/native/bonded.py +82 -0
  11. hessian_ff/native/bonded_ext.cpp +640 -0
  12. hessian_ff/native/loader.py +349 -0
  13. hessian_ff/native/nonbonded.py +118 -0
  14. hessian_ff/native/nonbonded_ext.cpp +1150 -0
  15. hessian_ff/prmtop_parmed.py +23 -0
  16. hessian_ff/system.py +107 -0
  17. hessian_ff/terms/__init__.py +14 -0
  18. hessian_ff/terms/angle.py +73 -0
  19. hessian_ff/terms/bond.py +44 -0
  20. hessian_ff/terms/cmap.py +406 -0
  21. hessian_ff/terms/dihedral.py +141 -0
  22. hessian_ff/terms/nonbonded.py +209 -0
  23. hessian_ff/tests/__init__.py +0 -0
  24. hessian_ff/tests/conftest.py +75 -0
  25. hessian_ff/tests/data/small/complex.parm7 +1346 -0
  26. hessian_ff/tests/data/small/complex.pdb +125 -0
  27. hessian_ff/tests/data/small/complex.rst7 +63 -0
  28. hessian_ff/tests/test_coords_input.py +44 -0
  29. hessian_ff/tests/test_energy_force.py +49 -0
  30. hessian_ff/tests/test_hessian.py +137 -0
  31. hessian_ff/tests/test_smoke.py +18 -0
  32. hessian_ff/tests/test_validation.py +40 -0
  33. hessian_ff/workflows.py +889 -0
  34. mlmm/__init__.py +36 -0
  35. mlmm/__main__.py +7 -0
  36. mlmm/_version.py +34 -0
  37. mlmm/add_elem_info.py +374 -0
  38. mlmm/advanced_help.py +91 -0
  39. mlmm/align_freeze_atoms.py +601 -0
  40. mlmm/all.py +3535 -0
  41. mlmm/bond_changes.py +231 -0
  42. mlmm/bool_compat.py +223 -0
  43. mlmm/cli.py +574 -0
  44. mlmm/cli_utils.py +166 -0
  45. mlmm/default_group.py +337 -0
  46. mlmm/defaults.py +467 -0
  47. mlmm/define_layer.py +526 -0
  48. mlmm/dft.py +1041 -0
  49. mlmm/energy_diagram.py +253 -0
  50. mlmm/extract.py +2213 -0
  51. mlmm/fix_altloc.py +464 -0
  52. mlmm/freq.py +1406 -0
  53. mlmm/harmonic_constraints.py +140 -0
  54. mlmm/hessian_cache.py +44 -0
  55. mlmm/hessian_calc.py +174 -0
  56. mlmm/irc.py +638 -0
  57. mlmm/mlmm_calc.py +2262 -0
  58. mlmm/mm_parm.py +945 -0
  59. mlmm/oniom_export.py +1983 -0
  60. mlmm/oniom_import.py +457 -0
  61. mlmm/opt.py +1742 -0
  62. mlmm/path_opt.py +1353 -0
  63. mlmm/path_search.py +2299 -0
  64. mlmm/preflight.py +88 -0
  65. mlmm/py.typed +1 -0
  66. mlmm/pysis_runner.py +45 -0
  67. mlmm/scan.py +1047 -0
  68. mlmm/scan2d.py +1226 -0
  69. mlmm/scan3d.py +1265 -0
  70. mlmm/scan_common.py +184 -0
  71. mlmm/summary_log.py +736 -0
  72. mlmm/trj2fig.py +448 -0
  73. mlmm/tsopt.py +2871 -0
  74. mlmm/utils.py +2309 -0
  75. mlmm/xtb_embedcharge_correction.py +475 -0
  76. mlmm_toolkit-0.2.2.dev0.dist-info/METADATA +1159 -0
  77. mlmm_toolkit-0.2.2.dev0.dist-info/RECORD +372 -0
  78. mlmm_toolkit-0.2.2.dev0.dist-info/WHEEL +5 -0
  79. mlmm_toolkit-0.2.2.dev0.dist-info/entry_points.txt +2 -0
  80. mlmm_toolkit-0.2.2.dev0.dist-info/licenses/LICENSE +674 -0
  81. mlmm_toolkit-0.2.2.dev0.dist-info/top_level.txt +4 -0
  82. pysisyphus/Geometry.py +1667 -0
  83. pysisyphus/LICENSE +674 -0
  84. pysisyphus/TableFormatter.py +63 -0
  85. pysisyphus/TablePrinter.py +74 -0
  86. pysisyphus/__init__.py +12 -0
  87. pysisyphus/calculators/AFIR.py +452 -0
  88. pysisyphus/calculators/AnaPot.py +20 -0
  89. pysisyphus/calculators/AnaPot2.py +48 -0
  90. pysisyphus/calculators/AnaPot3.py +12 -0
  91. pysisyphus/calculators/AnaPot4.py +20 -0
  92. pysisyphus/calculators/AnaPotBase.py +337 -0
  93. pysisyphus/calculators/AnaPotCBM.py +25 -0
  94. pysisyphus/calculators/AtomAtomTransTorque.py +154 -0
  95. pysisyphus/calculators/CFOUR.py +250 -0
  96. pysisyphus/calculators/Calculator.py +844 -0
  97. pysisyphus/calculators/CerjanMiller.py +24 -0
  98. pysisyphus/calculators/Composite.py +123 -0
  99. pysisyphus/calculators/ConicalIntersection.py +171 -0
  100. pysisyphus/calculators/DFTBp.py +430 -0
  101. pysisyphus/calculators/DFTD3.py +66 -0
  102. pysisyphus/calculators/DFTD4.py +84 -0
  103. pysisyphus/calculators/Dalton.py +61 -0
  104. pysisyphus/calculators/Dimer.py +681 -0
  105. pysisyphus/calculators/Dummy.py +20 -0
  106. pysisyphus/calculators/EGO.py +76 -0
  107. pysisyphus/calculators/EnergyMin.py +224 -0
  108. pysisyphus/calculators/ExternalPotential.py +264 -0
  109. pysisyphus/calculators/FakeASE.py +35 -0
  110. pysisyphus/calculators/FourWellAnaPot.py +28 -0
  111. pysisyphus/calculators/FreeEndNEBPot.py +39 -0
  112. pysisyphus/calculators/Gaussian09.py +18 -0
  113. pysisyphus/calculators/Gaussian16.py +726 -0
  114. pysisyphus/calculators/HardSphere.py +159 -0
  115. pysisyphus/calculators/IDPPCalculator.py +49 -0
  116. pysisyphus/calculators/IPIClient.py +133 -0
  117. pysisyphus/calculators/IPIServer.py +234 -0
  118. pysisyphus/calculators/LEPSBase.py +24 -0
  119. pysisyphus/calculators/LEPSExpr.py +139 -0
  120. pysisyphus/calculators/LennardJones.py +80 -0
  121. pysisyphus/calculators/MOPAC.py +219 -0
  122. pysisyphus/calculators/MullerBrownSympyPot.py +51 -0
  123. pysisyphus/calculators/MultiCalc.py +85 -0
  124. pysisyphus/calculators/NFK.py +45 -0
  125. pysisyphus/calculators/OBabel.py +87 -0
  126. pysisyphus/calculators/ONIOMv2.py +1129 -0
  127. pysisyphus/calculators/ORCA.py +893 -0
  128. pysisyphus/calculators/ORCA5.py +6 -0
  129. pysisyphus/calculators/OpenMM.py +88 -0
  130. pysisyphus/calculators/OpenMolcas.py +281 -0
  131. pysisyphus/calculators/OverlapCalculator.py +908 -0
  132. pysisyphus/calculators/Psi4.py +218 -0
  133. pysisyphus/calculators/PyPsi4.py +37 -0
  134. pysisyphus/calculators/PySCF.py +341 -0
  135. pysisyphus/calculators/PyXTB.py +73 -0
  136. pysisyphus/calculators/QCEngine.py +106 -0
  137. pysisyphus/calculators/Rastrigin.py +22 -0
  138. pysisyphus/calculators/Remote.py +76 -0
  139. pysisyphus/calculators/Rosenbrock.py +15 -0
  140. pysisyphus/calculators/SocketCalc.py +97 -0
  141. pysisyphus/calculators/TIP3P.py +111 -0
  142. pysisyphus/calculators/TransTorque.py +161 -0
  143. pysisyphus/calculators/Turbomole.py +965 -0
  144. pysisyphus/calculators/VRIPot.py +37 -0
  145. pysisyphus/calculators/WFOWrapper.py +333 -0
  146. pysisyphus/calculators/WFOWrapper2.py +341 -0
  147. pysisyphus/calculators/XTB.py +418 -0
  148. pysisyphus/calculators/__init__.py +81 -0
  149. pysisyphus/calculators/cosmo_data.py +139 -0
  150. pysisyphus/calculators/parser.py +150 -0
  151. pysisyphus/color.py +19 -0
  152. pysisyphus/config.py +133 -0
  153. pysisyphus/constants.py +65 -0
  154. pysisyphus/cos/AdaptiveNEB.py +230 -0
  155. pysisyphus/cos/ChainOfStates.py +725 -0
  156. pysisyphus/cos/FreeEndNEB.py +25 -0
  157. pysisyphus/cos/FreezingString.py +103 -0
  158. pysisyphus/cos/GrowingChainOfStates.py +71 -0
  159. pysisyphus/cos/GrowingNT.py +309 -0
  160. pysisyphus/cos/GrowingString.py +508 -0
  161. pysisyphus/cos/NEB.py +189 -0
  162. pysisyphus/cos/SimpleZTS.py +64 -0
  163. pysisyphus/cos/__init__.py +22 -0
  164. pysisyphus/cos/stiffness.py +199 -0
  165. pysisyphus/drivers/__init__.py +17 -0
  166. pysisyphus/drivers/afir.py +855 -0
  167. pysisyphus/drivers/barriers.py +271 -0
  168. pysisyphus/drivers/birkholz.py +138 -0
  169. pysisyphus/drivers/cluster.py +318 -0
  170. pysisyphus/drivers/diabatization.py +133 -0
  171. pysisyphus/drivers/merge.py +368 -0
  172. pysisyphus/drivers/merge_mol2.py +322 -0
  173. pysisyphus/drivers/opt.py +375 -0
  174. pysisyphus/drivers/perf.py +91 -0
  175. pysisyphus/drivers/pka.py +52 -0
  176. pysisyphus/drivers/precon_pos_rot.py +669 -0
  177. pysisyphus/drivers/rates.py +480 -0
  178. pysisyphus/drivers/replace.py +219 -0
  179. pysisyphus/drivers/scan.py +212 -0
  180. pysisyphus/drivers/spectrum.py +166 -0
  181. pysisyphus/drivers/thermo.py +31 -0
  182. pysisyphus/dynamics/Gaussian.py +103 -0
  183. pysisyphus/dynamics/__init__.py +20 -0
  184. pysisyphus/dynamics/colvars.py +136 -0
  185. pysisyphus/dynamics/driver.py +297 -0
  186. pysisyphus/dynamics/helpers.py +256 -0
  187. pysisyphus/dynamics/lincs.py +105 -0
  188. pysisyphus/dynamics/mdp.py +364 -0
  189. pysisyphus/dynamics/rattle.py +121 -0
  190. pysisyphus/dynamics/thermostats.py +128 -0
  191. pysisyphus/dynamics/wigner.py +266 -0
  192. pysisyphus/elem_data.py +3473 -0
  193. pysisyphus/exceptions.py +2 -0
  194. pysisyphus/filtertrj.py +69 -0
  195. pysisyphus/helpers.py +623 -0
  196. pysisyphus/helpers_pure.py +649 -0
  197. pysisyphus/init_logging.py +50 -0
  198. pysisyphus/intcoords/Bend.py +69 -0
  199. pysisyphus/intcoords/Bend2.py +25 -0
  200. pysisyphus/intcoords/BondedFragment.py +32 -0
  201. pysisyphus/intcoords/Cartesian.py +41 -0
  202. pysisyphus/intcoords/CartesianCoords.py +140 -0
  203. pysisyphus/intcoords/Coords.py +56 -0
  204. pysisyphus/intcoords/DLC.py +197 -0
  205. pysisyphus/intcoords/DistanceFunction.py +34 -0
  206. pysisyphus/intcoords/DummyImproper.py +70 -0
  207. pysisyphus/intcoords/DummyTorsion.py +72 -0
  208. pysisyphus/intcoords/LinearBend.py +105 -0
  209. pysisyphus/intcoords/LinearDisplacement.py +80 -0
  210. pysisyphus/intcoords/OutOfPlane.py +59 -0
  211. pysisyphus/intcoords/PrimTypes.py +286 -0
  212. pysisyphus/intcoords/Primitive.py +137 -0
  213. pysisyphus/intcoords/RedundantCoords.py +659 -0
  214. pysisyphus/intcoords/RobustTorsion.py +59 -0
  215. pysisyphus/intcoords/Rotation.py +147 -0
  216. pysisyphus/intcoords/Stretch.py +31 -0
  217. pysisyphus/intcoords/Torsion.py +101 -0
  218. pysisyphus/intcoords/Torsion2.py +25 -0
  219. pysisyphus/intcoords/Translation.py +45 -0
  220. pysisyphus/intcoords/__init__.py +61 -0
  221. pysisyphus/intcoords/augment_bonds.py +126 -0
  222. pysisyphus/intcoords/derivatives.py +10512 -0
  223. pysisyphus/intcoords/eval.py +80 -0
  224. pysisyphus/intcoords/exceptions.py +37 -0
  225. pysisyphus/intcoords/findiffs.py +48 -0
  226. pysisyphus/intcoords/generate_derivatives.py +414 -0
  227. pysisyphus/intcoords/helpers.py +235 -0
  228. pysisyphus/intcoords/logging_conf.py +10 -0
  229. pysisyphus/intcoords/mp_derivatives.py +10836 -0
  230. pysisyphus/intcoords/setup.py +962 -0
  231. pysisyphus/intcoords/setup_fast.py +176 -0
  232. pysisyphus/intcoords/update.py +272 -0
  233. pysisyphus/intcoords/valid.py +89 -0
  234. pysisyphus/interpolate/Geodesic.py +93 -0
  235. pysisyphus/interpolate/IDPP.py +55 -0
  236. pysisyphus/interpolate/Interpolator.py +116 -0
  237. pysisyphus/interpolate/LST.py +70 -0
  238. pysisyphus/interpolate/Redund.py +152 -0
  239. pysisyphus/interpolate/__init__.py +9 -0
  240. pysisyphus/interpolate/helpers.py +34 -0
  241. pysisyphus/io/__init__.py +22 -0
  242. pysisyphus/io/aomix.py +178 -0
  243. pysisyphus/io/cjson.py +24 -0
  244. pysisyphus/io/crd.py +101 -0
  245. pysisyphus/io/cube.py +220 -0
  246. pysisyphus/io/fchk.py +184 -0
  247. pysisyphus/io/hdf5.py +49 -0
  248. pysisyphus/io/hessian.py +72 -0
  249. pysisyphus/io/mol2.py +146 -0
  250. pysisyphus/io/molden.py +293 -0
  251. pysisyphus/io/orca.py +189 -0
  252. pysisyphus/io/pdb.py +269 -0
  253. pysisyphus/io/psf.py +79 -0
  254. pysisyphus/io/pubchem.py +31 -0
  255. pysisyphus/io/qcschema.py +34 -0
  256. pysisyphus/io/sdf.py +29 -0
  257. pysisyphus/io/xyz.py +61 -0
  258. pysisyphus/io/zmat.py +175 -0
  259. pysisyphus/irc/DWI.py +108 -0
  260. pysisyphus/irc/DampedVelocityVerlet.py +134 -0
  261. pysisyphus/irc/Euler.py +22 -0
  262. pysisyphus/irc/EulerPC.py +345 -0
  263. pysisyphus/irc/GonzalezSchlegel.py +187 -0
  264. pysisyphus/irc/IMKMod.py +164 -0
  265. pysisyphus/irc/IRC.py +878 -0
  266. pysisyphus/irc/IRCDummy.py +10 -0
  267. pysisyphus/irc/Instanton.py +307 -0
  268. pysisyphus/irc/LQA.py +53 -0
  269. pysisyphus/irc/ModeKill.py +136 -0
  270. pysisyphus/irc/ParamPlot.py +53 -0
  271. pysisyphus/irc/RK4.py +36 -0
  272. pysisyphus/irc/__init__.py +31 -0
  273. pysisyphus/irc/initial_displ.py +219 -0
  274. pysisyphus/linalg.py +411 -0
  275. pysisyphus/line_searches/Backtracking.py +88 -0
  276. pysisyphus/line_searches/HagerZhang.py +184 -0
  277. pysisyphus/line_searches/LineSearch.py +232 -0
  278. pysisyphus/line_searches/StrongWolfe.py +108 -0
  279. pysisyphus/line_searches/__init__.py +9 -0
  280. pysisyphus/line_searches/interpol.py +15 -0
  281. pysisyphus/modefollow/NormalMode.py +40 -0
  282. pysisyphus/modefollow/__init__.py +10 -0
  283. pysisyphus/modefollow/davidson.py +199 -0
  284. pysisyphus/modefollow/lanczos.py +95 -0
  285. pysisyphus/optimizers/BFGS.py +99 -0
  286. pysisyphus/optimizers/BacktrackingOptimizer.py +113 -0
  287. pysisyphus/optimizers/ConjugateGradient.py +98 -0
  288. pysisyphus/optimizers/CubicNewton.py +75 -0
  289. pysisyphus/optimizers/FIRE.py +113 -0
  290. pysisyphus/optimizers/HessianOptimizer.py +1176 -0
  291. pysisyphus/optimizers/LBFGS.py +228 -0
  292. pysisyphus/optimizers/LayerOpt.py +411 -0
  293. pysisyphus/optimizers/MicroOptimizer.py +169 -0
  294. pysisyphus/optimizers/NCOptimizer.py +90 -0
  295. pysisyphus/optimizers/Optimizer.py +1084 -0
  296. pysisyphus/optimizers/PreconLBFGS.py +260 -0
  297. pysisyphus/optimizers/PreconSteepestDescent.py +7 -0
  298. pysisyphus/optimizers/QuickMin.py +74 -0
  299. pysisyphus/optimizers/RFOptimizer.py +181 -0
  300. pysisyphus/optimizers/RSA.py +99 -0
  301. pysisyphus/optimizers/StabilizedQNMethod.py +248 -0
  302. pysisyphus/optimizers/SteepestDescent.py +23 -0
  303. pysisyphus/optimizers/StringOptimizer.py +173 -0
  304. pysisyphus/optimizers/__init__.py +41 -0
  305. pysisyphus/optimizers/closures.py +301 -0
  306. pysisyphus/optimizers/cls_map.py +58 -0
  307. pysisyphus/optimizers/exceptions.py +6 -0
  308. pysisyphus/optimizers/gdiis.py +280 -0
  309. pysisyphus/optimizers/guess_hessians.py +311 -0
  310. pysisyphus/optimizers/hessian_updates.py +355 -0
  311. pysisyphus/optimizers/poly_fit.py +285 -0
  312. pysisyphus/optimizers/precon.py +153 -0
  313. pysisyphus/optimizers/restrict_step.py +24 -0
  314. pysisyphus/pack.py +172 -0
  315. pysisyphus/peakdetect.py +948 -0
  316. pysisyphus/plot.py +1031 -0
  317. pysisyphus/run.py +2106 -0
  318. pysisyphus/socket_helper.py +74 -0
  319. pysisyphus/stocastic/FragmentKick.py +132 -0
  320. pysisyphus/stocastic/Kick.py +81 -0
  321. pysisyphus/stocastic/Pipeline.py +303 -0
  322. pysisyphus/stocastic/__init__.py +21 -0
  323. pysisyphus/stocastic/align.py +127 -0
  324. pysisyphus/testing.py +96 -0
  325. pysisyphus/thermo.py +156 -0
  326. pysisyphus/trj.py +824 -0
  327. pysisyphus/tsoptimizers/RSIRFOptimizer.py +56 -0
  328. pysisyphus/tsoptimizers/RSPRFOptimizer.py +182 -0
  329. pysisyphus/tsoptimizers/TRIM.py +59 -0
  330. pysisyphus/tsoptimizers/TSHessianOptimizer.py +463 -0
  331. pysisyphus/tsoptimizers/__init__.py +23 -0
  332. pysisyphus/wavefunction/Basis.py +239 -0
  333. pysisyphus/wavefunction/DIIS.py +76 -0
  334. pysisyphus/wavefunction/__init__.py +25 -0
  335. pysisyphus/wavefunction/build_ext.py +42 -0
  336. pysisyphus/wavefunction/cart2sph.py +190 -0
  337. pysisyphus/wavefunction/diabatization.py +304 -0
  338. pysisyphus/wavefunction/excited_states.py +435 -0
  339. pysisyphus/wavefunction/gen_ints.py +1811 -0
  340. pysisyphus/wavefunction/helpers.py +104 -0
  341. pysisyphus/wavefunction/ints/__init__.py +0 -0
  342. pysisyphus/wavefunction/ints/boys.py +193 -0
  343. pysisyphus/wavefunction/ints/boys_table_N_64_xasym_27.1_step_0.01.npy +0 -0
  344. pysisyphus/wavefunction/ints/cart_gto3d.py +176 -0
  345. pysisyphus/wavefunction/ints/coulomb3d.py +25928 -0
  346. pysisyphus/wavefunction/ints/diag_quadrupole3d.py +10036 -0
  347. pysisyphus/wavefunction/ints/dipole3d.py +8762 -0
  348. pysisyphus/wavefunction/ints/int2c2e3d.py +7198 -0
  349. pysisyphus/wavefunction/ints/int3c2e3d_sph.py +65040 -0
  350. pysisyphus/wavefunction/ints/kinetic3d.py +8240 -0
  351. pysisyphus/wavefunction/ints/ovlp3d.py +3777 -0
  352. pysisyphus/wavefunction/ints/quadrupole3d.py +15054 -0
  353. pysisyphus/wavefunction/ints/self_ovlp3d.py +198 -0
  354. pysisyphus/wavefunction/localization.py +458 -0
  355. pysisyphus/wavefunction/multipole.py +159 -0
  356. pysisyphus/wavefunction/normalization.py +36 -0
  357. pysisyphus/wavefunction/pop_analysis.py +134 -0
  358. pysisyphus/wavefunction/shells.py +1171 -0
  359. pysisyphus/wavefunction/wavefunction.py +504 -0
  360. pysisyphus/wrapper/__init__.py +11 -0
  361. pysisyphus/wrapper/exceptions.py +2 -0
  362. pysisyphus/wrapper/jmol.py +120 -0
  363. pysisyphus/wrapper/mwfn.py +169 -0
  364. pysisyphus/wrapper/packmol.py +71 -0
  365. pysisyphus/xyzloader.py +168 -0
  366. pysisyphus/yaml_mods.py +45 -0
  367. thermoanalysis/LICENSE +674 -0
  368. thermoanalysis/QCData.py +244 -0
  369. thermoanalysis/__init__.py +0 -0
  370. thermoanalysis/config.py +3 -0
  371. thermoanalysis/constants.py +20 -0
  372. thermoanalysis/thermo.py +1011 -0
@@ -0,0 +1,301 @@
1
+ # [1] https://arxiv.org/pdf/2101.04413.pdf
2
+ # A Regularized Limited Memory BFGS method for Large-Scale Unconstrained
3
+ # Optimization and itsefficient Implementations
4
+ # Tankaria, Sugimoto, Yamashita, 2021
5
+ # [2] Regularization of Limited Memory Quasi-Newton Methods for Large-Scale
6
+ # Nonconvex Minimization
7
+ # https://arxiv.org/pdf/1911.04584.pdf
8
+ # Kanzow, Steck 2021
9
+
10
+ from collections import deque
11
+
12
+ import numpy as np
13
+ from scipy.sparse.linalg import spsolve
14
+
15
+
16
+ def get_update_mu_reg(
17
+ mu_min=1e-3, gamma_1=0.1, gamma_2=5.0, eta_1=0.01, eta_2=0.9, logger=None
18
+ ):
19
+ """See 5.1 in [1]"""
20
+ assert 0.0 < mu_min
21
+ assert 0.0 < gamma_1 <= 1.0 < gamma_2
22
+ assert 0.0 < eta_1 < eta_2 <= 1.0
23
+
24
+ def log(msg):
25
+ if logger is not None:
26
+ logger.debug(msg)
27
+
28
+ def update_mu_reg(mu, cur_energy, trial_energy, cur_grad, step):
29
+ """Update regularization parameter μ_reg according to ratio r.
30
+ See 2.3 in [1].
31
+
32
+ r = actual_reduction / predicted_reduction
33
+ r = (f(x) - f(x + step)) / (f(x) - model_func(step, mu))
34
+
35
+ mode_func(step, mu) = f(x) + df(x)^T step(mu) + 1/2 step(mu)^T H^-1 step(mu)
36
+
37
+ f(x) - model_func(step, mu)
38
+ = f(x) - (f(x) + df(x)^T step(mu) + 1/2 step(mu)^T H(mu)^-1 step(mu))
39
+ = -df(x)^T step(mu) - 1/2 step(mu)^T H(mu)^-1 step(mu)
40
+ = -df(x)^T step(mu) + 1/2 step(mu)^T df(x)
41
+ = -1/2 df(x)^T step(mu)
42
+
43
+ r = (f(x) - f(x + step) / (-1/2 df(x)^T step(mu))
44
+ r = -2 * (f(x) - f(x + step) / (df(x)^T step(mu))
45
+
46
+ """
47
+ log(f" Cur energy: {cur_energy:.8f}")
48
+ log(f"Trial energy: {trial_energy:.8f}")
49
+ act_change = cur_energy - trial_energy
50
+ log(f" Actual change: {act_change: .8f}")
51
+ pred_change = -1 / 2 * cur_grad.dot(step)
52
+ log(f"Predicted change: {pred_change: .8f}")
53
+ r = act_change / pred_change
54
+ log(f"Ratio r={r:.8f}")
55
+
56
+ # Default case for eta_1 <= r < eta_2. Keep mu and accept step.
57
+ mu_updated = mu
58
+ recompute_step = False
59
+ # Poor actual reduction. Increase shift parameter and request new step.
60
+ if r < eta_1:
61
+ mu_updated *= gamma_2
62
+ recompute_step = True
63
+ log(f"Increased μ_reg to {mu_updated:.6f}.")
64
+ # Significant actual reduction. Reduce shift parameter und accept step.
65
+ elif r >= eta_2:
66
+ mu_updated = max(mu_min, mu_updated * gamma_1)
67
+ log(f"Decreased μ_reg to {mu_updated:.6f}.")
68
+ return mu_updated, recompute_step
69
+
70
+ return update_mu_reg
71
+
72
+
73
+ def bfgs_multiply(
74
+ s_list,
75
+ y_list,
76
+ vector,
77
+ beta=1,
78
+ P=None,
79
+ logger=None,
80
+ gamma_mult=True,
81
+ mu_reg=None,
82
+ inds=None,
83
+ cur_size=None,
84
+ ):
85
+ """Matrix-vector product H·v.
86
+
87
+ Multiplies given vector with inverse Hessian, obtained
88
+ from repeated BFGS updates calculated from steps in 's_list'
89
+ and gradient differences in 'y_list'.
90
+
91
+ Based on algorithm 7.4 Nocedal, Num. Opt., p. 178."""
92
+
93
+ assert len(s_list) == len(y_list), (
94
+ "lengths of step list 's_list' and gradient list 'y_list' differ!"
95
+ )
96
+
97
+ cycles = len(s_list)
98
+ q = vector.copy()
99
+ alphas = list()
100
+ rhos = list()
101
+
102
+ if mu_reg is not None:
103
+ # Regularized L-BFGS with ŷ = y + μs, see [1]
104
+ assert mu_reg > 0.0
105
+ y_list_reg = list()
106
+ for i, si in enumerate(s_list):
107
+ yi = y_list[i]
108
+ y_hat = yi + mu_reg * si
109
+ if y_hat.dot(si) <= 0:
110
+ # See 2 in [1]
111
+ y_hat = yi + (max(0, -si.dot(yi) / si.dot(si)) + mu_reg) + si
112
+ y_list_reg.append(y_hat)
113
+ y_list = y_list_reg
114
+
115
+ # Store rho and alphas as they are also needed in the second loop
116
+ for i in reversed(range(cycles)):
117
+ s = s_list[i]
118
+ y = y_list[i]
119
+ rho = 1 / y.dot(s)
120
+ rhos.append(rho)
121
+ try:
122
+ alpha = rho * s.dot(q)
123
+ q -= alpha * y
124
+ except ValueError:
125
+ inds_i = inds[i]
126
+ q_ = q.reshape(cur_size, -1)
127
+ alpha = rho * s.dot(q_[inds_i].flatten())
128
+ # This also modifies q!
129
+ q_[inds_i] -= alpha * y.reshape(len(inds_i), -1)
130
+ alphas.append(alpha)
131
+
132
+ # Restore original order, so that rho[i] = 1/s_list[i].dot(y_list[i]) etc.
133
+ alphas = alphas[::-1]
134
+ rhos = rhos[::-1]
135
+
136
+ if P is not None:
137
+ r = spsolve(P, q)
138
+ msg = "preconditioner"
139
+ elif gamma_mult and (cycles > 0):
140
+ s = s_list[-1]
141
+ y = y_list[-1]
142
+ gamma = s.dot(y) / y.dot(y)
143
+ r = gamma * q
144
+ msg = f"gamma={gamma:.4f}"
145
+ else:
146
+ r = beta * q
147
+ msg = f"beta={beta:.4f}"
148
+
149
+ if mu_reg is not None:
150
+ msg += f" and μ_reg={mu_reg:.6f}"
151
+
152
+ if logger is not None:
153
+ msg = f"BFGS multiply using {cycles} previous cycles with {msg}."
154
+ if len(s_list) == 0:
155
+ msg += "\nProduced simple SD step."
156
+ logger.debug(msg)
157
+
158
+ for i in range(cycles):
159
+ s = s_list[i]
160
+ y = y_list[i]
161
+ try:
162
+ beta = rhos[i] * y.dot(r)
163
+ r += s * (alphas[i] - beta)
164
+ except ValueError:
165
+ inds_i = inds[i]
166
+ r_ = r.reshape(cur_size, -1)
167
+ beta = rhos[i] * y.dot(r_[inds_i].flatten())
168
+ # This also modifies r!
169
+ r_[inds_i] += s.reshape(len(inds_i), -1) * (alphas[i] - beta)
170
+
171
+ return r
172
+
173
+
174
+ def lbfgs_closure(force_getter, M=10, beta=1, restrict_step=None):
175
+ x_list = list()
176
+ s_list = list()
177
+ y_list = list()
178
+ force_list = list()
179
+ cur_cycle = 0
180
+
181
+ if restrict_step is None:
182
+ restrict_step = lambda x, dx: dx
183
+
184
+ def lbfgs(x, *getter_args):
185
+ nonlocal cur_cycle
186
+ nonlocal s_list
187
+ nonlocal y_list
188
+
189
+ force = force_getter(x, *getter_args)
190
+ if cur_cycle > 0:
191
+ prev_x = x_list[-1]
192
+ s = x - prev_x
193
+ s_list.append(s)
194
+ prev_force = force_list[-1]
195
+ y = prev_force - force
196
+ y_list.append(y)
197
+ x_list.append(x)
198
+ force_list.append(force)
199
+
200
+ step = bfgs_multiply(s_list, y_list, force, beta=beta)
201
+ step = restrict_step(x, step)
202
+ # Only keep last m cycles
203
+ s_list = s_list[-M:]
204
+ y_list = y_list[-M:]
205
+ cur_cycle += 1
206
+ return step, force
207
+
208
+ return lbfgs
209
+
210
+
211
+ def modified_broyden_closure(force_getter, M=5, beta=1, restrict_step=None):
212
+ """https://doi.org/10.1006/jcph.1996.0059
213
+ F corresponds to the residual gradient, so we after calling
214
+ force_getter we multiply the force by -1 to get the gradient."""
215
+
216
+ dxs = list()
217
+ dFs = list()
218
+ F = None
219
+ # The next line is used in SciPy, but then beta strongly depends
220
+ # on the magnitude of x, which is quite weird. Disregarding the
221
+ # analytical potentials the electronic energy is usually
222
+ # invariant under translation and rotation and doesn't depend on
223
+ # the magnitude of x.
224
+ # beta = 0.5 * max(np.linalg.norm(x), 1) / np.linalg.norm(F)
225
+ a = None
226
+
227
+ if restrict_step is None:
228
+ restrict_step = lambda x, dx: dx
229
+
230
+ def modified_broyden(x, *getter_args):
231
+ nonlocal dxs
232
+ nonlocal dFs
233
+ nonlocal F
234
+ nonlocal a
235
+
236
+ F_new = -force_getter(x, *getter_args)
237
+ if F is not None:
238
+ dF = F_new - F
239
+ dFs.append(dF)
240
+ dFs = dFs[-M:]
241
+ # Overlap matrix
242
+ a = np.zeros((len(dFs), len(dFs)))
243
+ for k, dF_k in enumerate(dFs):
244
+ for m, dF_m in enumerate(dFs):
245
+ a[k, m] = dF_k.dot(dF_m)
246
+ F = F_new
247
+ dx = -beta * F
248
+
249
+ if len(dxs) > 0:
250
+ # Calculate gamma
251
+ dF_F = [dF_k.dot(F) for dF_k in dFs]
252
+ # Use least squares to avoid crashes with singular matrices
253
+ gammas, *_ = np.linalg.lstsq(a, dF_F, rcond=None)
254
+ gammas = gammas[:, None]
255
+ _ = np.array(dxs) - beta * np.array(dFs)
256
+ # Substract step correction
257
+ dx = dx - np.sum(gammas * _, axis=0)
258
+ dx = restrict_step(x, dx)
259
+ dxs.append(dx)
260
+
261
+ # Keep only informations of the last M cycles
262
+ dxs = dxs[-M:]
263
+
264
+ return dx, -F
265
+
266
+ return modified_broyden
267
+
268
+
269
+ def small_lbfgs_closure(history=5, gamma_mult=True):
270
+ """Compact LBFGS closure.
271
+
272
+ The returned function takes two arguments: forces and prev_step.
273
+ forces are the forces at the current iterate and prev_step is the
274
+ previous step that lead us to the current iterate. In this way
275
+ step restriction/line search can be done outisde of the lbfgs function.
276
+ """
277
+
278
+ prev_forces = None # lgtm [py/unused-local-variable]
279
+ grad_diffs = deque(maxlen=history)
280
+ steps = deque(maxlen=history)
281
+ cur_cycle = 0
282
+
283
+ def lbfgs(forces, prev_step=None):
284
+ nonlocal cur_cycle
285
+ nonlocal prev_forces
286
+
287
+ if prev_step is not None:
288
+ steps.append(prev_step)
289
+
290
+ # Steepest descent in the first cycle
291
+ step = forces
292
+ # LBFGS in the following cycles
293
+ if cur_cycle > 0:
294
+ grad_diffs.append(-forces - -prev_forces)
295
+ step = bfgs_multiply(steps, grad_diffs, forces, gamma_mult=gamma_mult)
296
+
297
+ prev_forces = forces
298
+ cur_cycle += 1
299
+ return step
300
+
301
+ return lbfgs
@@ -0,0 +1,58 @@
1
+ from pysisyphus.optimizers import (
2
+ BFGS,
3
+ CubicNewton,
4
+ ConjugateGradient,
5
+ FIRE,
6
+ LayerOpt,
7
+ LBFGS,
8
+ MicroOptimizer,
9
+ NCOptimizer,
10
+ PreconLBFGS,
11
+ PreconSteepestDescent,
12
+ QuickMin,
13
+ RFOptimizer,
14
+ SteepestDescent,
15
+ StabilizedQNMethod,
16
+ StringOptimizer,
17
+ )
18
+ from pysisyphus.tsoptimizers import (
19
+ RSPRFOptimizer,
20
+ TRIM,
21
+ RSIRFOptimizer,
22
+ )
23
+
24
+
25
+ OPT_DICT = {
26
+ "bfgs": BFGS.BFGS,
27
+ "cg": ConjugateGradient.ConjugateGradient,
28
+ "cnewton": CubicNewton,
29
+ "fire": FIRE.FIRE,
30
+ "layers": LayerOpt.LayerOpt,
31
+ "lbfgs": LBFGS.LBFGS,
32
+ "micro": MicroOptimizer,
33
+ "nc": NCOptimizer.NCOptimizer,
34
+ "plbfgs": PreconLBFGS.PreconLBFGS,
35
+ "psd": PreconSteepestDescent.PreconSteepestDescent,
36
+ "qm": QuickMin.QuickMin,
37
+ "rfo": RFOptimizer.RFOptimizer,
38
+ "sd": SteepestDescent.SteepestDescent,
39
+ "sqnm": StabilizedQNMethod.StabilizedQNMethod,
40
+ "string": StringOptimizer.StringOptimizer,
41
+ }
42
+
43
+ TSOPT_DICT = {
44
+ "rsprfo": RSPRFOptimizer,
45
+ "prfo": RSPRFOptimizer, # Shortcut
46
+ "trim": TRIM,
47
+ "rsirfo": RSIRFOptimizer,
48
+ "irfo": RSIRFOptimizer, # Shortcut
49
+ }
50
+ OPT_DICT.update(TSOPT_DICT)
51
+
52
+
53
+ def get_opt_cls(opt_key):
54
+ return OPT_DICT[opt_key]
55
+
56
+
57
+ def key_is_tsopt(opt_key):
58
+ return opt_key in TSOPT_DICT
@@ -0,0 +1,6 @@
1
+ class ZeroStepLength(Exception):
2
+ pass
3
+
4
+
5
+ class OptimizationError(Exception):
6
+ pass
@@ -0,0 +1,280 @@
1
+ # [1] https://doi.org/10.1016/S0022-2860(84)87198-7
2
+ # Pulay, 1984
3
+ # [2] https://pubs.rsc.org/en/content/articlehtml/2002/cp/b108658h
4
+ # Stabilized GDIIS
5
+ # Farkas, Schlegel, 2002
6
+ # [3] https://pubs.acs.org/doi/abs/10.1021/ct050275a
7
+ # GEDIIS/Hybrid method
8
+ # Li, Frisch, 2006
9
+ # [4] https://aip.scitation.org/doi/10.1063/1.2977735
10
+ # Sim-GEDIIS using hessian information
11
+ # Moss, Li, 2008
12
+
13
+ from collections import namedtuple
14
+ import logging
15
+
16
+ import autograd.numpy as anp
17
+ from autograd import grad
18
+ import numpy as np
19
+ from scipy.optimize import minimize
20
+
21
+ from pysisyphus.helpers import array2string
22
+ import torch
23
+
24
+ COS_CUTOFFS = {
25
+ # Looser cutoffs
26
+ 2: 0.80,
27
+ 3: 0.75,
28
+ # Original cutoffs, as published in [2]
29
+ # 2: 0.97,
30
+ # 3: 0.84,
31
+ 4: 0.71,
32
+ 5: 0.67,
33
+ 6: 0.62,
34
+ 7: 0.56,
35
+ 8: 0.49,
36
+ 9: 0.41,
37
+ }
38
+ DIISResult = namedtuple("DIISResult", "coeffs coords forces energy N type")
39
+
40
+
41
+ logger = logging.getLogger("optimizer")
42
+
43
+
44
+ def log(msg):
45
+ logger.debug(msg)
46
+
47
+
48
+ def valid_diis_direction(diis_step, ref_step, use):
49
+ if isinstance(diis_step, torch.Tensor):
50
+ ref_direction = ref_step / torch.linalg.norm(ref_step)
51
+ diis_direction = diis_step / torch.linalg.norm(diis_step)
52
+ cos = torch.dot(diis_direction, ref_direction)
53
+ else:
54
+ ref_direction = ref_step / np.linalg.norm(ref_step)
55
+ diis_direction = diis_step / np.linalg.norm(diis_step)
56
+ cos = diis_direction @ ref_direction
57
+ return (cos >= COS_CUTOFFS[use]) and (cos >= 0)
58
+
59
+
60
+ def from_coeffs(vec, coeffs):
61
+ if isinstance(vec, torch.Tensor):
62
+ return torch.sum(coeffs[:, None] * vec.flip(0)[: len(coeffs)], dim=0)
63
+ else:
64
+ return np.sum(coeffs[:, None] * vec[::-1][: len(coeffs)], axis=0)
65
+
66
+
67
+ def diis_result(coeffs, coords, forces, energy=None, prefix=""):
68
+ diis_coords = from_coeffs(coords, coeffs)
69
+ diis_forces = from_coeffs(forces, coeffs)
70
+ diis_result = DIISResult(
71
+ coeffs=coeffs,
72
+ coords=diis_coords,
73
+ forces=diis_forces,
74
+ energy=energy,
75
+ N=len(coeffs),
76
+ type=f"{prefix}DIIS",
77
+ )
78
+ log(f"\tUsed {len(coeffs)} error vectors for {prefix}DIIS.")
79
+ log("")
80
+ return diis_result
81
+
82
+
83
+ def gdiis(err_vecs, coords, forces, ref_step, max_vecs=5, test_direction=True):
84
+ # Scale error vectors so the smallest norm is 1
85
+ if isinstance(err_vecs, torch.Tensor):
86
+ coords = torch.from_numpy(np.array(coords)).to(err_vecs.device, dtype=err_vecs.dtype)
87
+ forces = torch.from_numpy(np.array(forces)).to(err_vecs.device, dtype=err_vecs.dtype)
88
+ norms = torch.linalg.norm(err_vecs, dim=1)
89
+ else:
90
+ norms = np.linalg.norm(err_vecs, axis=1)
91
+ err_vecs = err_vecs / norms.min()
92
+
93
+ valid_coeffs = None
94
+ for use in range(2, min(max_vecs, len(err_vecs)) + 1):
95
+ log(f"Trying GDIIS with {use} previous cycles.")
96
+ if isinstance(err_vecs, torch.Tensor):
97
+ use_vecs = err_vecs.flip(0)[:use]
98
+ A = torch.einsum("ij,kj->ik", use_vecs, use_vecs)
99
+ else:
100
+ use_vecs = np.array(err_vecs[::-1][:use])
101
+ A = np.einsum("ij,kj->ik", use_vecs, use_vecs)
102
+ try:
103
+ if isinstance(err_vecs, torch.Tensor):
104
+ coeffs = torch.linalg.solve(A, torch.ones(use, device=err_vecs.device))
105
+ else:
106
+ coeffs = np.linalg.solve(A, np.ones(use))
107
+ except np.linalg.LinAlgError:
108
+ log("LinAlgError when solving GDIIS matrix.")
109
+ break
110
+ except torch.linalg.LinAlgError:
111
+ log("Torch LinAlgError when solving GDIIS matrix.")
112
+ break
113
+ # Scale coeffs so that their sum equals 1
114
+ if isinstance(err_vecs, torch.Tensor):
115
+ coeffs_norm = torch.linalg.norm(coeffs)
116
+ else:
117
+ coeffs_norm = np.linalg.norm(coeffs)
118
+ valid_coeffs_norm = coeffs_norm <= 1e8
119
+ log(f"\tError vectors are linearly independent: {valid_coeffs_norm}")
120
+ if isinstance(err_vecs, torch.Tensor):
121
+ coeffs /= torch.sum(coeffs)
122
+ else:
123
+ coeffs /= np.sum(coeffs)
124
+ coeffs_str = array2string(coeffs, precision=4)
125
+ log(f"\tGDIIS coefficients: {coeffs_str}")
126
+
127
+ # Uncomment these lines and break here to only do the basic check
128
+ # for linear dependency above.
129
+ # valid_coeffs = coeffs
130
+ # break
131
+
132
+ # Check degree of extra- and interpolation.
133
+ pos_sum = abs(coeffs[coeffs > 0].sum())
134
+ neg_sum = abs(coeffs[coeffs < 0].sum())
135
+ valid_sums = (pos_sum <= 15) and (neg_sum <= 15)
136
+ log(f"\tSum of positive coefficients: {pos_sum:.2f}")
137
+ log(f"\tSum of negative coefficients: {neg_sum:.2f}")
138
+ log(f"\tSums are valid: {valid_sums}")
139
+
140
+ # Calculate GDIIS step for comparison to the reference step
141
+ diis_coords = from_coeffs(coords, coeffs)
142
+ diis_step = diis_coords - coords[-1]
143
+ if isinstance(err_vecs, torch.Tensor):
144
+ valid_length = torch.linalg.norm(diis_step) <= (10 * torch.linalg.norm(ref_step))
145
+ else:
146
+ valid_length = np.linalg.norm(diis_step) <= (10 * np.linalg.norm(ref_step))
147
+ log(f"\tGDIIS step has valid length: {valid_length}")
148
+
149
+ # Compare directions of GDIIS- and reference step
150
+ valid_direction = (
151
+ True
152
+ if (not test_direction)
153
+ else valid_diis_direction(diis_step, ref_step, use)
154
+ )
155
+ log(f"\tGDIIS step has valid direction: {valid_direction}")
156
+
157
+ gdiis_valid = (
158
+ valid_sums and valid_coeffs_norm and valid_direction and valid_length
159
+ )
160
+ log(f"\tGDIIS step is valid: {gdiis_valid}")
161
+ if not gdiis_valid:
162
+ break
163
+ # Update valid DIIS coefficients
164
+ valid_coeffs = coeffs
165
+ log("")
166
+
167
+ if valid_coeffs is None:
168
+ return None
169
+
170
+ # if len(valid_coeffs) is 2:
171
+ # print("GDIIS with only 2 cycles. Skipping! Return None")
172
+ # return None
173
+
174
+ return diis_result(valid_coeffs, coords, forces, prefix="G")
175
+
176
+
177
+ def gediis(coords, energies, forces, hessian=None, max_vecs=3):
178
+ use = min(len(coords), max_vecs)
179
+ if isinstance(hessian, torch.Tensor):
180
+ coords = torch.from_numpy(np.array(coords)).to(device=hessian.device, dtype=hessian.dtype)
181
+ energies = torch.tensor(energies, device=hessian.device, dtype=hessian.dtype)
182
+ forces = torch.from_numpy(np.array(forces)).to(device=hessian.device, dtype=hessian.dtype)
183
+ R = coords.flip(0)[:use]
184
+ E = energies.flip(0)[:use].reshape(-1)
185
+ f = forces.flip(0)[:use]
186
+ else:
187
+ R = coords[::-1][:use]
188
+ E = np.ravel(energies[::-1][:use])
189
+ f = forces[::-1][:use]
190
+ assert len(R) == len(E) == len(f)
191
+ log(f"Trying GEDIIS with {use} previous cycles.")
192
+ # Precompute values so they can be reused in fun()
193
+ if isinstance(R, torch.Tensor):
194
+ Rifi = torch.einsum("ik,ik->i", R, f).cpu().numpy()
195
+ Rjfi = torch.einsum("jk,ik->ji", R, f).cpu().numpy()
196
+ else:
197
+ Rifi = np.einsum("ik,ik->i", R, f)
198
+ Rjfi = np.einsum("jk,ik->ji", R, f)
199
+
200
+ def x2c(x):
201
+ return x ** 2 / (x ** 2).sum()
202
+
203
+ # def fun(xs):
204
+ # """Naive implementation with loops."""
205
+ # cs = x2c(xs)
206
+ # first = (cs*E).sum()
207
+ # sec = 0.
208
+ # for i, ci in enumerate(cs):
209
+ # for j, cj in enumerate(cs):
210
+ # sec += ci * cj * (f[j] - f[i]) @ (R[i] - R[j])
211
+ # return first - 1/2 * sec
212
+
213
+ # def fun(xs):
214
+ # """Recalculation of all values in every call."""
215
+ # cs = x2c(xs)
216
+ # return anp.sum(cs*E) - anp.einsum("i,j,jk,ik", cs, cs, R, f) + anp.einsum("i,ij,ij", cs, R, f)
217
+
218
+ # Using precomputed values from above in 'fun()'
219
+ if hessian is None:
220
+
221
+ def fun(xs):
222
+ """Eq. (6) from [3]."""
223
+ cs = x2c(xs)
224
+ return (
225
+ anp.sum(cs * E) - anp.sum(anp.outer(cs, cs) * Rjfi) + (cs * Rifi).sum()
226
+ )
227
+
228
+ else:
229
+ if isinstance(hessian, torch.Tensor):
230
+ hessian_inv = torch.linalg.pinv(hessian, rcond=1e-6)
231
+ gHig = torch.einsum("ki,ji,ki->k", f, hessian_inv, f).cpu().numpy()
232
+ else:
233
+ hessian_inv = np.linalg.pinv(hessian, rcond=1e-6)
234
+ # It doesn't matter if we use forces or gradients, as the signs will cancel.
235
+ # gHig = 0.5 * np.einsum("ki,ji,ki->k", f, hessian_inv, f)
236
+ gHig = np.einsum("ki,ji,ki->k", f, hessian_inv, f)
237
+
238
+ def fun(xs):
239
+ """Eq. (5) from [4]."""
240
+ cs = x2c(xs)
241
+ # Consider the hessian in the first term
242
+ return (
243
+ 0.5 * anp.sum(cs * gHig)
244
+ - anp.sum(anp.outer(cs, cs) * Rjfi)
245
+ + (cs * Rifi).sum()
246
+ )
247
+
248
+ # def fun(xs):
249
+ # cs = x2c(xs)
250
+ # cRjfi = anp.einsum("j,jk,ik->ji", cs, R, f).sum(axis=0)
251
+ # return anp.sum(
252
+ # cs * (E + cRjfi + Rifi)
253
+ # )
254
+
255
+ jac = grad(fun)
256
+
257
+ x0 = np.ones(use) / use
258
+ res = minimize(fun, x0=x0, jac=jac) # , tol=1e-7)
259
+ # print(res)
260
+ # print("final x", res.x)
261
+ # x = res.x
262
+ # import pdb; pdb.set_trace()
263
+
264
+ coeffs = None
265
+ if res.success:
266
+ coeffs = x2c(res.x)
267
+ en_ = res.fun
268
+ if isinstance(hessian, torch.Tensor):
269
+ coeffs = torch.from_numpy(coeffs).to(device=hessian.device, dtype=hessian.dtype)
270
+ log(f"\tOptimization converged!")
271
+ coeff_str = array2string(coeffs, precision=4)
272
+ log(f"\tCoefficients: {coeff_str}")
273
+ # en_ = (E * coeffs).sum()
274
+ # import pdb; pdb.set_trace()
275
+ if (hessian is None) and (en_ >= E[0]):
276
+ print(
277
+ f"GEDIIS converged, but proposed energy is above current energy! Returning None"
278
+ )
279
+ return None
280
+ return diis_result(coeffs, coords, forces, energy=en_, prefix="GE")