dftax 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dftax/__init__.py +88 -0
- dftax/basis/__init__.py +5 -0
- dftax/basis/data/cart2sph.npz +0 -0
- dftax/basis/loader.py +160 -0
- dftax/energy/__init__.py +1 -0
- dftax/energy/boys.py +71 -0
- dftax/energy/grid.py +37 -0
- dftax/energy/gto.py +526 -0
- dftax/energy/hartree.py +53 -0
- dftax/energy/hybrid.py +58 -0
- dftax/energy/jax_df_integrals.py +256 -0
- dftax/energy/orbitals.py +50 -0
- dftax/energy/potentials.py +35 -0
- dftax/energy/xc.py +578 -0
- dftax/grid/__init__.py +12 -0
- dftax/grid/becke.py +86 -0
- dftax/grid/data/lebedev.npz +0 -0
- dftax/grid/grid.py +64 -0
- dftax/grid/lebedev.py +40 -0
- dftax/integrals/__init__.py +35 -0
- dftax/integrals/coulomb_potential.py +209 -0
- dftax/integrals/eri2c.py +325 -0
- dftax/integrals/eri3c.py +323 -0
- dftax/integrals/eri4c.py +478 -0
- dftax/integrals/multipole.py +118 -0
- dftax/integrals/nuclear_attraction.py +443 -0
- dftax/integrals/nuclear_repulsion.py +30 -0
- dftax/integrals/overlap.py +591 -0
- dftax/integrals/shell_pairs.py +151 -0
- dftax/ks/__init__.py +38 -0
- dftax/ks/batched.py +154 -0
- dftax/ks/driver.py +149 -0
- dftax/ks/energy.py +582 -0
- dftax/ks/energy_uks.py +363 -0
- dftax/ks/forces.py +91 -0
- dftax/ks/forces_uks.py +93 -0
- dftax/ks/implicit.py +160 -0
- dftax/ks/minimize.py +121 -0
- dftax/ks/minimize_uks.py +100 -0
- dftax/ks/properties.py +302 -0
- dftax/ks/scf.py +201 -0
- dftax/ks/scf_uks.py +150 -0
- dftax/py.typed +0 -0
- dftax/system/__init__.py +5 -0
- dftax/system/molecule.py +91 -0
- dftax/utils/__init__.py +3 -0
- dftax/utils/energy_aux.py +44 -0
- dftax/utils/vmap.py +275 -0
- dftax-0.1.0.dist-info/METADATA +257 -0
- dftax-0.1.0.dist-info/RECORD +52 -0
- dftax-0.1.0.dist-info/WHEEL +4 -0
- dftax-0.1.0.dist-info/licenses/LICENSE +201 -0
dftax/__init__.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""dftax: a self-contained, pure-JAX/Equinox Kohn-Sham DFT engine.
|
|
2
|
+
|
|
3
|
+
The package exposes a differentiable KS-DFT toolkit with no PySCF runtime
|
|
4
|
+
dependency for the core compute path:
|
|
5
|
+
|
|
6
|
+
- ``dftax.integrals``: analytical one- and two-electron integral matrices
|
|
7
|
+
(overlap S, kinetic T, nuclear attraction V, nuclear repulsion, ERIs 2c/3c/4c)
|
|
8
|
+
via the Obara-Saika recurrence, all jit/vmap/grad-friendly.
|
|
9
|
+
- ``dftax.energy``: GTO basis evaluation, the Boys function, density fitting,
|
|
10
|
+
Hartree and hybrid exchange, exchange-correlation functionals, real-space
|
|
11
|
+
grids, and pointwise potentials.
|
|
12
|
+
- ``dftax.utils``: chunked ``vmap`` helpers and shared types.
|
|
13
|
+
- ``dftax.ks``: Kohn-Sham drivers (energy functional + DIIS SCF) for both
|
|
14
|
+
closed-shell (RKS) and open-shell (UKS) systems, plus the unified ``run_ks``.
|
|
15
|
+
|
|
16
|
+
A few of the most common entry points are re-exported here for convenience;
|
|
17
|
+
import the submodules directly for the full surface.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from dftax.energy.gto import BasisData, extract_basis_data, eval_gto
|
|
21
|
+
from dftax.integrals import (
|
|
22
|
+
overlap_matrix,
|
|
23
|
+
kinetic_matrix,
|
|
24
|
+
nuclear_attraction_matrix,
|
|
25
|
+
nuclear_repulsion,
|
|
26
|
+
eri2c_matrix,
|
|
27
|
+
eri3c_matrix,
|
|
28
|
+
)
|
|
29
|
+
from dftax.ks import (
|
|
30
|
+
RKS, run_rks, rks_scf, rks_minimize, rks_forces, SCFResult,
|
|
31
|
+
UKS, run_uks, uks_scf, uks_minimize, uks_forces, UKSResult,
|
|
32
|
+
run_ks,
|
|
33
|
+
run_ks_batched, run_rks_batched, run_uks_batched, BatchedResult,
|
|
34
|
+
dipole, polarizability, hessian, vibrations, ir_spectrum, raman_spectrum,
|
|
35
|
+
alchemical_deriv, implicit_density,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
from importlib.metadata import PackageNotFoundError, version as _pkg_version
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
__version__ = _pkg_version("dftax")
|
|
42
|
+
except PackageNotFoundError: # running from a source tree without an install
|
|
43
|
+
__version__ = "0.0.0+local"
|
|
44
|
+
del _pkg_version, PackageNotFoundError
|
|
45
|
+
|
|
46
|
+
__all__ = [
|
|
47
|
+
"__version__",
|
|
48
|
+
"BasisData",
|
|
49
|
+
"extract_basis_data",
|
|
50
|
+
"eval_gto",
|
|
51
|
+
"overlap_matrix",
|
|
52
|
+
"kinetic_matrix",
|
|
53
|
+
"nuclear_attraction_matrix",
|
|
54
|
+
"nuclear_repulsion",
|
|
55
|
+
"eri2c_matrix",
|
|
56
|
+
"eri3c_matrix",
|
|
57
|
+
# restricted Kohn-Sham driver
|
|
58
|
+
"RKS",
|
|
59
|
+
"run_rks",
|
|
60
|
+
"rks_scf",
|
|
61
|
+
"rks_minimize",
|
|
62
|
+
"rks_forces",
|
|
63
|
+
"SCFResult",
|
|
64
|
+
# unrestricted (open-shell) Kohn-Sham driver
|
|
65
|
+
"UKS",
|
|
66
|
+
"run_uks",
|
|
67
|
+
"uks_scf",
|
|
68
|
+
"uks_minimize",
|
|
69
|
+
"uks_forces",
|
|
70
|
+
"UKSResult",
|
|
71
|
+
# unified dispatcher
|
|
72
|
+
"run_ks",
|
|
73
|
+
# batched (vmap over geometries)
|
|
74
|
+
"run_ks_batched",
|
|
75
|
+
"run_rks_batched",
|
|
76
|
+
"run_uks_batched",
|
|
77
|
+
"BatchedResult",
|
|
78
|
+
# response properties
|
|
79
|
+
"dipole",
|
|
80
|
+
"polarizability",
|
|
81
|
+
"hessian",
|
|
82
|
+
"vibrations",
|
|
83
|
+
"ir_spectrum",
|
|
84
|
+
"raman_spectrum",
|
|
85
|
+
"alchemical_deriv",
|
|
86
|
+
# implicit-differentiation SCF response
|
|
87
|
+
"implicit_density",
|
|
88
|
+
]
|
dftax/basis/__init__.py
ADDED
|
Binary file
|
dftax/basis/loader.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""Build :class:`~dftax.energy.gto.BasisData` from Basis Set Exchange data.
|
|
2
|
+
|
|
3
|
+
This is the PySCF-free counterpart of ``extract_basis_data``: it fetches the
|
|
4
|
+
contracted-GTO definition for each element from the ``basis_set_exchange``
|
|
5
|
+
library and assembles the same normalized Cartesian-AO arrays the integral
|
|
6
|
+
engine consumes. Normalization matches ``extract_basis_data`` exactly
|
|
7
|
+
(primitive ``gto_norm`` times, for l<=1, a contracted norm; libcint's
|
|
8
|
+
per-shell convention for l>=2).
|
|
9
|
+
|
|
10
|
+
Emits Cartesian GTOs by default (``cart2sph=None``); with ``spherical=True`` it
|
|
11
|
+
also builds the Cartesian->spherical transform (vendored per-l blocks under
|
|
12
|
+
``data/cart2sph.npz``) so spherical d/f bases (cc-pVDZ, cc-pVTZ, def2-*) match a
|
|
13
|
+
spherical PySCF reference. For l<=1 bases the two spans coincide.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import math
|
|
19
|
+
from functools import lru_cache
|
|
20
|
+
from importlib.resources import files
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
import jax.numpy as jnp
|
|
24
|
+
import basis_set_exchange as bse
|
|
25
|
+
|
|
26
|
+
from dftax.energy.gto import BasisData, _CART_COMPONENTS, _contracted_norm
|
|
27
|
+
from dftax.system.molecule import symbol_to_Z
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@lru_cache(maxsize=1)
|
|
31
|
+
def _cart2sph_blocks() -> dict[int, np.ndarray]:
|
|
32
|
+
"""Vendored per-l Cartesian->spherical transform blocks (ncart x nsph)."""
|
|
33
|
+
path = files("dftax.basis").joinpath("data", "cart2sph.npz")
|
|
34
|
+
with path.open("rb") as fh:
|
|
35
|
+
npz = np.load(fh)
|
|
36
|
+
return {int(k[1:]): np.asarray(npz[k], dtype=np.float64) for k in npz.files}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _block_diag(blocks: list[np.ndarray]) -> np.ndarray:
|
|
40
|
+
"""Assemble a block-diagonal matrix from a list of 2-D blocks (no SciPy)."""
|
|
41
|
+
nr = sum(b.shape[0] for b in blocks)
|
|
42
|
+
nc = sum(b.shape[1] for b in blocks)
|
|
43
|
+
out = np.zeros((nr, nc), dtype=np.float64)
|
|
44
|
+
r = c = 0
|
|
45
|
+
for b in blocks:
|
|
46
|
+
out[r : r + b.shape[0], c : c + b.shape[1]] = b
|
|
47
|
+
r += b.shape[0]
|
|
48
|
+
c += b.shape[1]
|
|
49
|
+
return out
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def gto_norm(l: int, alpha: float) -> float:
|
|
53
|
+
"""Radial normalization of a primitive GTO (matches ``pyscf.gto.gto_norm``).
|
|
54
|
+
|
|
55
|
+
``N`` such that the primitive ``N r^l e^{-alpha r^2}`` has unit self-overlap
|
|
56
|
+
in the radial/spherical sense.
|
|
57
|
+
"""
|
|
58
|
+
f = (
|
|
59
|
+
2 ** (2 * l + 3)
|
|
60
|
+
* math.factorial(l + 1)
|
|
61
|
+
* (2 * alpha) ** (l + 1.5)
|
|
62
|
+
/ (math.factorial(2 * l + 2) * math.sqrt(math.pi))
|
|
63
|
+
)
|
|
64
|
+
return math.sqrt(f)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _iter_contractions(shell):
|
|
68
|
+
"""Yield ``(l, coeff_vector)`` for each contracted function in a BSE shell.
|
|
69
|
+
|
|
70
|
+
- single angular momentum: each coefficient block is an independent
|
|
71
|
+
(general) contraction of that L;
|
|
72
|
+
- multiple angular momenta (e.g. an ``sp`` shell): block ``i`` belongs to
|
|
73
|
+
``angular_momentum[i]``.
|
|
74
|
+
"""
|
|
75
|
+
ams = shell["angular_momentum"]
|
|
76
|
+
cblocks = shell["coefficients"]
|
|
77
|
+
if len(ams) == 1:
|
|
78
|
+
for c in cblocks:
|
|
79
|
+
yield ams[0], np.asarray(c, dtype=np.float64)
|
|
80
|
+
else:
|
|
81
|
+
for l, c in zip(ams, cblocks):
|
|
82
|
+
yield l, np.asarray(c, dtype=np.float64)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def build_basis_data(
|
|
86
|
+
symbols: list[str],
|
|
87
|
+
coords_bohr: np.ndarray,
|
|
88
|
+
basis_name: str,
|
|
89
|
+
*,
|
|
90
|
+
spherical: bool = False,
|
|
91
|
+
return_atom_index: bool = False,
|
|
92
|
+
):
|
|
93
|
+
"""Assemble :class:`BasisData` for a molecule from a named basis set.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
symbols: element symbols, one per atom.
|
|
97
|
+
coords_bohr: nuclear coordinates in Bohr, shape (n_atoms, 3).
|
|
98
|
+
basis_name: a Basis Set Exchange basis name (e.g. ``"sto-3g"``).
|
|
99
|
+
spherical: if True, attach the Cartesian->spherical transform so the
|
|
100
|
+
basis uses (2l+1) spherical harmonics (standard for cc-pVXZ/def2);
|
|
101
|
+
otherwise emit Cartesian GTOs (``cart2sph=None``). For l<=1 the two
|
|
102
|
+
coincide.
|
|
103
|
+
return_atom_index: if True, also return an int array mapping each
|
|
104
|
+
Cartesian AO to its owning atom (needed to rebuild differentiable
|
|
105
|
+
centers from nuclear coordinates, e.g. for forces).
|
|
106
|
+
"""
|
|
107
|
+
coords = np.asarray(coords_bohr, dtype=np.float64).reshape(-1, 3)
|
|
108
|
+
Zs = [symbol_to_Z(s) for s in symbols]
|
|
109
|
+
bdata = bse.get_basis(basis_name, elements=sorted(set(Zs)), header=False)
|
|
110
|
+
elements = bdata["elements"]
|
|
111
|
+
|
|
112
|
+
# Gather every contracted shell: (l, exponents, raw coefficients, center, atom).
|
|
113
|
+
raw_shells: list[tuple[int, np.ndarray, np.ndarray, np.ndarray, int]] = []
|
|
114
|
+
for atom_idx, (Z, center) in enumerate(zip(Zs, coords)):
|
|
115
|
+
shells = elements[str(Z)]["electron_shells"]
|
|
116
|
+
for shell in shells:
|
|
117
|
+
exps = np.asarray(shell["exponents"], dtype=np.float64)
|
|
118
|
+
for l, c_raw in _iter_contractions(shell):
|
|
119
|
+
raw_shells.append((l, exps, c_raw, center, atom_idx))
|
|
120
|
+
|
|
121
|
+
max_prim = max(len(exps) for (_, exps, _, _, _) in raw_shells)
|
|
122
|
+
|
|
123
|
+
centers, all_exps, all_coeffs, angular, atom_index = [], [], [], [], []
|
|
124
|
+
for l, exps, c_raw, center, atom_idx in raw_shells:
|
|
125
|
+
prim_norms = np.array([gto_norm(l, e) for e in exps])
|
|
126
|
+
c_prim = c_raw * prim_norms
|
|
127
|
+
for ang in _CART_COMPONENTS[l]:
|
|
128
|
+
if l <= 1:
|
|
129
|
+
c_final = c_prim * _contracted_norm(exps, c_prim, ang)
|
|
130
|
+
else:
|
|
131
|
+
c_final = c_prim.copy()
|
|
132
|
+
pad_e = np.zeros(max_prim, dtype=np.float64)
|
|
133
|
+
pad_c = np.zeros(max_prim, dtype=np.float64)
|
|
134
|
+
pad_e[: len(exps)] = exps
|
|
135
|
+
pad_c[: len(c_final)] = c_final
|
|
136
|
+
centers.append(center)
|
|
137
|
+
all_exps.append(pad_e)
|
|
138
|
+
all_coeffs.append(pad_c)
|
|
139
|
+
angular.append(ang)
|
|
140
|
+
atom_index.append(atom_idx)
|
|
141
|
+
|
|
142
|
+
cart2sph = None
|
|
143
|
+
if spherical:
|
|
144
|
+
# One cart2sph(l) block per contracted shell, in build order.
|
|
145
|
+
blocks = _cart2sph_blocks()
|
|
146
|
+
cart2sph = jnp.asarray(
|
|
147
|
+
_block_diag([blocks[l] for (l, _, _, _, _) in raw_shells])
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
basis = BasisData(
|
|
151
|
+
centers=jnp.asarray(np.array(centers, dtype=np.float64)),
|
|
152
|
+
exponents=jnp.asarray(np.array(all_exps, dtype=np.float64)),
|
|
153
|
+
coefficients=jnp.asarray(np.array(all_coeffs, dtype=np.float64)),
|
|
154
|
+
angular=jnp.asarray(np.array(angular, dtype=np.int32)),
|
|
155
|
+
cart2sph=cart2sph,
|
|
156
|
+
max_l=int(max(l for (l, _, _, _, _) in raw_shells)),
|
|
157
|
+
)
|
|
158
|
+
if return_atom_index:
|
|
159
|
+
return basis, np.array(atom_index, dtype=np.int64)
|
|
160
|
+
return basis
|
dftax/energy/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Energy evaluation modules."""
|
dftax/energy/boys.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""Boys function F_n(t) in pure JAX.
|
|
2
|
+
|
|
3
|
+
The Boys function is defined as:
|
|
4
|
+
F_n(t) = integral_0^1 u^{2n} exp(-t u^2) du
|
|
5
|
+
|
|
6
|
+
It is related to the lower incomplete gamma function by:
|
|
7
|
+
F_n(t) = Gamma(n+0.5) * P(n+0.5, t) / (2 * t^{n+0.5})
|
|
8
|
+
where P(a, x) = gammainc(a, x) is the regularized lower incomplete gamma.
|
|
9
|
+
|
|
10
|
+
Key properties:
|
|
11
|
+
F_n(0) = 1 / (2n + 1)
|
|
12
|
+
dF_n/dt = -F_{n+1}(t)
|
|
13
|
+
F_0(t) = sqrt(pi) / (2*sqrt(t)) * erf(sqrt(t)) for t > 0
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
from jax.scipy.special import gammainc, gammaln
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def boys(n: int, t: jax.Array) -> jax.Array:
|
|
24
|
+
"""Boys function F_n(t) = integral_0^1 u^{2n} exp(-t u^2) du.
|
|
25
|
+
|
|
26
|
+
Accurate to better than 1e-8 vs numerical integration for n=0..10, t >= 0.
|
|
27
|
+
Fully differentiable: jax.grad(boys(n, .))(t) == -boys(n+1, t).
|
|
28
|
+
Works under jit and vmap.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
n: Order (Python int, not a traced value).
|
|
32
|
+
t: Argument, a JAX scalar or array with t >= 0.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
F_n(t) as a JAX array of the same shape and dtype as t.
|
|
36
|
+
"""
|
|
37
|
+
t = jnp.asarray(t)
|
|
38
|
+
dtype = t.dtype if jnp.issubdtype(t.dtype, jnp.floating) else jnp.float64
|
|
39
|
+
t = t.astype(dtype)
|
|
40
|
+
|
|
41
|
+
a = float(n) + 0.5
|
|
42
|
+
|
|
43
|
+
# --- Large-t branch: via incomplete gamma ---
|
|
44
|
+
# F_n(t) = Gamma(n+0.5) * gammainc(n+0.5, t) / (2 * t^{n+0.5})
|
|
45
|
+
# In log-space to avoid overflow:
|
|
46
|
+
# log F_n = gammaln(a) + log(gammainc(a, t)) - a*log(t) - log(2)
|
|
47
|
+
# Use t_safe >= 1.0 to guarantee finite gradients (log, pow, gammainc).
|
|
48
|
+
# When t < 1.0 the result is discarded by jnp.where, but JAX evaluates
|
|
49
|
+
# both branches under autodiff, so 0 * NaN = NaN must be avoided.
|
|
50
|
+
safe_t = jnp.where(t < 1.0, 1.0, t)
|
|
51
|
+
log_gamma_inc = jnp.log(jnp.maximum(gammainc(a, safe_t), jnp.finfo(dtype).tiny))
|
|
52
|
+
large_t = 0.5 * jnp.exp(gammaln(a) + log_gamma_inc - a * jnp.log(safe_t))
|
|
53
|
+
|
|
54
|
+
# --- Small-t branch: Taylor series ---
|
|
55
|
+
# F_n(t) = sum_{k=0}^{K} (-t)^k / (k! * (2n + 2k + 1))
|
|
56
|
+
# The k=0 term is the constant 1/(2n+1). We separate it to avoid
|
|
57
|
+
# computing (-t)^0 = 0^0 at t=0, whose JAX gradient is NaN.
|
|
58
|
+
K = 28
|
|
59
|
+
inv_denom_0 = 1.0 / (2 * n + 1)
|
|
60
|
+
inv_denom_rest = jnp.array(
|
|
61
|
+
[1.0 / (math.factorial(k) * (2 * n + 2 * k + 1)) for k in range(1, K)],
|
|
62
|
+
dtype=dtype,
|
|
63
|
+
)
|
|
64
|
+
k_vals_rest = jnp.arange(1, K, dtype=dtype)
|
|
65
|
+
small_t = inv_denom_0 + jnp.sum(
|
|
66
|
+
inv_denom_rest * ((-t[..., None]) ** k_vals_rest), axis=-1
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Use Taylor series for t < 1 (very accurate there), incomplete gamma elsewhere.
|
|
70
|
+
# The two branches agree to machine precision at t ~ 1.
|
|
71
|
+
return jnp.where(t < 1.0, small_t, large_t)
|
dftax/energy/grid.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Grid-based exchange-correlation energy integration."""
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from jaxtyping import Array, Float, Scalar
|
|
5
|
+
|
|
6
|
+
from dftax.energy.xc import XCFunctional
|
|
7
|
+
from dftax.energy.potentials import xc_potential
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def xc_energy(
|
|
11
|
+
xc: XCFunctional,
|
|
12
|
+
rho: Float[Array, "n"], # type: ignore
|
|
13
|
+
weights: Float[Array, "n"], # type: ignore
|
|
14
|
+
chunk_size: int | None = None,
|
|
15
|
+
grad_rho: Float[Array, "n 3"] | None = None, # type: ignore
|
|
16
|
+
rho_thresh: float = 1e-10,
|
|
17
|
+
) -> Scalar:
|
|
18
|
+
"""Exchange-correlation energy ``∫ ε_xc(ρ, ∇ρ) ρ dr`` on a quadrature grid.
|
|
19
|
+
|
|
20
|
+
``rho``/``grad_rho`` are the density (and, for GGA, its gradient) sampled at
|
|
21
|
+
the grid points; ``weights`` are the quadrature weights.
|
|
22
|
+
|
|
23
|
+
Grid points with ``ρ < rho_thresh`` are masked out with a nan-safe
|
|
24
|
+
double-``where``: the functional is evaluated on a clamped density and the
|
|
25
|
+
contribution (and its gradient) is forced to zero. This keeps the GGA
|
|
26
|
+
reduced-gradient terms, whose derivatives diverge as ρ→0, from producing
|
|
27
|
+
NaNs on the far, vanishing-density tail of an unpruned grid.
|
|
28
|
+
"""
|
|
29
|
+
mask = rho > rho_thresh
|
|
30
|
+
safe_rho = jnp.where(mask, rho, 1.0)
|
|
31
|
+
if xc.xc_type == "GGA" and grad_rho is not None:
|
|
32
|
+
safe_grad = jnp.where(mask[:, None], grad_rho, 0.0)
|
|
33
|
+
eps = xc_potential(xc, safe_rho, chunk_size, grad_rho=safe_grad)
|
|
34
|
+
else:
|
|
35
|
+
eps = xc_potential(xc, safe_rho, chunk_size, grad_rho=grad_rho)
|
|
36
|
+
contrib = jnp.where(mask, weights * eps * rho, 0.0)
|
|
37
|
+
return jnp.sum(contrib)
|