off 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- off/__init__.py +23 -0
- off/atom_energies.py +151 -0
- off/config/_config.py +108 -0
- off/dft_distrax/__init__.py +27 -0
- off/dft_distrax/dft_distrax.py +216 -0
- off/flow/__init__.py +29 -0
- off/flow/equiv_flows.py +99 -0
- off/functionals/__init__.py +35 -0
- off/functionals/core_correction.py +84 -0
- off/functionals/exchange_correlation.py +174 -0
- off/functionals/external.py +49 -0
- off/functionals/functional.py +129 -0
- off/functionals/hartree.py +62 -0
- off/functionals/kinetic.py +87 -0
- off/main.py +172 -0
- off/ode_solver/__init__.py +32 -0
- off/ode_solver/eqx_ode.py +76 -0
- off/plot_binding_csv.py +63 -0
- off/plot_pes_ema.py +259 -0
- off/plot_pes_mpl.py +280 -0
- off/promolecular/__init__.py +27 -0
- off/promolecular/promolecular_dist.py +465 -0
- off/quadrature.py +261 -0
- off/quadrature_scan.py +188 -0
- off/scan_pes.py +133 -0
- off/test_fwd_rev.py +290 -0
- off/train/__init__.py +44 -0
- off/train/loop.py +228 -0
- off/train/loss.py +149 -0
- off/train/utils.py +38 -0
- off/utils.py +618 -0
- off-0.1.0.dist-info/METADATA +154 -0
- off-0.1.0.dist-info/RECORD +37 -0
- off-0.1.0.dist-info/WHEEL +5 -0
- off-0.1.0.dist-info/entry_points.txt +3 -0
- off-0.1.0.dist-info/licenses/LICENSE +21 -0
- off-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
from jaxtyping import Array, Float
|
|
5
|
+
from jax import vmap, lax
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class KatoCondition(eqx.Module):
|
|
9
|
+
r"""
|
|
10
|
+
Kato cusp-condition functional.
|
|
11
|
+
|
|
12
|
+
Penalizes deviations from the correct electron-nucleus cusp behaviour, weighting
|
|
13
|
+
the (Weizsacker-like) local kinetic term by a 1s-like envelope
|
|
14
|
+
w_i(r) = exp(-a Z_i |r - R_i|) around each nucleus.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
a : float, optional
|
|
19
|
+
Envelope decay prefactor, by default 2/3.
|
|
20
|
+
eps : float, optional
|
|
21
|
+
Small constant for numerical stability, by default 1e-5.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
a: float
|
|
25
|
+
eps: float
|
|
26
|
+
|
|
27
|
+
def __init__(self, a=2.0/3.0, eps=1e-5):
|
|
28
|
+
self.a = a
|
|
29
|
+
self.eps = eps
|
|
30
|
+
|
|
31
|
+
def __call__(self, den, score, x, Ne, mol, xp) -> Float[Array, "batch"]:
|
|
32
|
+
r"""Uses x, score, Ne, mol; den, xp unused. Returns the cusp penalty per point."""
|
|
33
|
+
def _wi(pts, molecule):
|
|
34
|
+
r = jnp.sqrt(jnp.sum((pts - molecule['coords']) ** 2, axis=1) + self.eps ** 2)
|
|
35
|
+
return jnp.exp(-self.a * molecule['z'] * r)
|
|
36
|
+
|
|
37
|
+
wi = vmap(_wi, in_axes=(None, 0), out_axes=-1)(x, mol)
|
|
38
|
+
score_sqr = jnp.einsum('ij,ij->i', score, score)
|
|
39
|
+
weizs = (1.0 / 8.0) * lax.expand_dims(score_sqr, (1,))
|
|
40
|
+
kinetic = jnp.abs(weizs - (mol['z'] ** 2 / 2))
|
|
41
|
+
return Ne * jnp.sum(kinetic * wi, axis=-1, keepdims=True)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class HutcheonCuspCondition(eqx.Module):
|
|
45
|
+
r"""
|
|
46
|
+
Exact nuclear-cusp cost functional.
|
|
47
|
+
|
|
48
|
+
From Hutcheon & Wibowo-Teale, Phys. Rev. B 110, 195146 (2024), Eqs. (4-5, 7):
|
|
49
|
+
|
|
50
|
+
E_{\text{CC}} = \sum_i C_i, \qquad
|
|
51
|
+
C_i(\rho) = \int |\nabla \rho + 2 Z_i \rho\, \hat{r}_i|^2\, w_i(r)\, d^3 r,
|
|
52
|
+
|
|
53
|
+
with \hat{r}_i = (r - R_i)/|r - R_i| and w_i(r) = (\pi / Z_i^3)^{1/2} e^{-Z_i |r - R_i|} [Eq. (5)].
|
|
54
|
+
Since \nabla \rho = \rho\, s (s the score), the integrand is \rho^2 |s + 2 Z_i \hat{r}_i|^2,
|
|
55
|
+
and the Kato cusp is satisfied exactly when s = -2 Z_i \hat{r}_i at R_i. With samples
|
|
56
|
+
from p = \rho / N_e,
|
|
57
|
+
|
|
58
|
+
C_i \approx N_e^2\, \mathbb{E}\left[ \rho(x)\, |s(x) + 2 Z_i \hat{r}_i(x)|^2\, w_i(x) \right].
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
eps : float, optional
|
|
63
|
+
Small constant for numerical stability, by default 1e-8.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
eps: float
|
|
67
|
+
|
|
68
|
+
def __init__(self, eps=1e-8):
|
|
69
|
+
self.eps = eps
|
|
70
|
+
|
|
71
|
+
def __call__(self, den, score, x, Ne, mol, xp) -> Float[Array, "batch 1"]:
|
|
72
|
+
r"""Uses x, den, score, Ne, mol; xp unused. Returns the cusp cost per point."""
|
|
73
|
+
coords, z = mol['coords'], mol['z']
|
|
74
|
+
|
|
75
|
+
def per_atom(R_i, Z_i):
|
|
76
|
+
r_vec = x - R_i
|
|
77
|
+
r_norm = jnp.sqrt(jnp.sum(r_vec ** 2, axis=-1) + self.eps ** 2)
|
|
78
|
+
r_hat = r_vec / r_norm[:, None]
|
|
79
|
+
w_i = jnp.sqrt(jnp.pi / Z_i ** 3) * jnp.exp(-Z_i * r_norm)
|
|
80
|
+
cusp_sq = jnp.sum((score + 2.0 * Z_i * r_hat) ** 2, axis=-1)
|
|
81
|
+
return cusp_sq * w_i
|
|
82
|
+
|
|
83
|
+
total = jnp.sum(jax.vmap(per_atom)(coords, z), axis=0)
|
|
84
|
+
return (Ne * den.squeeze(-1) * total)[:, None]
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jax import lax
|
|
3
|
+
from .functional import Functional, CompositeFunctional, unit_coefficient
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def lda_density(den, score, x, Ne, mol, xp):
|
|
7
|
+
r"""
|
|
8
|
+
Local density approximation (LDA) / Dirac exchange functional.
|
|
9
|
+
|
|
10
|
+
See eq. 2.72 in "Time-Dependent Density-Functional Theory", Carsten A. Ullrich.
|
|
11
|
+
|
|
12
|
+
E_{\text{X}}^{\text{LDA}}[\rho] = -\frac{3}{4}\left(\frac{3}{\pi}\right)^{1/3} \int \rho(\boldsymbol{x})^{4/3} d\boldsymbol{x}
|
|
13
|
+
= -\frac{3}{4}\left(\frac{3}{\pi}\right)^{1/3} N_e^{4/3}\, \mathbb{E}_{\rho_\phi}\left[ \rho_\phi(\boldsymbol{x})^{1/3} \right]
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
den : Array
|
|
18
|
+
Density.
|
|
19
|
+
Ne : int
|
|
20
|
+
Number of electrons.
|
|
21
|
+
|
|
22
|
+
Notes
|
|
23
|
+
-----
|
|
24
|
+
score, x, mol, xp are accepted for the shared functional interface but unused here.
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
jax.Array
|
|
29
|
+
LDA exchange energy density (up to the rho factor).
|
|
30
|
+
"""
|
|
31
|
+
l = -(3 / 4) * (Ne ** (4 / 3)) * (3 / jnp.pi) ** (1 / 3)
|
|
32
|
+
return l * den ** (1 / 3)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def b88_density(den, score, x, Ne, mol, xp, clip_cte=1e-30, beta=0.0042):
|
|
36
|
+
r"""
|
|
37
|
+
B88 exchange functional.
|
|
38
|
+
|
|
39
|
+
See eq. 8 in https://journals.aps.org/pra/abstract/10.1103/PhysRevA.38.3098
|
|
40
|
+
See also https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/gga_exc/gga_x_b88.mpl
|
|
41
|
+
|
|
42
|
+
E_{\text{X}}^{\text{B88}}[\rho] = -\beta \int \frac{X^2}{1 + 6\beta X \sinh^{-1}(X)} \rho(\boldsymbol{x})^{4/3} d\boldsymbol{x},
|
|
43
|
+
\qquad X = \frac{|\nabla \rho|}{\rho^{4/3}}.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
den : Array
|
|
48
|
+
Density.
|
|
49
|
+
score : Array
|
|
50
|
+
Gradient of the log-density, s = (nabla rho)/rho.
|
|
51
|
+
Ne : int
|
|
52
|
+
Number of electrons.
|
|
53
|
+
clip_cte : float, optional
|
|
54
|
+
Small constant for numerical stability, by default 1e-30.
|
|
55
|
+
beta : float, optional
|
|
56
|
+
B88 parameter, by default 0.0042.
|
|
57
|
+
|
|
58
|
+
Notes
|
|
59
|
+
-----
|
|
60
|
+
x, mol, xp are accepted for the shared functional interface but unused here.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
jax.Array
|
|
65
|
+
B88 exchange energy density (up to the rho factor).
|
|
66
|
+
"""
|
|
67
|
+
den_clipped = jnp.clip(den, clip_cte)
|
|
68
|
+
log_den = jnp.log2(den_clipped)
|
|
69
|
+
score_sqr = jnp.einsum('ij,ij->i', score, score)
|
|
70
|
+
grad_den_norm_sq = lax.expand_dims(score_sqr, (1,)) * den_clipped * den_clipped
|
|
71
|
+
log_grad_den_norm = jnp.log2(jnp.clip(grad_den_norm_sq, clip_cte)) / 2
|
|
72
|
+
log_x_sigma = log_grad_den_norm - 4 / 3.0 * log_den
|
|
73
|
+
x_sigma = 2 ** log_x_sigma
|
|
74
|
+
b88_e = -(beta * 2 ** (4 * log_den / 3 + 2 * log_x_sigma
|
|
75
|
+
- jnp.log2(1 + 6 * beta * x_sigma * jnp.arcsinh(x_sigma))))
|
|
76
|
+
return b88_e * Ne ** (2 / 3)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def vwn_density(den, score, x, Ne, mol, xp, clip_cte=1e-30):
|
|
80
|
+
r"""
|
|
81
|
+
VWN correlation functional.
|
|
82
|
+
|
|
83
|
+
See original paper eq. 4.4 in https://cdnsciencepub.com/doi/abs/10.1139/p80-159
|
|
84
|
+
See also the text after eq. 8.9.6.1 in https://www.theoretical-physics.com/dev/quantum/dft.html
|
|
85
|
+
|
|
86
|
+
\epsilon_{\text{C}}^{\text{VWN}} = \frac{A}{2}\left\{ \ln\frac{y^2}{Y(y)}
|
|
87
|
+
+ \frac{2b}{Q}\arctan\frac{Q}{2y+b}
|
|
88
|
+
- \frac{b y_0}{Y(y_0)}\left[ \ln\frac{(y-y_0)^2}{Y(y)}
|
|
89
|
+
+ \frac{2(b+2y_0)}{Q}\arctan\frac{Q}{2y+b}\right] \right\}
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
den : Array
|
|
94
|
+
Density.
|
|
95
|
+
Ne : int
|
|
96
|
+
Number of electrons.
|
|
97
|
+
clip_cte : float, optional
|
|
98
|
+
Small constant for numerical stability, by default 1e-30.
|
|
99
|
+
|
|
100
|
+
Notes
|
|
101
|
+
-----
|
|
102
|
+
score, x, mol, xp are accepted for the shared functional interface but unused here.
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
jax.Array
|
|
107
|
+
VWN correlation energy density (up to the rho factor).
|
|
108
|
+
"""
|
|
109
|
+
A, b, c, x0 = 0.0621814, 3.72744, 12.9352, -0.10498
|
|
110
|
+
den_clipped = jnp.where(den > clip_cte, den, 0.0)
|
|
111
|
+
log_den = jnp.log2(jnp.clip(den_clipped, clip_cte))
|
|
112
|
+
log_rs = jnp.log2((3 / (4 * jnp.pi)) ** (1 / 3)) - log_den / 3.0
|
|
113
|
+
log_x = log_rs / 2
|
|
114
|
+
x_ = 2. ** log_x
|
|
115
|
+
X = 2. ** (2. * log_x) + 2. ** (log_x + jnp.log2(b)) + c
|
|
116
|
+
X0 = x0 ** 2 + b * x0 + c
|
|
117
|
+
Q = jnp.sqrt(4 * c - b ** 2)
|
|
118
|
+
e_PF = A / 2. * (
|
|
119
|
+
2. * jnp.log(x_) - jnp.log(X) + 2. * b / Q * jnp.arctan(Q / (2. * x_ + b))
|
|
120
|
+
- b * x0 / X0 * (jnp.log((x_ - x0) ** 2. / X)
|
|
121
|
+
+ 2. * (2. * x0 + b) / Q * jnp.arctan(Q / (2. * x_ + b)))
|
|
122
|
+
)
|
|
123
|
+
return Ne * e_PF
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def pw92_density(den, score, x, Ne, mol, xp, clip_cte=1e-30):
|
|
127
|
+
r"""
|
|
128
|
+
PW92 correlation functional.
|
|
129
|
+
|
|
130
|
+
See eq. 10 in https://journals.aps.org/prb/abstract/10.1103/PhysRevB.45.13244
|
|
131
|
+
|
|
132
|
+
\epsilon_{\text{C}}^{\text{PW92}} = -2A(1 + \alpha_1 r_s)
|
|
133
|
+
\ln\left[ 1 + \frac{1}{2A(\beta_1 r_s^{1/2} + \beta_2 r_s + \beta_3 r_s^{3/2} + \beta_4 r_s^2)} \right]
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
den : Array
|
|
138
|
+
Density.
|
|
139
|
+
Ne : int
|
|
140
|
+
Number of electrons.
|
|
141
|
+
clip_cte : float, optional
|
|
142
|
+
Small constant for numerical stability, by default 1e-30.
|
|
143
|
+
|
|
144
|
+
Notes
|
|
145
|
+
-----
|
|
146
|
+
score, x, mol, xp are accepted for the shared functional interface but unused here.
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
jax.Array
|
|
151
|
+
PW92 correlation energy density (up to the rho factor).
|
|
152
|
+
"""
|
|
153
|
+
A_, alpha1 = 0.031091, 0.21370
|
|
154
|
+
beta1, beta2, beta3, beta4 = 7.5957, 3.5876, 1.6382, 0.49294
|
|
155
|
+
log_den = jnp.log2(jnp.clip(den, clip_cte))
|
|
156
|
+
log_rs = jnp.log2((3 / (4 * jnp.pi)) ** (1 / 3)) - log_den / 3.0
|
|
157
|
+
brs_1_2 = 2 ** (log_rs / 2 + jnp.log2(beta1))
|
|
158
|
+
ars = 2 ** (log_rs + jnp.log2(alpha1))
|
|
159
|
+
brs = 2 ** (log_rs + jnp.log2(beta2))
|
|
160
|
+
brs_3_2 = 2 ** (3 * log_rs / 2 + jnp.log2(beta3))
|
|
161
|
+
brs2 = 2 ** (2 * log_rs + jnp.log2(beta4))
|
|
162
|
+
e_PF = -2 * A_ * (1 + ars) * jnp.log(1 + (1 / (2 * A_)) / (brs_1_2 + brs + brs_3_2 + brs2))
|
|
163
|
+
return Ne * e_PF
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
lda = Functional(coefficients=unit_coefficient, energy_densities=lda_density)
|
|
167
|
+
b88 = Functional(coefficients=unit_coefficient, energy_densities=b88_density)
|
|
168
|
+
vwn = Functional(coefficients=unit_coefficient, energy_densities=vwn_density)
|
|
169
|
+
pw92 = Functional(coefficients=unit_coefficient, energy_densities=pw92_density)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def lda_b88():
|
|
173
|
+
r"""LDA + B88 exchange."""
|
|
174
|
+
return CompositeFunctional(lda, b88)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import equinox as eqx
|
|
3
|
+
from jaxtyping import Array, Float
|
|
4
|
+
from jax import vmap
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class NuclearPotential(eqx.Module):
|
|
8
|
+
r"""
|
|
9
|
+
External electron-nuclei attraction potential.
|
|
10
|
+
|
|
11
|
+
V_{\text{ext}}(\boldsymbol{x}) = -N_e \sum_i \frac{Z_i}{|\boldsymbol{x} - \boldsymbol{R}_i|}
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
eps : float, optional
|
|
16
|
+
Small constant for numerical stability, by default 1e-5.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
eps: float
|
|
20
|
+
|
|
21
|
+
def __init__(self, eps=1e-5):
|
|
22
|
+
self.eps = eps
|
|
23
|
+
|
|
24
|
+
def __call__(self, den, score, x, Ne, mol, xp) -> Float[Array, "batch"]:
|
|
25
|
+
r"""
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
x : Array
|
|
29
|
+
Points where the potential is evaluated.
|
|
30
|
+
Ne : int
|
|
31
|
+
Number of electrons.
|
|
32
|
+
mol : dict
|
|
33
|
+
Nuclear coordinates and charges, {'coords': ..., 'z': ...}.
|
|
34
|
+
|
|
35
|
+
Notes
|
|
36
|
+
-----
|
|
37
|
+
den, score, xp are accepted for the shared functional interface but unused here.
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
jax.Array
|
|
42
|
+
Electron-nuclei attraction at each point (up to the rho factor).
|
|
43
|
+
"""
|
|
44
|
+
def _potential(pts, molecule):
|
|
45
|
+
r = jnp.sqrt(jnp.sum((pts - molecule['coords']) ** 2, axis=1)) + self.eps
|
|
46
|
+
return molecule['z'] / r
|
|
47
|
+
|
|
48
|
+
r = vmap(_potential, in_axes=(None, 0), out_axes=-1)(x, mol)
|
|
49
|
+
return -Ne * jnp.sum(r, axis=-1, keepdims=True)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import equinox as eqx
|
|
3
|
+
from jaxtyping import Array, Float
|
|
4
|
+
from typing import Any, Callable, NamedTuple, Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class FunctionalInputs(NamedTuple):
|
|
8
|
+
r"""
|
|
9
|
+
Container holding every input a functional might need.
|
|
10
|
+
|
|
11
|
+
It is the single bundle handed to :class:`EnergyFunctional`, which then
|
|
12
|
+
*separates* it into explicit arguments for each component functional. The leaf
|
|
13
|
+
functionals therefore receive plain ``(den, score, x, Ne, mol, xp)`` and never
|
|
14
|
+
see this object.
|
|
15
|
+
|
|
16
|
+
Fields
|
|
17
|
+
------
|
|
18
|
+
den : Float[Array, "batch 1"] density shape factor rho_phi.
|
|
19
|
+
score : Float[Array, "batch d"] score, (nabla rho)/rho = nabla log rho_phi.
|
|
20
|
+
x : Float[Array, "batch d"] sample / grid positions.
|
|
21
|
+
Ne : int number of electrons.
|
|
22
|
+
mol : dict {'coords': ..., 'z': ...} nuclear geometry / charges.
|
|
23
|
+
xp : Float[Array, "batch d"] second set of positions for the pairwise Hartree.
|
|
24
|
+
"""
|
|
25
|
+
den: Any
|
|
26
|
+
score: Any
|
|
27
|
+
x: Any
|
|
28
|
+
Ne: Any
|
|
29
|
+
mol: Any
|
|
30
|
+
xp: Any = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def unit_coefficient(self, *_):
|
|
34
|
+
r"""Constant unit weight (c = 1): turns a Functional into a fixed (non-learned) functional."""
|
|
35
|
+
return jnp.array([[1.0]])
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Functional(eqx.Module):
|
|
39
|
+
r"""
|
|
40
|
+
Local density functional, F[\rho] = \int c_\theta[\rho] \cdot e[\rho]\, d\boldsymbol{x}.
|
|
41
|
+
|
|
42
|
+
Every functional shares the explicit signature ``(den, score, x, Ne, mol, xp)``
|
|
43
|
+
-- all inputs are passed even if a given functional ignores some -- and returns
|
|
44
|
+
its per-point energy density of shape (batch, 1). Assembled from two callables:
|
|
45
|
+
|
|
46
|
+
energy_densities(den, score, x, Ne, mol, xp) -> e[rho]
|
|
47
|
+
the energy densities, returned up to the rho factor.
|
|
48
|
+
coefficients(self, cinputs) -> c_theta[rho]
|
|
49
|
+
the weights; :func:`unit_coefficient` (c = 1) for a fixed functional, a
|
|
50
|
+
network for a learned one.
|
|
51
|
+
coefficient_inputs(den, score, x, Ne, mol, xp) -> features (optional)
|
|
52
|
+
features fed to ``coefficients`` (learned functionals only).
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
coefficients: Callable
|
|
56
|
+
energy_densities: Callable
|
|
57
|
+
coefficient_inputs: Optional[Callable] = None
|
|
58
|
+
|
|
59
|
+
def __call__(self, den, score, x, Ne, mol, xp) -> Float[Array, "batch 1"]:
|
|
60
|
+
e = self.energy_densities(den, score, x, Ne, mol, xp)
|
|
61
|
+
ci = (self.coefficient_inputs(den, score, x, Ne, mol, xp)
|
|
62
|
+
if self.coefficient_inputs is not None else None)
|
|
63
|
+
c = self.coefficients(self, ci)
|
|
64
|
+
return jnp.sum(c * e, axis=-1, keepdims=True)
|
|
65
|
+
#Add a functional, or density.
|
|
66
|
+
|
|
67
|
+
class CompositeFunctional(eqx.Module):
|
|
68
|
+
r"""Sum of several functionals, all sharing the ``(den, score, x, Ne, mol, xp)`` signature."""
|
|
69
|
+
|
|
70
|
+
functionals: list
|
|
71
|
+
|
|
72
|
+
def __init__(self, *functionals):
|
|
73
|
+
self.functionals = functionals
|
|
74
|
+
|
|
75
|
+
def __call__(self, den, score, x, Ne, mol, xp) -> Float[Array, "batch 1"]:
|
|
76
|
+
result = 0.0
|
|
77
|
+
for func in self.functionals:
|
|
78
|
+
result = result + func(den, score, x, Ne, mol, xp)
|
|
79
|
+
return result
|
|
80
|
+
|
|
81
|
+
def __add__(self, other):
|
|
82
|
+
if isinstance(other, CompositeFunctional):
|
|
83
|
+
return CompositeFunctional(self.functionals + other.functionals)
|
|
84
|
+
return CompositeFunctional(self.functionals + [other])
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class EnergyFunctional(eqx.Module):
|
|
88
|
+
r"""
|
|
89
|
+
High-level OF-DFT energy functional.
|
|
90
|
+
|
|
91
|
+
Receives the single :class:`FunctionalInputs` bundle and *separates* it into
|
|
92
|
+
explicit arguments for each component functional (kinetic, external/nuclear,
|
|
93
|
+
Hartree, exchange, and optionally correlation and core-correction): every
|
|
94
|
+
component is called with the same ``(den, score, x, Ne, mol, xp)`` and uses only
|
|
95
|
+
what it needs. ``correlation`` and ``core_correction`` may be ``None``.
|
|
96
|
+
|
|
97
|
+
``terms`` returns the per-component energy densities (convenient for logging);
|
|
98
|
+
``__call__`` returns their sum. Integrate with :meth:`_integrate` (Monte-Carlo
|
|
99
|
+
measure 1/N during training, or grid weights w*rho for quadrature).
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
kinetic: Any
|
|
103
|
+
external: Any
|
|
104
|
+
hartree: Any
|
|
105
|
+
exchange: Any
|
|
106
|
+
correlation: Any = None
|
|
107
|
+
core_correction: Any = None
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def _integrate(energy_density, weights):
|
|
111
|
+
r"""Quadrature: integral ~ sum_i w_i e_i. `weights` is 1/N (Monte-Carlo) or w_i*rho_i (grid)."""
|
|
112
|
+
return jnp.sum(weights * energy_density)
|
|
113
|
+
|
|
114
|
+
def terms(self, inp: FunctionalInputs) -> dict:
|
|
115
|
+
r"""Separate ``inp`` and evaluate every component on the same explicit arguments."""
|
|
116
|
+
a = (inp.den, inp.score, inp.x, inp.Ne, inp.mol, inp.xp)
|
|
117
|
+
return {
|
|
118
|
+
"kin": self.kinetic(*a),
|
|
119
|
+
"vnuc": self.external(*a),
|
|
120
|
+
"hart": self.hartree(*a),
|
|
121
|
+
"x": self.exchange(*a),
|
|
122
|
+
"c": self.correlation(*a) if self.correlation is not None else 0.0,
|
|
123
|
+
"cc": self.core_correction(*a) if self.core_correction is not None else 0.0,
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
def __call__(self, inp: FunctionalInputs):
|
|
127
|
+
r"""Per-point total energy density (sum of all components)."""
|
|
128
|
+
t = self.terms(inp)
|
|
129
|
+
return t["kin"] + t["vnuc"] + t["hart"] + t["x"] + t["c"] + t["cc"]
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import equinox as eqx
|
|
3
|
+
from jaxtyping import Array, Float
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CoulombPotential_(eqx.Module):
|
|
7
|
+
r"""
|
|
8
|
+
Classical electron-electron repulsion (Hartree) potential, element-wise pairs.
|
|
9
|
+
|
|
10
|
+
Pairs each x_i with x'_i (the two half-batches), giving `batch` Monte-Carlo
|
|
11
|
+
pairs:
|
|
12
|
+
|
|
13
|
+
V_{\text{Hartree}}(\boldsymbol{x}, \boldsymbol{x}') = \frac{1}{2} N_e^2 \frac{1}{|\boldsymbol{x} - \boldsymbol{x}'|}
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
eps : float, optional
|
|
18
|
+
Small constant for numerical stability, by default 1e-5.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
eps: float
|
|
22
|
+
|
|
23
|
+
def __init__(self, eps=1e-5):
|
|
24
|
+
self.eps = eps
|
|
25
|
+
|
|
26
|
+
def __call__(self, den, score, x, Ne, mol, xp) -> Float[Array, "batch 1"]:
|
|
27
|
+
r"""Uses x, xp (paired points) and Ne; den, score, mol unused. Returns ½Ne²/|x-x'| per pair."""
|
|
28
|
+
z = jnp.sum((x - xp) ** 2 + self.eps, axis=-1, keepdims=True)
|
|
29
|
+
return 0.5 * (Ne ** 2) / jnp.sqrt(z)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CoulombPotential(eqx.Module):
|
|
33
|
+
r"""
|
|
34
|
+
Classical electron-electron repulsion (Hartree) potential, all-pairs estimator.
|
|
35
|
+
|
|
36
|
+
Same physics as :class:`CoulombPotential_` (true 1/|x-x'| Coulomb), but averages
|
|
37
|
+
each x_i against *every* x'_j (the full batch x batch double sum), i.e. batch^2
|
|
38
|
+
pairs instead of batch. This is the same estimator the grid quadrature uses and
|
|
39
|
+
has substantially lower Monte-Carlo variance for the same samples, at the cost
|
|
40
|
+
of a (batch, batch) distance matrix.
|
|
41
|
+
|
|
42
|
+
V_{\text{Hartree}} = \frac{1}{2} N_e^2 \left\langle \frac{1}{|\boldsymbol{x}_i - \boldsymbol{x}'_j|} \right\rangle_j
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
eps : float, optional
|
|
47
|
+
Small constant for numerical stability, by default 1e-5.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
eps: float
|
|
51
|
+
|
|
52
|
+
def __init__(self, eps=1e-5):
|
|
53
|
+
self.eps = eps
|
|
54
|
+
|
|
55
|
+
def __call__(self, den, score, x, Ne, mol, xp) -> Float[Array, "batch 1"]:
|
|
56
|
+
r"""Uses x, xp and Ne; den, score, mol unused. Returns the per-point Hartree potential."""
|
|
57
|
+
x2 = jnp.sum(x * x, axis=-1)
|
|
58
|
+
xp2 = jnp.sum(xp * xp, axis=-1)
|
|
59
|
+
r2 = x2[:, None] + xp2[None, :] - 2.0 * (x @ xp.T)
|
|
60
|
+
r2 = jnp.maximum(r2, 0.0) + self.eps
|
|
61
|
+
v = jnp.mean(1.0 / jnp.sqrt(r2), axis=-1, keepdims=True)
|
|
62
|
+
return 0.5 * (Ne ** 2) * v
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from jax import lax
|
|
3
|
+
from .functional import Functional, CompositeFunctional, unit_coefficient
|
|
4
|
+
|
|
5
|
+
C_TF = (3. / 10.) * (3. * jnp.pi ** 2) ** (2 / 3)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def tf_density(den, score, x, Ne, mol, xp, c=C_TF):
|
|
9
|
+
r"""
|
|
10
|
+
Thomas-Fermi kinetic functional.
|
|
11
|
+
|
|
12
|
+
See paper eq. 2 in https://pubs.aip.org/aip/jcp/article/114/2/631/184186/Thomas-Fermi-Dirac-von-Weizsacker-models-in-finite
|
|
13
|
+
|
|
14
|
+
T_{\text{TF}}[\rho] = c \int \rho(\boldsymbol{x})^{5/3} d\boldsymbol{x}
|
|
15
|
+
= c \int \rho(\boldsymbol{x})^{2/3} \rho(\boldsymbol{x}) d\boldsymbol{x}
|
|
16
|
+
T_{\text{TF}}[\rho] = c\, N_e^{5/3}\, \mathbb{E}_{\rho_\phi}\left[ \rho_\phi(\boldsymbol{x})^{2/3} \right]
|
|
17
|
+
|
|
18
|
+
with c = \frac{3}{10}(3\pi^2)^{2/3}.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
den : Array
|
|
23
|
+
Density.
|
|
24
|
+
Ne : int
|
|
25
|
+
Number of electrons.
|
|
26
|
+
c : float, optional
|
|
27
|
+
Prefactor, by default (3/10)(3 pi^2)^(2/3).
|
|
28
|
+
|
|
29
|
+
Notes
|
|
30
|
+
-----
|
|
31
|
+
score, x, mol, xp are accepted for the shared functional interface but unused here.
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
jax.Array
|
|
36
|
+
Thomas-Fermi kinetic energy density (up to the rho factor).
|
|
37
|
+
"""
|
|
38
|
+
return c * (Ne ** (5 / 3)) * den ** (2 / 3)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def weizsacker_density(den, score, x, Ne, mol, xp, lam=0.2):
|
|
42
|
+
r"""
|
|
43
|
+
von Weizsacker gradient correction.
|
|
44
|
+
|
|
45
|
+
See paper eq. 3 in https://pubs.aip.org/aip/jcp/article/114/2/631/184186/Thomas-Fermi-Dirac-von-Weizsacker-models-in-finite
|
|
46
|
+
|
|
47
|
+
T_{\text{Weizsacker}}[\rho] = \frac{\lambda}{8} \int \frac{(\nabla \rho)^2}{\rho} d\boldsymbol{x}
|
|
48
|
+
= \frac{\lambda}{8} \int \rho \left(\frac{\nabla \rho}{\rho}\right)^2 d\boldsymbol{x}
|
|
49
|
+
T_{\text{Weizsacker}}[\rho] = \frac{\lambda N_e}{8}\, \mathbb{E}_{\rho_\phi}\left[ \left(\frac{\nabla \rho}{\rho}\right)^2 \right]
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
score : Array
|
|
54
|
+
Gradient of the log-density, s = (nabla rho)/rho.
|
|
55
|
+
Ne : int
|
|
56
|
+
Number of electrons.
|
|
57
|
+
lam : float, optional (W. Stich, E.K.U. Gross, Z. Physik A 309(1):511, 1982)
|
|
58
|
+
Phenomenological parameter lambda, by default 0.2.
|
|
59
|
+
|
|
60
|
+
Notes
|
|
61
|
+
-----
|
|
62
|
+
den, x, mol, xp are accepted for the shared functional interface but unused here.
|
|
63
|
+
|
|
64
|
+
Returns
|
|
65
|
+
-------
|
|
66
|
+
jax.Array
|
|
67
|
+
von Weizsacker kinetic energy density (up to the rho factor).
|
|
68
|
+
"""
|
|
69
|
+
score_sqr = jnp.einsum('ij,ij->i', score, score)
|
|
70
|
+
return (lam * Ne / 8.) * lax.expand_dims(score_sqr, (1,))
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
tf = Functional(coefficients=unit_coefficient, energy_densities=tf_density)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def weizsacker(lam=0.2):
|
|
77
|
+
r"""von Weizsacker functional with prefactor ``lam`` (see :func:`weizsacker_density`)."""
|
|
78
|
+
return Functional(
|
|
79
|
+
coefficients=unit_coefficient,
|
|
80
|
+
energy_densities=lambda den, score, x, Ne, mol, xp: weizsacker_density(
|
|
81
|
+
den, score, x, Ne, mol, xp, lam),
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def tf_weizsacker(lam=0.2):
|
|
86
|
+
r"""TF-lambda-W kinetic functional: T = T_TF + lambda * T_W."""
|
|
87
|
+
return CompositeFunctional(tf, weizsacker(lam))
|