dftax 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. dftax/__init__.py +88 -0
  2. dftax/basis/__init__.py +5 -0
  3. dftax/basis/data/cart2sph.npz +0 -0
  4. dftax/basis/loader.py +160 -0
  5. dftax/energy/__init__.py +1 -0
  6. dftax/energy/boys.py +71 -0
  7. dftax/energy/grid.py +37 -0
  8. dftax/energy/gto.py +526 -0
  9. dftax/energy/hartree.py +53 -0
  10. dftax/energy/hybrid.py +58 -0
  11. dftax/energy/jax_df_integrals.py +256 -0
  12. dftax/energy/orbitals.py +50 -0
  13. dftax/energy/potentials.py +35 -0
  14. dftax/energy/xc.py +578 -0
  15. dftax/grid/__init__.py +12 -0
  16. dftax/grid/becke.py +86 -0
  17. dftax/grid/data/lebedev.npz +0 -0
  18. dftax/grid/grid.py +64 -0
  19. dftax/grid/lebedev.py +40 -0
  20. dftax/integrals/__init__.py +35 -0
  21. dftax/integrals/coulomb_potential.py +209 -0
  22. dftax/integrals/eri2c.py +325 -0
  23. dftax/integrals/eri3c.py +323 -0
  24. dftax/integrals/eri4c.py +478 -0
  25. dftax/integrals/multipole.py +118 -0
  26. dftax/integrals/nuclear_attraction.py +443 -0
  27. dftax/integrals/nuclear_repulsion.py +30 -0
  28. dftax/integrals/overlap.py +591 -0
  29. dftax/integrals/shell_pairs.py +151 -0
  30. dftax/ks/__init__.py +38 -0
  31. dftax/ks/batched.py +154 -0
  32. dftax/ks/driver.py +149 -0
  33. dftax/ks/energy.py +582 -0
  34. dftax/ks/energy_uks.py +363 -0
  35. dftax/ks/forces.py +91 -0
  36. dftax/ks/forces_uks.py +93 -0
  37. dftax/ks/implicit.py +160 -0
  38. dftax/ks/minimize.py +121 -0
  39. dftax/ks/minimize_uks.py +100 -0
  40. dftax/ks/properties.py +302 -0
  41. dftax/ks/scf.py +201 -0
  42. dftax/ks/scf_uks.py +150 -0
  43. dftax/py.typed +0 -0
  44. dftax/system/__init__.py +5 -0
  45. dftax/system/molecule.py +91 -0
  46. dftax/utils/__init__.py +3 -0
  47. dftax/utils/energy_aux.py +44 -0
  48. dftax/utils/vmap.py +275 -0
  49. dftax-0.1.0.dist-info/METADATA +257 -0
  50. dftax-0.1.0.dist-info/RECORD +52 -0
  51. dftax-0.1.0.dist-info/WHEEL +4 -0
  52. dftax-0.1.0.dist-info/licenses/LICENSE +201 -0
dftax/__init__.py ADDED
@@ -0,0 +1,88 @@
1
+ """dftax: a self-contained, pure-JAX/Equinox Kohn-Sham DFT engine.
2
+
3
+ The package exposes a differentiable KS-DFT toolkit with no PySCF runtime
4
+ dependency for the core compute path:
5
+
6
+ - ``dftax.integrals``: analytical one- and two-electron integral matrices
7
+ (overlap S, kinetic T, nuclear attraction V, nuclear repulsion, ERIs 2c/3c/4c)
8
+ via the Obara-Saika recurrence, all jit/vmap/grad-friendly.
9
+ - ``dftax.energy``: GTO basis evaluation, the Boys function, density fitting,
10
+ Hartree and hybrid exchange, exchange-correlation functionals, real-space
11
+ grids, and pointwise potentials.
12
+ - ``dftax.utils``: chunked ``vmap`` helpers and shared types.
13
+ - ``dftax.ks``: Kohn-Sham drivers (energy functional + DIIS SCF) for both
14
+ closed-shell (RKS) and open-shell (UKS) systems, plus the unified ``run_ks``.
15
+
16
+ A few of the most common entry points are re-exported here for convenience;
17
+ import the submodules directly for the full surface.
18
+ """
19
+
20
+ from dftax.energy.gto import BasisData, extract_basis_data, eval_gto
21
+ from dftax.integrals import (
22
+ overlap_matrix,
23
+ kinetic_matrix,
24
+ nuclear_attraction_matrix,
25
+ nuclear_repulsion,
26
+ eri2c_matrix,
27
+ eri3c_matrix,
28
+ )
29
+ from dftax.ks import (
30
+ RKS, run_rks, rks_scf, rks_minimize, rks_forces, SCFResult,
31
+ UKS, run_uks, uks_scf, uks_minimize, uks_forces, UKSResult,
32
+ run_ks,
33
+ run_ks_batched, run_rks_batched, run_uks_batched, BatchedResult,
34
+ dipole, polarizability, hessian, vibrations, ir_spectrum, raman_spectrum,
35
+ alchemical_deriv, implicit_density,
36
+ )
37
+
38
+ from importlib.metadata import PackageNotFoundError, version as _pkg_version
39
+
40
+ try:
41
+ __version__ = _pkg_version("dftax")
42
+ except PackageNotFoundError: # running from a source tree without an install
43
+ __version__ = "0.0.0+local"
44
+ del _pkg_version, PackageNotFoundError
45
+
46
+ __all__ = [
47
+ "__version__",
48
+ "BasisData",
49
+ "extract_basis_data",
50
+ "eval_gto",
51
+ "overlap_matrix",
52
+ "kinetic_matrix",
53
+ "nuclear_attraction_matrix",
54
+ "nuclear_repulsion",
55
+ "eri2c_matrix",
56
+ "eri3c_matrix",
57
+ # restricted Kohn-Sham driver
58
+ "RKS",
59
+ "run_rks",
60
+ "rks_scf",
61
+ "rks_minimize",
62
+ "rks_forces",
63
+ "SCFResult",
64
+ # unrestricted (open-shell) Kohn-Sham driver
65
+ "UKS",
66
+ "run_uks",
67
+ "uks_scf",
68
+ "uks_minimize",
69
+ "uks_forces",
70
+ "UKSResult",
71
+ # unified dispatcher
72
+ "run_ks",
73
+ # batched (vmap over geometries)
74
+ "run_ks_batched",
75
+ "run_rks_batched",
76
+ "run_uks_batched",
77
+ "BatchedResult",
78
+ # response properties
79
+ "dipole",
80
+ "polarizability",
81
+ "hessian",
82
+ "vibrations",
83
+ "ir_spectrum",
84
+ "raman_spectrum",
85
+ "alchemical_deriv",
86
+ # implicit-differentiation SCF response
87
+ "implicit_density",
88
+ ]
@@ -0,0 +1,5 @@
1
+ """PySCF-free basis-set loading via Basis Set Exchange."""
2
+
3
+ from dftax.basis.loader import build_basis_data, gto_norm
4
+
5
+ __all__ = ["build_basis_data", "gto_norm"]
Binary file
dftax/basis/loader.py ADDED
@@ -0,0 +1,160 @@
1
+ """Build :class:`~dftax.energy.gto.BasisData` from Basis Set Exchange data.
2
+
3
+ This is the PySCF-free counterpart of ``extract_basis_data``: it fetches the
4
+ contracted-GTO definition for each element from the ``basis_set_exchange``
5
+ library and assembles the same normalized Cartesian-AO arrays the integral
6
+ engine consumes. Normalization matches ``extract_basis_data`` exactly
7
+ (primitive ``gto_norm`` times, for l<=1, a contracted norm; libcint's
8
+ per-shell convention for l>=2).
9
+
10
+ Emits Cartesian GTOs by default (``cart2sph=None``); with ``spherical=True`` it
11
+ also builds the Cartesian->spherical transform (vendored per-l blocks under
12
+ ``data/cart2sph.npz``) so spherical d/f bases (cc-pVDZ, cc-pVTZ, def2-*) match a
13
+ spherical PySCF reference. For l<=1 bases the two spans coincide.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import math
19
+ from functools import lru_cache
20
+ from importlib.resources import files
21
+
22
+ import numpy as np
23
+ import jax.numpy as jnp
24
+ import basis_set_exchange as bse
25
+
26
+ from dftax.energy.gto import BasisData, _CART_COMPONENTS, _contracted_norm
27
+ from dftax.system.molecule import symbol_to_Z
28
+
29
+
30
+ @lru_cache(maxsize=1)
31
+ def _cart2sph_blocks() -> dict[int, np.ndarray]:
32
+ """Vendored per-l Cartesian->spherical transform blocks (ncart x nsph)."""
33
+ path = files("dftax.basis").joinpath("data", "cart2sph.npz")
34
+ with path.open("rb") as fh:
35
+ npz = np.load(fh)
36
+ return {int(k[1:]): np.asarray(npz[k], dtype=np.float64) for k in npz.files}
37
+
38
+
39
+ def _block_diag(blocks: list[np.ndarray]) -> np.ndarray:
40
+ """Assemble a block-diagonal matrix from a list of 2-D blocks (no SciPy)."""
41
+ nr = sum(b.shape[0] for b in blocks)
42
+ nc = sum(b.shape[1] for b in blocks)
43
+ out = np.zeros((nr, nc), dtype=np.float64)
44
+ r = c = 0
45
+ for b in blocks:
46
+ out[r : r + b.shape[0], c : c + b.shape[1]] = b
47
+ r += b.shape[0]
48
+ c += b.shape[1]
49
+ return out
50
+
51
+
52
+ def gto_norm(l: int, alpha: float) -> float:
53
+ """Radial normalization of a primitive GTO (matches ``pyscf.gto.gto_norm``).
54
+
55
+ ``N`` such that the primitive ``N r^l e^{-alpha r^2}`` has unit self-overlap
56
+ in the radial/spherical sense.
57
+ """
58
+ f = (
59
+ 2 ** (2 * l + 3)
60
+ * math.factorial(l + 1)
61
+ * (2 * alpha) ** (l + 1.5)
62
+ / (math.factorial(2 * l + 2) * math.sqrt(math.pi))
63
+ )
64
+ return math.sqrt(f)
65
+
66
+
67
+ def _iter_contractions(shell):
68
+ """Yield ``(l, coeff_vector)`` for each contracted function in a BSE shell.
69
+
70
+ - single angular momentum: each coefficient block is an independent
71
+ (general) contraction of that L;
72
+ - multiple angular momenta (e.g. an ``sp`` shell): block ``i`` belongs to
73
+ ``angular_momentum[i]``.
74
+ """
75
+ ams = shell["angular_momentum"]
76
+ cblocks = shell["coefficients"]
77
+ if len(ams) == 1:
78
+ for c in cblocks:
79
+ yield ams[0], np.asarray(c, dtype=np.float64)
80
+ else:
81
+ for l, c in zip(ams, cblocks):
82
+ yield l, np.asarray(c, dtype=np.float64)
83
+
84
+
85
+ def build_basis_data(
86
+ symbols: list[str],
87
+ coords_bohr: np.ndarray,
88
+ basis_name: str,
89
+ *,
90
+ spherical: bool = False,
91
+ return_atom_index: bool = False,
92
+ ):
93
+ """Assemble :class:`BasisData` for a molecule from a named basis set.
94
+
95
+ Args:
96
+ symbols: element symbols, one per atom.
97
+ coords_bohr: nuclear coordinates in Bohr, shape (n_atoms, 3).
98
+ basis_name: a Basis Set Exchange basis name (e.g. ``"sto-3g"``).
99
+ spherical: if True, attach the Cartesian->spherical transform so the
100
+ basis uses (2l+1) spherical harmonics (standard for cc-pVXZ/def2);
101
+ otherwise emit Cartesian GTOs (``cart2sph=None``). For l<=1 the two
102
+ coincide.
103
+ return_atom_index: if True, also return an int array mapping each
104
+ Cartesian AO to its owning atom (needed to rebuild differentiable
105
+ centers from nuclear coordinates, e.g. for forces).
106
+ """
107
+ coords = np.asarray(coords_bohr, dtype=np.float64).reshape(-1, 3)
108
+ Zs = [symbol_to_Z(s) for s in symbols]
109
+ bdata = bse.get_basis(basis_name, elements=sorted(set(Zs)), header=False)
110
+ elements = bdata["elements"]
111
+
112
+ # Gather every contracted shell: (l, exponents, raw coefficients, center, atom).
113
+ raw_shells: list[tuple[int, np.ndarray, np.ndarray, np.ndarray, int]] = []
114
+ for atom_idx, (Z, center) in enumerate(zip(Zs, coords)):
115
+ shells = elements[str(Z)]["electron_shells"]
116
+ for shell in shells:
117
+ exps = np.asarray(shell["exponents"], dtype=np.float64)
118
+ for l, c_raw in _iter_contractions(shell):
119
+ raw_shells.append((l, exps, c_raw, center, atom_idx))
120
+
121
+ max_prim = max(len(exps) for (_, exps, _, _, _) in raw_shells)
122
+
123
+ centers, all_exps, all_coeffs, angular, atom_index = [], [], [], [], []
124
+ for l, exps, c_raw, center, atom_idx in raw_shells:
125
+ prim_norms = np.array([gto_norm(l, e) for e in exps])
126
+ c_prim = c_raw * prim_norms
127
+ for ang in _CART_COMPONENTS[l]:
128
+ if l <= 1:
129
+ c_final = c_prim * _contracted_norm(exps, c_prim, ang)
130
+ else:
131
+ c_final = c_prim.copy()
132
+ pad_e = np.zeros(max_prim, dtype=np.float64)
133
+ pad_c = np.zeros(max_prim, dtype=np.float64)
134
+ pad_e[: len(exps)] = exps
135
+ pad_c[: len(c_final)] = c_final
136
+ centers.append(center)
137
+ all_exps.append(pad_e)
138
+ all_coeffs.append(pad_c)
139
+ angular.append(ang)
140
+ atom_index.append(atom_idx)
141
+
142
+ cart2sph = None
143
+ if spherical:
144
+ # One cart2sph(l) block per contracted shell, in build order.
145
+ blocks = _cart2sph_blocks()
146
+ cart2sph = jnp.asarray(
147
+ _block_diag([blocks[l] for (l, _, _, _, _) in raw_shells])
148
+ )
149
+
150
+ basis = BasisData(
151
+ centers=jnp.asarray(np.array(centers, dtype=np.float64)),
152
+ exponents=jnp.asarray(np.array(all_exps, dtype=np.float64)),
153
+ coefficients=jnp.asarray(np.array(all_coeffs, dtype=np.float64)),
154
+ angular=jnp.asarray(np.array(angular, dtype=np.int32)),
155
+ cart2sph=cart2sph,
156
+ max_l=int(max(l for (l, _, _, _, _) in raw_shells)),
157
+ )
158
+ if return_atom_index:
159
+ return basis, np.array(atom_index, dtype=np.int64)
160
+ return basis
@@ -0,0 +1 @@
1
+ """Energy evaluation modules."""
dftax/energy/boys.py ADDED
@@ -0,0 +1,71 @@
1
+ """Boys function F_n(t) in pure JAX.
2
+
3
+ The Boys function is defined as:
4
+ F_n(t) = integral_0^1 u^{2n} exp(-t u^2) du
5
+
6
+ It is related to the lower incomplete gamma function by:
7
+ F_n(t) = Gamma(n+0.5) * P(n+0.5, t) / (2 * t^{n+0.5})
8
+ where P(a, x) = gammainc(a, x) is the regularized lower incomplete gamma.
9
+
10
+ Key properties:
11
+ F_n(0) = 1 / (2n + 1)
12
+ dF_n/dt = -F_{n+1}(t)
13
+ F_0(t) = sqrt(pi) / (2*sqrt(t)) * erf(sqrt(t)) for t > 0
14
+ """
15
+
16
+ import math
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from jax.scipy.special import gammainc, gammaln
21
+
22
+
23
+ def boys(n: int, t: jax.Array) -> jax.Array:
24
+ """Boys function F_n(t) = integral_0^1 u^{2n} exp(-t u^2) du.
25
+
26
+ Accurate to better than 1e-8 vs numerical integration for n=0..10, t >= 0.
27
+ Fully differentiable: jax.grad(boys(n, .))(t) == -boys(n+1, t).
28
+ Works under jit and vmap.
29
+
30
+ Args:
31
+ n: Order (Python int, not a traced value).
32
+ t: Argument, a JAX scalar or array with t >= 0.
33
+
34
+ Returns:
35
+ F_n(t) as a JAX array of the same shape and dtype as t.
36
+ """
37
+ t = jnp.asarray(t)
38
+ dtype = t.dtype if jnp.issubdtype(t.dtype, jnp.floating) else jnp.float64
39
+ t = t.astype(dtype)
40
+
41
+ a = float(n) + 0.5
42
+
43
+ # --- Large-t branch: via incomplete gamma ---
44
+ # F_n(t) = Gamma(n+0.5) * gammainc(n+0.5, t) / (2 * t^{n+0.5})
45
+ # In log-space to avoid overflow:
46
+ # log F_n = gammaln(a) + log(gammainc(a, t)) - a*log(t) - log(2)
47
+ # Use t_safe >= 1.0 to guarantee finite gradients (log, pow, gammainc).
48
+ # When t < 1.0 the result is discarded by jnp.where, but JAX evaluates
49
+ # both branches under autodiff, so 0 * NaN = NaN must be avoided.
50
+ safe_t = jnp.where(t < 1.0, 1.0, t)
51
+ log_gamma_inc = jnp.log(jnp.maximum(gammainc(a, safe_t), jnp.finfo(dtype).tiny))
52
+ large_t = 0.5 * jnp.exp(gammaln(a) + log_gamma_inc - a * jnp.log(safe_t))
53
+
54
+ # --- Small-t branch: Taylor series ---
55
+ # F_n(t) = sum_{k=0}^{K} (-t)^k / (k! * (2n + 2k + 1))
56
+ # The k=0 term is the constant 1/(2n+1). We separate it to avoid
57
+ # computing (-t)^0 = 0^0 at t=0, whose JAX gradient is NaN.
58
+ K = 28
59
+ inv_denom_0 = 1.0 / (2 * n + 1)
60
+ inv_denom_rest = jnp.array(
61
+ [1.0 / (math.factorial(k) * (2 * n + 2 * k + 1)) for k in range(1, K)],
62
+ dtype=dtype,
63
+ )
64
+ k_vals_rest = jnp.arange(1, K, dtype=dtype)
65
+ small_t = inv_denom_0 + jnp.sum(
66
+ inv_denom_rest * ((-t[..., None]) ** k_vals_rest), axis=-1
67
+ )
68
+
69
+ # Use Taylor series for t < 1 (very accurate there), incomplete gamma elsewhere.
70
+ # The two branches agree to machine precision at t ~ 1.
71
+ return jnp.where(t < 1.0, small_t, large_t)
dftax/energy/grid.py ADDED
@@ -0,0 +1,37 @@
1
+ """Grid-based exchange-correlation energy integration."""
2
+
3
+ import jax.numpy as jnp
4
+ from jaxtyping import Array, Float, Scalar
5
+
6
+ from dftax.energy.xc import XCFunctional
7
+ from dftax.energy.potentials import xc_potential
8
+
9
+
10
+ def xc_energy(
11
+ xc: XCFunctional,
12
+ rho: Float[Array, "n"], # type: ignore
13
+ weights: Float[Array, "n"], # type: ignore
14
+ chunk_size: int | None = None,
15
+ grad_rho: Float[Array, "n 3"] | None = None, # type: ignore
16
+ rho_thresh: float = 1e-10,
17
+ ) -> Scalar:
18
+ """Exchange-correlation energy ``∫ ε_xc(ρ, ∇ρ) ρ dr`` on a quadrature grid.
19
+
20
+ ``rho``/``grad_rho`` are the density (and, for GGA, its gradient) sampled at
21
+ the grid points; ``weights`` are the quadrature weights.
22
+
23
+ Grid points with ``ρ < rho_thresh`` are masked out with a nan-safe
24
+ double-``where``: the functional is evaluated on a clamped density and the
25
+ contribution (and its gradient) is forced to zero. This keeps the GGA
26
+ reduced-gradient terms, whose derivatives diverge as ρ→0, from producing
27
+ NaNs on the far, vanishing-density tail of an unpruned grid.
28
+ """
29
+ mask = rho > rho_thresh
30
+ safe_rho = jnp.where(mask, rho, 1.0)
31
+ if xc.xc_type == "GGA" and grad_rho is not None:
32
+ safe_grad = jnp.where(mask[:, None], grad_rho, 0.0)
33
+ eps = xc_potential(xc, safe_rho, chunk_size, grad_rho=safe_grad)
34
+ else:
35
+ eps = xc_potential(xc, safe_rho, chunk_size, grad_rho=grad_rho)
36
+ contrib = jnp.where(mask, weights * eps * rho, 0.0)
37
+ return jnp.sum(contrib)