off 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
off/quadrature.py ADDED
@@ -0,0 +1,261 @@
1
+ import glob
2
+ import json
3
+ import re
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import jax.random as jrnd
10
+ import equinox as eqx
11
+ from pyscf import gto, dft
12
+
13
+ jax.config.update("jax_enable_x64", True)
14
+
15
+ from .flow.equiv_flows import CNF
16
+ from .ode_solver.eqx_ode import fwd_ode, rev_ode
17
+ from .utils import one_hot_encode, coordinates, get_solver
18
+ from .promolecular.promolecular_dist import ProMolecularDensity, AtomDBDistribution
19
+ from .train.loss import build_energy_functional
20
+
21
+ AA_TO_BOHR = 1.8897259886
22
+
23
+
24
+ # ── model / prior loading ─────────────────────────────────────────────────────
25
+ def last_checkpoint(results_dir):
26
+ """(path, epoch) of the highest-epoch checkpoint in results_dir/Checkpoints/."""
27
+ ckpts = glob.glob(str(Path(results_dir) / "Checkpoints" / "checkpoint_*.eqx"))
28
+ if not ckpts:
29
+ raise FileNotFoundError(f"No checkpoints in {results_dir}/Checkpoints/")
30
+ ckpts.sort(key=lambda p: int(re.search(r'checkpoint_(\d+)\.eqx', p).group(1)))
31
+ last = ckpts[-1]
32
+ return last, int(re.search(r'checkpoint_(\d+)\.eqx', last).group(1))
33
+
34
+
35
+ def load_model(results_dir, p):
36
+ """Rebuild the CNF for job_params `p` and load its last checkpoint."""
37
+ Ne, atoms, z, coords = coordinates(p['mol_name'], p['bond_length'])
38
+ model = CNF(din=3, dim=p['hidden_layer'], mu=coords,
39
+ one_hot=one_hot_encode(z), key=jrnd.PRNGKey(0))
40
+ ckpt, epoch = last_checkpoint(results_dir)
41
+ model = eqx.tree_deserialise_leaves(ckpt, model)
42
+ return model, get_solver(p['solver']), Ne, atoms, z, coords, epoch
43
+
44
+
45
+ def build_prior(p, z, coords, Ne):
46
+ """Rebuild the base distribution used at training time (must match it)."""
47
+ if p.get('prior') == 'db_sir':
48
+ from atomdb import make_promolecule
49
+ db_prior = make_promolecule(atnums=z, coords=coords, dataset="hci")
50
+ return AtomDBDistribution(db_prior=db_prior, z=z, coords=coords, Ne=Ne)
51
+ return ProMolecularDensity(z.ravel(), coords)
52
+
53
+
54
+ # ── grid construction (PySCF) ─────────────────────────────────────────────────
55
+ def _grids_from_mol(mol, level):
56
+ """Build a PySCF Becke grid for `mol`; return (coords, weights) in Bohr."""
57
+ grid = dft.gen_grid.Grids(mol)
58
+ grid.level = level
59
+ grid.build()
60
+ return (jnp.asarray(grid.coords, dtype=jnp.float64),
61
+ jnp.asarray(grid.weights, dtype=jnp.float64))
62
+
63
+
64
+ def build_grid(atoms, coords, Ne, grid_level=3, basis="6-31G(d,p)", unit="B"):
65
+ """PySCF molecular quadrature grid. `coords` are interpreted in `unit`
66
+ ('B'/'Bohr' or 'Angstrom'); the returned grid coords/weights are in Bohr."""
67
+ atom_str = "; ".join(f"{a} {c[0]:.10f} {c[1]:.10f} {c[2]:.10f}"
68
+ for a, c in zip(atoms, np.asarray(coords)))
69
+ mol = gto.M(atom=atom_str, basis=basis, unit=unit, verbose=0, spin=int(Ne) % 2)
70
+ return _grids_from_mol(mol, grid_level)
71
+
72
+
73
+ def get_grid(geom, level=3, *, units="angstrom", basis="6-31G(d,p)", spin=0):
74
+ """User-facing quadrature grid.
75
+
76
+ Parameters
77
+ ----------
78
+ geom : str
79
+ Geometry in PySCF's ``atom=`` format — e.g. ``"H 0 0 0; H 0 0 0.74"``
80
+ or a multi-line XYZ-style block.
81
+ level : int
82
+ PySCF grid level (the "grid size");
83
+ units : {'angstrom', 'bohr'}
84
+ Units the geometry is given in (PySCF's default is angstrom).
85
+ basis, spin :
86
+ Forwarded to ``pyscf.gto.M``. The basis only sets the atom-centred
87
+ grid partitioning; it does not affect the flow density.
88
+
89
+ Returns
90
+ -------
91
+ (weights, coords) :
92
+ Note the order, to match the listing ``w_grid, x_grid = get_grid(...)``;
93
+ the internal :func:`build_grid` returns the opposite ``(coords, weights)``.
94
+ """
95
+ unit = "Bohr" if str(units).lower().startswith("b") else "Angstrom"
96
+ mol = gto.M(atom=geom, basis=basis, unit=unit, spin=spin, verbose=0)
97
+ coords, weights = _grids_from_mol(mol, level)
98
+ return weights, coords
99
+
100
+
101
+ getGrid = get_grid
102
+
103
+
104
+ def rho_on_grid(model, solver, prior, grid_coords, chunk=256):
105
+ """Evaluate (positions, ρ_φ, score = ∇log ρ_φ) at the grid points."""
106
+ x_l, rho_l, sc_l = [], [], []
107
+ for i in range(0, grid_coords.shape[0], chunk):
108
+ xc = grid_coords[i:i+chunk]; n = xc.shape[0]
109
+ st_r = jnp.concatenate([xc, jnp.zeros((n, 1)), jnp.zeros((n, 3))], axis=1)
110
+ zb, _ = rev_ode(model, st_r, solver)
111
+ lp0 = prior.log_prob(zb)
112
+ sc0 = prior.score(zb)
113
+ xt, lpt, sct = fwd_ode(model, jnp.concatenate([zb, lp0, sc0], axis=1), solver)
114
+ x_l.append(np.array(xt))
115
+ rho_l.append(np.array(jnp.exp(lpt)).ravel())
116
+ sc_l.append(np.array(sct))
117
+ return np.concatenate(x_l), np.concatenate(rho_l), np.concatenate(sc_l)
118
+
119
+
120
+ def quadrature_energy(functional, x_np, rho_np, sc_np, grid_coords, grid_weights,
121
+ mol_dict, Ne, chunk=256):
122
+ """Integrate every energy term on the grid.
123
+
124
+ The local terms use ``functional``'s component functionals; the Hartree term
125
+ is the grid double sum (true 1/r), not the functional's MC pairwise estimator.
126
+ """
127
+ w = np.array(grid_weights)
128
+ G = rho_np.shape[0]
129
+ rc = rho_np[:, None]
130
+ measure = jnp.asarray(w * rho_np)
131
+
132
+ def _args(sl):
133
+ return (jnp.array(rc[sl]), jnp.array(sc_np[sl]), jnp.array(x_np[sl]),
134
+ Ne, mol_dict, None)
135
+
136
+ def local(func): # ∫ f(...)·ρ dr via the shared _integrate
137
+ out = np.zeros(G)
138
+ for i in range(0, G, chunk):
139
+ sl = slice(i, min(i + chunk, G))
140
+ out[sl] = np.array(func(*_args(sl))).ravel()
141
+ return float(functional._integrate(jnp.asarray(out), measure))
142
+
143
+ T = local(functional.kinetic)
144
+ E_X = local(functional.exchange)
145
+ E_C = local(functional.correlation) if functional.correlation is not None else 0.0
146
+ V_N = local(functional.external)
147
+ E_CC = local(functional.core_correction) if functional.core_correction is not None else 0.0
148
+
149
+ # Hartree — grid double sum
150
+ gc = np.array(grid_coords)
151
+ vc = np.zeros(G)
152
+ for i in range(0, G, chunk):
153
+ xi = gc[i:i+chunk]
154
+ r2 = np.sum((gc[None, :, :] - xi[:, None, :]) ** 2, axis=-1)
155
+ vc[i:i+chunk] = np.dot(1. / np.sqrt(np.where(r2 == 0., np.inf, r2)), w * rho_np)
156
+ V_H = float(0.5 * Ne ** 2 * functional._integrate(jnp.asarray(vc), measure))
157
+
158
+ # Nuclear repulsion
159
+ cn = np.array(mol_dict['coords']); zn = np.array(mol_dict['z']).ravel()
160
+ E_NN = sum(zn[I] * zn[J] / float(np.linalg.norm(cn[I] - cn[J]))
161
+ for I in range(len(cn)) for J in range(I + 1, len(cn)))
162
+
163
+ return dict(T=T, V_N=V_N, V_H=V_H, E_X=E_X, E_C=E_C, E_CC=E_CC, E_NN=E_NN,
164
+ E_total=T + V_N + V_H + E_X + E_C + E_CC + E_NN)
165
+
166
+
167
+ # ── high-level entry points ───────────────────────────────────────────────────
168
+ def grid_energy(model, prior, solver, coords, z, atoms, Ne, functional, *,
169
+ grid_level=3, units="bohr", basis="6-31G(d,p)", chunk=256):
170
+ """Build the grid, evaluate ρ_φ, and integrate all energy terms.
171
+
172
+ Parameters
173
+ ----------
174
+ model, prior, solver : the trained CNF, its base distribution, ODE solver.
175
+ coords, z, atoms, Ne : molecular geometry / charges / electron count.
176
+ functional : an EnergyFunctional (e.g. from build_energy_functional).
177
+ grid_level : PySCF grid level (the "grid size").
178
+ units : 'bohr' or 'angstrom' — how `coords` are given; the flow
179
+ works in Bohr, so 'angstrom' inputs are converted.
180
+ Returns the energy dict plus 'Ne_integral'.
181
+ """
182
+ coords = np.asarray(coords, dtype=float)
183
+ unit = "Bohr" if str(units).lower().startswith("b") else "Angstrom"
184
+ coords_bohr = coords if unit == "Bohr" else coords * AA_TO_BOHR
185
+ gc, gw = build_grid(atoms, coords, Ne, grid_level=grid_level, basis=basis, unit=unit)
186
+ x_np, rho_np, sc_np = rho_on_grid(model, solver, prior, gc, chunk=chunk)
187
+ mol_dict = {'coords': jnp.asarray(coords_bohr, dtype=jnp.float64), 'z': jnp.asarray(z)}
188
+ en = quadrature_energy(functional, x_np, rho_np, sc_np, gc, gw, mol_dict, Ne, chunk=chunk)
189
+ en['Ne_integral'] = float(np.dot(np.array(gw), Ne * rho_np))
190
+ return en
191
+
192
+
193
+ def grid_energy_from_checkpoint(results_dir, *, grid_level=3, basis="6-31G(d,p)",
194
+ units="bohr", chunk=256, recompute=False, cache=True):
195
+ """One call from a trained run directory: read job_params.json (functional +
196
+ geometry), load the last checkpoint, and integrate. The geometry comes from
197
+ ``coordinates`` (always Bohr), so no units flag is needed here.
198
+
199
+ The result is cached in ``results_dir/energy_summary.json`` — pass
200
+ ``recompute=True`` to ignore an existing cache, or ``cache=False`` to skip
201
+ reading/writing it.
202
+ """
203
+ results_dir = Path(results_dir)
204
+ summary_path = results_dir / "energy_summary.json"
205
+ if cache and not recompute and summary_path.exists():
206
+ with open(summary_path) as f:
207
+ return json.load(f)
208
+
209
+ with open(results_dir / "job_params.json") as f:
210
+ p = json.load(f)
211
+ model, solver, Ne, atoms, z, coords, epoch = load_model(results_dir, p)
212
+ prior = build_prior(p, z, coords, Ne)
213
+ functional = build_energy_functional(
214
+ kinetic_name=p['kinetic'], lam=p['lam'], exchange_name=p['exchange'],
215
+ correlation_name=p['correlation'], hartree_name=p['hartree'],
216
+ external_name=p['external'], core_correction_name=p['core_correction'],
217
+ )
218
+ en = grid_energy(model, prior, solver, coords, z, atoms, Ne, functional,
219
+ grid_level=grid_level, units=units, basis=basis, chunk=chunk)
220
+ en.update(epoch=epoch, mol_name=p['mol_name'], bond_length=p['bond_length'])
221
+ if cache:
222
+ with open(summary_path, "w") as f:
223
+ json.dump(en, f, indent=4)
224
+ return en
225
+
226
+
227
+ def _print_energy(results_dir, en):
228
+ print(f"\n{results_dir}")
229
+ print(f" mol={en.get('mol_name')} R={en.get('bond_length')} epoch={en.get('epoch')}")
230
+ print(" " + "-" * 36)
231
+ for k in ("T", "V_N", "V_H", "E_X", "E_C", "E_CC", "E_NN"):
232
+ print(f" {k:8s} = {en[k]:+.6f} Ha")
233
+ print(" " + "-" * 36)
234
+ print(f" {'E_total':8s} = {en['E_total']:+.6f} Ha")
235
+ print(f" {'N_e':8s} = {en['Ne_integral']:.4f} (∫ρ, should be Ne)")
236
+
237
+
238
+ def main():
239
+ import argparse
240
+ ap = argparse.ArgumentParser(
241
+ description="Grid (quadrature) energy of a trained OFF run directory.")
242
+ ap.add_argument("results_dir", nargs="+",
243
+ help="bl_* run dir(s) with job_params.json and Checkpoints/ "
244
+ "(shell globs like Results/H2/<method>/bl_* are fine)")
245
+ ap.add_argument("--grid_level", type=int, default=1, help="PySCF grid level")
246
+ ap.add_argument("--bs", type=int, default=256, help="grid chunk size")
247
+ ap.add_argument("--basis", type=str, default="6-31G(d,p)",
248
+ help="PySCF basis (sets grid partitioning only)")
249
+ ap.add_argument("--recompute", action="store_true",
250
+ help="ignore cached energy_summary.json and recompute")
251
+ args = ap.parse_args()
252
+
253
+ for rd in args.results_dir:
254
+ en = grid_energy_from_checkpoint(
255
+ rd, grid_level=args.grid_level, basis=args.basis,
256
+ chunk=args.bs, recompute=args.recompute)
257
+ _print_energy(rd, en)
258
+
259
+
260
+ if __name__ == "__main__":
261
+ main()
off/quadrature_scan.py ADDED
@@ -0,0 +1,188 @@
1
+ """
2
+ Grid-quadrature total energy for every bond length of a molecule.
3
+
4
+ Thin CLI around ``of_flows/quadrature.py``: it walks every method directory and
5
+ every bl_* subdirectory under Results/{mol}/, calls
6
+ ``quadrature.grid_energy_from_checkpoint`` on each (which builds the PySCF grid,
7
+ evaluates ρ_φ via the flow, and integrates all energy terms), and writes one CSV
8
+ per molecule. For molecules it also integrates the constituent single atoms
9
+ under the same method tag and reports the binding energy.
10
+
11
+ Directory layout assumed (same as main.py):
12
+ Results/{mol}/{method}/bl_X.XXXX/
13
+ Checkpoints/checkpoint_*.eqx
14
+ job_params.json
15
+ Results/{atom}/{method}/bl_0.0000/ (binding reference)
16
+ Results/ is located next to this script, so it runs from anywhere.
17
+
18
+ Usage
19
+ -----
20
+ python quadrature_scan.py --H2
21
+ python quadrature_scan.py --H2 --N2
22
+ python quadrature_scan.py --mol H2 H10
23
+ python quadrature_scan.py --H10 --recompute
24
+
25
+ Output (one CSV per molecule, under Results/{mol}/):
26
+ Results/{mol}/quadrature_{mol}.csv
27
+ columns: method, R_bohr, epoch, E_total, T, V_N, V_H, E_X, E_C, E_CC,
28
+ E_NN, Ne_int, E_atoms, dE_bind_Ha
29
+ """
30
+
31
+ import sys, os
32
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
33
+
34
+ import gc
35
+ import re
36
+ import argparse
37
+ from pathlib import Path
38
+
39
+ import jax
40
+ import pandas as pd
41
+
42
+ from quadrature import grid_energy_from_checkpoint
43
+
44
+ _SCRIPT_DIR = Path(__file__).resolve().parent
45
+
46
+ KNOWN_MOLS = ["H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne",
47
+ "H2", "N2", "O2", "F2", "HF", "CO", "LiH", "H10"]
48
+ SINGLE_ATOMS = {"H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne"}
49
+
50
+ # ── CLI ───────────────────────────────────────────────────────────────────────
51
+ parser = argparse.ArgumentParser(
52
+ description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter,
53
+ allow_abbrev=False) # so --H is not treated as a prefix of --H2 / --H10
54
+ for _m in KNOWN_MOLS:
55
+ parser.add_argument(f"--{_m}", action="store_true", help=f"Scan molecule {_m}")
56
+ parser.add_argument("--mol", type=str, nargs="+", default=[], metavar="NAME",
57
+ help="Molecule name(s) to scan (alternative to the flags)")
58
+ parser.add_argument("--results_root", type=str, default=None,
59
+ help="Override Results root (default: <script_dir>/Results)")
60
+ parser.add_argument("--bs", type=int, default=256, help="Grid chunk size")
61
+ parser.add_argument("--grid_level", type=int, default=3, help="PySCF grid level")
62
+ parser.add_argument("--recompute", action="store_true",
63
+ help="Re-run grid integration even if energy_summary.json is cached")
64
+ parser.add_argument("--out", type=str, default=None,
65
+ help="Output CSV path (default: Results/{mol}/quadrature_{mol}.csv)")
66
+ args = parser.parse_args()
67
+
68
+ selected = list(args.mol) + [m for m in KNOWN_MOLS if getattr(args, m)]
69
+ selected = list(dict.fromkeys(selected))
70
+ if not selected:
71
+ parser.error("No molecule selected. Use a flag (e.g. --H2) or --mol H2 [...].")
72
+
73
+ root = Path(args.results_root).resolve() if args.results_root else (_SCRIPT_DIR / "Results")
74
+ print(f"Results root : {root}\n")
75
+
76
+
77
+ def constituents(mol: str) -> dict:
78
+ """{element: count} from a formula, e.g. N2->{N:2}, HF->{H:1,F:1}, H10->{H:10}."""
79
+ out = {}
80
+ for el, n in re.findall(r"([A-Z][a-z]?)(\d*)", mol):
81
+ if el:
82
+ out[el] = out.get(el, 0) + (int(n) if n else 1)
83
+ return out
84
+
85
+
86
+ def atom_reference(method_name: str, mol: str):
87
+ """Grid energy reference Σ_atoms count·E(atom) under the same method tag.
88
+ Returns (E_atoms, {element: E_atom}) or (None, None) if any atom is missing."""
89
+ total = 0.0
90
+ detail = {}
91
+ for el, n in constituents(mol).items():
92
+ adir = root / el / method_name / "bl_0.0000"
93
+ if not (adir / "job_params.json").exists():
94
+ print(f" atom reference: {el} not found at {adir} — binding skipped")
95
+ return None, None
96
+ try:
97
+ data = grid_energy_from_checkpoint(
98
+ adir, grid_level=args.grid_level, chunk=args.bs, recompute=args.recompute)
99
+ except Exception as e:
100
+ print(f" atom reference: {el} FAILED — {e}")
101
+ return None, None
102
+ detail[el] = data['E_total']
103
+ total += n * data['E_total']
104
+ return total, detail
105
+
106
+
107
+ def scan_molecule(mol: str):
108
+ mol_dir = root / mol
109
+ if not mol_dir.is_dir():
110
+ print(f"[{mol}] SKIP — {mol_dir} not found\n")
111
+ return
112
+
113
+ is_atom = mol in SINGLE_ATOMS
114
+ rows = []
115
+ for method_dir in sorted(d for d in mol_dir.iterdir() if d.is_dir()):
116
+ bl_dirs = sorted(method_dir.glob("bl_*"),
117
+ key=lambda d: float(d.name.split("_")[1]))
118
+ if not bl_dirs:
119
+ continue
120
+ print(f"[{mol}] method: {method_dir.name} ({len(bl_dirs)} bond lengths)")
121
+
122
+ # Single-atom reference for the binding energy (same method tag).
123
+ E_atoms = None
124
+ if not is_atom:
125
+ E_atoms, detail = atom_reference(method_dir.name, mol)
126
+ if E_atoms is not None:
127
+ ref = " ".join(f"{n}*E({el})={detail[el]:+.6f}"
128
+ for el, n in constituents(mol).items())
129
+ print(f" atom reference (grid): {ref} -> Σ = {E_atoms:+.6f} Ha")
130
+
131
+ for bl_dir in bl_dirs:
132
+ if not (bl_dir / "job_params.json").exists():
133
+ print(f" {bl_dir.name}: missing job_params.json — skipping")
134
+ continue
135
+ try:
136
+ data = grid_energy_from_checkpoint(
137
+ bl_dir, grid_level=args.grid_level, chunk=args.bs,
138
+ recompute=args.recompute)
139
+ except Exception as e:
140
+ print(f" {bl_dir.name}: FAILED — {e}")
141
+ continue
142
+ row = {
143
+ "method": method_dir.name,
144
+ "R_bohr": data['bond_length'],
145
+ "epoch": data.get('epoch', '?'),
146
+ "E_total": data['E_total'],
147
+ "T": data['T'],
148
+ "V_N": data['V_N'],
149
+ "V_H": data['V_H'],
150
+ "E_X": data['E_X'],
151
+ "E_C": data.get('E_C', 0.0),
152
+ "E_CC": data.get('E_CC', 0.0),
153
+ "E_NN": data['E_NN'],
154
+ "Ne_int": data['Ne_integral'],
155
+ }
156
+ if E_atoms is not None:
157
+ row["E_atoms"] = E_atoms
158
+ row["dE_bind_Ha"] = E_atoms - data['E_total'] # ΔE = ΣE(atom) - E(mol)
159
+ rows.append(row)
160
+ msg = (f" R={data['bond_length']:.4f} Bohr epoch={data.get('epoch','?'):>6}"
161
+ f" E_total={data['E_total']:+.6f} Ha")
162
+ if E_atoms is not None:
163
+ msg += f" ΔE={E_atoms - data['E_total']:+.6f} Ha"
164
+ print(msg)
165
+ jax.clear_caches()
166
+ gc.collect()
167
+ print()
168
+
169
+ if not rows:
170
+ print(f"[{mol}] nothing to write (no checkpoints found)\n")
171
+ return
172
+
173
+ df = (pd.DataFrame(rows)
174
+ .sort_values(["method", "R_bohr"])
175
+ .reset_index(drop=True))
176
+ out_path = (Path(args.out).resolve() if args.out
177
+ else mol_dir / f"quadrature_{mol}.csv")
178
+ out_path.parent.mkdir(parents=True, exist_ok=True)
179
+ df.to_csv(out_path, index=False, float_format="%.8f")
180
+
181
+ print("=" * 96)
182
+ print(df.to_string(index=False))
183
+ print("=" * 96)
184
+ print(f"[{mol}] saved → {out_path}\n")
185
+
186
+
187
+ for mol in selected:
188
+ scan_molecule(mol)
off/scan_pes.py ADDED
@@ -0,0 +1,133 @@
1
+ """
2
+ Potential Energy Surface scan over a set of bond-length result directories.
3
+
4
+ Thin CLI around ``of_flows/quadrature.py``: grid-integrates every bl_* directory
5
+ under --scan_dir (via ``grid_energy_from_checkpoint``) and, optionally, an atom
6
+ reference for the binding energy, then writes pes.csv and a plot.
7
+
8
+ Usage
9
+ -----
10
+ python scan_pes.py \
11
+ --scan_dir Results/H2/<method> \
12
+ --atom_dir Results/H/<method>/bl_0.0000
13
+
14
+ Outputs (written inside --scan_dir):
15
+ pes.csv — R, E_total, T, V_N, V_H, E_X, E_NN, E_bind, D_e (Ha / eV)
16
+ pes.png / .svg — PES curve (E_total and D_e vs R)
17
+ """
18
+
19
+ import sys, os
20
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
21
+
22
+ import argparse
23
+ from pathlib import Path
24
+
25
+ import matplotlib
26
+ matplotlib.use("Agg")
27
+ import matplotlib.pyplot as plt
28
+ import pandas as pd
29
+
30
+ from quadrature import grid_energy_from_checkpoint
31
+
32
+ # ── CLI ───────────────────────────────────────────────────────────────────────
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("--scan_dir", type=str, required=True,
35
+ help="Method directory containing bl_X.XXXX subdirectories")
36
+ parser.add_argument("--atom_dir", type=str, default=None,
37
+ help="bl_0.0000 directory for the atom (binding-energy reference)")
38
+ parser.add_argument("--bs", type=int, default=256, help="Grid chunk size")
39
+ parser.add_argument("--grid_level",type=int, default=3, help="PySCF grid level")
40
+ parser.add_argument("--recompute", action="store_true",
41
+ help="Re-run integration even if energy_summary.json already exists")
42
+ args = parser.parse_args()
43
+
44
+
45
+ def analyse(results_dir):
46
+ return grid_energy_from_checkpoint(
47
+ Path(results_dir).resolve(), grid_level=args.grid_level,
48
+ chunk=args.bs, recompute=args.recompute)
49
+
50
+
51
+ # ── Scan over all bl_* directories ───────────────────────────────────────────
52
+ scan_dir = Path(args.scan_dir).resolve()
53
+ bl_dirs = sorted(scan_dir.glob("bl_*"),
54
+ key=lambda d: float(d.name.split("_")[1]))
55
+ if not bl_dirs:
56
+ raise FileNotFoundError(f"No bl_* directories found in {scan_dir}")
57
+
58
+ print(f"\nFound {len(bl_dirs)} bond-length directories in:\n {scan_dir}\n")
59
+
60
+ # Optional atom reference (binding uses 2*E_atom — homonuclear diatomic)
61
+ E_atom = None
62
+ if args.atom_dir is not None:
63
+ print("atom reference:")
64
+ E_atom = analyse(args.atom_dir)['E_total']
65
+ print(f" E(atom) = {E_atom:+.6f} Ha\n")
66
+
67
+ # Main scan
68
+ rows = []
69
+ for bl_dir in bl_dirs:
70
+ if not (bl_dir / "job_params.json").exists():
71
+ print(f" {bl_dir.name}: missing job_params.json — skipping")
72
+ continue
73
+ data = analyse(bl_dir)
74
+ R = data['bond_length']
75
+ row = {'R_bohr': R,
76
+ 'epoch': data.get('epoch', '?'),
77
+ 'E_total': data['E_total'],
78
+ 'T': data['T'],
79
+ 'V_N': data['V_N'],
80
+ 'V_H': data['V_H'],
81
+ 'E_X': data['E_X'],
82
+ 'E_C': data.get('E_C', 0.0),
83
+ 'E_NN': data['E_NN'],
84
+ 'Ne_int': data['Ne_integral']}
85
+ if E_atom is not None:
86
+ E_bind = data['E_total'] - 2.0 * E_atom
87
+ row['E_bind_Ha'] = E_bind
88
+ row['D_e_eV'] = -E_bind * 27.2114
89
+ rows.append(row)
90
+ tag = f" R={R:.4f} Bohr epoch={row['epoch']:>6} E={data['E_total']:+.6f} Ha"
91
+ if E_atom is not None:
92
+ tag += f" E_bind={row['E_bind_Ha']:+.6f} Ha"
93
+ print(tag)
94
+
95
+ # ── Save CSV ─────────────────────────────────────────────────────────────────
96
+ df = pd.DataFrame(rows).sort_values('R_bohr').reset_index(drop=True)
97
+ csv_path = scan_dir / "pes.csv"
98
+ df.to_csv(csv_path, index=False, float_format='%.8f')
99
+ print(f"\nPES data saved → {csv_path}")
100
+ print(df.to_string(index=False))
101
+
102
+ # ── Plot ──────────────────────────────────────────────────────────────────────
103
+ fig, axes = plt.subplots(1, 2 if E_atom is not None else 1,
104
+ figsize=(11 if E_atom is not None else 5, 4))
105
+ if E_atom is None:
106
+ axes = [axes]
107
+
108
+ R_vals = df['R_bohr'].values
109
+ max_epoch = df['epoch'].max()
110
+ complete = df['epoch'] == max_epoch # True if run finished
111
+
112
+ for ax, y_col, ylabel, title, color in [
113
+ (axes[0], 'E_total', 'Energy [Ha]', 'Potential Energy Surface', 'tab:blue'),
114
+ *( [(axes[1], 'D_e_eV', 'D_e [eV]', 'Dissociation Energy', 'tab:orange')]
115
+ if E_atom is not None else [] ),
116
+ ]:
117
+ y = df[y_col].values
118
+ ax.plot(R_vals[complete], y[complete], 'o-', color=color, label=f'epoch={max_epoch}')
119
+ ax.plot(R_vals[~complete], y[~complete], '^', color=color, alpha=0.5,
120
+ label='incomplete', markerfacecolor='none')
121
+ if y_col == 'D_e_eV':
122
+ ax.axhline(0, color='k', linewidth=0.8, linestyle='--')
123
+ ax.set_xlabel("R [Bohr]")
124
+ ax.set_ylabel(ylabel)
125
+ ax.set_title(title)
126
+ ax.legend(fontsize=8)
127
+ ax.grid(True, alpha=0.3)
128
+
129
+ fig.suptitle(scan_dir.parent.name.split("/")[-1], fontsize=9)
130
+ fig.tight_layout()
131
+ fig.savefig(scan_dir / "pes.svg", transparent=True)
132
+ fig.savefig(scan_dir / "pes.png", dpi=150)
133
+ print(f"PES plot saved → {scan_dir}/pes.png")