off 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,84 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import equinox as eqx
4
+ from jaxtyping import Array, Float
5
+ from jax import vmap, lax
6
+
7
+
8
+ class KatoCondition(eqx.Module):
9
+ r"""
10
+ Kato cusp-condition functional.
11
+
12
+ Penalizes deviations from the correct electron-nucleus cusp behaviour, weighting
13
+ the (Weizsacker-like) local kinetic term by a 1s-like envelope
14
+ w_i(r) = exp(-a Z_i |r - R_i|) around each nucleus.
15
+
16
+ Parameters
17
+ ----------
18
+ a : float, optional
19
+ Envelope decay prefactor, by default 2/3.
20
+ eps : float, optional
21
+ Small constant for numerical stability, by default 1e-5.
22
+ """
23
+
24
+ a: float
25
+ eps: float
26
+
27
+ def __init__(self, a=2.0/3.0, eps=1e-5):
28
+ self.a = a
29
+ self.eps = eps
30
+
31
+ def __call__(self, den, score, x, Ne, mol, xp) -> Float[Array, "batch"]:
32
+ r"""Uses x, score, Ne, mol; den, xp unused. Returns the cusp penalty per point."""
33
+ def _wi(pts, molecule):
34
+ r = jnp.sqrt(jnp.sum((pts - molecule['coords']) ** 2, axis=1) + self.eps ** 2)
35
+ return jnp.exp(-self.a * molecule['z'] * r)
36
+
37
+ wi = vmap(_wi, in_axes=(None, 0), out_axes=-1)(x, mol)
38
+ score_sqr = jnp.einsum('ij,ij->i', score, score)
39
+ weizs = (1.0 / 8.0) * lax.expand_dims(score_sqr, (1,))
40
+ kinetic = jnp.abs(weizs - (mol['z'] ** 2 / 2))
41
+ return Ne * jnp.sum(kinetic * wi, axis=-1, keepdims=True)
42
+
43
+
44
+ class HutcheonCuspCondition(eqx.Module):
45
+ r"""
46
+ Exact nuclear-cusp cost functional.
47
+
48
+ From Hutcheon & Wibowo-Teale, Phys. Rev. B 110, 195146 (2024), Eqs. (4-5, 7):
49
+
50
+ E_{\text{CC}} = \sum_i C_i, \qquad
51
+ C_i(\rho) = \int |\nabla \rho + 2 Z_i \rho\, \hat{r}_i|^2\, w_i(r)\, d^3 r,
52
+
53
+ with \hat{r}_i = (r - R_i)/|r - R_i| and w_i(r) = (\pi / Z_i^3)^{1/2} e^{-Z_i |r - R_i|} [Eq. (5)].
54
+ Since \nabla \rho = \rho\, s (s the score), the integrand is \rho^2 |s + 2 Z_i \hat{r}_i|^2,
55
+ and the Kato cusp is satisfied exactly when s = -2 Z_i \hat{r}_i at R_i. With samples
56
+ from p = \rho / N_e,
57
+
58
+ C_i \approx N_e^2\, \mathbb{E}\left[ \rho(x)\, |s(x) + 2 Z_i \hat{r}_i(x)|^2\, w_i(x) \right].
59
+
60
+ Parameters
61
+ ----------
62
+ eps : float, optional
63
+ Small constant for numerical stability, by default 1e-8.
64
+ """
65
+
66
+ eps: float
67
+
68
+ def __init__(self, eps=1e-8):
69
+ self.eps = eps
70
+
71
+ def __call__(self, den, score, x, Ne, mol, xp) -> Float[Array, "batch 1"]:
72
+ r"""Uses x, den, score, Ne, mol; xp unused. Returns the cusp cost per point."""
73
+ coords, z = mol['coords'], mol['z']
74
+
75
+ def per_atom(R_i, Z_i):
76
+ r_vec = x - R_i
77
+ r_norm = jnp.sqrt(jnp.sum(r_vec ** 2, axis=-1) + self.eps ** 2)
78
+ r_hat = r_vec / r_norm[:, None]
79
+ w_i = jnp.sqrt(jnp.pi / Z_i ** 3) * jnp.exp(-Z_i * r_norm)
80
+ cusp_sq = jnp.sum((score + 2.0 * Z_i * r_hat) ** 2, axis=-1)
81
+ return cusp_sq * w_i
82
+
83
+ total = jnp.sum(jax.vmap(per_atom)(coords, z), axis=0)
84
+ return (Ne * den.squeeze(-1) * total)[:, None]
@@ -0,0 +1,174 @@
1
+ import jax.numpy as jnp
2
+ from jax import lax
3
+ from .functional import Functional, CompositeFunctional, unit_coefficient
4
+
5
+
6
+ def lda_density(den, score, x, Ne, mol, xp):
7
+ r"""
8
+ Local density approximation (LDA) / Dirac exchange functional.
9
+
10
+ See eq. 2.72 in "Time-Dependent Density-Functional Theory", Carsten A. Ullrich.
11
+
12
+ E_{\text{X}}^{\text{LDA}}[\rho] = -\frac{3}{4}\left(\frac{3}{\pi}\right)^{1/3} \int \rho(\boldsymbol{x})^{4/3} d\boldsymbol{x}
13
+ = -\frac{3}{4}\left(\frac{3}{\pi}\right)^{1/3} N_e^{4/3}\, \mathbb{E}_{\rho_\phi}\left[ \rho_\phi(\boldsymbol{x})^{1/3} \right]
14
+
15
+ Parameters
16
+ ----------
17
+ den : Array
18
+ Density.
19
+ Ne : int
20
+ Number of electrons.
21
+
22
+ Notes
23
+ -----
24
+ score, x, mol, xp are accepted for the shared functional interface but unused here.
25
+
26
+ Returns
27
+ -------
28
+ jax.Array
29
+ LDA exchange energy density (up to the rho factor).
30
+ """
31
+ l = -(3 / 4) * (Ne ** (4 / 3)) * (3 / jnp.pi) ** (1 / 3)
32
+ return l * den ** (1 / 3)
33
+
34
+
35
+ def b88_density(den, score, x, Ne, mol, xp, clip_cte=1e-30, beta=0.0042):
36
+ r"""
37
+ B88 exchange functional.
38
+
39
+ See eq. 8 in https://journals.aps.org/pra/abstract/10.1103/PhysRevA.38.3098
40
+ See also https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/gga_exc/gga_x_b88.mpl
41
+
42
+ E_{\text{X}}^{\text{B88}}[\rho] = -\beta \int \frac{X^2}{1 + 6\beta X \sinh^{-1}(X)} \rho(\boldsymbol{x})^{4/3} d\boldsymbol{x},
43
+ \qquad X = \frac{|\nabla \rho|}{\rho^{4/3}}.
44
+
45
+ Parameters
46
+ ----------
47
+ den : Array
48
+ Density.
49
+ score : Array
50
+ Gradient of the log-density, s = (nabla rho)/rho.
51
+ Ne : int
52
+ Number of electrons.
53
+ clip_cte : float, optional
54
+ Small constant for numerical stability, by default 1e-30.
55
+ beta : float, optional
56
+ B88 parameter, by default 0.0042.
57
+
58
+ Notes
59
+ -----
60
+ x, mol, xp are accepted for the shared functional interface but unused here.
61
+
62
+ Returns
63
+ -------
64
+ jax.Array
65
+ B88 exchange energy density (up to the rho factor).
66
+ """
67
+ den_clipped = jnp.clip(den, clip_cte)
68
+ log_den = jnp.log2(den_clipped)
69
+ score_sqr = jnp.einsum('ij,ij->i', score, score)
70
+ grad_den_norm_sq = lax.expand_dims(score_sqr, (1,)) * den_clipped * den_clipped
71
+ log_grad_den_norm = jnp.log2(jnp.clip(grad_den_norm_sq, clip_cte)) / 2
72
+ log_x_sigma = log_grad_den_norm - 4 / 3.0 * log_den
73
+ x_sigma = 2 ** log_x_sigma
74
+ b88_e = -(beta * 2 ** (4 * log_den / 3 + 2 * log_x_sigma
75
+ - jnp.log2(1 + 6 * beta * x_sigma * jnp.arcsinh(x_sigma))))
76
+ return b88_e * Ne ** (2 / 3)
77
+
78
+
79
+ def vwn_density(den, score, x, Ne, mol, xp, clip_cte=1e-30):
80
+ r"""
81
+ VWN correlation functional.
82
+
83
+ See original paper eq. 4.4 in https://cdnsciencepub.com/doi/abs/10.1139/p80-159
84
+ See also the text after eq. 8.9.6.1 in https://www.theoretical-physics.com/dev/quantum/dft.html
85
+
86
+ \epsilon_{\text{C}}^{\text{VWN}} = \frac{A}{2}\left\{ \ln\frac{y^2}{Y(y)}
87
+ + \frac{2b}{Q}\arctan\frac{Q}{2y+b}
88
+ - \frac{b y_0}{Y(y_0)}\left[ \ln\frac{(y-y_0)^2}{Y(y)}
89
+ + \frac{2(b+2y_0)}{Q}\arctan\frac{Q}{2y+b}\right] \right\}
90
+
91
+ Parameters
92
+ ----------
93
+ den : Array
94
+ Density.
95
+ Ne : int
96
+ Number of electrons.
97
+ clip_cte : float, optional
98
+ Small constant for numerical stability, by default 1e-30.
99
+
100
+ Notes
101
+ -----
102
+ score, x, mol, xp are accepted for the shared functional interface but unused here.
103
+
104
+ Returns
105
+ -------
106
+ jax.Array
107
+ VWN correlation energy density (up to the rho factor).
108
+ """
109
+ A, b, c, x0 = 0.0621814, 3.72744, 12.9352, -0.10498
110
+ den_clipped = jnp.where(den > clip_cte, den, 0.0)
111
+ log_den = jnp.log2(jnp.clip(den_clipped, clip_cte))
112
+ log_rs = jnp.log2((3 / (4 * jnp.pi)) ** (1 / 3)) - log_den / 3.0
113
+ log_x = log_rs / 2
114
+ x_ = 2. ** log_x
115
+ X = 2. ** (2. * log_x) + 2. ** (log_x + jnp.log2(b)) + c
116
+ X0 = x0 ** 2 + b * x0 + c
117
+ Q = jnp.sqrt(4 * c - b ** 2)
118
+ e_PF = A / 2. * (
119
+ 2. * jnp.log(x_) - jnp.log(X) + 2. * b / Q * jnp.arctan(Q / (2. * x_ + b))
120
+ - b * x0 / X0 * (jnp.log((x_ - x0) ** 2. / X)
121
+ + 2. * (2. * x0 + b) / Q * jnp.arctan(Q / (2. * x_ + b)))
122
+ )
123
+ return Ne * e_PF
124
+
125
+
126
+ def pw92_density(den, score, x, Ne, mol, xp, clip_cte=1e-30):
127
+ r"""
128
+ PW92 correlation functional.
129
+
130
+ See eq. 10 in https://journals.aps.org/prb/abstract/10.1103/PhysRevB.45.13244
131
+
132
+ \epsilon_{\text{C}}^{\text{PW92}} = -2A(1 + \alpha_1 r_s)
133
+ \ln\left[ 1 + \frac{1}{2A(\beta_1 r_s^{1/2} + \beta_2 r_s + \beta_3 r_s^{3/2} + \beta_4 r_s^2)} \right]
134
+
135
+ Parameters
136
+ ----------
137
+ den : Array
138
+ Density.
139
+ Ne : int
140
+ Number of electrons.
141
+ clip_cte : float, optional
142
+ Small constant for numerical stability, by default 1e-30.
143
+
144
+ Notes
145
+ -----
146
+ score, x, mol, xp are accepted for the shared functional interface but unused here.
147
+
148
+ Returns
149
+ -------
150
+ jax.Array
151
+ PW92 correlation energy density (up to the rho factor).
152
+ """
153
+ A_, alpha1 = 0.031091, 0.21370
154
+ beta1, beta2, beta3, beta4 = 7.5957, 3.5876, 1.6382, 0.49294
155
+ log_den = jnp.log2(jnp.clip(den, clip_cte))
156
+ log_rs = jnp.log2((3 / (4 * jnp.pi)) ** (1 / 3)) - log_den / 3.0
157
+ brs_1_2 = 2 ** (log_rs / 2 + jnp.log2(beta1))
158
+ ars = 2 ** (log_rs + jnp.log2(alpha1))
159
+ brs = 2 ** (log_rs + jnp.log2(beta2))
160
+ brs_3_2 = 2 ** (3 * log_rs / 2 + jnp.log2(beta3))
161
+ brs2 = 2 ** (2 * log_rs + jnp.log2(beta4))
162
+ e_PF = -2 * A_ * (1 + ars) * jnp.log(1 + (1 / (2 * A_)) / (brs_1_2 + brs + brs_3_2 + brs2))
163
+ return Ne * e_PF
164
+
165
+
166
+ lda = Functional(coefficients=unit_coefficient, energy_densities=lda_density)
167
+ b88 = Functional(coefficients=unit_coefficient, energy_densities=b88_density)
168
+ vwn = Functional(coefficients=unit_coefficient, energy_densities=vwn_density)
169
+ pw92 = Functional(coefficients=unit_coefficient, energy_densities=pw92_density)
170
+
171
+
172
+ def lda_b88():
173
+ r"""LDA + B88 exchange."""
174
+ return CompositeFunctional(lda, b88)
@@ -0,0 +1,49 @@
1
+ import jax.numpy as jnp
2
+ import equinox as eqx
3
+ from jaxtyping import Array, Float
4
+ from jax import vmap
5
+
6
+
7
+ class NuclearPotential(eqx.Module):
8
+ r"""
9
+ External electron-nuclei attraction potential.
10
+
11
+ V_{\text{ext}}(\boldsymbol{x}) = -N_e \sum_i \frac{Z_i}{|\boldsymbol{x} - \boldsymbol{R}_i|}
12
+
13
+ Parameters
14
+ ----------
15
+ eps : float, optional
16
+ Small constant for numerical stability, by default 1e-5.
17
+ """
18
+
19
+ eps: float
20
+
21
+ def __init__(self, eps=1e-5):
22
+ self.eps = eps
23
+
24
+ def __call__(self, den, score, x, Ne, mol, xp) -> Float[Array, "batch"]:
25
+ r"""
26
+ Parameters
27
+ ----------
28
+ x : Array
29
+ Points where the potential is evaluated.
30
+ Ne : int
31
+ Number of electrons.
32
+ mol : dict
33
+ Nuclear coordinates and charges, {'coords': ..., 'z': ...}.
34
+
35
+ Notes
36
+ -----
37
+ den, score, xp are accepted for the shared functional interface but unused here.
38
+
39
+ Returns
40
+ -------
41
+ jax.Array
42
+ Electron-nuclei attraction at each point (up to the rho factor).
43
+ """
44
+ def _potential(pts, molecule):
45
+ r = jnp.sqrt(jnp.sum((pts - molecule['coords']) ** 2, axis=1)) + self.eps
46
+ return molecule['z'] / r
47
+
48
+ r = vmap(_potential, in_axes=(None, 0), out_axes=-1)(x, mol)
49
+ return -Ne * jnp.sum(r, axis=-1, keepdims=True)
@@ -0,0 +1,129 @@
1
+ import jax.numpy as jnp
2
+ import equinox as eqx
3
+ from jaxtyping import Array, Float
4
+ from typing import Any, Callable, NamedTuple, Optional
5
+
6
+
7
+ class FunctionalInputs(NamedTuple):
8
+ r"""
9
+ Container holding every input a functional might need.
10
+
11
+ It is the single bundle handed to :class:`EnergyFunctional`, which then
12
+ *separates* it into explicit arguments for each component functional. The leaf
13
+ functionals therefore receive plain ``(den, score, x, Ne, mol, xp)`` and never
14
+ see this object.
15
+
16
+ Fields
17
+ ------
18
+ den : Float[Array, "batch 1"] density shape factor rho_phi.
19
+ score : Float[Array, "batch d"] score, (nabla rho)/rho = nabla log rho_phi.
20
+ x : Float[Array, "batch d"] sample / grid positions.
21
+ Ne : int number of electrons.
22
+ mol : dict {'coords': ..., 'z': ...} nuclear geometry / charges.
23
+ xp : Float[Array, "batch d"] second set of positions for the pairwise Hartree.
24
+ """
25
+ den: Any
26
+ score: Any
27
+ x: Any
28
+ Ne: Any
29
+ mol: Any
30
+ xp: Any = None
31
+
32
+
33
+ def unit_coefficient(self, *_):
34
+ r"""Constant unit weight (c = 1): turns a Functional into a fixed (non-learned) functional."""
35
+ return jnp.array([[1.0]])
36
+
37
+
38
+ class Functional(eqx.Module):
39
+ r"""
40
+ Local density functional, F[\rho] = \int c_\theta[\rho] \cdot e[\rho]\, d\boldsymbol{x}.
41
+
42
+ Every functional shares the explicit signature ``(den, score, x, Ne, mol, xp)``
43
+ -- all inputs are passed even if a given functional ignores some -- and returns
44
+ its per-point energy density of shape (batch, 1). Assembled from two callables:
45
+
46
+ energy_densities(den, score, x, Ne, mol, xp) -> e[rho]
47
+ the energy densities, returned up to the rho factor.
48
+ coefficients(self, cinputs) -> c_theta[rho]
49
+ the weights; :func:`unit_coefficient` (c = 1) for a fixed functional, a
50
+ network for a learned one.
51
+ coefficient_inputs(den, score, x, Ne, mol, xp) -> features (optional)
52
+ features fed to ``coefficients`` (learned functionals only).
53
+ """
54
+
55
+ coefficients: Callable
56
+ energy_densities: Callable
57
+ coefficient_inputs: Optional[Callable] = None
58
+
59
+ def __call__(self, den, score, x, Ne, mol, xp) -> Float[Array, "batch 1"]:
60
+ e = self.energy_densities(den, score, x, Ne, mol, xp)
61
+ ci = (self.coefficient_inputs(den, score, x, Ne, mol, xp)
62
+ if self.coefficient_inputs is not None else None)
63
+ c = self.coefficients(self, ci)
64
+ return jnp.sum(c * e, axis=-1, keepdims=True)
65
+ #Add a functional, or density.
66
+
67
+ class CompositeFunctional(eqx.Module):
68
+ r"""Sum of several functionals, all sharing the ``(den, score, x, Ne, mol, xp)`` signature."""
69
+
70
+ functionals: list
71
+
72
+ def __init__(self, *functionals):
73
+ self.functionals = functionals
74
+
75
+ def __call__(self, den, score, x, Ne, mol, xp) -> Float[Array, "batch 1"]:
76
+ result = 0.0
77
+ for func in self.functionals:
78
+ result = result + func(den, score, x, Ne, mol, xp)
79
+ return result
80
+
81
+ def __add__(self, other):
82
+ if isinstance(other, CompositeFunctional):
83
+ return CompositeFunctional(self.functionals + other.functionals)
84
+ return CompositeFunctional(self.functionals + [other])
85
+
86
+
87
+ class EnergyFunctional(eqx.Module):
88
+ r"""
89
+ High-level OF-DFT energy functional.
90
+
91
+ Receives the single :class:`FunctionalInputs` bundle and *separates* it into
92
+ explicit arguments for each component functional (kinetic, external/nuclear,
93
+ Hartree, exchange, and optionally correlation and core-correction): every
94
+ component is called with the same ``(den, score, x, Ne, mol, xp)`` and uses only
95
+ what it needs. ``correlation`` and ``core_correction`` may be ``None``.
96
+
97
+ ``terms`` returns the per-component energy densities (convenient for logging);
98
+ ``__call__`` returns their sum. Integrate with :meth:`_integrate` (Monte-Carlo
99
+ measure 1/N during training, or grid weights w*rho for quadrature).
100
+ """
101
+
102
+ kinetic: Any
103
+ external: Any
104
+ hartree: Any
105
+ exchange: Any
106
+ correlation: Any = None
107
+ core_correction: Any = None
108
+
109
+ @staticmethod
110
+ def _integrate(energy_density, weights):
111
+ r"""Quadrature: integral ~ sum_i w_i e_i. `weights` is 1/N (Monte-Carlo) or w_i*rho_i (grid)."""
112
+ return jnp.sum(weights * energy_density)
113
+
114
+ def terms(self, inp: FunctionalInputs) -> dict:
115
+ r"""Separate ``inp`` and evaluate every component on the same explicit arguments."""
116
+ a = (inp.den, inp.score, inp.x, inp.Ne, inp.mol, inp.xp)
117
+ return {
118
+ "kin": self.kinetic(*a),
119
+ "vnuc": self.external(*a),
120
+ "hart": self.hartree(*a),
121
+ "x": self.exchange(*a),
122
+ "c": self.correlation(*a) if self.correlation is not None else 0.0,
123
+ "cc": self.core_correction(*a) if self.core_correction is not None else 0.0,
124
+ }
125
+
126
+ def __call__(self, inp: FunctionalInputs):
127
+ r"""Per-point total energy density (sum of all components)."""
128
+ t = self.terms(inp)
129
+ return t["kin"] + t["vnuc"] + t["hart"] + t["x"] + t["c"] + t["cc"]
@@ -0,0 +1,62 @@
1
+ import jax.numpy as jnp
2
+ import equinox as eqx
3
+ from jaxtyping import Array, Float
4
+
5
+
6
+ class CoulombPotential_(eqx.Module):
7
+ r"""
8
+ Classical electron-electron repulsion (Hartree) potential, element-wise pairs.
9
+
10
+ Pairs each x_i with x'_i (the two half-batches), giving `batch` Monte-Carlo
11
+ pairs:
12
+
13
+ V_{\text{Hartree}}(\boldsymbol{x}, \boldsymbol{x}') = \frac{1}{2} N_e^2 \frac{1}{|\boldsymbol{x} - \boldsymbol{x}'|}
14
+
15
+ Parameters
16
+ ----------
17
+ eps : float, optional
18
+ Small constant for numerical stability, by default 1e-5.
19
+ """
20
+
21
+ eps: float
22
+
23
+ def __init__(self, eps=1e-5):
24
+ self.eps = eps
25
+
26
+ def __call__(self, den, score, x, Ne, mol, xp) -> Float[Array, "batch 1"]:
27
+ r"""Uses x, xp (paired points) and Ne; den, score, mol unused. Returns ½Ne²/|x-x'| per pair."""
28
+ z = jnp.sum((x - xp) ** 2 + self.eps, axis=-1, keepdims=True)
29
+ return 0.5 * (Ne ** 2) / jnp.sqrt(z)
30
+
31
+
32
+ class CoulombPotential(eqx.Module):
33
+ r"""
34
+ Classical electron-electron repulsion (Hartree) potential, all-pairs estimator.
35
+
36
+ Same physics as :class:`CoulombPotential_` (true 1/|x-x'| Coulomb), but averages
37
+ each x_i against *every* x'_j (the full batch x batch double sum), i.e. batch^2
38
+ pairs instead of batch. This is the same estimator the grid quadrature uses and
39
+ has substantially lower Monte-Carlo variance for the same samples, at the cost
40
+ of a (batch, batch) distance matrix.
41
+
42
+ V_{\text{Hartree}} = \frac{1}{2} N_e^2 \left\langle \frac{1}{|\boldsymbol{x}_i - \boldsymbol{x}'_j|} \right\rangle_j
43
+
44
+ Parameters
45
+ ----------
46
+ eps : float, optional
47
+ Small constant for numerical stability, by default 1e-5.
48
+ """
49
+
50
+ eps: float
51
+
52
+ def __init__(self, eps=1e-5):
53
+ self.eps = eps
54
+
55
+ def __call__(self, den, score, x, Ne, mol, xp) -> Float[Array, "batch 1"]:
56
+ r"""Uses x, xp and Ne; den, score, mol unused. Returns the per-point Hartree potential."""
57
+ x2 = jnp.sum(x * x, axis=-1)
58
+ xp2 = jnp.sum(xp * xp, axis=-1)
59
+ r2 = x2[:, None] + xp2[None, :] - 2.0 * (x @ xp.T)
60
+ r2 = jnp.maximum(r2, 0.0) + self.eps
61
+ v = jnp.mean(1.0 / jnp.sqrt(r2), axis=-1, keepdims=True)
62
+ return 0.5 * (Ne ** 2) * v
@@ -0,0 +1,87 @@
1
+ import jax.numpy as jnp
2
+ from jax import lax
3
+ from .functional import Functional, CompositeFunctional, unit_coefficient
4
+
5
+ C_TF = (3. / 10.) * (3. * jnp.pi ** 2) ** (2 / 3)
6
+
7
+
8
+ def tf_density(den, score, x, Ne, mol, xp, c=C_TF):
9
+ r"""
10
+ Thomas-Fermi kinetic functional.
11
+
12
+ See paper eq. 2 in https://pubs.aip.org/aip/jcp/article/114/2/631/184186/Thomas-Fermi-Dirac-von-Weizsacker-models-in-finite
13
+
14
+ T_{\text{TF}}[\rho] = c \int \rho(\boldsymbol{x})^{5/3} d\boldsymbol{x}
15
+ = c \int \rho(\boldsymbol{x})^{2/3} \rho(\boldsymbol{x}) d\boldsymbol{x}
16
+ T_{\text{TF}}[\rho] = c\, N_e^{5/3}\, \mathbb{E}_{\rho_\phi}\left[ \rho_\phi(\boldsymbol{x})^{2/3} \right]
17
+
18
+ with c = \frac{3}{10}(3\pi^2)^{2/3}.
19
+
20
+ Parameters
21
+ ----------
22
+ den : Array
23
+ Density.
24
+ Ne : int
25
+ Number of electrons.
26
+ c : float, optional
27
+ Prefactor, by default (3/10)(3 pi^2)^(2/3).
28
+
29
+ Notes
30
+ -----
31
+ score, x, mol, xp are accepted for the shared functional interface but unused here.
32
+
33
+ Returns
34
+ -------
35
+ jax.Array
36
+ Thomas-Fermi kinetic energy density (up to the rho factor).
37
+ """
38
+ return c * (Ne ** (5 / 3)) * den ** (2 / 3)
39
+
40
+
41
+ def weizsacker_density(den, score, x, Ne, mol, xp, lam=0.2):
42
+ r"""
43
+ von Weizsacker gradient correction.
44
+
45
+ See paper eq. 3 in https://pubs.aip.org/aip/jcp/article/114/2/631/184186/Thomas-Fermi-Dirac-von-Weizsacker-models-in-finite
46
+
47
+ T_{\text{Weizsacker}}[\rho] = \frac{\lambda}{8} \int \frac{(\nabla \rho)^2}{\rho} d\boldsymbol{x}
48
+ = \frac{\lambda}{8} \int \rho \left(\frac{\nabla \rho}{\rho}\right)^2 d\boldsymbol{x}
49
+ T_{\text{Weizsacker}}[\rho] = \frac{\lambda N_e}{8}\, \mathbb{E}_{\rho_\phi}\left[ \left(\frac{\nabla \rho}{\rho}\right)^2 \right]
50
+
51
+ Parameters
52
+ ----------
53
+ score : Array
54
+ Gradient of the log-density, s = (nabla rho)/rho.
55
+ Ne : int
56
+ Number of electrons.
57
+ lam : float, optional (W. Stich, E.K.U. Gross, Z. Physik A 309(1):511, 1982)
58
+ Phenomenological parameter lambda, by default 0.2.
59
+
60
+ Notes
61
+ -----
62
+ den, x, mol, xp are accepted for the shared functional interface but unused here.
63
+
64
+ Returns
65
+ -------
66
+ jax.Array
67
+ von Weizsacker kinetic energy density (up to the rho factor).
68
+ """
69
+ score_sqr = jnp.einsum('ij,ij->i', score, score)
70
+ return (lam * Ne / 8.) * lax.expand_dims(score_sqr, (1,))
71
+
72
+
73
+ tf = Functional(coefficients=unit_coefficient, energy_densities=tf_density)
74
+
75
+
76
+ def weizsacker(lam=0.2):
77
+ r"""von Weizsacker functional with prefactor ``lam`` (see :func:`weizsacker_density`)."""
78
+ return Functional(
79
+ coefficients=unit_coefficient,
80
+ energy_densities=lambda den, score, x, Ne, mol, xp: weizsacker_density(
81
+ den, score, x, Ne, mol, xp, lam),
82
+ )
83
+
84
+
85
+ def tf_weizsacker(lam=0.2):
86
+ r"""TF-lambda-W kinetic functional: T = T_TF + lambda * T_W."""
87
+ return CompositeFunctional(tf, weizsacker(lam))