off 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
off/test_fwd_rev.py ADDED
@@ -0,0 +1,290 @@
1
+ """
2
+ CNF analysis script: normalization, energy, binding energy, density plot.
3
+
4
+ Usage
5
+ -----
6
+ # Single molecule (energy + density):
7
+ python test_fwd_rev.py \
8
+ --results_dir of_flows/Results/H2/tf_w_lam0.2_none_lda_none_dopri8_promolecular_sched_MIX/bl_3.0000
9
+
10
+ # With binding energy (needs H atom result dir):
11
+ python test_fwd_rev.py \
12
+ --results_dir of_flows/Results/H2/tf_w_lam0.2_none_lda_none_dopri8_promolecular_sched_MIX/bl_3.0000 \
13
+ --atom_results_dir of_flows/Results/H/tf_w_lam0.2_none_lda_none_dopri8_promolecular_sched_MIX/bl_0.0000
14
+ """
15
+
16
+ import sys, os
17
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "of_flows"))
18
+
19
+ import argparse
20
+ import glob
21
+ import json
22
+ import re
23
+ from pathlib import Path
24
+
25
+ import jax
26
+ import jax.numpy as jnp
27
+ import jax.random as jrnd
28
+ import equinox as eqx
29
+ import matplotlib.pyplot as plt
30
+ import numpy as np
31
+ from pyscf import gto, dft
32
+ from atomdb import make_promolecule
33
+
34
+ jax.config.update("jax_enable_x64", True)
35
+
36
+ from flow.equiv_flows import CNF
37
+ from ode_solver.eqx_ode import fwd_ode, rev_ode
38
+ from utils import one_hot_encode, coordinates, get_solver
39
+ from promolecular.promolecular_dist import ProMolecularDensity, AtomDBDistribution, SIRDistribution
40
+ from train.loss import FUNCTIONAL_CLASSES, _build_kinetic
41
+ from functionals.functional import FunctionalInputs
42
+
43
+ # ── CLI ───────────────────────────────────────────────────────────────────────
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument("--results_dir", type=str, required=True,
46
+ help="Path to bl_X.XXXX result directory (contains job_params.json)")
47
+ parser.add_argument("--atom_results_dir", type=str, default=None,
48
+ help="Path to the H atom bl_0.0000 result directory (for binding energy)")
49
+ parser.add_argument("--bs", type=int, default=256, help="Grid chunk size")
50
+ parser.add_argument("--grid_level", type=int, default=3, help="PySCF grid level")
51
+ args = parser.parse_args()
52
+
53
+ # ── helpers ───────────────────────────────────────────────────────────────────
54
+ def load_results(results_dir: str):
55
+ """Load job_params, find last checkpoint, build and restore the CNF model."""
56
+ rdir = Path(results_dir).resolve()
57
+
58
+ with open(rdir / "job_params.json") as f:
59
+ p = json.load(f)
60
+
61
+ # Find the highest-epoch checkpoint
62
+ ckpts = glob.glob(str(rdir / "Checkpoints" / "checkpoint_*.eqx"))
63
+ if not ckpts:
64
+ raise FileNotFoundError(f"No checkpoints found in {rdir}/Checkpoints/")
65
+ ckpts.sort(key=lambda path: int(re.search(r'checkpoint_(\d+)\.eqx', path).group(1)))
66
+ last_ckpt = ckpts[-1]
67
+ print(f" Loading checkpoint: {last_ckpt}")
68
+
69
+ Ne, atoms, z, coords = coordinates(p['mol_name'], p['bond_length'])
70
+ z_one_hot = one_hot_encode(z)
71
+ key = jrnd.PRNGKey(0)
72
+ model = CNF(din=3, dim=p['hidden_layer'], mu=coords, one_hot=z_one_hot, key=key)
73
+ model = eqx.tree_deserialise_leaves(last_ckpt, model)
74
+ solver = get_solver(p['solver'])
75
+
76
+ return p, model, solver, Ne, atoms, z, coords
77
+
78
+
79
+ def build_prior(p, z, coords, Ne):
80
+ prior = ProMolecularDensity(z.ravel(), coords)
81
+ if p['prior'] == 'db_sir':
82
+ # Direct AtomDB sampling via per-atom inverse-CDF (no SIR needed)
83
+ db_prior = make_promolecule(atnums=z, coords=coords, dataset="slater")
84
+ return AtomDBDistribution(db_prior=db_prior, z=z, coords=coords, Ne=Ne)
85
+ return prior
86
+
87
+
88
+ def build_pyscf_mol(atoms, coords, Ne):
89
+ atom_str = "; ".join(f"{a} {c[0]:.8f} {c[1]:.8f} {c[2]:.8f}"
90
+ for a, c in zip(atoms, coords))
91
+ return gto.M(atom=atom_str, basis="6-31G(d,p)", unit="B",
92
+ verbose=0, spin=int(Ne) % 2)
93
+
94
+
95
+ def compute_rho_on_grid(model, solver, sampling_dist, grid_coords, chunk):
96
+ """Two-pass rev→fwd: get ρ and score at every grid point."""
97
+ x_list, rho_list, score_list = [], [], []
98
+ G = grid_coords.shape[0]
99
+ for i in range(0, G, chunk):
100
+ xc = grid_coords[i:i+chunk]
101
+ n = xc.shape[0]
102
+ state_rev = jnp.concatenate([xc, jnp.zeros((n,1)), jnp.zeros((n,3))], axis=1)
103
+ z_base, _ = rev_ode(model, state_rev, solver)
104
+ log_p0 = sampling_dist.log_prob(z_base)
105
+ score_p0 = sampling_dist.score(z_base)
106
+ state_fwd = jnp.concatenate([z_base, log_p0, score_p0], axis=1)
107
+ x_t1, logp_t1, score_t1 = fwd_ode(model, state_fwd, solver)
108
+ x_list.append(np.array(x_t1))
109
+ rho_list.append(np.array(jnp.exp(logp_t1)).ravel())
110
+ score_list.append(np.array(score_t1))
111
+ return (np.concatenate(x_list),
112
+ np.concatenate(rho_list),
113
+ np.concatenate(score_list))
114
+
115
+
116
+ def compute_energy(p, x_np, rho_np, score_np, grid_coords, grid_weights, mol_dict, Ne, chunk):
117
+ """Quadrature integrals for all energy components."""
118
+ rho_col = rho_np[:, None]
119
+ w = np.array(grid_weights)
120
+ gc = x_np
121
+
122
+ t_func = _build_kinetic(p['kinetic'], p['lam'])
123
+ x_func = FUNCTIONAL_CLASSES[p['exchange']]()
124
+ n_func = FUNCTIONAL_CLASSES[p['external']]()
125
+ h_func = FUNCTIONAL_CLASSES[p['hartree']]()
126
+ c_func = FUNCTIONAL_CLASSES[p['correlation']]() if p['correlation'] != 'none' else None
127
+ cc_func = FUNCTIONAL_CLASSES[p['core_correction']]() if p['core_correction'] != 'none' else None
128
+
129
+ G = rho_np.shape[0]
130
+ t_e = np.zeros(G); x_e = np.zeros(G)
131
+ n_e = np.zeros(G); c_e = np.zeros(G); cc_e = np.zeros(G)
132
+
133
+ for i in range(0, G, chunk):
134
+ sl = slice(i, min(i+chunk, G))
135
+ inp = FunctionalInputs(den=jnp.array(rho_col[sl]), score=jnp.array(score_np[sl]),
136
+ x=jnp.array(gc[sl]), Ne=Ne, mol=mol_dict, xp=None)
137
+ t_e[sl] = np.array(t_func(inp)).ravel()
138
+ x_e[sl] = np.array(x_func(inp)).ravel()
139
+ n_e[sl] = np.array(n_func(inp)).ravel()
140
+ if c_func is not None: c_e[sl] = np.array(c_func(inp)).ravel()
141
+ if cc_func is not None: cc_e[sl] = np.array(cc_func(inp)).ravel()
142
+
143
+ T = float(np.dot(w, t_e * rho_np))
144
+ E_X = float(np.dot(w, x_e * rho_np))
145
+ V_N = float(np.dot(w, n_e * rho_np))
146
+ E_C = float(np.dot(w, c_e * rho_np))
147
+ E_CC= float(np.dot(w, cc_e * rho_np))
148
+
149
+ # Hartree — O(G²) double integral, j≠k
150
+ coords_H = np.array(grid_coords)
151
+ v_coulomb = np.zeros(G)
152
+ for i in range(0, G, chunk):
153
+ xi = coords_H[i:i+chunk]
154
+ diff = coords_H[None,:,:] - xi[:,None,:]
155
+ r2 = np.sum(diff**2, axis=-1)
156
+ safe_r = np.sqrt(np.where(r2 == 0., np.inf, r2))
157
+ v_coulomb[i:i+chunk] = np.dot(1./safe_r, w * rho_np)
158
+ V_H = float(0.5 * Ne**2 * np.dot(w * rho_np, v_coulomb))
159
+
160
+ # Nuclear repulsion
161
+ coords_np = np.array(mol_dict['coords'])
162
+ z_arr = np.array(mol_dict['z']).ravel()
163
+ E_NN = 0.0
164
+ for I in range(len(coords_np)):
165
+ for J in range(I+1, len(coords_np)):
166
+ E_NN += float(z_arr[I]) * float(z_arr[J]) / float(np.linalg.norm(coords_np[I]-coords_np[J]))
167
+
168
+ E_total = T + V_N + V_H + E_X + E_C + E_CC + E_NN
169
+ return dict(T=T, V_N=V_N, V_H=V_H, E_X=E_X, E_C=E_C, E_CC=E_CC, E_NN=E_NN,
170
+ E_total=E_total)
171
+
172
+
173
+ def run_analysis(results_dir: str, chunk: int, grid_level: int):
174
+ """Full analysis for one result directory. Returns energy dict."""
175
+ print(f"\n{'='*60}")
176
+ print(f"Analysing: {results_dir}")
177
+ p, model, solver, Ne, atoms, z, coords = load_results(results_dir)
178
+ mol_dict = {'coords': coords, 'z': z}
179
+ sampling_dist = build_prior(p, z, coords, Ne)
180
+
181
+ mol_pyscf = build_pyscf_mol(atoms, coords, Ne)
182
+ grid = dft.gen_grid.Grids(mol_pyscf)
183
+ grid.level = grid_level
184
+ grid.build()
185
+ grid_coords = jnp.array(grid.coords, dtype=jnp.float64)
186
+ grid_weights = jnp.array(grid.weights, dtype=jnp.float64)
187
+ print(f" Grid: {grid_coords.shape[0]} points (level={grid_level})")
188
+
189
+ print(" Computing ρ via rev→fwd ...")
190
+ x_np, rho_np, score_np = compute_rho_on_grid(
191
+ model, solver, sampling_dist, grid_coords, chunk)
192
+
193
+ pos_err = float(np.max(np.abs(x_np - np.array(grid_coords))))
194
+ Ne_est = float(np.dot(np.array(grid_weights), Ne * rho_np))
195
+ print(f" Round-trip error : {pos_err:.3e}")
196
+ print(f" ∫ρ_M dx : {Ne_est:.6f} (should be {Ne})")
197
+
198
+ energies = compute_energy(
199
+ p, x_np, rho_np, score_np,
200
+ grid_coords, grid_weights, mol_dict, Ne, chunk)
201
+
202
+ print(f"\n === ENERGY ({p['kinetic']} / λ={p['lam']} / {p['exchange']}) ===")
203
+ print(f" T = {energies['T']:+.6f} Ha")
204
+ print(f" V_N = {energies['V_N']:+.6f} Ha")
205
+ print(f" V_H = {energies['V_H']:+.6f} Ha")
206
+ print(f" E_X = {energies['E_X']:+.6f} Ha")
207
+ if p['correlation'] != 'none': print(f" E_C = {energies['E_C']:+.6f} Ha")
208
+ if p['core_correction']!= 'none': print(f" E_CC = {energies['E_CC']:+.6f} Ha")
209
+ if energies['E_NN'] != 0.0: print(f" E_NN = {energies['E_NN']:+.6f} Ha")
210
+ print(f" ─────────────────────")
211
+ print(f" E_tot = {energies['E_total']:+.6f} Ha")
212
+ if p['mol_name'] == 'H':
213
+ print(f" (exact H = -0.500000 Ha)")
214
+
215
+ return p, model, solver, Ne, atoms, coords, grid_coords, grid_weights, \
216
+ rho_np, score_np, sampling_dist, energies
217
+
218
+
219
+ # ── Main molecule ─────────────────────────────────────────────────────────────
220
+ (p, model, solver, Ne, atoms, coords, grid_coords, grid_weights,
221
+ rho_np, score_np, sampling_dist, energies) = run_analysis(
222
+ args.results_dir, args.bs, args.grid_level)
223
+
224
+ # ── Atom reference for binding energy ─────────────────────────────────────────
225
+ if args.atom_results_dir is not None:
226
+ _, _, _, _, _, _, _, _, _, _, _, energies_H = run_analysis(
227
+ args.atom_results_dir, args.bs, args.grid_level)
228
+
229
+ E_mol = energies['E_total']
230
+ E_atom = energies_H['E_total']
231
+ E_bind = E_mol - 2.0 * E_atom # negative = bound
232
+ D_e = -E_bind # dissociation energy (positive = stable)
233
+
234
+ print(f"\n=== BINDING ENERGY ===")
235
+ print(f" E({p['mol_name']}, R={p['bond_length']:.4f} Bohr) = {E_mol:+.6f} Ha")
236
+ print(f" E(H atom) = {E_atom:+.6f} Ha")
237
+ print(f" E_bind = E(mol) - 2·E(H) = {E_bind:+.6f} Ha")
238
+ print(f" D_e = 2·E(H) - E(mol) = {D_e:+.6f} Ha ({D_e*27.2114:.4f} eV)")
239
+
240
+ # ── Density plot along z-axis ─────────────────────────────────────────────────
241
+ z_min = float(coords[:, 2].min()) - 3.0
242
+ z_max = float(coords[:, 2].max()) + 3.0
243
+ zt = np.linspace(z_min, z_max, 300)
244
+ line_pts = jnp.array(np.stack([np.zeros_like(zt),
245
+ np.zeros_like(zt),
246
+ zt], axis=1), dtype=jnp.float64)
247
+
248
+ print("\nComputing density along z-axis ...")
249
+ rho_line = []
250
+ for i in range(0, line_pts.shape[0], args.bs):
251
+ xc = line_pts[i:i+args.bs]
252
+ n = xc.shape[0]
253
+ state_rev = jnp.concatenate([xc, jnp.zeros((n,1)), jnp.zeros((n,3))], axis=1)
254
+ z_b, _ = rev_ode(model, state_rev, solver)
255
+ log_p0 = sampling_dist.log_prob(z_b)
256
+ score_p0 = sampling_dist.score(z_b)
257
+ _, logp_fwd, _ = fwd_ode(model,
258
+ jnp.concatenate([z_b, log_p0, score_p0], axis=1), solver)
259
+ rho_line.append(np.array(jnp.exp(logp_fwd)).ravel())
260
+ rho_pred = np.concatenate(rho_line)
261
+
262
+ R = float(jnp.linalg.norm(coords[0] - coords[-1])) if len(coords) > 1 else 0.0
263
+
264
+ fig, ax = plt.subplots(figsize=(6, 4))
265
+ ax.plot(zt, Ne * rho_pred, color='tab:blue',
266
+ label=rf"$N_e\,\rho_{{NF}}(z)$, R={R:.3f} Bohr")
267
+ ax.set_xlabel("z [Bohr]")
268
+ ax.set_ylabel(r"$\rho(z)$ [Bohr$^{-3}$]")
269
+ ax.set_title(f"{p['mol_name']} | {p['kinetic']} λ={p['lam']} | {p['exchange']}")
270
+ ax.legend()
271
+ fig.tight_layout()
272
+
273
+ out_dir = Path(args.results_dir).resolve()
274
+ fig.savefig(out_dir / "density.svg", transparent=True)
275
+ fig.savefig(out_dir / "density.png", dpi=150)
276
+ print(f"Density plot saved → {out_dir}/density.png")
277
+
278
+ # ── Save energy summary ───────────────────────────────────────────────────────
279
+ summary = {**energies,
280
+ 'mol_name': p['mol_name'],
281
+ 'bond_length': p['bond_length'],
282
+ 'Ne_integral': float(np.dot(np.array(grid_weights), Ne * rho_np))}
283
+ if args.atom_results_dir is not None:
284
+ summary['E_atom'] = energies_H['E_total']
285
+ summary['E_bind'] = E_bind
286
+ summary['D_e'] = D_e
287
+
288
+ with open(out_dir / "energy_summary.json", "w") as f:
289
+ json.dump(summary, f, indent=4)
290
+ print(f"Energy summary saved → {out_dir}/energy_summary.json")
off/train/__init__.py ADDED
@@ -0,0 +1,44 @@
1
+ # MIT License
2
+
3
+ # Copyright (c) 2025 AlexandreDeCamargo
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ __version__ = "0.1.0"
24
+
25
+ from .loop import (
26
+ setup_molecule,
27
+ setup_model,
28
+ setup_optimizer,
29
+ setup_ema,
30
+ log_metrics,
31
+ training,
32
+ )
33
+
34
+ from .loss import (
35
+ create_loss_function
36
+ )
37
+
38
+ from .utils import (
39
+ step
40
+ )
41
+
42
+ from ..config._config import (
43
+ Config
44
+ )
off/train/loop.py ADDED
@@ -0,0 +1,228 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import jax.random as jrnd
4
+ import equinox as eqx
5
+ import optax
6
+ from optax import ema
7
+ import pandas as pd
8
+ import time
9
+ from typing import Optional
10
+
11
+ from ..flow.equiv_flows import CNF
12
+ from ..utils import one_hot_encode, coordinates, batch_generator, get_solver, get_scheduler
13
+ from ..promolecular.promolecular_dist import AtomDBDistribution,SIRDistribution,ProMolecularDensity
14
+ from .utils import step
15
+ from .loss import create_loss_function, F_values
16
+ from ..config._config import Config
17
+
18
+ jax.config.update("jax_enable_x64", True)
19
+
20
+
21
+ def setup_molecule(mol_name: str, bond_length: float = 0.74144):
22
+ """Setup molecular system."""
23
+ Ne, atoms, z, coords = coordinates(mol_name, bond_length)
24
+ mol = {'coords': coords, 'z': z}
25
+ return Ne, atoms, z, coords, mol
26
+
27
+
28
+ def setup_model(coords, z, hidden_layer: int, key):
29
+ """Initialize flow model."""
30
+ mu = coords
31
+ z_one_hot = one_hot_encode(z)
32
+ data_dim = 3
33
+ return CNF(data_dim, hidden_layer, mu, z_one_hot, key)
34
+
35
+
36
+ def setup_optimizer(flow_model, epochs: int, lr: float, scheduler_type: str):
37
+ """Setup optimizer with scheduler."""
38
+ _lr = get_scheduler(epochs=epochs, sched_type=scheduler_type, lr=lr)
39
+ optimizer = optax.chain(
40
+ optax.clip_by_global_norm(1.0),
41
+ optax.adamw(_lr, weight_decay=1e-5)
42
+ )
43
+ optimizer_state = optimizer.init(eqx.filter(flow_model, eqx.is_array))
44
+ return optimizer, optimizer_state
45
+
46
+
47
+ def setup_ema():
48
+ """Setup EMA for tracking energies."""
49
+ energies_ema = ema(decay=0.99)
50
+ energies_state = energies_ema.init(
51
+ F_values(energy=jnp.array(0.), kin=jnp.array(0.),
52
+ vnuc=jnp.array(0.), hart=jnp.array(0.),
53
+ xc=jnp.array(0.), cc=jnp.array(0.))
54
+ )
55
+ return energies_ema, energies_state
56
+
57
+
58
+ def log_metrics(itr: int, loss_epoch: float, losses: F_values,
59
+ energies_i_ema: F_values, elapsed_time: float):
60
+ """Create metrics dictionaries for logging."""
61
+ r_instant = {
62
+ 'epoch': itr,
63
+ 'E': loss_epoch - losses.cc,
64
+ 'T': losses.kin,
65
+ 'V': losses.vnuc,
66
+ 'H': losses.hart,
67
+ 'XC': losses.xc,
68
+ 'CC': losses.cc,
69
+ 't': elapsed_time
70
+ }
71
+
72
+ r_ema = {
73
+ 'epoch': itr,
74
+ 'E': energies_i_ema.energy - energies_i_ema.cc,
75
+ 'T': energies_i_ema.kin,
76
+ 'V': energies_i_ema.vnuc,
77
+ 'H': energies_i_ema.hart,
78
+ 'XC': energies_i_ema.xc,
79
+ 'CC': energies_i_ema.cc,
80
+ 't': elapsed_time
81
+ }
82
+
83
+ return r_instant, r_ema
84
+
85
+
86
+ def training(mol_name: str,
87
+ bond_length: float = 1.4008538753,
88
+ tw_kin: str = 'tf_w',
89
+ lam: float = 1.0,
90
+ n_pot: str = 'np',
91
+ h_pot: str = 'coulomb',
92
+ x_pot: str = 'lda',
93
+ c_pot: str = 'vwn_c',
94
+ cc_pot: str = 'kato',
95
+ batch_size: int = 256,
96
+ hidden_layer: int = 64,
97
+ epochs: int = 100,
98
+ lr: float = 1e-5,
99
+ scheduler_type: str = 'ones',
100
+ solver_type: str = 'tsit5',
101
+ prior_type: str = 'promolecular',
102
+ prior_dist: Optional[ProMolecularDensity] = None,
103
+ checkpoint_dir: str = './checkpoints',
104
+ checkpoint_freq: int = 50,
105
+ ):
106
+ """
107
+ Main training loop.
108
+
109
+ Parameters
110
+ ----------
111
+ mol_name : str
112
+ Name of molecule
113
+ bond_length: float
114
+ Bond length in a.u.
115
+ tw_kin : str
116
+ Kinetic functional name
117
+ n_pot : str
118
+ External potential functional name
119
+ h_pot : str
120
+ Hartree functional name
121
+ x_pot : str
122
+ Exchange functional name
123
+ c_pot : str
124
+ Correlation functional name
125
+ cc_pot : str
126
+ Core correction functional name
127
+ batch_size : int
128
+ Batch size for training
129
+ hidden_layer : int
130
+ Hidden layer size for neural network
131
+ epochs : int
132
+ Number of training epochs
133
+ lr : float
134
+ Learning rate
135
+ scheduler_type : str
136
+ Type of learning rate scheduler
137
+ solver_type : str
138
+ ODE solver type
139
+ prior_type: str
140
+ Type of prior distribution for sampling
141
+ prior_dist : ProMolecularDensity, optional
142
+ Initial distribution
143
+ checkpoint_dir : str
144
+ Directory to save checkpoints
145
+ checkpoint_freq : int
146
+ Frequency of checkpoint saving
147
+
148
+ Returns
149
+ -------
150
+ flow_model : CNF
151
+ Trained flow model
152
+ df : pd.DataFrame
153
+ Training metrics
154
+ df_ema : pd.DataFrame
155
+ EMA training metrics
156
+ """
157
+
158
+ # Setup
159
+ Ne, atoms, z, coords, mol = setup_molecule(mol_name, bond_length)
160
+
161
+ key = jrnd.PRNGKey(0)
162
+ _, key = jrnd.split(key)
163
+
164
+ flow_model = setup_model(coords, z, hidden_layer, key)
165
+ solver = get_solver(solver_type)
166
+ optimizer, optimizer_state = setup_optimizer(flow_model, epochs, lr, scheduler_type)
167
+ energies_ema, energies_state = setup_ema()
168
+ prior_dist = ProMolecularDensity(z.ravel(), coords)
169
+
170
+ if prior_type == 'db_sir':
171
+ from atomdb import make_promolecule # optional dep — only needed for db_sir
172
+ db_prior = make_promolecule(atnums=z, coords=coords, dataset="slater")
173
+ sampling_dist = AtomDBDistribution(
174
+ db_prior=db_prior, z=z, coords=coords, Ne=Ne
175
+ )
176
+ else:
177
+ sampling_dist = prior_dist
178
+
179
+ gen_batches = batch_generator(key, batch_size, sampling_dist)
180
+
181
+ grad_loss_fn = create_loss_function(
182
+ kinetic_name=tw_kin,
183
+ lam=lam,
184
+ exchange_name=x_pot,
185
+ correlation_name=c_pot,
186
+ hartree_name=h_pot,
187
+ external_name=n_pot,
188
+ core_correction_name=cc_pot
189
+ )
190
+
191
+ # Training loop
192
+ df = pd.DataFrame()
193
+ df_ema = pd.DataFrame()
194
+
195
+ for itr in range(epochs + 1):
196
+ start_time = time.time()
197
+
198
+ batch = next(gen_batches)
199
+ # batch = next(db_gen_batches)
200
+
201
+ loss, flow_model, optimizer_state = step(
202
+ flow_model, batch, optimizer, optimizer_state,
203
+ grad_loss_fn, solver, Ne, mol
204
+ )
205
+
206
+ elapsed_time = time.time() - start_time
207
+
208
+ loss_epoch, losses = loss
209
+
210
+ # Update EMA
211
+ energies_i_ema, energies_state = energies_ema.update(losses, energies_state)
212
+
213
+ # Log metrics
214
+ r_instant, r_ema = log_metrics(itr, loss_epoch, losses, energies_i_ema, elapsed_time)
215
+
216
+ df = pd.concat([df, pd.DataFrame([r_instant])], ignore_index=True)
217
+ df_ema = pd.concat([df_ema, pd.DataFrame([r_ema])], ignore_index=True)
218
+
219
+ print(f"Epoch {itr}: {r_ema}")
220
+
221
+ df.to_csv(f"{Config.results_dir}/training_metrics.csv", index=False)
222
+ df_ema.to_csv(f"{Config.results_dir}/training_metrics_ema.csv", index=False)
223
+
224
+ # Save checkpoint
225
+ if itr % checkpoint_freq == 0 or itr == epochs:
226
+ eqx.tree_serialise_leaves(f"{checkpoint_dir}/checkpoint_{itr}.eqx", flow_model)
227
+
228
+ return flow_model, df, df_ema