off 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- off/__init__.py +23 -0
- off/atom_energies.py +151 -0
- off/config/_config.py +108 -0
- off/dft_distrax/__init__.py +27 -0
- off/dft_distrax/dft_distrax.py +216 -0
- off/flow/__init__.py +29 -0
- off/flow/equiv_flows.py +99 -0
- off/functionals/__init__.py +35 -0
- off/functionals/core_correction.py +84 -0
- off/functionals/exchange_correlation.py +174 -0
- off/functionals/external.py +49 -0
- off/functionals/functional.py +129 -0
- off/functionals/hartree.py +62 -0
- off/functionals/kinetic.py +87 -0
- off/main.py +172 -0
- off/ode_solver/__init__.py +32 -0
- off/ode_solver/eqx_ode.py +76 -0
- off/plot_binding_csv.py +63 -0
- off/plot_pes_ema.py +259 -0
- off/plot_pes_mpl.py +280 -0
- off/promolecular/__init__.py +27 -0
- off/promolecular/promolecular_dist.py +465 -0
- off/quadrature.py +261 -0
- off/quadrature_scan.py +188 -0
- off/scan_pes.py +133 -0
- off/test_fwd_rev.py +290 -0
- off/train/__init__.py +44 -0
- off/train/loop.py +228 -0
- off/train/loss.py +149 -0
- off/train/utils.py +38 -0
- off/utils.py +618 -0
- off-0.1.0.dist-info/METADATA +154 -0
- off-0.1.0.dist-info/RECORD +37 -0
- off-0.1.0.dist-info/WHEEL +5 -0
- off-0.1.0.dist-info/entry_points.txt +3 -0
- off-0.1.0.dist-info/licenses/LICENSE +21 -0
- off-0.1.0.dist-info/top_level.txt +1 -0
off/train/loss.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
from typing import NamedTuple
|
|
3
|
+
from ..functionals.kinetic import tf, weizsacker, tf_weizsacker
|
|
4
|
+
from ..functionals.exchange_correlation import lda, b88, vwn, pw92, lda_b88
|
|
5
|
+
from ..functionals.hartree import CoulombPotential_, CoulombPotential
|
|
6
|
+
from ..functionals.external import NuclearPotential
|
|
7
|
+
from ..functionals.core_correction import KatoCondition, HutcheonCuspCondition
|
|
8
|
+
from ..functionals.functional import FunctionalInputs, EnergyFunctional
|
|
9
|
+
from ..ode_solver.eqx_ode import fwd_ode
|
|
10
|
+
import jax
|
|
11
|
+
class F_values(NamedTuple):
|
|
12
|
+
"""Container for energy components."""
|
|
13
|
+
energy: float
|
|
14
|
+
kin: float
|
|
15
|
+
vnuc: float
|
|
16
|
+
hart: float
|
|
17
|
+
xc: float
|
|
18
|
+
cc: float
|
|
19
|
+
|
|
20
|
+
FUNCTIONAL_CLASSES = {
|
|
21
|
+
# Kinetic
|
|
22
|
+
'tf': lambda: tf,
|
|
23
|
+
'w': lambda: weizsacker(),
|
|
24
|
+
'tf_w': lambda: tf_weizsacker(),
|
|
25
|
+
|
|
26
|
+
# Exchange
|
|
27
|
+
'lda': lambda: lda,
|
|
28
|
+
'b88_x': lambda: b88,
|
|
29
|
+
'lda_b88_x': lambda: lda_b88(),
|
|
30
|
+
|
|
31
|
+
# Correlation
|
|
32
|
+
'vwn_c': lambda: vwn,
|
|
33
|
+
'pw92_c': lambda: pw92,
|
|
34
|
+
|
|
35
|
+
# Hartree
|
|
36
|
+
'coulomb': CoulombPotential, # all-pairs (batch²), main
|
|
37
|
+
'coulomb_': CoulombPotential_, # element-wise
|
|
38
|
+
'coulomb_allpairs': CoulombPotential, # back-compat: old job_params tag
|
|
39
|
+
|
|
40
|
+
# External
|
|
41
|
+
'np': NuclearPotential,
|
|
42
|
+
|
|
43
|
+
# Core correction
|
|
44
|
+
'kato': KatoCondition,
|
|
45
|
+
'hutcheon': HutcheonCuspCondition,
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
def _build_kinetic(kinetic_name: str, lam: float):
|
|
49
|
+
"""Build the kinetic functional."""
|
|
50
|
+
if kinetic_name == 'w':
|
|
51
|
+
return weizsacker(lam)
|
|
52
|
+
if kinetic_name == 'tf_w':
|
|
53
|
+
return tf_weizsacker(lam)
|
|
54
|
+
return tf
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def build_energy_functional(
|
|
58
|
+
kinetic_name: str = 'tf',
|
|
59
|
+
lam: float = 1.0,
|
|
60
|
+
exchange_name: str = 'lda',
|
|
61
|
+
correlation_name: str = 'none',
|
|
62
|
+
hartree_name: str = 'coulomb',
|
|
63
|
+
external_name: str = 'np',
|
|
64
|
+
core_correction_name: str = 'none',
|
|
65
|
+
):
|
|
66
|
+
return EnergyFunctional(
|
|
67
|
+
kinetic=_build_kinetic(kinetic_name, lam),
|
|
68
|
+
external=FUNCTIONAL_CLASSES[external_name](),
|
|
69
|
+
hartree=FUNCTIONAL_CLASSES[hartree_name](),
|
|
70
|
+
exchange=FUNCTIONAL_CLASSES[exchange_name](),
|
|
71
|
+
correlation=FUNCTIONAL_CLASSES[correlation_name]()
|
|
72
|
+
if correlation_name != 'none' else None,
|
|
73
|
+
core_correction=FUNCTIONAL_CLASSES[core_correction_name]()
|
|
74
|
+
if core_correction_name != 'none' else None,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def create_loss_function(
|
|
79
|
+
kinetic_name: str = 'tf',
|
|
80
|
+
lam: float = 1.0,
|
|
81
|
+
exchange_name: str = 'lda',
|
|
82
|
+
correlation_name: str = 'none',
|
|
83
|
+
hartree_name: str = 'coulomb',
|
|
84
|
+
external_name: str = 'np',
|
|
85
|
+
core_correction_name: str = 'none'
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
Factory function to create a loss function with specific functionals.
|
|
89
|
+
|
|
90
|
+
Parameters
|
|
91
|
+
----------
|
|
92
|
+
kinetic_name : str
|
|
93
|
+
Name of kinetic functional ('tf', 'w', 'tf_w')
|
|
94
|
+
lam : float
|
|
95
|
+
Weizsäcker prefactor λ in TF-λW
|
|
96
|
+
exchange_name : str
|
|
97
|
+
Name of exchange functional ('lda', 'b88_x')
|
|
98
|
+
correlation_name : str
|
|
99
|
+
Name of correlation functional ('vwn_c', 'pw92_c', 'none')
|
|
100
|
+
hartree_name : str
|
|
101
|
+
Name of Hartree functional ('coulomb')
|
|
102
|
+
external_name : str
|
|
103
|
+
Name of external potential functional ('np')
|
|
104
|
+
core_correction_name : str
|
|
105
|
+
Name of core correction functional ('kato', 'hutcheon', 'none')
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
grad_loss : callable
|
|
110
|
+
"""
|
|
111
|
+
functional = build_energy_functional(
|
|
112
|
+
kinetic_name, lam, exchange_name, correlation_name,
|
|
113
|
+
hartree_name, external_name, core_correction_name,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def grad_loss(model, z_and_logpz, solver, Ne, mol):
|
|
117
|
+
"""
|
|
118
|
+
Compute the loss function.
|
|
119
|
+
"""
|
|
120
|
+
x, log_px, _score = fwd_ode(model, z_and_logpz, solver)
|
|
121
|
+
|
|
122
|
+
bs = int(x.shape[0] / 2)
|
|
123
|
+
|
|
124
|
+
den_all, x_all, score_all = jnp.exp(log_px), x, _score
|
|
125
|
+
score, scorep = score_all[:bs], score_all[bs:]
|
|
126
|
+
den, denp = den_all[:bs], den_all[bs:]
|
|
127
|
+
x, xp = x_all[:bs], x_all[bs:]
|
|
128
|
+
|
|
129
|
+
inp = FunctionalInputs(den=den, score=score, x=x, Ne=Ne, mol=mol, xp=xp)
|
|
130
|
+
terms = functional.terms(inp)
|
|
131
|
+
t_e, n_e, h_e = terms["kin"], terms["vnuc"], terms["hart"]
|
|
132
|
+
x_e, c_e, cc_e = terms["x"], terms["c"], terms["cc"]
|
|
133
|
+
xc_e = x_e + c_e
|
|
134
|
+
|
|
135
|
+
e = t_e + n_e + h_e + xc_e + cc_e
|
|
136
|
+
energy = functional._integrate(jnp.reshape(e, (-1,)), 1.0 / bs)
|
|
137
|
+
|
|
138
|
+
f_values = F_values(
|
|
139
|
+
energy=energy,
|
|
140
|
+
kin=jnp.mean(t_e),
|
|
141
|
+
vnuc=jnp.mean(n_e),
|
|
142
|
+
hart=jnp.mean(h_e),
|
|
143
|
+
xc=jnp.mean(xc_e),
|
|
144
|
+
cc=jnp.mean(cc_e) if isinstance(cc_e, jnp.ndarray) else cc_e
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
return energy, f_values
|
|
148
|
+
|
|
149
|
+
return grad_loss
|
off/train/utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Utility functions for training."""
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
|
|
4
|
+
import equinox as eqx
|
|
5
|
+
import optax
|
|
6
|
+
from jaxtyping import Array, PyTree, Scalar
|
|
7
|
+
|
|
8
|
+
@eqx.filter_jit
|
|
9
|
+
def step(
|
|
10
|
+
flow_model: PyTree,
|
|
11
|
+
batch: Array,
|
|
12
|
+
optimizer: optax.GradientTransformation,
|
|
13
|
+
optimizer_state: PyTree,
|
|
14
|
+
loss_fn: Callable[[PyTree, PyTree], Scalar],
|
|
15
|
+
*loss_args
|
|
16
|
+
):
|
|
17
|
+
"""Carry out a training step.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
params: Flow model
|
|
21
|
+
batch: Arguments passed to the loss function (often the static components
|
|
22
|
+
of the model).
|
|
23
|
+
optimizer: Optax optimizer.
|
|
24
|
+
optimizer_state: Optimizer state.
|
|
25
|
+
loss_fn: The loss function. This should take params and static as the first two
|
|
26
|
+
arguments.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
tuple: (loss_val, params, optmizer state)
|
|
30
|
+
"""
|
|
31
|
+
# Compute loss and gradients
|
|
32
|
+
loss, grads = eqx.filter_value_and_grad(loss_fn, has_aux=True)(flow_model, batch,*loss_args)
|
|
33
|
+
|
|
34
|
+
# Update the model parameters
|
|
35
|
+
updates, optimizer_state = optimizer.update(grads, optimizer_state,flow_model)
|
|
36
|
+
flow_model = eqx.apply_updates(flow_model, updates)
|
|
37
|
+
|
|
38
|
+
return loss, flow_model, optimizer_state
|