MultiOptPy 1.20.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (246) hide show
  1. multioptpy/Calculator/__init__.py +0 -0
  2. multioptpy/Calculator/ase_calculation_tools.py +424 -0
  3. multioptpy/Calculator/ase_tools/__init__.py +0 -0
  4. multioptpy/Calculator/ase_tools/fairchem.py +28 -0
  5. multioptpy/Calculator/ase_tools/gamess.py +19 -0
  6. multioptpy/Calculator/ase_tools/gaussian.py +165 -0
  7. multioptpy/Calculator/ase_tools/mace.py +28 -0
  8. multioptpy/Calculator/ase_tools/mopac.py +19 -0
  9. multioptpy/Calculator/ase_tools/nwchem.py +31 -0
  10. multioptpy/Calculator/ase_tools/orca.py +22 -0
  11. multioptpy/Calculator/ase_tools/pygfn0.py +37 -0
  12. multioptpy/Calculator/dxtb_calculation_tools.py +344 -0
  13. multioptpy/Calculator/emt_calculation_tools.py +458 -0
  14. multioptpy/Calculator/gpaw_calculation_tools.py +183 -0
  15. multioptpy/Calculator/lj_calculation_tools.py +314 -0
  16. multioptpy/Calculator/psi4_calculation_tools.py +334 -0
  17. multioptpy/Calculator/pwscf_calculation_tools.py +189 -0
  18. multioptpy/Calculator/pyscf_calculation_tools.py +327 -0
  19. multioptpy/Calculator/sqm1_calculation_tools.py +611 -0
  20. multioptpy/Calculator/sqm2_calculation_tools.py +376 -0
  21. multioptpy/Calculator/tblite_calculation_tools.py +352 -0
  22. multioptpy/Calculator/tersoff_calculation_tools.py +818 -0
  23. multioptpy/Constraint/__init__.py +0 -0
  24. multioptpy/Constraint/constraint_condition.py +834 -0
  25. multioptpy/Coordinate/__init__.py +0 -0
  26. multioptpy/Coordinate/polar_coordinate.py +199 -0
  27. multioptpy/Coordinate/redundant_coordinate.py +638 -0
  28. multioptpy/IRC/__init__.py +0 -0
  29. multioptpy/IRC/converge_criteria.py +28 -0
  30. multioptpy/IRC/dvv.py +544 -0
  31. multioptpy/IRC/euler.py +439 -0
  32. multioptpy/IRC/hpc.py +564 -0
  33. multioptpy/IRC/lqa.py +540 -0
  34. multioptpy/IRC/modekill.py +662 -0
  35. multioptpy/IRC/rk4.py +579 -0
  36. multioptpy/Interpolation/__init__.py +0 -0
  37. multioptpy/Interpolation/adaptive_interpolation.py +283 -0
  38. multioptpy/Interpolation/binomial_interpolation.py +179 -0
  39. multioptpy/Interpolation/geodesic_interpolation.py +785 -0
  40. multioptpy/Interpolation/interpolation.py +156 -0
  41. multioptpy/Interpolation/linear_interpolation.py +473 -0
  42. multioptpy/Interpolation/savitzky_golay_interpolation.py +252 -0
  43. multioptpy/Interpolation/spline_interpolation.py +353 -0
  44. multioptpy/MD/__init__.py +0 -0
  45. multioptpy/MD/thermostat.py +185 -0
  46. multioptpy/MEP/__init__.py +0 -0
  47. multioptpy/MEP/pathopt_bneb_force.py +443 -0
  48. multioptpy/MEP/pathopt_dmf_force.py +448 -0
  49. multioptpy/MEP/pathopt_dneb_force.py +130 -0
  50. multioptpy/MEP/pathopt_ewbneb_force.py +207 -0
  51. multioptpy/MEP/pathopt_gpneb_force.py +512 -0
  52. multioptpy/MEP/pathopt_lup_force.py +113 -0
  53. multioptpy/MEP/pathopt_neb_force.py +225 -0
  54. multioptpy/MEP/pathopt_nesb_force.py +205 -0
  55. multioptpy/MEP/pathopt_om_force.py +153 -0
  56. multioptpy/MEP/pathopt_qsm_force.py +174 -0
  57. multioptpy/MEP/pathopt_qsmv2_force.py +304 -0
  58. multioptpy/ModelFunction/__init__.py +7 -0
  59. multioptpy/ModelFunction/avoiding_model_function.py +29 -0
  60. multioptpy/ModelFunction/binary_image_ts_search_model_function.py +47 -0
  61. multioptpy/ModelFunction/conical_model_function.py +26 -0
  62. multioptpy/ModelFunction/opt_meci.py +50 -0
  63. multioptpy/ModelFunction/opt_mesx.py +47 -0
  64. multioptpy/ModelFunction/opt_mesx_2.py +49 -0
  65. multioptpy/ModelFunction/seam_model_function.py +27 -0
  66. multioptpy/ModelHessian/__init__.py +0 -0
  67. multioptpy/ModelHessian/approx_hessian.py +147 -0
  68. multioptpy/ModelHessian/calc_params.py +227 -0
  69. multioptpy/ModelHessian/fischer.py +236 -0
  70. multioptpy/ModelHessian/fischerd3.py +360 -0
  71. multioptpy/ModelHessian/fischerd4.py +398 -0
  72. multioptpy/ModelHessian/gfn0xtb.py +633 -0
  73. multioptpy/ModelHessian/gfnff.py +709 -0
  74. multioptpy/ModelHessian/lindh.py +165 -0
  75. multioptpy/ModelHessian/lindh2007d2.py +707 -0
  76. multioptpy/ModelHessian/lindh2007d3.py +822 -0
  77. multioptpy/ModelHessian/lindh2007d4.py +1030 -0
  78. multioptpy/ModelHessian/morse.py +106 -0
  79. multioptpy/ModelHessian/schlegel.py +144 -0
  80. multioptpy/ModelHessian/schlegeld3.py +322 -0
  81. multioptpy/ModelHessian/schlegeld4.py +559 -0
  82. multioptpy/ModelHessian/shortrange.py +346 -0
  83. multioptpy/ModelHessian/swartd2.py +496 -0
  84. multioptpy/ModelHessian/swartd3.py +706 -0
  85. multioptpy/ModelHessian/swartd4.py +918 -0
  86. multioptpy/ModelHessian/tshess.py +40 -0
  87. multioptpy/Optimizer/QHAdam.py +61 -0
  88. multioptpy/Optimizer/__init__.py +0 -0
  89. multioptpy/Optimizer/abc_fire.py +83 -0
  90. multioptpy/Optimizer/adabelief.py +58 -0
  91. multioptpy/Optimizer/adabound.py +68 -0
  92. multioptpy/Optimizer/adadelta.py +65 -0
  93. multioptpy/Optimizer/adaderivative.py +56 -0
  94. multioptpy/Optimizer/adadiff.py +68 -0
  95. multioptpy/Optimizer/adafactor.py +70 -0
  96. multioptpy/Optimizer/adam.py +65 -0
  97. multioptpy/Optimizer/adamax.py +62 -0
  98. multioptpy/Optimizer/adamod.py +83 -0
  99. multioptpy/Optimizer/adamw.py +65 -0
  100. multioptpy/Optimizer/adiis.py +523 -0
  101. multioptpy/Optimizer/afire_neb.py +282 -0
  102. multioptpy/Optimizer/block_hessian_update.py +709 -0
  103. multioptpy/Optimizer/c2diis.py +491 -0
  104. multioptpy/Optimizer/component_wise_scaling.py +405 -0
  105. multioptpy/Optimizer/conjugate_gradient.py +82 -0
  106. multioptpy/Optimizer/conjugate_gradient_neb.py +345 -0
  107. multioptpy/Optimizer/coordinate_locking.py +405 -0
  108. multioptpy/Optimizer/dic_rsirfo.py +1015 -0
  109. multioptpy/Optimizer/ediis.py +417 -0
  110. multioptpy/Optimizer/eve.py +76 -0
  111. multioptpy/Optimizer/fastadabelief.py +61 -0
  112. multioptpy/Optimizer/fire.py +77 -0
  113. multioptpy/Optimizer/fire2.py +249 -0
  114. multioptpy/Optimizer/fire_neb.py +92 -0
  115. multioptpy/Optimizer/gan_step.py +486 -0
  116. multioptpy/Optimizer/gdiis.py +609 -0
  117. multioptpy/Optimizer/gediis.py +203 -0
  118. multioptpy/Optimizer/geodesic_step.py +433 -0
  119. multioptpy/Optimizer/gpmin.py +633 -0
  120. multioptpy/Optimizer/gpr_step.py +364 -0
  121. multioptpy/Optimizer/gradientdescent.py +78 -0
  122. multioptpy/Optimizer/gradientdescent_neb.py +52 -0
  123. multioptpy/Optimizer/hessian_update.py +433 -0
  124. multioptpy/Optimizer/hybrid_rfo.py +998 -0
  125. multioptpy/Optimizer/kdiis.py +625 -0
  126. multioptpy/Optimizer/lars.py +21 -0
  127. multioptpy/Optimizer/lbfgs.py +253 -0
  128. multioptpy/Optimizer/lbfgs_neb.py +355 -0
  129. multioptpy/Optimizer/linesearch.py +236 -0
  130. multioptpy/Optimizer/lookahead.py +40 -0
  131. multioptpy/Optimizer/nadam.py +64 -0
  132. multioptpy/Optimizer/newton.py +200 -0
  133. multioptpy/Optimizer/prodigy.py +70 -0
  134. multioptpy/Optimizer/purtubation.py +16 -0
  135. multioptpy/Optimizer/quickmin_neb.py +245 -0
  136. multioptpy/Optimizer/radam.py +75 -0
  137. multioptpy/Optimizer/rfo_neb.py +302 -0
  138. multioptpy/Optimizer/ric_rfo.py +842 -0
  139. multioptpy/Optimizer/rl_step.py +627 -0
  140. multioptpy/Optimizer/rmspropgrave.py +65 -0
  141. multioptpy/Optimizer/rsirfo.py +1647 -0
  142. multioptpy/Optimizer/rsprfo.py +1056 -0
  143. multioptpy/Optimizer/sadam.py +60 -0
  144. multioptpy/Optimizer/samsgrad.py +63 -0
  145. multioptpy/Optimizer/tr_lbfgs.py +678 -0
  146. multioptpy/Optimizer/trim.py +273 -0
  147. multioptpy/Optimizer/trust_radius.py +207 -0
  148. multioptpy/Optimizer/trust_radius_neb.py +121 -0
  149. multioptpy/Optimizer/yogi.py +60 -0
  150. multioptpy/OtherMethod/__init__.py +0 -0
  151. multioptpy/OtherMethod/addf.py +1150 -0
  152. multioptpy/OtherMethod/dimer.py +895 -0
  153. multioptpy/OtherMethod/elastic_image_pair.py +629 -0
  154. multioptpy/OtherMethod/modelfunction.py +456 -0
  155. multioptpy/OtherMethod/newton_traj.py +454 -0
  156. multioptpy/OtherMethod/twopshs.py +1095 -0
  157. multioptpy/PESAnalyzer/__init__.py +0 -0
  158. multioptpy/PESAnalyzer/calc_irc_curvature.py +125 -0
  159. multioptpy/PESAnalyzer/cmds_analysis.py +152 -0
  160. multioptpy/PESAnalyzer/koopman_analysis.py +268 -0
  161. multioptpy/PESAnalyzer/pca_analysis.py +314 -0
  162. multioptpy/Parameters/__init__.py +0 -0
  163. multioptpy/Parameters/atomic_mass.py +20 -0
  164. multioptpy/Parameters/atomic_number.py +22 -0
  165. multioptpy/Parameters/covalent_radii.py +44 -0
  166. multioptpy/Parameters/d2.py +61 -0
  167. multioptpy/Parameters/d3.py +63 -0
  168. multioptpy/Parameters/d4.py +103 -0
  169. multioptpy/Parameters/dreiding.py +34 -0
  170. multioptpy/Parameters/gfn0xtb_param.py +137 -0
  171. multioptpy/Parameters/gfnff_param.py +315 -0
  172. multioptpy/Parameters/gnb.py +104 -0
  173. multioptpy/Parameters/parameter.py +22 -0
  174. multioptpy/Parameters/uff.py +72 -0
  175. multioptpy/Parameters/unit_values.py +20 -0
  176. multioptpy/Potential/AFIR_potential.py +55 -0
  177. multioptpy/Potential/LJ_repulsive_potential.py +345 -0
  178. multioptpy/Potential/__init__.py +0 -0
  179. multioptpy/Potential/anharmonic_keep_potential.py +28 -0
  180. multioptpy/Potential/asym_elllipsoidal_potential.py +718 -0
  181. multioptpy/Potential/electrostatic_potential.py +69 -0
  182. multioptpy/Potential/flux_potential.py +30 -0
  183. multioptpy/Potential/gaussian_potential.py +101 -0
  184. multioptpy/Potential/idpp.py +516 -0
  185. multioptpy/Potential/keep_angle_potential.py +146 -0
  186. multioptpy/Potential/keep_dihedral_angle_potential.py +105 -0
  187. multioptpy/Potential/keep_outofplain_angle_potential.py +70 -0
  188. multioptpy/Potential/keep_potential.py +99 -0
  189. multioptpy/Potential/mechano_force_potential.py +74 -0
  190. multioptpy/Potential/nanoreactor_potential.py +52 -0
  191. multioptpy/Potential/potential.py +896 -0
  192. multioptpy/Potential/spacer_model_potential.py +221 -0
  193. multioptpy/Potential/switching_potential.py +258 -0
  194. multioptpy/Potential/universal_potential.py +34 -0
  195. multioptpy/Potential/value_range_potential.py +36 -0
  196. multioptpy/Potential/void_point_potential.py +25 -0
  197. multioptpy/SQM/__init__.py +0 -0
  198. multioptpy/SQM/sqm1/__init__.py +0 -0
  199. multioptpy/SQM/sqm1/sqm1_core.py +1792 -0
  200. multioptpy/SQM/sqm2/__init__.py +0 -0
  201. multioptpy/SQM/sqm2/calc_tools.py +95 -0
  202. multioptpy/SQM/sqm2/sqm2_basis.py +850 -0
  203. multioptpy/SQM/sqm2/sqm2_bond.py +119 -0
  204. multioptpy/SQM/sqm2/sqm2_core.py +303 -0
  205. multioptpy/SQM/sqm2/sqm2_data.py +1229 -0
  206. multioptpy/SQM/sqm2/sqm2_disp.py +65 -0
  207. multioptpy/SQM/sqm2/sqm2_eeq.py +243 -0
  208. multioptpy/SQM/sqm2/sqm2_overlapint.py +704 -0
  209. multioptpy/SQM/sqm2/sqm2_qm.py +578 -0
  210. multioptpy/SQM/sqm2/sqm2_rep.py +66 -0
  211. multioptpy/SQM/sqm2/sqm2_srb.py +70 -0
  212. multioptpy/Thermo/__init__.py +0 -0
  213. multioptpy/Thermo/normal_mode_analyzer.py +865 -0
  214. multioptpy/Utils/__init__.py +0 -0
  215. multioptpy/Utils/bond_connectivity.py +264 -0
  216. multioptpy/Utils/calc_tools.py +884 -0
  217. multioptpy/Utils/oniom.py +96 -0
  218. multioptpy/Utils/pbc.py +48 -0
  219. multioptpy/Utils/riemann_curvature.py +208 -0
  220. multioptpy/Utils/symmetry_analyzer.py +482 -0
  221. multioptpy/Visualization/__init__.py +0 -0
  222. multioptpy/Visualization/visualization.py +156 -0
  223. multioptpy/WFAnalyzer/MO_analysis.py +104 -0
  224. multioptpy/WFAnalyzer/__init__.py +0 -0
  225. multioptpy/Wrapper/__init__.py +0 -0
  226. multioptpy/Wrapper/autots.py +1239 -0
  227. multioptpy/Wrapper/ieip_wrapper.py +93 -0
  228. multioptpy/Wrapper/md_wrapper.py +92 -0
  229. multioptpy/Wrapper/neb_wrapper.py +94 -0
  230. multioptpy/Wrapper/optimize_wrapper.py +76 -0
  231. multioptpy/__init__.py +5 -0
  232. multioptpy/entrypoints.py +916 -0
  233. multioptpy/fileio.py +660 -0
  234. multioptpy/ieip.py +340 -0
  235. multioptpy/interface.py +1086 -0
  236. multioptpy/irc.py +529 -0
  237. multioptpy/moleculardynamics.py +432 -0
  238. multioptpy/neb.py +1267 -0
  239. multioptpy/optimization.py +1553 -0
  240. multioptpy/optimizer.py +709 -0
  241. multioptpy-1.20.2.dist-info/METADATA +438 -0
  242. multioptpy-1.20.2.dist-info/RECORD +246 -0
  243. multioptpy-1.20.2.dist-info/WHEEL +5 -0
  244. multioptpy-1.20.2.dist-info/entry_points.txt +9 -0
  245. multioptpy-1.20.2.dist-info/licenses/LICENSE +674 -0
  246. multioptpy-1.20.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,704 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ import numpy as np
5
+
6
+
7
+ class OverlapCalculator:
8
+ def __init__(self, element_list=None, param=None, wf=None):
9
+ """
10
+ Initializes the OverlapCalculator with the given parameters.
11
+ """
12
+ self.element_list = element_list
13
+ self.param = param
14
+ self.wf = wf
15
+ self.basis = wf.basis
16
+
17
+ # --- Constants Initialization ---
18
+ self.MAXL = 6
19
+ self.MAXL2 = self.MAXL * 2
20
+
21
+ # --- Cartesian Exponent Lookup Tables (Numpy arrays) ---
22
+ self.LX = np.array([0,
23
+ 1,0,0,
24
+ 2,0,0,1,1,0,
25
+ 3,0,0,2,2,1,0,1,0,1,
26
+ 4,0,0,3,3,1,0,1,0,2,2,0,2,1,1,
27
+ 5,0,0,3,3,2,2,0,0,4,4,1,0,0,1,1,3,1,2,2,1,
28
+ 6,0,0,3,3,0,5,5,1,0,0,1,4,4,2,0,2,0,3,3,1,2,2,1,4,1,1,2], dtype=int)
29
+ self.LY = np.array([0,
30
+ 0,1,0,
31
+ 0,2,0,1,0,1,
32
+ 0,3,0,1,0,2,2,0,1,1,
33
+ 0,4,0,1,0,3,3,0,1,2,0,2,1,2,1,
34
+ 0,5,0,2,0,3,0,3,2,1,0,4,4,1,0,1,1,3,2,1,2,
35
+ 0,6,0,3,0,3,1,0,0,1,5,5,2,0,0,2,4,4,2,1,3,1,3,2,1,4,1,2], dtype=int)
36
+ self.LZ = np.array([0,
37
+ 0,0,1,
38
+ 0,0,2,0,1,1,
39
+ 0,0,3,0,1,0,1,2,2,1,
40
+ 0,0,4,0,1,0,1,3,3,0,2,2,1,1,2,
41
+ 0,0,5,0,2,0,3,2,3,0,1,0,1,4,4,3,1,1,1,2,2,
42
+ 0,0,6,0,3,3,0,1,5,5,1,0,0,2,4,4,0,2,1,2,2,3,1,3,1,1,4,2], dtype=int)
43
+ self.LXYZ_TYPE_NP = np.stack([self.LX, self.LY, self.LZ], axis=0).T
44
+ self.LXYZ_TYPE_LIST = self.LXYZ_TYPE_NP.tolist()
45
+
46
+ tmp_LXYZ = self.basis["ao_type_id_list"]
47
+
48
+ self.LXYZ_LIST = []
49
+ for tmp in tmp_LXYZ:
50
+ self.LXYZ_LIST.append(self.LXYZ_TYPE_LIST[tmp - 1])
51
+
52
+ self.LXYZ_TENSOR = torch.tensor(self.LXYZ_LIST, dtype=torch.int64)
53
+ self.LXYZ_TYPE_TENSOR = torch.tensor(self.LXYZ_TYPE_LIST, dtype=torch.int64)
54
+
55
+ # --- Shell Start Index Offsets (Tensor) ---
56
+ # L=0, 1, 2, 3 (s, p, d, f)
57
+ self.ITT = torch.tensor([0, 1, 4, 10], dtype=torch.int64)
58
+
59
+ # --- Cartesian d -> Spherical d Transformation Matrix (Tensor) ---
60
+ s5 = math.sqrt(1.0 / 5.0)
61
+ s3 = math.sqrt(3.0)
62
+ self.TRAFO_NP = np.array([
63
+ [s5, s5, s5, 0.0, 0.0, 0.0],
64
+ [0.5 * s3, -0.5 * s3, 0.0, 0.0, 0.0, 0.0],
65
+ [0.5, 0.5, -1.0, 0.0, 0.0, 0.0],
66
+ [0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
67
+ [0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
68
+ [0.0, 0.0, 0.0, 0.0, 0.0, 1.0]], dtype=np.float64)
69
+
70
+ self.TRAFO = None
71
+ self.TRAFO_T = None
72
+
73
+ self.LLAO = np.array([1, 3, 6, 10], dtype=np.int64) # Cartesian counts
74
+ self.LLAO2 = np.array([1, 3, 5, 7], dtype=np.int64) # Spherical counts
75
+
76
+ # --- Double Factorial Table for 1D overlap integral ---
77
+ # DFTR(lh) = (2*lh - 1)!! for lh = 0, 1, 2, ...
78
+ # This matches Fortran: dftr(0:7) = [1, 1, 3, 15, 105, 945, 10395, 135135]
79
+ # where dftr(lh) = (2*lh - 1)!!
80
+ # lh=0: (-1)!! = 1
81
+ # lh=1: 1!! = 1
82
+ # lh=2: 3!! = 3
83
+ # lh=3: 5!! = 15
84
+ # lh=4: 7!! = 105
85
+ # lh=5: 9!! = 945
86
+ # lh=6: 11!! = 10395
87
+ # lh=7: 13!! = 135135
88
+ self.DFTR = torch.tensor([
89
+ 1.0, 1.0, 3.0, 15.0, 105.0, 945.0, 10395.0, 135135.0
90
+ ], dtype=torch.float64)
91
+
92
+ # --- Binomial Coefficients ---
93
+ # self.BINO[l, m] = l! / (m! * (l-m)!)
94
+ m = torch.arange(self.MAXL + 1, dtype=torch.float64)
95
+ l_fact = torch.lgamma(m + 1)
96
+ l_vec = m.view(-1, 1)
97
+ m_vec = m.view(1, -1)
98
+
99
+ log_bino = l_fact[l_vec.long()] - l_fact[m_vec.long()] - l_fact[(l_vec - m_vec).long()]
100
+ bino = torch.exp(log_bino).round()
101
+ self.BINO = torch.where(l_vec < m_vec, 0.0, bino).to(torch.float64)
102
+
103
+
104
+ def calc_1d_overlap_constants(self, l, gamma):
105
+ """
106
+ Calculates the 1D overlap integral factor (double factorial part).
107
+ Corresponds to Fortran: olapp(l, gama)
108
+
109
+ Returns 0 for odd l, otherwise returns (0.5/gamma)^(l/2) * (l-1)!!
110
+ """
111
+ if l % 2 != 0:
112
+ s = 0.0
113
+ else:
114
+ lh = l // 2
115
+ gm = 0.5 / gamma
116
+ # Use DFTR[lh] which gives (2*lh - 1)!! = (l - 1)!!
117
+ if self.DFTR.device != (gamma.device if isinstance(gamma, torch.Tensor) else 'cpu'):
118
+ self.DFTR = self.DFTR.to(gamma.device if isinstance(gamma, torch.Tensor) else 'cpu')
119
+ dftr_val = self.DFTR[lh]
120
+ s = (gm ** lh) * dftr_val
121
+ if isinstance(gamma, torch.Tensor) and not isinstance(s, torch.Tensor):
122
+ s = torch.tensor(s, dtype=gamma.dtype, device=gamma.device)
123
+ return s
124
+
125
+
126
+ def calc_1d_overlap_constants_vec(self, l: int, gamma: torch.Tensor) -> torch.Tensor:
127
+ """
128
+ Vectorized version of calc_1d_overlap_constants.
129
+ Corresponds to Fortran: olapp(l, gama)
130
+ """
131
+ if l % 2 != 0:
132
+ return torch.zeros_like(gamma)
133
+ else:
134
+ lh = l // 2
135
+ gm = 0.5 / gamma
136
+
137
+ # Ensure DFTR is on the correct device
138
+ if self.DFTR.device != gamma.device:
139
+ self.DFTR = self.DFTR.to(gamma.device)
140
+
141
+ dftr_val = self.DFTR[lh] # (2*lh - 1)!! = (l - 1)!!
142
+ return (gm ** lh) * dftr_val
143
+
144
+ def calc_product_exponent_and_overlap_vec(self, alpha_i, r_i, alpha_j, r_j):
145
+ """
146
+ Vectorized: Calculates product exponent gamma and s-orbital overlap Kab.
147
+ Corresponds to Fortran: build_kab(ra, alp, rb, bet, gama, kab)
148
+ """
149
+ # alpha_i: (Ni, 1), alpha_j: (1, Nj)
150
+ # r_i, r_j: (3,)
151
+ gamma = alpha_i + alpha_j # (Ni, Nj)
152
+ inv_gamma = 1.0 / gamma
153
+
154
+ r_ij = r_i - r_j
155
+ r_ij_2 = torch.dot(r_ij, r_ij) # scalar
156
+
157
+ est = r_ij_2 * alpha_i * alpha_j * inv_gamma # (Ni, Nj)
158
+ # kab = exp(-est) * (sqrt(pi) * sqrt(1/gamma))^3
159
+ k_ab = torch.exp(-est) * (torch.pi * inv_gamma) ** 1.5 # (Ni, Nj)
160
+ return gamma, k_ab
161
+
162
+ def calc_gau_product_center_vec(self, alpha_i, r_i, alpha_j, r_j, gamma):
163
+ """
164
+ Vectorized: Calculates Gaussian product center P = (aA + bB) / (a + b).
165
+ Corresponds to Fortran: gpcenter(alp, ra, bet, rb)
166
+ """
167
+ # alpha_i: (Ni, 1, 1), alpha_j: (1, Nj, 1)
168
+ # r_i, r_j: (1, 1, 3)
169
+ # gamma: (Ni, Nj, 1)
170
+ r_gpc = (alpha_i * r_i + alpha_j * r_j) / gamma # (Ni, Nj, 3)
171
+ return r_gpc
172
+
173
+ def _primitive_norm(self, alphas: torch.Tensor, lmn_mat: torch.Tensor) -> torch.Tensor:
174
+ """
175
+ Vectorized normalization factor for Cartesian Gaussian primitives.
176
+ N = (2*alpha/pi)^(3/4) * (4*alpha)^(L/2) / sqrt((2lx-1)!! * (2ly-1)!! * (2lz-1)! !)
177
+ """
178
+ # (Ncomp,)
179
+ lx, ly, lz = lmn_mat[:, 0], lmn_mat[:, 1], lmn_mat[:, 2]
180
+ L = lx + ly + lz # (Ncomp,)
181
+
182
+ # Ensure DFTR is on the correct device
183
+ if self.DFTR.device != alphas.device:
184
+ self.DFTR = self.DFTR.to(alphas.device)
185
+
186
+ # DFTR[lx] = (2*lx - 1)!! (with DFTR[0] = (-1)!! = 1)
187
+ dftr_x = self.DFTR[lx.long()] # (Ncomp,)
188
+ dftr_y = self.DFTR[ly.long()]
189
+ dftr_z = self.DFTR[lz.long()]
190
+
191
+ # Add epsilon to avoid sqrt(0) -> NaN gradients
192
+ den = torch.sqrt(dftr_x * dftr_y * dftr_z + 1e-30) # (Ncomp,)
193
+
194
+ pre = (2.0 * alphas / math.pi) ** 0.75 # (Ni,)
195
+
196
+ # (Ni, 1) ** (1, Ncomp) -> (Ni, Ncomp)
197
+ pow_term = (4.0 * alphas[:, None]) ** (L[None, :] / 2.0)
198
+
199
+ # (Ni, 1) * (Ni, Ncomp) / (1, Ncomp) -> (Ni, Ncomp)
200
+ return (pre[:, None] * pow_term) / den[None, :]
201
+
202
+ def _ensure_trafo_device(self, device):
203
+ """ Helper to move TRAFO matrix to the correct device once. """
204
+ if self.TRAFO is None or self.TRAFO.device != device:
205
+ self.TRAFO = torch.tensor(self.TRAFO_NP, dtype=torch.float64, device=device).contiguous()
206
+ self.TRAFO_T = self.TRAFO.T.contiguous()
207
+
208
+ def transform_d_cartesian_to_spherical(self, int_mat, l_i, l_j):
209
+ """
210
+ Transforms d-orbital block from Cartesian to Spherical basis.
211
+ Corresponds to Fortran: dtrf2(s, li, lj)
212
+ """
213
+ if l_i < 2 and l_j < 2:
214
+ return int_mat
215
+
216
+ self._ensure_trafo_device(int_mat.device)
217
+ trafo = self.TRAFO
218
+ trafo_t = self.TRAFO_T
219
+
220
+ if l_i == 0: # s-d
221
+ s_out = torch.matmul(trafo, int_mat[0:6, 0:1])
222
+ elif l_i == 1: # p-d
223
+ s_out = torch.matmul(trafo, int_mat[0:6, 0:3])
224
+ elif l_j == 0: # d-s
225
+ s_out = torch.matmul(int_mat[0:1, 0:6], trafo_t)
226
+ elif l_j == 1: # d-p
227
+ s_out = torch.matmul(int_mat[0:3, 0:6], trafo_t)
228
+ else: # d-d
229
+ dum = torch.matmul(trafo, int_mat[0:6, 0:6])
230
+ s_out = torch.matmul(dum, trafo_t)
231
+
232
+ return s_out
233
+
234
+ def calculate_center_shift_coefficients_vec(self, cfs_in, a, e, l):
235
+ """
236
+ Vectorized: Calculates coefficients for center shift using binomial expansion.
237
+ Corresponds to Fortran: build_hshift2(cfs, a, e, l) / horizontal_shift
238
+
239
+ (x - a)^l = ((x - e) + (e - a))^l = sum_m [ C(l,m) * (e-a)^(l-m) * (x-e)^m ]
240
+ Returns coefficients c_m = C(l,m) * (e-a)^(l-m)
241
+ """
242
+ if l > self.MAXL:
243
+ raise NotImplementedError(f"calculate_center_shift_coefficients_vec not implemented for l = {l} > MAXL = {self.MAXL}")
244
+
245
+ ae = e - a # (Ni, Nj)
246
+ val_l = cfs_in[l] # scalar (1.0)
247
+
248
+ if l == 0:
249
+ return val_l.expand_as(ae).unsqueeze(-1) # (Ni, Nj, 1)
250
+
251
+ # Ensure BINO is on the correct device
252
+ if self.BINO.device != ae.device:
253
+ self.BINO = self.BINO.to(ae.device)
254
+
255
+ # m = 0, 1, ..., l
256
+ m = torch.arange(l + 1, device=ae.device) # (l+1,)
257
+
258
+ # Binomial coefficients C(l, m)
259
+ bino_coeffs = self.BINO[l, m].view(1, 1, -1) # (1, 1, l+1)
260
+
261
+ # Powers (e-a)^(l-m)
262
+ l_minus_m = (l - m).view(1, 1, -1) # (1, 1, l+1)
263
+ ae_b = ae.unsqueeze(-1) # (Ni, Nj, 1)
264
+ ae_powers = ae_b ** l_minus_m # (Ni, Nj, l+1)
265
+
266
+ # c_m = C(l,m) * (e-a)^(l-m) * val_l (where val_l is 1.0)
267
+ coeffs = bino_coeffs * ae_powers * val_l
268
+
269
+ return coeffs # (Ni, Nj, l+1)
270
+
271
+ def calculate_center_shift_coefficients_batch_l(
272
+ self, l_vec: torch.Tensor, a: torch.Tensor, e: torch.Tensor
273
+ ) -> torch.Tensor:
274
+ """
275
+ Vectorized: Calculates center shift coefficients for a *batch* of l values.
276
+ """
277
+ Ni, Nj = e.shape[0:2]
278
+ Ncomp = l_vec.shape[0]
279
+
280
+ if Ncomp == 0:
281
+ return torch.empty((Ni, Nj, 0, 0), dtype=e.dtype, device=e.device)
282
+
283
+ if l_vec.max().item() > self.MAXL:
284
+ raise NotImplementedError(f"l_vec max {l_vec.max()} > MAXL {self.MAXL}")
285
+
286
+ ae = e - a # (Ni, Nj, 3) or (Ni, Nj)
287
+ if ae.dim() == 2:
288
+ ae = ae.unsqueeze(-1) # (Ni, Nj, 1)
289
+
290
+ L_max = l_vec.max().item()
291
+ if L_max < 0:
292
+ return torch.empty((Ni, Nj, Ncomp, 0), dtype=e.dtype, device=e.device)
293
+
294
+ m = torch.arange(L_max + 1, device=ae.device) # (L_max+1,)
295
+
296
+ # Ensure BINO is on the correct device
297
+ if self.BINO.device != ae.device:
298
+ self.BINO = self.BINO.to(ae.device)
299
+
300
+ # C(l, m) -> (Ncomp, L_max+1)
301
+ bino_coeffs = self.BINO[l_vec.long()][:, :L_max+1]
302
+
303
+ # (e-a)^(l-m)
304
+ l_vec_b = l_vec.view(Ncomp, 1) # (Ncomp, 1)
305
+ m_b = m.view(1, L_max+1) # (1, L_max+1)
306
+
307
+ l_minus_m_raw = l_vec_b - m_b # (Ncomp, L_max+1)
308
+
309
+ # Create mask for m > l, where l_minus_m is negative
310
+ mask = m_b > l_vec_b # (Ncomp, L_max+1)
311
+
312
+ l_minus_m = torch.where(mask, 0.0, l_minus_m_raw)
313
+
314
+ ae_b = ae.unsqueeze(-1) # (Ni, Nj, 3 or 1, 1)
315
+ ae_powers_raw = ae_b ** l_minus_m[None, None, :, :]
316
+ ae_powers = torch.where(mask[None, None, :, :], 0.0, ae_powers_raw)
317
+
318
+ # C(l,m) * (e-a)^(l-m)
319
+ coeffs = bino_coeffs[None, None, :, :] * ae_powers
320
+
321
+ return coeffs # (Ni, Nj, Ncomp, L_max+1)
322
+
323
+ def compute_1d_product_coefficients_vec(self, coeff_a, coeff_b, la, lb):
324
+ """
325
+ Vectorized 1D convolution of coefficient arrays using F.conv1d.
326
+ Corresponds to Fortran: form_product / prod3
327
+ """
328
+ Ni, Nj, Na = coeff_a.shape
329
+ _, _, Nb = coeff_b.shape
330
+ output_size = la + lb + 1
331
+
332
+ N_batch = Ni * Nj
333
+ if N_batch == 0:
334
+ return torch.empty((Ni, Nj, output_size), dtype=coeff_a.dtype, device=coeff_a.device)
335
+
336
+ coeff_a_flat = coeff_a.view(1, N_batch, Na)
337
+ coeff_b_flat_rev = torch.flip(coeff_b, dims=[2]).view(N_batch, 1, Nb)
338
+
339
+ conv_out = F.conv1d(
340
+ coeff_a_flat,
341
+ coeff_b_flat_rev,
342
+ padding=lb,
343
+ groups=N_batch
344
+ )
345
+
346
+ return conv_out.view(Ni, Nj, output_size)
347
+
348
+ def compute_1d_product_coefficients_batch_outer(
349
+ self, coeff_a: torch.Tensor, coeff_b: torch.Tensor,
350
+ la_vec: torch.Tensor, lb_vec: torch.Tensor
351
+ ) -> torch.Tensor:
352
+ """
353
+ Vectorized 1D convolution for an outer product of coefficient arrays.
354
+ """
355
+ Ni, Nj, naoi, Na_max = coeff_a.shape
356
+ _, _, naoj, Nb_max = coeff_b.shape
357
+
358
+ max_la = la_vec.max().item() if naoi > 0 else -1
359
+ max_lb = lb_vec.max().item() if naoj > 0 else -1
360
+ N_out_max = max_la + max_lb + 1
361
+
362
+ if Ni*Nj*naoi*naoj == 0:
363
+ return torch.empty((Ni, Nj, naoi, naoj, N_out_max), dtype=coeff_a.dtype, device=coeff_a.device)
364
+
365
+ a_b = coeff_a.unsqueeze(3).expand(-1, -1, -1, naoj, -1)
366
+ b_b = coeff_b.unsqueeze(2).expand(-1, -1, naoi, -1, -1)
367
+
368
+ N_batch = Ni * Nj * naoi * naoj
369
+ a_flat = a_b.reshape(1, N_batch, Na_max)
370
+ b_flat_rev = b_b.reshape(N_batch, 1, Nb_max).flip(dims=[-1])
371
+
372
+ conv_out = F.conv1d(
373
+ a_flat,
374
+ b_flat_rev,
375
+ padding=Nb_max - 1,
376
+ groups=N_batch
377
+ )
378
+
379
+ return conv_out.view(Ni, Nj, naoi, naoj, -1)[..., 0:N_out_max]
380
+
381
+ def _build_s1d_factors_vec(self, max_l: int, gama_mat: torch.Tensor) -> torch.Tensor:
382
+ """
383
+ Vectorized: Creates array of 1D integral factors (olapp results).
384
+ """
385
+ s1d_factors = []
386
+ for l in range(max_l + 1):
387
+ factor = self.calc_1d_overlap_constants_vec(l, gama_mat)
388
+ s1d_factors.append(factor)
389
+ return torch.stack(s1d_factors, dim=-1) # (Ni, Nj, max_l+1)
390
+
391
+ def _assemble_3d_multipole_factors_vec(
392
+ self,
393
+ ri: torch.Tensor, rj: torch.Tensor, rp: torch.Tensor, rc: torch.Tensor,
394
+ li_mat: torch.Tensor, lj_mat: torch.Tensor,
395
+ s1d_factors: torch.Tensor
396
+ ) -> torch.Tensor:
397
+ """
398
+ Vectorized Python (PyTorch) version of multipole_3d subroutine.
399
+ Corresponds to Fortran: multipole_3d(ri, rj, rc, rp, ai, aj, li, lj, s1d, s3d)
400
+ """
401
+ Ni, Nj, _ = rp.shape
402
+ naoi = li_mat.shape[0]
403
+ naoj = lj_mat.shape[0]
404
+ tensor_opts = {'dtype': ri.dtype, 'device': ri.device}
405
+
406
+ if naoi == 0 or naoj == 0:
407
+ return torch.empty((Ni, Nj, naoi, naoj, 10), **tensor_opts)
408
+
409
+ val_S_k = torch.zeros((Ni, Nj, naoi, naoj, 3), **tensor_opts)
410
+ val_M_k = torch.zeros((Ni, Nj, naoi, naoj, 3), **tensor_opts)
411
+ val_Q_k = torch.zeros((Ni, Nj, naoi, naoj, 3), **tensor_opts)
412
+
413
+ for k in range(3):
414
+ l_i_k_vec = li_mat[:, k] # (naoi,)
415
+ l_j_k_vec = lj_mat[:, k] # (naoj,)
416
+
417
+ max_l_i_k = l_i_k_vec.max().item()
418
+ max_l_j_k = l_j_k_vec.max().item()
419
+
420
+ center_i_k = ri[k]
421
+ center_j_k = rj[k]
422
+ center_p_k = rp[..., k]
423
+ center_c_k = rc[k]
424
+ rpc_k = center_p_k - center_c_k
425
+
426
+ vi_shifted = self.calculate_center_shift_coefficients_batch_l(
427
+ l_i_k_vec, center_i_k, center_p_k)
428
+ vj_shifted = self.calculate_center_shift_coefficients_batch_l(
429
+ l_j_k_vec, center_j_k, center_p_k)
430
+
431
+ l_total_max = max_l_i_k + max_l_j_k
432
+ pad_out = l_total_max + 1
433
+
434
+ vv = self.compute_1d_product_coefficients_batch_outer(
435
+ vi_shifted, vj_shifted, l_i_k_vec, l_j_k_vec)
436
+
437
+ l_total_mat = l_i_k_vec[:, None] + l_j_k_vec[None, :]
438
+
439
+ vv_subset_raw = vv[..., 0:pad_out]
440
+ s1d_L = s1d_factors[..., 0:pad_out]
441
+ s1d_Lplus1 = s1d_factors[..., 1:pad_out + 1]
442
+ s1d_Lplus2 = s1d_factors[..., 2:pad_out + 2]
443
+
444
+ m_vec = torch.arange(pad_out, device=ri.device)[None, None, :]
445
+ l_total_b = l_total_mat.unsqueeze(-1)
446
+ mask = m_vec > l_total_b
447
+ mask_b = mask[None, None, :, :, :]
448
+
449
+ vv_subset = torch.where(mask_b, 0.0, vv_subset_raw)
450
+
451
+ rpc_b = rpc_k[..., None, None, None]
452
+ s1d_L_b = s1d_L[..., None, None, :]
453
+ s1d_Lplus1_b = s1d_Lplus1[..., None, None, :]
454
+ s1d_Lplus2_b = s1d_Lplus2[..., None, None, :]
455
+
456
+ val_S_k[..., k] = torch.sum(vv_subset * s1d_L_b, dim=-1)
457
+
458
+ dipole_s1d_vec = s1d_Lplus1_b + rpc_b * s1d_L_b
459
+ val_M_k[..., k] = torch.sum(vv_subset * dipole_s1d_vec, dim=-1)
460
+
461
+ quad_s1d_vec = s1d_Lplus2_b + 2.0 * rpc_b * s1d_Lplus1_b + (rpc_b**2) * s1d_L_b
462
+ val_Q_k[..., k] = torch.sum(vv_subset * quad_s1d_vec, dim=-1)
463
+
464
+ S_x, S_y, S_z = val_S_k[..., 0], val_S_k[..., 1], val_S_k[..., 2]
465
+ M_x, M_y, M_z = val_M_k[..., 0], val_M_k[..., 1], val_M_k[..., 2]
466
+ Q_x, Q_y, Q_z = val_Q_k[..., 0], val_Q_k[..., 1], val_Q_k[..., 2]
467
+
468
+ s3d_factors = torch.stack([
469
+ S_x * S_y * S_z, M_x * S_y * S_z, S_x * M_y * S_z, S_x * S_y * M_z,
470
+ Q_x * S_y * S_z, S_x * Q_y * S_z, S_x * S_y * Q_z,
471
+ M_x * M_y * S_z, M_x * S_y * M_z, S_x * M_y * M_z
472
+ ], dim=-1)
473
+
474
+ return s3d_factors
475
+
476
+
477
+ def compute_contracted_overlap_matrix(
478
+ self,
479
+ ish: int, jsh: int, naoi_cart: int, naoj_cart: int,
480
+ ishtyp: int, jshtyp: int, ri: torch.Tensor, rj: torch.Tensor,
481
+ point: torch.Tensor, intcut: float,
482
+ basis: dict, alp_tensor: torch.Tensor, cont_tensor: torch.Tensor,
483
+ prim_indices_i: list, prim_indices_j: list,
484
+ itt_device: torch.Tensor, lxyz_type_device: torch.Tensor,
485
+ device: torch.device = None
486
+ ) -> torch.Tensor:
487
+ """
488
+ Computes the overlap matrix block between two *Shells* (ish, jsh).
489
+ Corresponds to Fortran: get_overlap(...)
490
+ """
491
+ if device is None:
492
+ device = ri.device
493
+ tensor_opts = {'dtype': ri.dtype, 'device': device}
494
+
495
+ rij = ri - rj
496
+ rij2 = torch.dot(rij, rij)
497
+ max_r2 = 2000.0
498
+ if rij2 > max_r2:
499
+ return torch.zeros((naoj_cart, naoi_cart), **tensor_opts)
500
+
501
+ iptyp = itt_device[ishtyp].item()
502
+ jptyp = itt_device[jshtyp].item()
503
+
504
+ Ni = len(prim_indices_i)
505
+ Nj = len(prim_indices_j)
506
+ if Ni == 0 or Nj == 0:
507
+ return torch.zeros((naoj_cart, naoi_cart), **tensor_opts)
508
+
509
+ alp_i_vec = alp_tensor[prim_indices_i]
510
+ cont_i_vec = cont_tensor[prim_indices_i]
511
+ alp_j_vec = alp_tensor[prim_indices_j]
512
+ cont_j_vec = cont_tensor[prim_indices_j]
513
+
514
+ alp_i_b = alp_i_vec.unsqueeze(1)
515
+ alp_j_b = alp_j_vec.unsqueeze(0)
516
+
517
+ gamma_mat, kab_mat = self.calc_product_exponent_and_overlap_vec(alp_i_b, ri, alp_j_b, rj)
518
+
519
+ rp_mat = self.calc_gau_product_center_vec(
520
+ alp_i_b.unsqueeze(2), ri.view(1, 1, 3),
521
+ alp_j_b.unsqueeze(2), rj.view(1, 1, 3),
522
+ gamma_mat.unsqueeze(2)
523
+ )
524
+
525
+ max_l_k_prim = ishtyp + jshtyp
526
+ max_l_needed_prim = max_l_k_prim + 2
527
+ t_factors_mat = self._build_s1d_factors_vec(max_l_needed_prim, gamma_mat)
528
+
529
+ li_mat = lxyz_type_device[iptyp:iptyp + naoi_cart]
530
+ lj_mat = lxyz_type_device[jptyp:jptyp + naoj_cart]
531
+
532
+ n_i_mat = self._primitive_norm(alp_i_vec, li_mat)
533
+ n_j_mat = self._primitive_norm(alp_j_vec, lj_mat)
534
+
535
+ c_n_i = cont_i_vec[:, None] * n_i_mat
536
+ c_n_j = cont_j_vec[:, None] * n_j_mat
537
+
538
+ cc_full_mat = c_n_i[:, None, :, None] * c_n_j[None, :, None, :]
539
+
540
+ saw_factors_mat = self._assemble_3d_multipole_factors_vec(
541
+ ri, rj, rp_mat, point,
542
+ li_mat, lj_mat,
543
+ t_factors_mat
544
+ )
545
+
546
+ s_prim_full = kab_mat[..., None, None] * saw_factors_mat[..., 0]
547
+
548
+ sint_block_cart_T = torch.sum(s_prim_full * cc_full_mat, dim=(0, 1))
549
+
550
+ sint_block_cart = sint_block_cart_T.T
551
+
552
+ return sint_block_cart
553
+
554
+
555
+ def calculate_overlap_matrix_full(self, xyz: torch.Tensor, intcut) -> torch.Tensor:
556
+ """ Calculates the full molecular overlap matrix S (nao, nao). """
557
+ if xyz.shape[0] == 3 and xyz.shape[1] != 3:
558
+ xyz = xyz.T
559
+ n_atoms = xyz.shape[0]
560
+ current_device = xyz.device
561
+
562
+ n_ao = self.basis['number_of_ao']
563
+ n_shells = self.basis['number_of_shells']
564
+ tensor_opts = {'dtype': xyz.dtype, 'device': current_device}
565
+ sint = torch.zeros((n_ao, n_ao), **tensor_opts)
566
+
567
+ LLAO = torch.tensor(self.LLAO, dtype=torch.int64, device=current_device)
568
+ point = torch.zeros(3, **tensor_opts)
569
+
570
+ shell_amqn_list = torch.tensor(self.basis['shell_amqn_list'], dtype=torch.int64, device=current_device)
571
+ shell_atom_list = torch.tensor(self.basis['shell_atom_list'], dtype=torch.int64, device=current_device)
572
+ shell_ao_map = torch.tensor(self.basis['shell_ao_map'], dtype=torch.int64, device=current_device)
573
+ alp = self.basis.get('primitive_alpha_list', [])
574
+ alp_tensor = torch.tensor(alp, **tensor_opts)
575
+ cont = self.basis.get('primitive_coeff_list', [])
576
+ cont_tensor = torch.tensor(cont, **tensor_opts)
577
+
578
+ itt_device = self.ITT.to(current_device)
579
+ lxyz_type_device = self.LXYZ_TYPE_TENSOR.to(current_device)
580
+ self.DFTR = self.DFTR.to(current_device)
581
+
582
+ shell_cgf_map = torch.tensor(self.basis['shell_cgf_map'], dtype=torch.int64, device=current_device)
583
+ cgf_primitive_count_list = torch.tensor(self.basis['cgf_primitive_count_list'], dtype=torch.int64, device=current_device)
584
+ cgf_primitive_start_idx_list = torch.tensor(self.basis['cgf_primitive_start_idx_list'], dtype=torch.int64, device=current_device)
585
+
586
+ shell_prim_indices = []
587
+ for ish in range(n_shells):
588
+ indices = []
589
+ icgf_start, icgf_count = shell_cgf_map[ish]
590
+ for icgf_local in range(icgf_count.item()):
591
+ icgf = icgf_start.item() + icgf_local
592
+ nprim = cgf_primitive_count_list[icgf].item()
593
+ start = cgf_primitive_start_idx_list[icgf].item()
594
+ indices.extend(range(start, start + nprim))
595
+ shell_prim_indices.append(indices)
596
+
597
+ for ish in range(n_shells):
598
+ iat_idx = shell_atom_list[ish]
599
+ ri = xyz[iat_idx]
600
+ ishtyp = shell_amqn_list[ish].item()
601
+ naoi_cart = LLAO[ishtyp].item()
602
+ iao_start, naoi_spher = shell_ao_map[ish]
603
+ iao_start, naoi_spher = iao_start.item(), naoi_spher.item()
604
+ iao_slice = slice(iao_start, iao_start + naoi_spher)
605
+
606
+ prim_indices_i = shell_prim_indices[ish]
607
+
608
+ for jsh in range(ish + 1):
609
+ jat_idx = shell_atom_list[jsh]
610
+ rj = xyz[jat_idx]
611
+ jshtyp = shell_amqn_list[jsh].item()
612
+ naoj_cart = LLAO[jshtyp].item()
613
+ jao_start, naoj_spher = shell_ao_map[jsh]
614
+ jao_start, naoj_spher = jao_start.item(), naoj_spher.item()
615
+ jao_slice = slice(jao_start, jao_start + naoj_spher)
616
+
617
+ prim_indices_j = shell_prim_indices[jsh]
618
+
619
+ ss_cart_shell = self.compute_contracted_overlap_matrix(
620
+ ish=ish, jsh=jsh,
621
+ naoi_cart=naoi_cart, naoj_cart=naoj_cart,
622
+ ishtyp=ishtyp, jshtyp=jshtyp,
623
+ ri=ri, rj=rj, point=point, intcut=intcut,
624
+ basis=self.basis,
625
+ alp_tensor=alp_tensor, cont_tensor=cont_tensor,
626
+ prim_indices_i=prim_indices_i, prim_indices_j=prim_indices_j,
627
+ itt_device=itt_device, lxyz_type_device=lxyz_type_device,
628
+ device=current_device
629
+ )
630
+
631
+ ss_spher_padded = self.transform_d_cartesian_to_spherical(ss_cart_shell, ishtyp, jshtyp)
632
+
633
+ i_slice_start = 1 if ishtyp == 2 else 0
634
+ j_slice_start = 1 if jshtyp == 2 else 0
635
+ i_slice_end = i_slice_start + naoi_spher
636
+ j_slice_end = j_slice_start + naoj_spher
637
+
638
+ ss_spher_shell = ss_spher_padded[j_slice_start:j_slice_end, i_slice_start:i_slice_end]
639
+
640
+ if ish == jsh:
641
+ sint[jao_slice, iao_slice] = sint[jao_slice, iao_slice] + ss_spher_shell
642
+ else:
643
+ sint[jao_slice, iao_slice] = sint[jao_slice, iao_slice] + ss_spher_shell
644
+ sint[iao_slice, jao_slice] = sint[iao_slice, jao_slice] + ss_spher_shell.T
645
+
646
+ sint = self.normalize_overlap_matrix(sint)
647
+ return sint
648
+
649
+ def normalize_overlap_matrix(self, sint: torch.Tensor) -> torch.Tensor:
650
+ """ Normalizes the overlap matrix S to have ones on the diagonal."""
651
+ diag = torch.diagonal(sint)
652
+ diag_safe = diag + 1e-20
653
+ inv_sqrt_diag = 1.0 / torch.sqrt(diag_safe)
654
+ inv_sqrt_diag = torch.where(diag > 1e-12, inv_sqrt_diag, torch.zeros_like(diag))
655
+
656
+ D_inv_sqrt = torch.diag(inv_sqrt_diag)
657
+ sint_normalized = torch.matmul(D_inv_sqrt, torch.matmul(sint, D_inv_sqrt))
658
+ return sint_normalized
659
+
660
+
661
+ def calculation(self, xyz: torch.Tensor, intcut=40):
662
+ """ Main calculation entry point, returns S, D(zeros), Q(zeros)."""
663
+ sint = self.calculate_overlap_matrix_full(xyz, intcut)
664
+ n_ao = sint.shape[0]
665
+ dpint = torch.zeros((3, n_ao, n_ao), dtype=sint.dtype, device=sint.device)
666
+ qpint = torch.zeros((6, n_ao, n_ao), dtype=sint.dtype, device=sint.device)
667
+
668
+ self.sint = sint
669
+ return sint, dpint, qpint
670
+
671
+ def get_overlap_integral_matrix(self) -> torch.Tensor:
672
+ """ Returns the last computed overlap integral matrix S."""
673
+ if not hasattr(self, 'sint'):
674
+ print("Warning: Overlap integral matrix S has not been computed yet. Returning None.")
675
+ return None
676
+ return self.sint
677
+
678
+ def overlap_int(self, xyz: np.ndarray) -> torch.Tensor:
679
+ """ Calculates overlap matrix from NumPy coordinate array."""
680
+ xyz_torch = torch.tensor(xyz, dtype=torch.float64, requires_grad=False)
681
+ sint, _, _ = self.calculation(xyz_torch)
682
+ return sint
683
+
684
+ def overlap_int_torch(self, xyz: torch.Tensor) -> torch.Tensor:
685
+ """ Calculates overlap matrix from Torch Tensor coordinates."""
686
+ sint = self._get_sint_only(xyz)
687
+ return sint
688
+
689
+ def _get_sint_only(self, xyz: torch.Tensor) -> torch.Tensor:
690
+ """ Helper function for jacrev/hessian to get S. """
691
+ sint, _, _ = self.calculation(xyz)
692
+ return sint
693
+
694
+ def d_overlap_int_dxyz(self, xyz: np.ndarray) -> torch.Tensor:
695
+ """ Calculates the Jacobian of the overlap matrix w.r.t.nuclear coordinates."""
696
+ xyz_torch = torch.tensor(xyz, dtype=torch.float64, requires_grad=True)
697
+ d_sint_dxyz = torch.func.jacrev(self._get_sint_only, argnums=0)(xyz_torch)
698
+ return d_sint_dxyz
699
+
700
+ def d2_overlap_int_dxyz2(self, xyz: np.ndarray) -> torch.Tensor:
701
+ """ Calculates the Hessian of the overlap matrix w.r.t.nuclear coordinates. """
702
+ xyz_torch = torch.tensor(xyz, dtype=torch.float64, requires_grad=True)
703
+ d2_sint_sum_dxyz2 = torch.func.hessian(self._get_sint_only, argnums=0)(xyz_torch)
704
+ return d2_sint_sum_dxyz2