off 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- off/__init__.py +23 -0
- off/atom_energies.py +151 -0
- off/config/_config.py +108 -0
- off/dft_distrax/__init__.py +27 -0
- off/dft_distrax/dft_distrax.py +216 -0
- off/flow/__init__.py +29 -0
- off/flow/equiv_flows.py +99 -0
- off/functionals/__init__.py +35 -0
- off/functionals/core_correction.py +84 -0
- off/functionals/exchange_correlation.py +174 -0
- off/functionals/external.py +49 -0
- off/functionals/functional.py +129 -0
- off/functionals/hartree.py +62 -0
- off/functionals/kinetic.py +87 -0
- off/main.py +172 -0
- off/ode_solver/__init__.py +32 -0
- off/ode_solver/eqx_ode.py +76 -0
- off/plot_binding_csv.py +63 -0
- off/plot_pes_ema.py +259 -0
- off/plot_pes_mpl.py +280 -0
- off/promolecular/__init__.py +27 -0
- off/promolecular/promolecular_dist.py +465 -0
- off/quadrature.py +261 -0
- off/quadrature_scan.py +188 -0
- off/scan_pes.py +133 -0
- off/test_fwd_rev.py +290 -0
- off/train/__init__.py +44 -0
- off/train/loop.py +228 -0
- off/train/loss.py +149 -0
- off/train/utils.py +38 -0
- off/utils.py +618 -0
- off-0.1.0.dist-info/METADATA +154 -0
- off-0.1.0.dist-info/RECORD +37 -0
- off-0.1.0.dist-info/WHEEL +5 -0
- off-0.1.0.dist-info/entry_points.txt +3 -0
- off-0.1.0.dist-info/licenses/LICENSE +21 -0
- off-0.1.0.dist-info/top_level.txt +1 -0
off/test_fwd_rev.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CNF analysis script: normalization, energy, binding energy, density plot.
|
|
3
|
+
|
|
4
|
+
Usage
|
|
5
|
+
-----
|
|
6
|
+
# Single molecule (energy + density):
|
|
7
|
+
python test_fwd_rev.py \
|
|
8
|
+
--results_dir of_flows/Results/H2/tf_w_lam0.2_none_lda_none_dopri8_promolecular_sched_MIX/bl_3.0000
|
|
9
|
+
|
|
10
|
+
# With binding energy (needs H atom result dir):
|
|
11
|
+
python test_fwd_rev.py \
|
|
12
|
+
--results_dir of_flows/Results/H2/tf_w_lam0.2_none_lda_none_dopri8_promolecular_sched_MIX/bl_3.0000 \
|
|
13
|
+
--atom_results_dir of_flows/Results/H/tf_w_lam0.2_none_lda_none_dopri8_promolecular_sched_MIX/bl_0.0000
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import sys, os
|
|
17
|
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "of_flows"))
|
|
18
|
+
|
|
19
|
+
import argparse
|
|
20
|
+
import glob
|
|
21
|
+
import json
|
|
22
|
+
import re
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
|
|
25
|
+
import jax
|
|
26
|
+
import jax.numpy as jnp
|
|
27
|
+
import jax.random as jrnd
|
|
28
|
+
import equinox as eqx
|
|
29
|
+
import matplotlib.pyplot as plt
|
|
30
|
+
import numpy as np
|
|
31
|
+
from pyscf import gto, dft
|
|
32
|
+
from atomdb import make_promolecule
|
|
33
|
+
|
|
34
|
+
jax.config.update("jax_enable_x64", True)
|
|
35
|
+
|
|
36
|
+
from flow.equiv_flows import CNF
|
|
37
|
+
from ode_solver.eqx_ode import fwd_ode, rev_ode
|
|
38
|
+
from utils import one_hot_encode, coordinates, get_solver
|
|
39
|
+
from promolecular.promolecular_dist import ProMolecularDensity, AtomDBDistribution, SIRDistribution
|
|
40
|
+
from train.loss import FUNCTIONAL_CLASSES, _build_kinetic
|
|
41
|
+
from functionals.functional import FunctionalInputs
|
|
42
|
+
|
|
43
|
+
# ── CLI ───────────────────────────────────────────────────────────────────────
|
|
44
|
+
parser = argparse.ArgumentParser()
|
|
45
|
+
parser.add_argument("--results_dir", type=str, required=True,
|
|
46
|
+
help="Path to bl_X.XXXX result directory (contains job_params.json)")
|
|
47
|
+
parser.add_argument("--atom_results_dir", type=str, default=None,
|
|
48
|
+
help="Path to the H atom bl_0.0000 result directory (for binding energy)")
|
|
49
|
+
parser.add_argument("--bs", type=int, default=256, help="Grid chunk size")
|
|
50
|
+
parser.add_argument("--grid_level", type=int, default=3, help="PySCF grid level")
|
|
51
|
+
args = parser.parse_args()
|
|
52
|
+
|
|
53
|
+
# ── helpers ───────────────────────────────────────────────────────────────────
|
|
54
|
+
def load_results(results_dir: str):
|
|
55
|
+
"""Load job_params, find last checkpoint, build and restore the CNF model."""
|
|
56
|
+
rdir = Path(results_dir).resolve()
|
|
57
|
+
|
|
58
|
+
with open(rdir / "job_params.json") as f:
|
|
59
|
+
p = json.load(f)
|
|
60
|
+
|
|
61
|
+
# Find the highest-epoch checkpoint
|
|
62
|
+
ckpts = glob.glob(str(rdir / "Checkpoints" / "checkpoint_*.eqx"))
|
|
63
|
+
if not ckpts:
|
|
64
|
+
raise FileNotFoundError(f"No checkpoints found in {rdir}/Checkpoints/")
|
|
65
|
+
ckpts.sort(key=lambda path: int(re.search(r'checkpoint_(\d+)\.eqx', path).group(1)))
|
|
66
|
+
last_ckpt = ckpts[-1]
|
|
67
|
+
print(f" Loading checkpoint: {last_ckpt}")
|
|
68
|
+
|
|
69
|
+
Ne, atoms, z, coords = coordinates(p['mol_name'], p['bond_length'])
|
|
70
|
+
z_one_hot = one_hot_encode(z)
|
|
71
|
+
key = jrnd.PRNGKey(0)
|
|
72
|
+
model = CNF(din=3, dim=p['hidden_layer'], mu=coords, one_hot=z_one_hot, key=key)
|
|
73
|
+
model = eqx.tree_deserialise_leaves(last_ckpt, model)
|
|
74
|
+
solver = get_solver(p['solver'])
|
|
75
|
+
|
|
76
|
+
return p, model, solver, Ne, atoms, z, coords
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def build_prior(p, z, coords, Ne):
|
|
80
|
+
prior = ProMolecularDensity(z.ravel(), coords)
|
|
81
|
+
if p['prior'] == 'db_sir':
|
|
82
|
+
# Direct AtomDB sampling via per-atom inverse-CDF (no SIR needed)
|
|
83
|
+
db_prior = make_promolecule(atnums=z, coords=coords, dataset="slater")
|
|
84
|
+
return AtomDBDistribution(db_prior=db_prior, z=z, coords=coords, Ne=Ne)
|
|
85
|
+
return prior
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def build_pyscf_mol(atoms, coords, Ne):
|
|
89
|
+
atom_str = "; ".join(f"{a} {c[0]:.8f} {c[1]:.8f} {c[2]:.8f}"
|
|
90
|
+
for a, c in zip(atoms, coords))
|
|
91
|
+
return gto.M(atom=atom_str, basis="6-31G(d,p)", unit="B",
|
|
92
|
+
verbose=0, spin=int(Ne) % 2)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def compute_rho_on_grid(model, solver, sampling_dist, grid_coords, chunk):
|
|
96
|
+
"""Two-pass rev→fwd: get ρ and score at every grid point."""
|
|
97
|
+
x_list, rho_list, score_list = [], [], []
|
|
98
|
+
G = grid_coords.shape[0]
|
|
99
|
+
for i in range(0, G, chunk):
|
|
100
|
+
xc = grid_coords[i:i+chunk]
|
|
101
|
+
n = xc.shape[0]
|
|
102
|
+
state_rev = jnp.concatenate([xc, jnp.zeros((n,1)), jnp.zeros((n,3))], axis=1)
|
|
103
|
+
z_base, _ = rev_ode(model, state_rev, solver)
|
|
104
|
+
log_p0 = sampling_dist.log_prob(z_base)
|
|
105
|
+
score_p0 = sampling_dist.score(z_base)
|
|
106
|
+
state_fwd = jnp.concatenate([z_base, log_p0, score_p0], axis=1)
|
|
107
|
+
x_t1, logp_t1, score_t1 = fwd_ode(model, state_fwd, solver)
|
|
108
|
+
x_list.append(np.array(x_t1))
|
|
109
|
+
rho_list.append(np.array(jnp.exp(logp_t1)).ravel())
|
|
110
|
+
score_list.append(np.array(score_t1))
|
|
111
|
+
return (np.concatenate(x_list),
|
|
112
|
+
np.concatenate(rho_list),
|
|
113
|
+
np.concatenate(score_list))
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def compute_energy(p, x_np, rho_np, score_np, grid_coords, grid_weights, mol_dict, Ne, chunk):
|
|
117
|
+
"""Quadrature integrals for all energy components."""
|
|
118
|
+
rho_col = rho_np[:, None]
|
|
119
|
+
w = np.array(grid_weights)
|
|
120
|
+
gc = x_np
|
|
121
|
+
|
|
122
|
+
t_func = _build_kinetic(p['kinetic'], p['lam'])
|
|
123
|
+
x_func = FUNCTIONAL_CLASSES[p['exchange']]()
|
|
124
|
+
n_func = FUNCTIONAL_CLASSES[p['external']]()
|
|
125
|
+
h_func = FUNCTIONAL_CLASSES[p['hartree']]()
|
|
126
|
+
c_func = FUNCTIONAL_CLASSES[p['correlation']]() if p['correlation'] != 'none' else None
|
|
127
|
+
cc_func = FUNCTIONAL_CLASSES[p['core_correction']]() if p['core_correction'] != 'none' else None
|
|
128
|
+
|
|
129
|
+
G = rho_np.shape[0]
|
|
130
|
+
t_e = np.zeros(G); x_e = np.zeros(G)
|
|
131
|
+
n_e = np.zeros(G); c_e = np.zeros(G); cc_e = np.zeros(G)
|
|
132
|
+
|
|
133
|
+
for i in range(0, G, chunk):
|
|
134
|
+
sl = slice(i, min(i+chunk, G))
|
|
135
|
+
inp = FunctionalInputs(den=jnp.array(rho_col[sl]), score=jnp.array(score_np[sl]),
|
|
136
|
+
x=jnp.array(gc[sl]), Ne=Ne, mol=mol_dict, xp=None)
|
|
137
|
+
t_e[sl] = np.array(t_func(inp)).ravel()
|
|
138
|
+
x_e[sl] = np.array(x_func(inp)).ravel()
|
|
139
|
+
n_e[sl] = np.array(n_func(inp)).ravel()
|
|
140
|
+
if c_func is not None: c_e[sl] = np.array(c_func(inp)).ravel()
|
|
141
|
+
if cc_func is not None: cc_e[sl] = np.array(cc_func(inp)).ravel()
|
|
142
|
+
|
|
143
|
+
T = float(np.dot(w, t_e * rho_np))
|
|
144
|
+
E_X = float(np.dot(w, x_e * rho_np))
|
|
145
|
+
V_N = float(np.dot(w, n_e * rho_np))
|
|
146
|
+
E_C = float(np.dot(w, c_e * rho_np))
|
|
147
|
+
E_CC= float(np.dot(w, cc_e * rho_np))
|
|
148
|
+
|
|
149
|
+
# Hartree — O(G²) double integral, j≠k
|
|
150
|
+
coords_H = np.array(grid_coords)
|
|
151
|
+
v_coulomb = np.zeros(G)
|
|
152
|
+
for i in range(0, G, chunk):
|
|
153
|
+
xi = coords_H[i:i+chunk]
|
|
154
|
+
diff = coords_H[None,:,:] - xi[:,None,:]
|
|
155
|
+
r2 = np.sum(diff**2, axis=-1)
|
|
156
|
+
safe_r = np.sqrt(np.where(r2 == 0., np.inf, r2))
|
|
157
|
+
v_coulomb[i:i+chunk] = np.dot(1./safe_r, w * rho_np)
|
|
158
|
+
V_H = float(0.5 * Ne**2 * np.dot(w * rho_np, v_coulomb))
|
|
159
|
+
|
|
160
|
+
# Nuclear repulsion
|
|
161
|
+
coords_np = np.array(mol_dict['coords'])
|
|
162
|
+
z_arr = np.array(mol_dict['z']).ravel()
|
|
163
|
+
E_NN = 0.0
|
|
164
|
+
for I in range(len(coords_np)):
|
|
165
|
+
for J in range(I+1, len(coords_np)):
|
|
166
|
+
E_NN += float(z_arr[I]) * float(z_arr[J]) / float(np.linalg.norm(coords_np[I]-coords_np[J]))
|
|
167
|
+
|
|
168
|
+
E_total = T + V_N + V_H + E_X + E_C + E_CC + E_NN
|
|
169
|
+
return dict(T=T, V_N=V_N, V_H=V_H, E_X=E_X, E_C=E_C, E_CC=E_CC, E_NN=E_NN,
|
|
170
|
+
E_total=E_total)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def run_analysis(results_dir: str, chunk: int, grid_level: int):
|
|
174
|
+
"""Full analysis for one result directory. Returns energy dict."""
|
|
175
|
+
print(f"\n{'='*60}")
|
|
176
|
+
print(f"Analysing: {results_dir}")
|
|
177
|
+
p, model, solver, Ne, atoms, z, coords = load_results(results_dir)
|
|
178
|
+
mol_dict = {'coords': coords, 'z': z}
|
|
179
|
+
sampling_dist = build_prior(p, z, coords, Ne)
|
|
180
|
+
|
|
181
|
+
mol_pyscf = build_pyscf_mol(atoms, coords, Ne)
|
|
182
|
+
grid = dft.gen_grid.Grids(mol_pyscf)
|
|
183
|
+
grid.level = grid_level
|
|
184
|
+
grid.build()
|
|
185
|
+
grid_coords = jnp.array(grid.coords, dtype=jnp.float64)
|
|
186
|
+
grid_weights = jnp.array(grid.weights, dtype=jnp.float64)
|
|
187
|
+
print(f" Grid: {grid_coords.shape[0]} points (level={grid_level})")
|
|
188
|
+
|
|
189
|
+
print(" Computing ρ via rev→fwd ...")
|
|
190
|
+
x_np, rho_np, score_np = compute_rho_on_grid(
|
|
191
|
+
model, solver, sampling_dist, grid_coords, chunk)
|
|
192
|
+
|
|
193
|
+
pos_err = float(np.max(np.abs(x_np - np.array(grid_coords))))
|
|
194
|
+
Ne_est = float(np.dot(np.array(grid_weights), Ne * rho_np))
|
|
195
|
+
print(f" Round-trip error : {pos_err:.3e}")
|
|
196
|
+
print(f" ∫ρ_M dx : {Ne_est:.6f} (should be {Ne})")
|
|
197
|
+
|
|
198
|
+
energies = compute_energy(
|
|
199
|
+
p, x_np, rho_np, score_np,
|
|
200
|
+
grid_coords, grid_weights, mol_dict, Ne, chunk)
|
|
201
|
+
|
|
202
|
+
print(f"\n === ENERGY ({p['kinetic']} / λ={p['lam']} / {p['exchange']}) ===")
|
|
203
|
+
print(f" T = {energies['T']:+.6f} Ha")
|
|
204
|
+
print(f" V_N = {energies['V_N']:+.6f} Ha")
|
|
205
|
+
print(f" V_H = {energies['V_H']:+.6f} Ha")
|
|
206
|
+
print(f" E_X = {energies['E_X']:+.6f} Ha")
|
|
207
|
+
if p['correlation'] != 'none': print(f" E_C = {energies['E_C']:+.6f} Ha")
|
|
208
|
+
if p['core_correction']!= 'none': print(f" E_CC = {energies['E_CC']:+.6f} Ha")
|
|
209
|
+
if energies['E_NN'] != 0.0: print(f" E_NN = {energies['E_NN']:+.6f} Ha")
|
|
210
|
+
print(f" ─────────────────────")
|
|
211
|
+
print(f" E_tot = {energies['E_total']:+.6f} Ha")
|
|
212
|
+
if p['mol_name'] == 'H':
|
|
213
|
+
print(f" (exact H = -0.500000 Ha)")
|
|
214
|
+
|
|
215
|
+
return p, model, solver, Ne, atoms, coords, grid_coords, grid_weights, \
|
|
216
|
+
rho_np, score_np, sampling_dist, energies
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# ── Main molecule ─────────────────────────────────────────────────────────────
|
|
220
|
+
(p, model, solver, Ne, atoms, coords, grid_coords, grid_weights,
|
|
221
|
+
rho_np, score_np, sampling_dist, energies) = run_analysis(
|
|
222
|
+
args.results_dir, args.bs, args.grid_level)
|
|
223
|
+
|
|
224
|
+
# ── Atom reference for binding energy ─────────────────────────────────────────
|
|
225
|
+
if args.atom_results_dir is not None:
|
|
226
|
+
_, _, _, _, _, _, _, _, _, _, _, energies_H = run_analysis(
|
|
227
|
+
args.atom_results_dir, args.bs, args.grid_level)
|
|
228
|
+
|
|
229
|
+
E_mol = energies['E_total']
|
|
230
|
+
E_atom = energies_H['E_total']
|
|
231
|
+
E_bind = E_mol - 2.0 * E_atom # negative = bound
|
|
232
|
+
D_e = -E_bind # dissociation energy (positive = stable)
|
|
233
|
+
|
|
234
|
+
print(f"\n=== BINDING ENERGY ===")
|
|
235
|
+
print(f" E({p['mol_name']}, R={p['bond_length']:.4f} Bohr) = {E_mol:+.6f} Ha")
|
|
236
|
+
print(f" E(H atom) = {E_atom:+.6f} Ha")
|
|
237
|
+
print(f" E_bind = E(mol) - 2·E(H) = {E_bind:+.6f} Ha")
|
|
238
|
+
print(f" D_e = 2·E(H) - E(mol) = {D_e:+.6f} Ha ({D_e*27.2114:.4f} eV)")
|
|
239
|
+
|
|
240
|
+
# ── Density plot along z-axis ─────────────────────────────────────────────────
|
|
241
|
+
z_min = float(coords[:, 2].min()) - 3.0
|
|
242
|
+
z_max = float(coords[:, 2].max()) + 3.0
|
|
243
|
+
zt = np.linspace(z_min, z_max, 300)
|
|
244
|
+
line_pts = jnp.array(np.stack([np.zeros_like(zt),
|
|
245
|
+
np.zeros_like(zt),
|
|
246
|
+
zt], axis=1), dtype=jnp.float64)
|
|
247
|
+
|
|
248
|
+
print("\nComputing density along z-axis ...")
|
|
249
|
+
rho_line = []
|
|
250
|
+
for i in range(0, line_pts.shape[0], args.bs):
|
|
251
|
+
xc = line_pts[i:i+args.bs]
|
|
252
|
+
n = xc.shape[0]
|
|
253
|
+
state_rev = jnp.concatenate([xc, jnp.zeros((n,1)), jnp.zeros((n,3))], axis=1)
|
|
254
|
+
z_b, _ = rev_ode(model, state_rev, solver)
|
|
255
|
+
log_p0 = sampling_dist.log_prob(z_b)
|
|
256
|
+
score_p0 = sampling_dist.score(z_b)
|
|
257
|
+
_, logp_fwd, _ = fwd_ode(model,
|
|
258
|
+
jnp.concatenate([z_b, log_p0, score_p0], axis=1), solver)
|
|
259
|
+
rho_line.append(np.array(jnp.exp(logp_fwd)).ravel())
|
|
260
|
+
rho_pred = np.concatenate(rho_line)
|
|
261
|
+
|
|
262
|
+
R = float(jnp.linalg.norm(coords[0] - coords[-1])) if len(coords) > 1 else 0.0
|
|
263
|
+
|
|
264
|
+
fig, ax = plt.subplots(figsize=(6, 4))
|
|
265
|
+
ax.plot(zt, Ne * rho_pred, color='tab:blue',
|
|
266
|
+
label=rf"$N_e\,\rho_{{NF}}(z)$, R={R:.3f} Bohr")
|
|
267
|
+
ax.set_xlabel("z [Bohr]")
|
|
268
|
+
ax.set_ylabel(r"$\rho(z)$ [Bohr$^{-3}$]")
|
|
269
|
+
ax.set_title(f"{p['mol_name']} | {p['kinetic']} λ={p['lam']} | {p['exchange']}")
|
|
270
|
+
ax.legend()
|
|
271
|
+
fig.tight_layout()
|
|
272
|
+
|
|
273
|
+
out_dir = Path(args.results_dir).resolve()
|
|
274
|
+
fig.savefig(out_dir / "density.svg", transparent=True)
|
|
275
|
+
fig.savefig(out_dir / "density.png", dpi=150)
|
|
276
|
+
print(f"Density plot saved → {out_dir}/density.png")
|
|
277
|
+
|
|
278
|
+
# ── Save energy summary ───────────────────────────────────────────────────────
|
|
279
|
+
summary = {**energies,
|
|
280
|
+
'mol_name': p['mol_name'],
|
|
281
|
+
'bond_length': p['bond_length'],
|
|
282
|
+
'Ne_integral': float(np.dot(np.array(grid_weights), Ne * rho_np))}
|
|
283
|
+
if args.atom_results_dir is not None:
|
|
284
|
+
summary['E_atom'] = energies_H['E_total']
|
|
285
|
+
summary['E_bind'] = E_bind
|
|
286
|
+
summary['D_e'] = D_e
|
|
287
|
+
|
|
288
|
+
with open(out_dir / "energy_summary.json", "w") as f:
|
|
289
|
+
json.dump(summary, f, indent=4)
|
|
290
|
+
print(f"Energy summary saved → {out_dir}/energy_summary.json")
|
off/train/__init__.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# MIT License
|
|
2
|
+
|
|
3
|
+
# Copyright (c) 2025 AlexandreDeCamargo
|
|
4
|
+
|
|
5
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
# in the Software without restriction, including without limitation the rights
|
|
8
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
# furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
# copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
# SOFTWARE.
|
|
22
|
+
|
|
23
|
+
__version__ = "0.1.0"
|
|
24
|
+
|
|
25
|
+
from .loop import (
|
|
26
|
+
setup_molecule,
|
|
27
|
+
setup_model,
|
|
28
|
+
setup_optimizer,
|
|
29
|
+
setup_ema,
|
|
30
|
+
log_metrics,
|
|
31
|
+
training,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
from .loss import (
|
|
35
|
+
create_loss_function
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
from .utils import (
|
|
39
|
+
step
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
from ..config._config import (
|
|
43
|
+
Config
|
|
44
|
+
)
|
off/train/loop.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import jax.random as jrnd
|
|
4
|
+
import equinox as eqx
|
|
5
|
+
import optax
|
|
6
|
+
from optax import ema
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import time
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
from ..flow.equiv_flows import CNF
|
|
12
|
+
from ..utils import one_hot_encode, coordinates, batch_generator, get_solver, get_scheduler
|
|
13
|
+
from ..promolecular.promolecular_dist import AtomDBDistribution,SIRDistribution,ProMolecularDensity
|
|
14
|
+
from .utils import step
|
|
15
|
+
from .loss import create_loss_function, F_values
|
|
16
|
+
from ..config._config import Config
|
|
17
|
+
|
|
18
|
+
jax.config.update("jax_enable_x64", True)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def setup_molecule(mol_name: str, bond_length: float = 0.74144):
|
|
22
|
+
"""Setup molecular system."""
|
|
23
|
+
Ne, atoms, z, coords = coordinates(mol_name, bond_length)
|
|
24
|
+
mol = {'coords': coords, 'z': z}
|
|
25
|
+
return Ne, atoms, z, coords, mol
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def setup_model(coords, z, hidden_layer: int, key):
|
|
29
|
+
"""Initialize flow model."""
|
|
30
|
+
mu = coords
|
|
31
|
+
z_one_hot = one_hot_encode(z)
|
|
32
|
+
data_dim = 3
|
|
33
|
+
return CNF(data_dim, hidden_layer, mu, z_one_hot, key)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def setup_optimizer(flow_model, epochs: int, lr: float, scheduler_type: str):
|
|
37
|
+
"""Setup optimizer with scheduler."""
|
|
38
|
+
_lr = get_scheduler(epochs=epochs, sched_type=scheduler_type, lr=lr)
|
|
39
|
+
optimizer = optax.chain(
|
|
40
|
+
optax.clip_by_global_norm(1.0),
|
|
41
|
+
optax.adamw(_lr, weight_decay=1e-5)
|
|
42
|
+
)
|
|
43
|
+
optimizer_state = optimizer.init(eqx.filter(flow_model, eqx.is_array))
|
|
44
|
+
return optimizer, optimizer_state
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def setup_ema():
|
|
48
|
+
"""Setup EMA for tracking energies."""
|
|
49
|
+
energies_ema = ema(decay=0.99)
|
|
50
|
+
energies_state = energies_ema.init(
|
|
51
|
+
F_values(energy=jnp.array(0.), kin=jnp.array(0.),
|
|
52
|
+
vnuc=jnp.array(0.), hart=jnp.array(0.),
|
|
53
|
+
xc=jnp.array(0.), cc=jnp.array(0.))
|
|
54
|
+
)
|
|
55
|
+
return energies_ema, energies_state
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def log_metrics(itr: int, loss_epoch: float, losses: F_values,
|
|
59
|
+
energies_i_ema: F_values, elapsed_time: float):
|
|
60
|
+
"""Create metrics dictionaries for logging."""
|
|
61
|
+
r_instant = {
|
|
62
|
+
'epoch': itr,
|
|
63
|
+
'E': loss_epoch - losses.cc,
|
|
64
|
+
'T': losses.kin,
|
|
65
|
+
'V': losses.vnuc,
|
|
66
|
+
'H': losses.hart,
|
|
67
|
+
'XC': losses.xc,
|
|
68
|
+
'CC': losses.cc,
|
|
69
|
+
't': elapsed_time
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
r_ema = {
|
|
73
|
+
'epoch': itr,
|
|
74
|
+
'E': energies_i_ema.energy - energies_i_ema.cc,
|
|
75
|
+
'T': energies_i_ema.kin,
|
|
76
|
+
'V': energies_i_ema.vnuc,
|
|
77
|
+
'H': energies_i_ema.hart,
|
|
78
|
+
'XC': energies_i_ema.xc,
|
|
79
|
+
'CC': energies_i_ema.cc,
|
|
80
|
+
't': elapsed_time
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
return r_instant, r_ema
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def training(mol_name: str,
|
|
87
|
+
bond_length: float = 1.4008538753,
|
|
88
|
+
tw_kin: str = 'tf_w',
|
|
89
|
+
lam: float = 1.0,
|
|
90
|
+
n_pot: str = 'np',
|
|
91
|
+
h_pot: str = 'coulomb',
|
|
92
|
+
x_pot: str = 'lda',
|
|
93
|
+
c_pot: str = 'vwn_c',
|
|
94
|
+
cc_pot: str = 'kato',
|
|
95
|
+
batch_size: int = 256,
|
|
96
|
+
hidden_layer: int = 64,
|
|
97
|
+
epochs: int = 100,
|
|
98
|
+
lr: float = 1e-5,
|
|
99
|
+
scheduler_type: str = 'ones',
|
|
100
|
+
solver_type: str = 'tsit5',
|
|
101
|
+
prior_type: str = 'promolecular',
|
|
102
|
+
prior_dist: Optional[ProMolecularDensity] = None,
|
|
103
|
+
checkpoint_dir: str = './checkpoints',
|
|
104
|
+
checkpoint_freq: int = 50,
|
|
105
|
+
):
|
|
106
|
+
"""
|
|
107
|
+
Main training loop.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
mol_name : str
|
|
112
|
+
Name of molecule
|
|
113
|
+
bond_length: float
|
|
114
|
+
Bond length in a.u.
|
|
115
|
+
tw_kin : str
|
|
116
|
+
Kinetic functional name
|
|
117
|
+
n_pot : str
|
|
118
|
+
External potential functional name
|
|
119
|
+
h_pot : str
|
|
120
|
+
Hartree functional name
|
|
121
|
+
x_pot : str
|
|
122
|
+
Exchange functional name
|
|
123
|
+
c_pot : str
|
|
124
|
+
Correlation functional name
|
|
125
|
+
cc_pot : str
|
|
126
|
+
Core correction functional name
|
|
127
|
+
batch_size : int
|
|
128
|
+
Batch size for training
|
|
129
|
+
hidden_layer : int
|
|
130
|
+
Hidden layer size for neural network
|
|
131
|
+
epochs : int
|
|
132
|
+
Number of training epochs
|
|
133
|
+
lr : float
|
|
134
|
+
Learning rate
|
|
135
|
+
scheduler_type : str
|
|
136
|
+
Type of learning rate scheduler
|
|
137
|
+
solver_type : str
|
|
138
|
+
ODE solver type
|
|
139
|
+
prior_type: str
|
|
140
|
+
Type of prior distribution for sampling
|
|
141
|
+
prior_dist : ProMolecularDensity, optional
|
|
142
|
+
Initial distribution
|
|
143
|
+
checkpoint_dir : str
|
|
144
|
+
Directory to save checkpoints
|
|
145
|
+
checkpoint_freq : int
|
|
146
|
+
Frequency of checkpoint saving
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
flow_model : CNF
|
|
151
|
+
Trained flow model
|
|
152
|
+
df : pd.DataFrame
|
|
153
|
+
Training metrics
|
|
154
|
+
df_ema : pd.DataFrame
|
|
155
|
+
EMA training metrics
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
# Setup
|
|
159
|
+
Ne, atoms, z, coords, mol = setup_molecule(mol_name, bond_length)
|
|
160
|
+
|
|
161
|
+
key = jrnd.PRNGKey(0)
|
|
162
|
+
_, key = jrnd.split(key)
|
|
163
|
+
|
|
164
|
+
flow_model = setup_model(coords, z, hidden_layer, key)
|
|
165
|
+
solver = get_solver(solver_type)
|
|
166
|
+
optimizer, optimizer_state = setup_optimizer(flow_model, epochs, lr, scheduler_type)
|
|
167
|
+
energies_ema, energies_state = setup_ema()
|
|
168
|
+
prior_dist = ProMolecularDensity(z.ravel(), coords)
|
|
169
|
+
|
|
170
|
+
if prior_type == 'db_sir':
|
|
171
|
+
from atomdb import make_promolecule # optional dep — only needed for db_sir
|
|
172
|
+
db_prior = make_promolecule(atnums=z, coords=coords, dataset="slater")
|
|
173
|
+
sampling_dist = AtomDBDistribution(
|
|
174
|
+
db_prior=db_prior, z=z, coords=coords, Ne=Ne
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
sampling_dist = prior_dist
|
|
178
|
+
|
|
179
|
+
gen_batches = batch_generator(key, batch_size, sampling_dist)
|
|
180
|
+
|
|
181
|
+
grad_loss_fn = create_loss_function(
|
|
182
|
+
kinetic_name=tw_kin,
|
|
183
|
+
lam=lam,
|
|
184
|
+
exchange_name=x_pot,
|
|
185
|
+
correlation_name=c_pot,
|
|
186
|
+
hartree_name=h_pot,
|
|
187
|
+
external_name=n_pot,
|
|
188
|
+
core_correction_name=cc_pot
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Training loop
|
|
192
|
+
df = pd.DataFrame()
|
|
193
|
+
df_ema = pd.DataFrame()
|
|
194
|
+
|
|
195
|
+
for itr in range(epochs + 1):
|
|
196
|
+
start_time = time.time()
|
|
197
|
+
|
|
198
|
+
batch = next(gen_batches)
|
|
199
|
+
# batch = next(db_gen_batches)
|
|
200
|
+
|
|
201
|
+
loss, flow_model, optimizer_state = step(
|
|
202
|
+
flow_model, batch, optimizer, optimizer_state,
|
|
203
|
+
grad_loss_fn, solver, Ne, mol
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
elapsed_time = time.time() - start_time
|
|
207
|
+
|
|
208
|
+
loss_epoch, losses = loss
|
|
209
|
+
|
|
210
|
+
# Update EMA
|
|
211
|
+
energies_i_ema, energies_state = energies_ema.update(losses, energies_state)
|
|
212
|
+
|
|
213
|
+
# Log metrics
|
|
214
|
+
r_instant, r_ema = log_metrics(itr, loss_epoch, losses, energies_i_ema, elapsed_time)
|
|
215
|
+
|
|
216
|
+
df = pd.concat([df, pd.DataFrame([r_instant])], ignore_index=True)
|
|
217
|
+
df_ema = pd.concat([df_ema, pd.DataFrame([r_ema])], ignore_index=True)
|
|
218
|
+
|
|
219
|
+
print(f"Epoch {itr}: {r_ema}")
|
|
220
|
+
|
|
221
|
+
df.to_csv(f"{Config.results_dir}/training_metrics.csv", index=False)
|
|
222
|
+
df_ema.to_csv(f"{Config.results_dir}/training_metrics_ema.csv", index=False)
|
|
223
|
+
|
|
224
|
+
# Save checkpoint
|
|
225
|
+
if itr % checkpoint_freq == 0 or itr == epochs:
|
|
226
|
+
eqx.tree_serialise_leaves(f"{checkpoint_dir}/checkpoint_{itr}.eqx", flow_model)
|
|
227
|
+
|
|
228
|
+
return flow_model, df, df_ema
|