MultiOptPy 1.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (246) hide show
  1. multioptpy/Calculator/__init__.py +0 -0
  2. multioptpy/Calculator/ase_calculation_tools.py +424 -0
  3. multioptpy/Calculator/ase_tools/__init__.py +0 -0
  4. multioptpy/Calculator/ase_tools/fairchem.py +28 -0
  5. multioptpy/Calculator/ase_tools/gamess.py +19 -0
  6. multioptpy/Calculator/ase_tools/gaussian.py +165 -0
  7. multioptpy/Calculator/ase_tools/mace.py +28 -0
  8. multioptpy/Calculator/ase_tools/mopac.py +19 -0
  9. multioptpy/Calculator/ase_tools/nwchem.py +31 -0
  10. multioptpy/Calculator/ase_tools/orca.py +22 -0
  11. multioptpy/Calculator/ase_tools/pygfn0.py +37 -0
  12. multioptpy/Calculator/dxtb_calculation_tools.py +344 -0
  13. multioptpy/Calculator/emt_calculation_tools.py +458 -0
  14. multioptpy/Calculator/gpaw_calculation_tools.py +183 -0
  15. multioptpy/Calculator/lj_calculation_tools.py +314 -0
  16. multioptpy/Calculator/psi4_calculation_tools.py +334 -0
  17. multioptpy/Calculator/pwscf_calculation_tools.py +189 -0
  18. multioptpy/Calculator/pyscf_calculation_tools.py +327 -0
  19. multioptpy/Calculator/sqm1_calculation_tools.py +611 -0
  20. multioptpy/Calculator/sqm2_calculation_tools.py +376 -0
  21. multioptpy/Calculator/tblite_calculation_tools.py +352 -0
  22. multioptpy/Calculator/tersoff_calculation_tools.py +818 -0
  23. multioptpy/Constraint/__init__.py +0 -0
  24. multioptpy/Constraint/constraint_condition.py +834 -0
  25. multioptpy/Coordinate/__init__.py +0 -0
  26. multioptpy/Coordinate/polar_coordinate.py +199 -0
  27. multioptpy/Coordinate/redundant_coordinate.py +638 -0
  28. multioptpy/IRC/__init__.py +0 -0
  29. multioptpy/IRC/converge_criteria.py +28 -0
  30. multioptpy/IRC/dvv.py +544 -0
  31. multioptpy/IRC/euler.py +439 -0
  32. multioptpy/IRC/hpc.py +564 -0
  33. multioptpy/IRC/lqa.py +540 -0
  34. multioptpy/IRC/modekill.py +662 -0
  35. multioptpy/IRC/rk4.py +579 -0
  36. multioptpy/Interpolation/__init__.py +0 -0
  37. multioptpy/Interpolation/adaptive_interpolation.py +283 -0
  38. multioptpy/Interpolation/binomial_interpolation.py +179 -0
  39. multioptpy/Interpolation/geodesic_interpolation.py +785 -0
  40. multioptpy/Interpolation/interpolation.py +156 -0
  41. multioptpy/Interpolation/linear_interpolation.py +473 -0
  42. multioptpy/Interpolation/savitzky_golay_interpolation.py +252 -0
  43. multioptpy/Interpolation/spline_interpolation.py +353 -0
  44. multioptpy/MD/__init__.py +0 -0
  45. multioptpy/MD/thermostat.py +185 -0
  46. multioptpy/MEP/__init__.py +0 -0
  47. multioptpy/MEP/pathopt_bneb_force.py +443 -0
  48. multioptpy/MEP/pathopt_dmf_force.py +448 -0
  49. multioptpy/MEP/pathopt_dneb_force.py +130 -0
  50. multioptpy/MEP/pathopt_ewbneb_force.py +207 -0
  51. multioptpy/MEP/pathopt_gpneb_force.py +512 -0
  52. multioptpy/MEP/pathopt_lup_force.py +113 -0
  53. multioptpy/MEP/pathopt_neb_force.py +225 -0
  54. multioptpy/MEP/pathopt_nesb_force.py +205 -0
  55. multioptpy/MEP/pathopt_om_force.py +153 -0
  56. multioptpy/MEP/pathopt_qsm_force.py +174 -0
  57. multioptpy/MEP/pathopt_qsmv2_force.py +304 -0
  58. multioptpy/ModelFunction/__init__.py +7 -0
  59. multioptpy/ModelFunction/avoiding_model_function.py +29 -0
  60. multioptpy/ModelFunction/binary_image_ts_search_model_function.py +47 -0
  61. multioptpy/ModelFunction/conical_model_function.py +26 -0
  62. multioptpy/ModelFunction/opt_meci.py +50 -0
  63. multioptpy/ModelFunction/opt_mesx.py +47 -0
  64. multioptpy/ModelFunction/opt_mesx_2.py +49 -0
  65. multioptpy/ModelFunction/seam_model_function.py +27 -0
  66. multioptpy/ModelHessian/__init__.py +0 -0
  67. multioptpy/ModelHessian/approx_hessian.py +147 -0
  68. multioptpy/ModelHessian/calc_params.py +227 -0
  69. multioptpy/ModelHessian/fischer.py +236 -0
  70. multioptpy/ModelHessian/fischerd3.py +360 -0
  71. multioptpy/ModelHessian/fischerd4.py +398 -0
  72. multioptpy/ModelHessian/gfn0xtb.py +633 -0
  73. multioptpy/ModelHessian/gfnff.py +709 -0
  74. multioptpy/ModelHessian/lindh.py +165 -0
  75. multioptpy/ModelHessian/lindh2007d2.py +707 -0
  76. multioptpy/ModelHessian/lindh2007d3.py +822 -0
  77. multioptpy/ModelHessian/lindh2007d4.py +1030 -0
  78. multioptpy/ModelHessian/morse.py +106 -0
  79. multioptpy/ModelHessian/schlegel.py +144 -0
  80. multioptpy/ModelHessian/schlegeld3.py +322 -0
  81. multioptpy/ModelHessian/schlegeld4.py +559 -0
  82. multioptpy/ModelHessian/shortrange.py +346 -0
  83. multioptpy/ModelHessian/swartd2.py +496 -0
  84. multioptpy/ModelHessian/swartd3.py +706 -0
  85. multioptpy/ModelHessian/swartd4.py +918 -0
  86. multioptpy/ModelHessian/tshess.py +40 -0
  87. multioptpy/Optimizer/QHAdam.py +61 -0
  88. multioptpy/Optimizer/__init__.py +0 -0
  89. multioptpy/Optimizer/abc_fire.py +83 -0
  90. multioptpy/Optimizer/adabelief.py +58 -0
  91. multioptpy/Optimizer/adabound.py +68 -0
  92. multioptpy/Optimizer/adadelta.py +65 -0
  93. multioptpy/Optimizer/adaderivative.py +56 -0
  94. multioptpy/Optimizer/adadiff.py +68 -0
  95. multioptpy/Optimizer/adafactor.py +70 -0
  96. multioptpy/Optimizer/adam.py +65 -0
  97. multioptpy/Optimizer/adamax.py +62 -0
  98. multioptpy/Optimizer/adamod.py +83 -0
  99. multioptpy/Optimizer/adamw.py +65 -0
  100. multioptpy/Optimizer/adiis.py +523 -0
  101. multioptpy/Optimizer/afire_neb.py +282 -0
  102. multioptpy/Optimizer/block_hessian_update.py +709 -0
  103. multioptpy/Optimizer/c2diis.py +491 -0
  104. multioptpy/Optimizer/component_wise_scaling.py +405 -0
  105. multioptpy/Optimizer/conjugate_gradient.py +82 -0
  106. multioptpy/Optimizer/conjugate_gradient_neb.py +345 -0
  107. multioptpy/Optimizer/coordinate_locking.py +405 -0
  108. multioptpy/Optimizer/dic_rsirfo.py +1015 -0
  109. multioptpy/Optimizer/ediis.py +417 -0
  110. multioptpy/Optimizer/eve.py +76 -0
  111. multioptpy/Optimizer/fastadabelief.py +61 -0
  112. multioptpy/Optimizer/fire.py +77 -0
  113. multioptpy/Optimizer/fire2.py +249 -0
  114. multioptpy/Optimizer/fire_neb.py +92 -0
  115. multioptpy/Optimizer/gan_step.py +486 -0
  116. multioptpy/Optimizer/gdiis.py +609 -0
  117. multioptpy/Optimizer/gediis.py +203 -0
  118. multioptpy/Optimizer/geodesic_step.py +433 -0
  119. multioptpy/Optimizer/gpmin.py +633 -0
  120. multioptpy/Optimizer/gpr_step.py +364 -0
  121. multioptpy/Optimizer/gradientdescent.py +78 -0
  122. multioptpy/Optimizer/gradientdescent_neb.py +52 -0
  123. multioptpy/Optimizer/hessian_update.py +433 -0
  124. multioptpy/Optimizer/hybrid_rfo.py +998 -0
  125. multioptpy/Optimizer/kdiis.py +625 -0
  126. multioptpy/Optimizer/lars.py +21 -0
  127. multioptpy/Optimizer/lbfgs.py +253 -0
  128. multioptpy/Optimizer/lbfgs_neb.py +355 -0
  129. multioptpy/Optimizer/linesearch.py +236 -0
  130. multioptpy/Optimizer/lookahead.py +40 -0
  131. multioptpy/Optimizer/nadam.py +64 -0
  132. multioptpy/Optimizer/newton.py +200 -0
  133. multioptpy/Optimizer/prodigy.py +70 -0
  134. multioptpy/Optimizer/purtubation.py +16 -0
  135. multioptpy/Optimizer/quickmin_neb.py +245 -0
  136. multioptpy/Optimizer/radam.py +75 -0
  137. multioptpy/Optimizer/rfo_neb.py +302 -0
  138. multioptpy/Optimizer/ric_rfo.py +842 -0
  139. multioptpy/Optimizer/rl_step.py +627 -0
  140. multioptpy/Optimizer/rmspropgrave.py +65 -0
  141. multioptpy/Optimizer/rsirfo.py +1647 -0
  142. multioptpy/Optimizer/rsprfo.py +1056 -0
  143. multioptpy/Optimizer/sadam.py +60 -0
  144. multioptpy/Optimizer/samsgrad.py +63 -0
  145. multioptpy/Optimizer/tr_lbfgs.py +678 -0
  146. multioptpy/Optimizer/trim.py +273 -0
  147. multioptpy/Optimizer/trust_radius.py +207 -0
  148. multioptpy/Optimizer/trust_radius_neb.py +121 -0
  149. multioptpy/Optimizer/yogi.py +60 -0
  150. multioptpy/OtherMethod/__init__.py +0 -0
  151. multioptpy/OtherMethod/addf.py +1150 -0
  152. multioptpy/OtherMethod/dimer.py +895 -0
  153. multioptpy/OtherMethod/elastic_image_pair.py +629 -0
  154. multioptpy/OtherMethod/modelfunction.py +456 -0
  155. multioptpy/OtherMethod/newton_traj.py +454 -0
  156. multioptpy/OtherMethod/twopshs.py +1095 -0
  157. multioptpy/PESAnalyzer/__init__.py +0 -0
  158. multioptpy/PESAnalyzer/calc_irc_curvature.py +125 -0
  159. multioptpy/PESAnalyzer/cmds_analysis.py +152 -0
  160. multioptpy/PESAnalyzer/koopman_analysis.py +268 -0
  161. multioptpy/PESAnalyzer/pca_analysis.py +314 -0
  162. multioptpy/Parameters/__init__.py +0 -0
  163. multioptpy/Parameters/atomic_mass.py +20 -0
  164. multioptpy/Parameters/atomic_number.py +22 -0
  165. multioptpy/Parameters/covalent_radii.py +44 -0
  166. multioptpy/Parameters/d2.py +61 -0
  167. multioptpy/Parameters/d3.py +63 -0
  168. multioptpy/Parameters/d4.py +103 -0
  169. multioptpy/Parameters/dreiding.py +34 -0
  170. multioptpy/Parameters/gfn0xtb_param.py +137 -0
  171. multioptpy/Parameters/gfnff_param.py +315 -0
  172. multioptpy/Parameters/gnb.py +104 -0
  173. multioptpy/Parameters/parameter.py +22 -0
  174. multioptpy/Parameters/uff.py +72 -0
  175. multioptpy/Parameters/unit_values.py +20 -0
  176. multioptpy/Potential/AFIR_potential.py +55 -0
  177. multioptpy/Potential/LJ_repulsive_potential.py +345 -0
  178. multioptpy/Potential/__init__.py +0 -0
  179. multioptpy/Potential/anharmonic_keep_potential.py +28 -0
  180. multioptpy/Potential/asym_elllipsoidal_potential.py +718 -0
  181. multioptpy/Potential/electrostatic_potential.py +69 -0
  182. multioptpy/Potential/flux_potential.py +30 -0
  183. multioptpy/Potential/gaussian_potential.py +101 -0
  184. multioptpy/Potential/idpp.py +516 -0
  185. multioptpy/Potential/keep_angle_potential.py +146 -0
  186. multioptpy/Potential/keep_dihedral_angle_potential.py +105 -0
  187. multioptpy/Potential/keep_outofplain_angle_potential.py +70 -0
  188. multioptpy/Potential/keep_potential.py +99 -0
  189. multioptpy/Potential/mechano_force_potential.py +74 -0
  190. multioptpy/Potential/nanoreactor_potential.py +52 -0
  191. multioptpy/Potential/potential.py +896 -0
  192. multioptpy/Potential/spacer_model_potential.py +221 -0
  193. multioptpy/Potential/switching_potential.py +258 -0
  194. multioptpy/Potential/universal_potential.py +34 -0
  195. multioptpy/Potential/value_range_potential.py +36 -0
  196. multioptpy/Potential/void_point_potential.py +25 -0
  197. multioptpy/SQM/__init__.py +0 -0
  198. multioptpy/SQM/sqm1/__init__.py +0 -0
  199. multioptpy/SQM/sqm1/sqm1_core.py +1792 -0
  200. multioptpy/SQM/sqm2/__init__.py +0 -0
  201. multioptpy/SQM/sqm2/calc_tools.py +95 -0
  202. multioptpy/SQM/sqm2/sqm2_basis.py +850 -0
  203. multioptpy/SQM/sqm2/sqm2_bond.py +119 -0
  204. multioptpy/SQM/sqm2/sqm2_core.py +303 -0
  205. multioptpy/SQM/sqm2/sqm2_data.py +1229 -0
  206. multioptpy/SQM/sqm2/sqm2_disp.py +65 -0
  207. multioptpy/SQM/sqm2/sqm2_eeq.py +243 -0
  208. multioptpy/SQM/sqm2/sqm2_overlapint.py +704 -0
  209. multioptpy/SQM/sqm2/sqm2_qm.py +578 -0
  210. multioptpy/SQM/sqm2/sqm2_rep.py +66 -0
  211. multioptpy/SQM/sqm2/sqm2_srb.py +70 -0
  212. multioptpy/Thermo/__init__.py +0 -0
  213. multioptpy/Thermo/normal_mode_analyzer.py +865 -0
  214. multioptpy/Utils/__init__.py +0 -0
  215. multioptpy/Utils/bond_connectivity.py +264 -0
  216. multioptpy/Utils/calc_tools.py +884 -0
  217. multioptpy/Utils/oniom.py +96 -0
  218. multioptpy/Utils/pbc.py +48 -0
  219. multioptpy/Utils/riemann_curvature.py +208 -0
  220. multioptpy/Utils/symmetry_analyzer.py +482 -0
  221. multioptpy/Visualization/__init__.py +0 -0
  222. multioptpy/Visualization/visualization.py +156 -0
  223. multioptpy/WFAnalyzer/MO_analysis.py +104 -0
  224. multioptpy/WFAnalyzer/__init__.py +0 -0
  225. multioptpy/Wrapper/__init__.py +0 -0
  226. multioptpy/Wrapper/autots.py +1239 -0
  227. multioptpy/Wrapper/ieip_wrapper.py +93 -0
  228. multioptpy/Wrapper/md_wrapper.py +92 -0
  229. multioptpy/Wrapper/neb_wrapper.py +94 -0
  230. multioptpy/Wrapper/optimize_wrapper.py +76 -0
  231. multioptpy/__init__.py +5 -0
  232. multioptpy/entrypoints.py +916 -0
  233. multioptpy/fileio.py +660 -0
  234. multioptpy/ieip.py +340 -0
  235. multioptpy/interface.py +1086 -0
  236. multioptpy/irc.py +529 -0
  237. multioptpy/moleculardynamics.py +432 -0
  238. multioptpy/neb.py +1267 -0
  239. multioptpy/optimization.py +1553 -0
  240. multioptpy/optimizer.py +709 -0
  241. multioptpy-1.20.2.dist-info/METADATA +438 -0
  242. multioptpy-1.20.2.dist-info/RECORD +246 -0
  243. multioptpy-1.20.2.dist-info/WHEEL +5 -0
  244. multioptpy-1.20.2.dist-info/entry_points.txt +9 -0
  245. multioptpy-1.20.2.dist-info/licenses/LICENSE +674 -0
  246. multioptpy-1.20.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,578 @@
1
+ import numpy as np
2
+ import torch
3
+ import math
4
+
5
+ from multioptpy.SQM.sqm2.sqm2_overlapint import OverlapCalculator
6
+ from multioptpy.SQM.sqm2.sqm2_basis import BasisSet
7
+ from multioptpy.SQM.sqm2.calc_tools import factorial2, dfactorial
8
+
9
+
10
+ class EHTCalculator:
11
+ def __init__(self, element_list, charge, spin, param, wf_instance):
12
+ # element_list: list[int]
13
+ # param: SQM2Parameters object containing EHT parameters
14
+ # wf_instance: BasisSet object (or similar containing 'basis' dict and atom type info)
15
+
16
+ # --- Device and Type ---
17
+ self.device = "cpu" # Or get from wf_instance/param if available
18
+ self.dtype = torch.float64
19
+
20
+ # --- Core Properties ---
21
+ self.element_list = element_list
22
+ self.params = param
23
+ self.wf = wf_instance
24
+ self.basis = wf_instance.basis # Basis dictionary
25
+ self.charge = charge
26
+ self.charge_t = torch.tensor(self.charge, dtype=self.dtype, device=self.device) # Pre-tensorize
27
+ self.spin = spin
28
+ self.n_atoms = len(self.element_list)
29
+ self.n_ao = self.basis["number_of_ao"]
30
+
31
+ # --- Constants (non-tensor) ---
32
+ self.PI = math.pi
33
+ self.SQRTPI = math.sqrt(self.PI)
34
+ self.LMAX_INTEGRAL = 8 # Max L for 1D integrals (l+l'+2)
35
+
36
+ # --- Initialize Overlap Calculator ---
37
+ self.overlap_calc = OverlapCalculator(self.element_list, self.params, self.wf)
38
+
39
+ # --- Pre-load Basis Info as Tensors ---
40
+ self.amqn_list = torch.tensor(self.basis["shell_amqn_list"], dtype=torch.int64, device=self.device)
41
+ self.atom_shells_map = torch.tensor(self.basis["atom_shells_map"], dtype=torch.int64, device=self.device)
42
+ self.shell_ao_map = torch.tensor(self.basis["shell_ao_map"], dtype=torch.int64, device=self.device)
43
+ self.ang_shells_list = self.wf.ang_shells_list # Keep as list of lists for flattening
44
+
45
+
46
+ # Build python lists first
47
+ paulingEN_list = []
48
+ kQatom_list = []
49
+ kQShell_list = []
50
+ shellpoly_list = []
51
+ atomicRad_list = []
52
+ nshells_list = []
53
+ self_energy_list = []
54
+ kCN_list = []
55
+ slaterexponent_list = []
56
+ referenceOcc = []
57
+
58
+ for i in range(len(self.element_list)):
59
+ atn = self.element_list[i]
60
+ paulingEN_list.append(param.paulingEN[atn])
61
+ kQatom_list.append(param.kQAtom[atn])
62
+ kQShell_list.append(param.kQShell[atn])
63
+ shellpoly_list.append(param.shellPoly[atn])
64
+ atomicRad_list.append(param.atomicRad[atn])
65
+ nshells_list.append(param.nShell[atn])
66
+ self_energy_list.append(param.selfEnergy[atn])
67
+ kCN_list.append(param.kCN[atn])
68
+ slaterexponent_list.append(param.slaterExponent[atn])
69
+ referenceOcc.append(param.referenceOcc[atn])
70
+
71
+ # --- Pre-compute total valence electrons ---
72
+ self.total_valence_e = sum(sum(occ) for occ in referenceOcc)
73
+ self.total_valence_e = torch.tensor(self.total_valence_e, dtype=self.dtype, device=self.device)
74
+
75
+ # --- Pad per-atom lists to (n_atoms, max_nshell) arrays ---
76
+ max_nshell = max(ns for ns in nshells_list)
77
+
78
+ def pad_list_of_lists(lst, max_len, fill=0.0):
79
+ # Helper to pad lists of lists/arrays to a uniform 2D NumPy array
80
+ padded = np.full((self.n_atoms, max_len), fill, dtype=np.float64)
81
+ for i, sublst in enumerate(lst):
82
+ n_to_copy = min(len(sublst), max_len)
83
+ padded[i, :n_to_copy] = sublst[:n_to_copy]
84
+
85
+ return padded
86
+
87
+ kQShell_padded = pad_list_of_lists(kQShell_list, max_nshell)
88
+ shellpoly_padded = pad_list_of_lists(shellpoly_list, max_nshell)
89
+ self_energy_padded = pad_list_of_lists(self_energy_list, max_nshell)
90
+ kCN_padded = pad_list_of_lists(kCN_list, max_nshell)
91
+ slaterexponent_padded = pad_list_of_lists(slaterexponent_list, max_nshell)
92
+
93
+ # --- Convert all to tensors ---
94
+ # 1D tensors (per-atom)
95
+ self.paulingEN_list = torch.tensor(paulingEN_list, dtype=self.dtype, device=self.device)
96
+ self.kQatom_list = torch.tensor(kQatom_list, dtype=self.dtype, device=self.device)
97
+ self.atomicRad_list = torch.tensor(atomicRad_list, dtype=self.dtype, device=self.device)
98
+ self.nshells_list = torch.tensor(nshells_list, dtype=torch.int64, device=self.device)
99
+
100
+ # 2D tensors (per-atom, per-shell)
101
+ self.kQShell_tensor = torch.tensor(kQShell_padded, dtype=self.dtype, device=self.device)
102
+ self.shellpoly_tensor = torch.tensor(shellpoly_padded, dtype=self.dtype, device=self.device)
103
+ self.self_energy_tensor = torch.tensor(self_energy_padded, dtype=self.dtype, device=self.device)
104
+ self.kCN_tensor = torch.tensor(kCN_padded, dtype=self.dtype, device=self.device)
105
+ self.slaterexponent_tensor = torch.tensor(slaterexponent_padded, dtype=self.dtype, device=self.device)
106
+
107
+ # --- EHT K-Factors (scalar) ---
108
+ self.k_ss_eht = param.k_ss_eht
109
+ self.k_pp_eht = param.k_pp_eht
110
+ self.k_dd_eht = param.k_dd_eht
111
+ self.k_sp_eht = param.k_sp_eht
112
+ self.k_sd_eht = param.k_sd_eht
113
+ self.k_pd_eht = param.k_pd_eht
114
+ self.k_hh_2s22 = param.k_hh_2s2s
115
+ self.k_ss_en_eht = param.k_ss_en_eht
116
+ self.k_pp_en_eht = param.k_pp_en_eht
117
+ self.k_dd_en_eht = param.k_dd_en_eht
118
+ self.k_sp_en_eht = param.k_sp_en_eht
119
+ self.k_sd_en_eht = param.k_sd_en_eht
120
+ self.k_pd_en_eht = param.k_pd_en_eht # <-- Fixed typo
121
+ self.b_en_eht = param.b_en_eht
122
+ self.k_MM_pair = param.k_MM_pair
123
+ self.k_g11_pair = param.k_g11_pair
124
+
125
+ # Convert bool lists to tensors
126
+ self.is_tm_tensor = torch.tensor(wf_instance.is_tm_list, dtype=torch.bool, device=self.device)
127
+ self.is_g11_tensor = torch.tensor(wf_instance.is_g11_element_list, dtype=torch.bool, device=self.device)
128
+
129
+ # --- Pre-compute static shell properties (Vectorized) ---
130
+ self.n_shell = self.shell_ao_map.shape[0]
131
+
132
+ # Map: shell_index -> atom_index
133
+ # e.g., [0, 0, 0, 1, 1, 2, 2, 2, 2, ...]
134
+ self.shell_atom_map = torch.repeat_interleave(
135
+ torch.arange(self.n_atoms, device=self.device),
136
+ self.nshells_list
137
+ )
138
+
139
+ # Map: shell_index -> local_shell_index (within its atom)
140
+ # e.g., [0, 1, 2, 0, 1, 0, 1, 2, 3, ...]
141
+ self.shell_local_idx_map = torch.cat(
142
+ [torch.arange(n.item(), dtype=torch.int64, device=self.device) for n in self.nshells_list]
143
+ )
144
+
145
+ # Map: shell_index -> shell_type (0=s, 1=p, 2=d)
146
+ # e.g., [0, 1, 0, 1, 2, 0, 1, 2, 3, ...]
147
+ ang_flat = [item for sublist in self.ang_shells_list for item in sublist]
148
+ self.shell_type_map = torch.tensor(ang_flat, dtype=torch.int64, device=self.device)
149
+
150
+ # --- Gather shell properties using maps (NO LOOPS) ---
151
+
152
+ # (n_shell,) tensors gathered from (n_atoms, max_nshell) tensors
153
+ self.shell_poly_const_map = self.shellpoly_tensor[self.shell_atom_map, self.shell_local_idx_map]
154
+ self.shell_slater_exp_map = self.slaterexponent_tensor[self.shell_atom_map, self.shell_local_idx_map]
155
+
156
+ # (n_shell,) tensors gathered from (n_atoms,) tensors
157
+ self.shell_rad_map = self.atomicRad_list[self.shell_atom_map]
158
+ self.shell_en_map = self.paulingEN_list[self.shell_atom_map]
159
+ self.shell_is_tm_map = self.is_tm_tensor[self.shell_atom_map]
160
+ self.shell_is_g11_map = self.is_g11_tensor[self.shell_atom_map]
161
+
162
+ # --- Pre-compute pair indices ---
163
+ i_indices_upper, j_indices_upper = torch.triu_indices(self.n_shell, self.n_shell, 1, device=self.device)
164
+
165
+ iat_pairs = self.shell_atom_map[i_indices_upper]
166
+ jat_pairs = self.shell_atom_map[j_indices_upper]
167
+
168
+ off_atom_mask = (iat_pairs != jat_pairs)
169
+ self.i_off_atom_pairs = i_indices_upper[off_atom_mask]
170
+ self.j_off_atom_pairs = j_indices_upper[off_atom_mask]
171
+
172
+ on_atom_mask = ~off_atom_mask
173
+ self.i_on_atom_pairs = i_indices_upper[on_atom_mask]
174
+ self.j_on_atom_pairs = j_indices_upper[on_atom_mask]
175
+
176
+ # --- Pre-load AO slice info ---
177
+ self.shell_ao_starts = self.shell_ao_map[:, 0]
178
+ self.shell_ao_nao = self.shell_ao_map[:, 1]
179
+
180
+ # --- Create AO -> Shell map (Vectorized) ---
181
+ # Map: ao_index -> shell_index
182
+ # e.g., [0, 0, 0, 1, 1, 1, 1, 1, 2, 2, ...]
183
+ self.ao_shell_map = torch.repeat_interleave(
184
+ torch.arange(self.n_shell, device=self.device),
185
+ self.shell_ao_nao
186
+ )
187
+ self.holistic_k_factor = 1.4 #1.0# # Overall scaling factor
188
+ return
189
+
190
+ def _get_eht_k_factor(self, ishtyp, jshtyp, iat, jat, delta_en):
191
+ """ Calculates the EHT K-factor based on shell and atom types. (Scalar version) """
192
+
193
+ # Determine base K based on shell types
194
+ if ishtyp == 0 and jshtyp == 0: # s-s
195
+ k_base = self.k_ss_eht
196
+ k_en = self.k_ss_en_eht
197
+ elif ishtyp == 1 and jshtyp == 1: # p-p
198
+ k_base = self.k_pp_eht
199
+ k_en = self.k_pp_en_eht
200
+ elif ishtyp == 2 and jshtyp == 2: # d-d
201
+ k_base = self.k_dd_eht
202
+ k_en = self.k_dd_en_eht
203
+ elif (ishtyp == 0 and jshtyp == 1) or (ishtyp == 1 and jshtyp == 0): # s-p
204
+ k_base = self.k_sp_eht
205
+ k_en = self.k_sp_en_eht
206
+ elif (ishtyp == 0 and jshtyp == 2) or (ishtyp == 2 and jshtyp == 0): # s-d
207
+ k_base = self.k_sd_eht
208
+ k_en = self.k_sd_en_eht
209
+ elif (ishtyp == 1 and jshtyp == 2) or (ishtyp == 2 and jshtyp == 1): # p-d
210
+ k_base = self.k_pd_eht
211
+ k_en = self.k_pd_en_eht
212
+ else: # f-shells or higher
213
+ k_base = 1.0
214
+ k_en = 0.0
215
+
216
+ en_factor = 1.0 + k_en * delta_en ** 2.0 + k_en * self.b_en_eht * delta_en ** 4.0
217
+
218
+ # Atom type scaling (TM/G11)
219
+ atom_factor = 1.0
220
+ # Use pre-loaded boolean tensors
221
+ is_tm_i = self.is_tm_tensor[iat]
222
+ is_tm_j = self.is_tm_tensor[jat]
223
+ is_g11_i = self.is_g11_tensor[iat]
224
+ is_g11_j = self.is_g11_tensor[jat]
225
+
226
+ if is_tm_i and is_tm_j:
227
+ atom_factor = self.k_MM_pair
228
+ elif is_g11_i and is_g11_j:
229
+ atom_factor = self.k_g11_pair
230
+
231
+
232
+
233
+ return k_base * en_factor * atom_factor
234
+
235
+ def _get_eht_k_factor_vec(self, ishtyp, jshtyp, is_tm_i, is_tm_j, is_g11_i, is_g11_j, delta_en):
236
+ """ (Vectorized) Calculates EHT K-factors for pairs. """
237
+ # ishtyp, jshtyp, ... are all 1D tensors of shape (n_pairs,)
238
+
239
+ k_base = torch.full_like(delta_en, 1.0)
240
+ k_en = torch.zeros_like(delta_en)
241
+
242
+ # s-s
243
+ mask = (ishtyp == 0) & (jshtyp == 0)
244
+ k_base[mask] = self.k_ss_eht
245
+ k_en[mask] = self.k_ss_en_eht
246
+ # p-p
247
+ mask = (ishtyp == 1) & (jshtyp == 1)
248
+ k_base[mask] = self.k_pp_eht
249
+ k_en[mask] = self.k_pp_en_eht
250
+ # d-d
251
+ mask = (ishtyp == 2) & (jshtyp == 2)
252
+ k_base[mask] = self.k_dd_eht
253
+ k_en[mask] = self.k_dd_en_eht
254
+ # s-p
255
+ mask = ((ishtyp == 0) & (jshtyp == 1)) | ((ishtyp == 1) & (jshtyp == 0))
256
+ k_base[mask] = self.k_sp_eht
257
+ k_en[mask] = self.k_sp_en_eht
258
+ # s-d
259
+ mask = ((ishtyp == 0) & (jshtyp == 2)) | ((ishtyp == 2) & (jshtyp == 0))
260
+ k_base[mask] = self.k_sd_eht
261
+ k_en[mask] = self.k_sd_en_eht
262
+ # p-d
263
+ mask = ((ishtyp == 1) & (jshtyp == 2)) | ((ishtyp == 2) & (jshtyp == 1))
264
+ k_base[mask] = self.k_pd_eht
265
+ k_en[mask] = self.k_pd_en_eht
266
+
267
+ en_factor = 1.0 + k_en * delta_en ** 2.0 + k_en * self.b_en_eht * delta_en ** 4.0
268
+
269
+ atom_factor = torch.ones_like(delta_en)
270
+ atom_factor[is_tm_i & is_tm_j] = self.k_MM_pair
271
+ atom_factor[is_g11_i & is_g11_j] = self.k_g11_pair
272
+
273
+ return k_base * en_factor * atom_factor * self.holistic_k_factor
274
+
275
+ def _get_self_energy(self, q, cn):
276
+ # q: torch.Tensor (N_atoms,)
277
+ # cn: torch.Tensor (N_atoms,)
278
+ """
279
+ Calculates the self-energy matrix (N_atoms, max_n_shell).
280
+ (Vectorized implementation using padded tensors)
281
+ """
282
+ # Reshape (N_atoms,) to (N_atoms, 1) for broadcasting
283
+ q_col = q.unsqueeze(-1)
284
+ cn_col = cn.unsqueeze(-1)
285
+ kqatom_col = self.kQatom_list.unsqueeze(-1)
286
+
287
+ # All tensors are (N_atoms, max_nshell) or broadcast to it
288
+ cn_corr = -self.kCN_tensor * cn_col
289
+ q_corr = -self.kQShell_tensor * q_col
290
+ q_corr_2 = -kqatom_col * (q_col ** 2)
291
+
292
+ # (N_atoms, max_nshell) + (N_atoms, max_nshell) + ...
293
+ self_energy_matrix = self.self_energy_tensor + cn_corr + q_corr + q_corr_2
294
+
295
+ return self_energy_matrix
296
+
297
+ def _get_shellpoly_corr(self, iat, jat, ish_local, jsh_local, vec_i, vec_j, rad_ij):
298
+ """ (Original scalar function, kept for reference/debugging if needed) """
299
+ a = 0.5
300
+ r_ij_vec = vec_i - vec_j
301
+ # Add epsilon for numerical stability in norm
302
+ r_ij_norm = torch.sqrt(torch.dot(r_ij_vec, r_ij_vec) + 1e-20)
303
+
304
+ atomic_rad_ij = rad_ij
305
+ ratio = r_ij_norm / atomic_rad_ij
306
+
307
+ # Accessing from padded tensor requires atom and local shell index
308
+ shellpoly_const_i = self.shellpoly_tensor[iat, ish_local]
309
+ shellpoly_const_j = self.shellpoly_tensor[jat, jsh_local]
310
+
311
+ shellpoly_corr_i = 1.0 + (0.01 * shellpoly_const_i) * ratio ** a
312
+ shellpoly_corr_j = 1.0 + (0.01 * shellpoly_const_j) * ratio ** a
313
+
314
+ return shellpoly_corr_i * shellpoly_corr_j
315
+
316
+ def _get_shellpoly_corr_vec(self, vec_i_pairs, vec_j_pairs, rad_ij_pairs, poly_i_pairs, poly_j_pairs):
317
+ """ (Vectorized) Calculates shellpoly correction for pairs. """
318
+ # vec_i_pairs, vec_j_pairs are (n_pairs, 3)
319
+ # rad_ij_pairs, poly_i_pairs, poly_j_pairs are (n_pairs,)
320
+ a = 0.5
321
+ r_ij_vec = vec_i_pairs - vec_j_pairs # (n_pairs, 3)
322
+ # Add epsilon for numerical stability
323
+ r_ij_norm = torch.linalg.norm(r_ij_vec, dim=1) + 1e-20 # (n_pairs,)
324
+
325
+ ratio = r_ij_norm / rad_ij_pairs
326
+
327
+ shellpoly_corr_i = 1.0 + (0.01 * poly_i_pairs) * ratio ** a
328
+ shellpoly_corr_j = 1.0 + (0.01 * poly_j_pairs) * ratio ** a
329
+
330
+ return shellpoly_corr_i * shellpoly_corr_j
331
+
332
+ def get_hamiltonian(self, xyz, q, cn, sint):
333
+
334
+ # 1. Get Self Energy
335
+ # self_energy_matrix (N_atoms, max_n_shell)
336
+ self_energy_matrix = self._get_self_energy(q, cn)
337
+
338
+ # Create a flat (n_shell,) tensor of self-energies
339
+ # hii_all[k] = self_energy_matrix[atom_of_shell_k, local_idx_of_shell_k]
340
+ hii_all = self_energy_matrix[self.shell_atom_map, self.shell_local_idx_map]
341
+
342
+ # --- 2. DIAGONAL ELEMENTS (Vectorized) ---
343
+ # Map hii from shells (n_shell,) to aos (n_ao,)
344
+ hii_ao = hii_all[self.ao_shell_map] # (n_ao,)
345
+ # Assign to diagonal of H0
346
+ H0 = torch.diag_embed(hii_ao)
347
+
348
+ # --- 3. OFF-DIAGONAL BLOCKS (On-Atom and Off-Atom) ---
349
+
350
+ # Initialize (n_shell, n_shell) matrix for H_av values
351
+ Hav_shell = torch.zeros((self.n_shell, self.n_shell), dtype=self.dtype, device=self.device)
352
+
353
+ # --- 3a. On-Atom Pairs (Vectorized) ---
354
+ i_pairs_on = self.i_on_atom_pairs
355
+ j_pairs_on = self.j_on_atom_pairs
356
+
357
+ if i_pairs_on.numel() > 0:
358
+ # Gather parameters
359
+ hii_pairs_on = hii_all[i_pairs_on]
360
+ hjj_pairs_on = hii_all[j_pairs_on]
361
+ ishtyp_pairs_on = self.shell_type_map[i_pairs_on]
362
+ jshtyp_pairs_on = self.shell_type_map[j_pairs_on]
363
+ slater_i_pairs_on = self.shell_slater_exp_map[i_pairs_on]
364
+ slater_j_pairs_on = self.shell_slater_exp_map[j_pairs_on]
365
+
366
+ is_tm_i_pairs_on = self.shell_is_tm_map[i_pairs_on]
367
+ is_g11_i_pairs_on = self.shell_is_g11_map[i_pairs_on]
368
+ delta_en_pairs_on = torch.zeros_like(hii_pairs_on) # On-atom, delta_en = 0
369
+
370
+ # Calculate k_eht
371
+ k_eht_vec_on = self._get_eht_k_factor_vec(
372
+ ishtyp_pairs_on, jshtyp_pairs_on,
373
+ is_tm_i_pairs_on, is_tm_i_pairs_on, # Same atom
374
+ is_g11_i_pairs_on, is_g11_i_pairs_on, # Same atom
375
+ delta_en_pairs_on
376
+ )
377
+
378
+ # Calculate slater_exp_corr
379
+ slater_exp_corr_vec_on = (2.0 * torch.sqrt(slater_i_pairs_on * slater_j_pairs_on)) / (slater_i_pairs_on + slater_j_pairs_on)
380
+
381
+ # shellpoly_corr is 1.0 for on-atom pairs.
382
+
383
+ # Calculate hav
384
+ hav_vec_on = 0.5 * k_eht_vec_on * (hii_pairs_on + hjj_pairs_on) * slater_exp_corr_vec_on
385
+
386
+ # Scatter into shell matrix (upper triangle)
387
+ Hav_shell[i_pairs_on, j_pairs_on] = hav_vec_on
388
+
389
+ # --- 3b. Off-Atom Pairs (Vectorized) ---
390
+ i_pairs_off = self.i_off_atom_pairs
391
+ j_pairs_off = self.j_off_atom_pairs
392
+
393
+ if i_pairs_off.numel() > 0:
394
+ # Gather all parameters for all pairs
395
+ iat_pairs_off = self.shell_atom_map[i_pairs_off]
396
+ jat_pairs_off = self.shell_atom_map[j_pairs_off]
397
+
398
+ hii_pairs_off = hii_all[i_pairs_off]
399
+ hjj_pairs_off = hii_all[j_pairs_off]
400
+
401
+ ishtyp_pairs_off = self.shell_type_map[i_pairs_off]
402
+ jshtyp_pairs_off = self.shell_type_map[j_pairs_off]
403
+
404
+ slater_i_pairs_off = self.shell_slater_exp_map[i_pairs_off]
405
+ slater_j_pairs_off = self.shell_slater_exp_map[j_pairs_off]
406
+
407
+ poly_i_pairs_off = self.shell_poly_const_map[i_pairs_off]
408
+ poly_j_pairs_off = self.shell_poly_const_map[j_pairs_off]
409
+
410
+ en_i_pairs_off = self.shell_en_map[i_pairs_off]
411
+ en_j_pairs_off = self.shell_en_map[j_pairs_off]
412
+
413
+ rad_i_pairs_off = self.shell_rad_map[i_pairs_off]
414
+ rad_j_pairs_off = self.shell_rad_map[j_pairs_off]
415
+ rad_ij_pairs_off = rad_i_pairs_off + rad_j_pairs_off
416
+
417
+ is_tm_i_pairs_off = self.shell_is_tm_map[i_pairs_off]
418
+ is_tm_j_pairs_off = self.shell_is_tm_map[j_pairs_off]
419
+ is_g11_i_pairs_off = self.shell_is_g11_map[i_pairs_off]
420
+ is_g11_j_pairs_off = self.shell_is_g11_map[j_pairs_off]
421
+
422
+ vec_i_pairs_off = xyz[iat_pairs_off] # (n_pairs, 3)
423
+ vec_j_pairs_off = xyz[jat_pairs_off] # (n_pairs, 3)
424
+
425
+ # Perform all calculations vectorized
426
+ delta_en_pairs_off = torch.abs(en_i_pairs_off - en_j_pairs_off)
427
+
428
+ k_eht_vec_off = self._get_eht_k_factor_vec(
429
+ ishtyp_pairs_off, jshtyp_pairs_off,
430
+ is_tm_i_pairs_off, is_tm_j_pairs_off,
431
+ is_g11_i_pairs_off, is_g11_j_pairs_off,
432
+ delta_en_pairs_off
433
+ )
434
+
435
+ slater_exp_corr_vec_off = (2.0 * torch.sqrt(slater_i_pairs_off * slater_j_pairs_off)) / (slater_i_pairs_off + slater_j_pairs_off)
436
+
437
+ shellpoly_corr_vec_off = self._get_shellpoly_corr_vec(
438
+ vec_i_pairs_off, vec_j_pairs_off, rad_ij_pairs_off, poly_i_pairs_off, poly_j_pairs_off
439
+ )
440
+
441
+ hav_vec_off = 0.5 * k_eht_vec_off * (hii_pairs_off + hjj_pairs_off) * slater_exp_corr_vec_off * shellpoly_corr_vec_off
442
+
443
+ # Scatter into shell matrix (upper triangle)
444
+ Hav_shell[i_pairs_off, j_pairs_off] = hav_vec_off
445
+
446
+ # --- 4. Assemble Final H0 (Vectorized) ---
447
+
448
+ # Symmetrize the Hav_shell matrix
449
+ Hav_shell = Hav_shell + Hav_shell.T
450
+
451
+ # Expand Hav from (n_shell, n_shell) to (n_ao, n_ao) using ao_shell_map
452
+ # Hav_ao[i, j] = Hav_shell[shell_of_ao_i, shell_of_ao_j]
453
+ Hav_ao = Hav_shell[self.ao_shell_map, :][:, self.ao_shell_map]
454
+
455
+ # Add the off-diagonal part (H_ij = Hav_ij * S_ij) to the diagonal H0
456
+ # Since the diagonal of Hav_ao is 0.0, this does not affect
457
+ # the diagonal elements of H0.
458
+ H0 = H0 + Hav_ao * sint
459
+
460
+ return H0
461
+
462
+ def calculation(self, xyz, q, cn):
463
+ # xyz: torch.Tensor (N, 3)
464
+ # q: torch.Tensor (N, 1) or (N,)
465
+ # cn: torch.Tensor (N, 1) or (N,)
466
+
467
+ # Ensure q and cn are 1D
468
+ q_1d = q.squeeze()
469
+ cn_1d = cn.squeeze()
470
+
471
+ # 1. Calculate Overlap
472
+ sint = self.overlap_calc.overlap_int_torch(xyz)
473
+
474
+ # 2. Calculate Hamiltonian H0
475
+ h0 = self.get_hamiltonian(xyz, q_1d, cn_1d, sint)
476
+
477
+ # 3. Solve Generalized Eigenvalue Problem H0 C = S C E
478
+ w_s, v_s = torch.linalg.eigh(sint)
479
+
480
+ thresh = 1e-8
481
+ mask = w_s > thresh
482
+ w_s_inv_sqrt = torch.zeros_like(w_s)
483
+ w_s_inv_sqrt[mask] = 1.0 / torch.sqrt(w_s[mask])
484
+
485
+ s_inv_sqrt = torch.matmul(v_s, torch.matmul(torch.diag(w_s_inv_sqrt), v_s.T))
486
+ f_tilde = torch.matmul(s_inv_sqrt, torch.matmul(h0, s_inv_sqrt))
487
+ f_tilde = 0.5 * (f_tilde + f_tilde.T) # Ensure symmetry
488
+ mo_ene, mo_eff_tilde = torch.linalg.eigh(f_tilde)
489
+
490
+ C = torch.matmul(s_inv_sqrt, mo_eff_tilde)
491
+
492
+ n_elec = self.total_valence_e - self.charge_t
493
+
494
+ # Use torch.floor for differentiability, though this part is usually constant
495
+ n_occ = (n_elec / 2.0).floor().long()
496
+
497
+ # Proper EHT energy: 2 * sum occupied mo_ene (assuming closed shell)
498
+ energy = 2.0 * torch.sum(mo_ene[:n_occ])
499
+
500
+ self.mo_energy = mo_ene
501
+ self.mo_coeff = C
502
+
503
+ # Return energy
504
+ return energy
505
+
506
+ def get_mo_energy(self):
507
+ return self.mo_energy
508
+
509
+ def get_mo_coeff(self):
510
+ return self.mo_coeff
511
+
512
+ def get_overlap_integral_matrix(self):
513
+ return self.overlap_calc.get_overlap_integral_matrix()
514
+
515
+
516
+ def energy(self, xyz, q, cn):
517
+ xyz_t = torch.tensor(xyz, dtype=self.dtype, device=self.device, requires_grad=True)
518
+ q_t = torch.tensor(q, dtype=self.dtype, device=self.device, requires_grad=True)
519
+ cn_t = torch.tensor(cn, dtype=self.dtype, device=self.device, requires_grad=True)
520
+
521
+ energy_val = self.calculation(xyz_t, q_t, cn_t)
522
+ return energy_val
523
+
524
+ def gradient(self, xyz, q, cn, d_eeq_charge, d_cn):
525
+ xyz_t = torch.tensor(xyz, dtype=self.dtype, device=self.device, requires_grad=True)
526
+ q_t = torch.tensor(q, dtype=self.dtype, device=self.device, requires_grad=True)
527
+ cn_t = torch.tensor(cn, dtype=self.dtype, device=self.device, requires_grad=True)
528
+ d_eeq_charge_t = torch.tensor(d_eeq_charge, dtype=self.dtype, device=self.device)
529
+ d_cn_t = torch.tensor(d_cn, dtype=self.dtype, device=self.device)
530
+
531
+ gradient_1 = torch.func.jacrev(self.calculation, argnums=0)(xyz_t, q_t, cn_t)
532
+ q_grad = torch.func.jacrev(self.calculation, argnums=1)(xyz_t, q_t, cn_t)
533
+ cn_grad = torch.func.jacrev(self.calculation, argnums=2)(xyz_t, q_t, cn_t)
534
+
535
+ gradient_2 = torch.einsum('i,ijk->jk', q_grad.squeeze(), d_eeq_charge_t)
536
+ gradient_3 = torch.einsum('i,ijk->jk', cn_grad.squeeze(), d_cn_t)
537
+
538
+ gradient = gradient_1 + gradient_2 + gradient_3
539
+ energy = self.energy(xyz, q, cn)
540
+ return energy, gradient
541
+
542
+ def hessian(self, xyz, q, cn, d_eeq_charge, dd_eeq_charge, d_cn, dd_cn):
543
+ xyz_t = torch.tensor(xyz, dtype=self.dtype, device=self.device, requires_grad=True)
544
+ q_t = torch.tensor(q, dtype=self.dtype, device=self.device, requires_grad=True)
545
+ cn_t = torch.tensor(cn, dtype=self.dtype, device=self.device, requires_grad=True)
546
+
547
+ d_eeq_charge_t = torch.tensor(d_eeq_charge, dtype=self.dtype, device=self.device)
548
+ dd_eeq_charge_t = torch.tensor(dd_eeq_charge, dtype=self.dtype, device=self.device)
549
+ d_cn_t = torch.tensor(d_cn, dtype=self.dtype, device=self.device)
550
+ dd_cn_t = torch.tensor(dd_cn, dtype=self.dtype, device=self.device)
551
+
552
+ n_atoms = xyz_t.shape[0]
553
+ n_dim = n_atoms * 3
554
+
555
+ hessian_1_raw = torch.func.hessian(self.calculation, argnums=0)(xyz_t, q_t, cn_t)
556
+ hessian_1 = hessian_1_raw.reshape(n_dim, n_dim)
557
+
558
+ q_hessian = torch.func.hessian(self.calculation, argnums=1)(xyz_t, q_t, cn_t)
559
+ q_hessian = q_hessian.reshape(n_atoms, n_atoms)
560
+
561
+ cn_hessian = torch.func.hessian(self.calculation, argnums=2)(xyz_t, q_t, cn_t)
562
+ cn_hessian = cn_hessian.reshape(n_atoms, n_atoms)
563
+
564
+ q_grad = torch.func.jacrev(self.calculation, argnums=1)(xyz_t, q_t, cn_t).squeeze()
565
+ cn_grad = torch.func.jacrev(self.calculation, argnums=2)(xyz_t, q_t, cn_t).squeeze()
566
+
567
+ dq_dr = d_eeq_charge_t.permute(0, 2, 1).reshape(n_atoms, n_dim) # (N, N*3)
568
+ dcn_dr = d_cn_t.permute(0, 2, 1).reshape(n_atoms, n_dim) # (N, N*3)
569
+
570
+ hessian_2 = torch.matmul(dq_dr.T, torch.matmul(q_hessian, dq_dr))
571
+ hessian_3 = torch.matmul(dcn_dr.T, torch.matmul(cn_hessian, dcn_dr))
572
+
573
+ hessian_4 = torch.einsum('i,ijk->jk', q_grad, dd_eeq_charge_t)
574
+ hessian_5 = torch.einsum('i,ijk->jk', cn_grad, dd_cn_t)
575
+
576
+ hessian = hessian_1 + hessian_2 + hessian_3 + hessian_4 + hessian_5
577
+
578
+ return hessian
@@ -0,0 +1,66 @@
1
+ import torch
2
+
3
+ class RepulsionCalculator:
4
+ def __init__(self, element_list, params):
5
+
6
+ self.rep_alpha_list = []
7
+ self.rep_zeff_list = []
8
+
9
+ for elem in element_list:
10
+ self.rep_alpha_list.append(params.repAlpha[elem])
11
+ self.rep_zeff_list.append(params.repZeff[elem])
12
+
13
+ self.rep_alpha_list = torch.tensor(self.rep_alpha_list, dtype=torch.float64)
14
+ self.rep_zeff_list = torch.tensor(self.rep_zeff_list, dtype=torch.float64)
15
+
16
+ return
17
+
18
+
19
+ def calculation(self, xyz):
20
+ """
21
+ This is the vectorized version of your calculation method.
22
+ It removes the nested Python loops for efficiency.
23
+
24
+ Args:
25
+ xyz (torch.Tensor): A tensor of atomic coordinates,
26
+ shape [n_atoms, 3].
27
+
28
+ Returns:
29
+ torch.Tensor: A scalar tensor containing the total energy.
30
+ """
31
+ device = xyz.device
32
+ zeff_list = self.rep_zeff_list.to(device)
33
+ alpha_list = self.rep_alpha_list.to(device)
34
+ diff = xyz.unsqueeze(1) - xyz.unsqueeze(0)
35
+ dist_sq_matrix = torch.sum(diff**2, dim=-1)
36
+ dist_matrix = torch.sqrt(dist_sq_matrix + 1e-12)
37
+ zeff_matrix = torch.outer(zeff_list, zeff_list)
38
+ alpha_matrix = torch.outer(alpha_list, alpha_list)
39
+ dist_cubed = dist_matrix ** 3.0
40
+ exp_term = torch.exp(-1.0 * torch.sqrt(alpha_matrix * dist_cubed))
41
+ inv_dist = 1.0 / dist_matrix
42
+ energy_matrix = zeff_matrix * inv_dist * exp_term
43
+ total_energy = torch.sum(torch.triu(energy_matrix, diagonal=1))
44
+ return total_energy
45
+
46
+
47
+ def energy(self, xyz):
48
+ xyz = torch.tensor(xyz, dtype=torch.float64, requires_grad=False)
49
+ energy = self.calculation(xyz)
50
+ return energy
51
+
52
+ def gradient(self, xyz):
53
+ xyz = torch.tensor(xyz, dtype=torch.float64, requires_grad=True)
54
+ energy = self.calculation(xyz)
55
+ gradient = torch.func.jacrev(self.calculation)(xyz)
56
+
57
+ return energy, gradient
58
+
59
+ def hessian(self, xyz):
60
+ xyz = torch.tensor(xyz, dtype=torch.float64, requires_grad=True)
61
+ energy = self.calculation(xyz)
62
+ hessian = torch.func.hessian(self.calculation)(xyz)
63
+ hessian = hessian.reshape(xyz.shape[0]*3, xyz.shape[0]*3)
64
+ return energy, hessian
65
+
66
+