off 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
off/train/loss.py ADDED
@@ -0,0 +1,149 @@
1
+ import jax.numpy as jnp
2
+ from typing import NamedTuple
3
+ from ..functionals.kinetic import tf, weizsacker, tf_weizsacker
4
+ from ..functionals.exchange_correlation import lda, b88, vwn, pw92, lda_b88
5
+ from ..functionals.hartree import CoulombPotential_, CoulombPotential
6
+ from ..functionals.external import NuclearPotential
7
+ from ..functionals.core_correction import KatoCondition, HutcheonCuspCondition
8
+ from ..functionals.functional import FunctionalInputs, EnergyFunctional
9
+ from ..ode_solver.eqx_ode import fwd_ode
10
+ import jax
11
+ class F_values(NamedTuple):
12
+ """Container for energy components."""
13
+ energy: float
14
+ kin: float
15
+ vnuc: float
16
+ hart: float
17
+ xc: float
18
+ cc: float
19
+
20
+ FUNCTIONAL_CLASSES = {
21
+ # Kinetic
22
+ 'tf': lambda: tf,
23
+ 'w': lambda: weizsacker(),
24
+ 'tf_w': lambda: tf_weizsacker(),
25
+
26
+ # Exchange
27
+ 'lda': lambda: lda,
28
+ 'b88_x': lambda: b88,
29
+ 'lda_b88_x': lambda: lda_b88(),
30
+
31
+ # Correlation
32
+ 'vwn_c': lambda: vwn,
33
+ 'pw92_c': lambda: pw92,
34
+
35
+ # Hartree
36
+ 'coulomb': CoulombPotential, # all-pairs (batch²), main
37
+ 'coulomb_': CoulombPotential_, # element-wise
38
+ 'coulomb_allpairs': CoulombPotential, # back-compat: old job_params tag
39
+
40
+ # External
41
+ 'np': NuclearPotential,
42
+
43
+ # Core correction
44
+ 'kato': KatoCondition,
45
+ 'hutcheon': HutcheonCuspCondition,
46
+ }
47
+
48
+ def _build_kinetic(kinetic_name: str, lam: float):
49
+ """Build the kinetic functional."""
50
+ if kinetic_name == 'w':
51
+ return weizsacker(lam)
52
+ if kinetic_name == 'tf_w':
53
+ return tf_weizsacker(lam)
54
+ return tf
55
+
56
+
57
+ def build_energy_functional(
58
+ kinetic_name: str = 'tf',
59
+ lam: float = 1.0,
60
+ exchange_name: str = 'lda',
61
+ correlation_name: str = 'none',
62
+ hartree_name: str = 'coulomb',
63
+ external_name: str = 'np',
64
+ core_correction_name: str = 'none',
65
+ ):
66
+ return EnergyFunctional(
67
+ kinetic=_build_kinetic(kinetic_name, lam),
68
+ external=FUNCTIONAL_CLASSES[external_name](),
69
+ hartree=FUNCTIONAL_CLASSES[hartree_name](),
70
+ exchange=FUNCTIONAL_CLASSES[exchange_name](),
71
+ correlation=FUNCTIONAL_CLASSES[correlation_name]()
72
+ if correlation_name != 'none' else None,
73
+ core_correction=FUNCTIONAL_CLASSES[core_correction_name]()
74
+ if core_correction_name != 'none' else None,
75
+ )
76
+
77
+
78
+ def create_loss_function(
79
+ kinetic_name: str = 'tf',
80
+ lam: float = 1.0,
81
+ exchange_name: str = 'lda',
82
+ correlation_name: str = 'none',
83
+ hartree_name: str = 'coulomb',
84
+ external_name: str = 'np',
85
+ core_correction_name: str = 'none'
86
+ ):
87
+ """
88
+ Factory function to create a loss function with specific functionals.
89
+
90
+ Parameters
91
+ ----------
92
+ kinetic_name : str
93
+ Name of kinetic functional ('tf', 'w', 'tf_w')
94
+ lam : float
95
+ Weizsäcker prefactor λ in TF-λW
96
+ exchange_name : str
97
+ Name of exchange functional ('lda', 'b88_x')
98
+ correlation_name : str
99
+ Name of correlation functional ('vwn_c', 'pw92_c', 'none')
100
+ hartree_name : str
101
+ Name of Hartree functional ('coulomb')
102
+ external_name : str
103
+ Name of external potential functional ('np')
104
+ core_correction_name : str
105
+ Name of core correction functional ('kato', 'hutcheon', 'none')
106
+
107
+ Returns
108
+ -------
109
+ grad_loss : callable
110
+ """
111
+ functional = build_energy_functional(
112
+ kinetic_name, lam, exchange_name, correlation_name,
113
+ hartree_name, external_name, core_correction_name,
114
+ )
115
+
116
+ def grad_loss(model, z_and_logpz, solver, Ne, mol):
117
+ """
118
+ Compute the loss function.
119
+ """
120
+ x, log_px, _score = fwd_ode(model, z_and_logpz, solver)
121
+
122
+ bs = int(x.shape[0] / 2)
123
+
124
+ den_all, x_all, score_all = jnp.exp(log_px), x, _score
125
+ score, scorep = score_all[:bs], score_all[bs:]
126
+ den, denp = den_all[:bs], den_all[bs:]
127
+ x, xp = x_all[:bs], x_all[bs:]
128
+
129
+ inp = FunctionalInputs(den=den, score=score, x=x, Ne=Ne, mol=mol, xp=xp)
130
+ terms = functional.terms(inp)
131
+ t_e, n_e, h_e = terms["kin"], terms["vnuc"], terms["hart"]
132
+ x_e, c_e, cc_e = terms["x"], terms["c"], terms["cc"]
133
+ xc_e = x_e + c_e
134
+
135
+ e = t_e + n_e + h_e + xc_e + cc_e
136
+ energy = functional._integrate(jnp.reshape(e, (-1,)), 1.0 / bs)
137
+
138
+ f_values = F_values(
139
+ energy=energy,
140
+ kin=jnp.mean(t_e),
141
+ vnuc=jnp.mean(n_e),
142
+ hart=jnp.mean(h_e),
143
+ xc=jnp.mean(xc_e),
144
+ cc=jnp.mean(cc_e) if isinstance(cc_e, jnp.ndarray) else cc_e
145
+ )
146
+
147
+ return energy, f_values
148
+
149
+ return grad_loss
off/train/utils.py ADDED
@@ -0,0 +1,38 @@
1
+ """Utility functions for training."""
2
+ from collections.abc import Callable
3
+
4
+ import equinox as eqx
5
+ import optax
6
+ from jaxtyping import Array, PyTree, Scalar
7
+
8
+ @eqx.filter_jit
9
+ def step(
10
+ flow_model: PyTree,
11
+ batch: Array,
12
+ optimizer: optax.GradientTransformation,
13
+ optimizer_state: PyTree,
14
+ loss_fn: Callable[[PyTree, PyTree], Scalar],
15
+ *loss_args
16
+ ):
17
+ """Carry out a training step.
18
+
19
+ Args:
20
+ params: Flow model
21
+ batch: Arguments passed to the loss function (often the static components
22
+ of the model).
23
+ optimizer: Optax optimizer.
24
+ optimizer_state: Optimizer state.
25
+ loss_fn: The loss function. This should take params and static as the first two
26
+ arguments.
27
+
28
+ Returns:
29
+ tuple: (loss_val, params, optmizer state)
30
+ """
31
+ # Compute loss and gradients
32
+ loss, grads = eqx.filter_value_and_grad(loss_fn, has_aux=True)(flow_model, batch,*loss_args)
33
+
34
+ # Update the model parameters
35
+ updates, optimizer_state = optimizer.update(grads, optimizer_state,flow_model)
36
+ flow_model = eqx.apply_updates(flow_model, updates)
37
+
38
+ return loss, flow_model, optimizer_state