off 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- off/__init__.py +23 -0
- off/atom_energies.py +151 -0
- off/config/_config.py +108 -0
- off/dft_distrax/__init__.py +27 -0
- off/dft_distrax/dft_distrax.py +216 -0
- off/flow/__init__.py +29 -0
- off/flow/equiv_flows.py +99 -0
- off/functionals/__init__.py +35 -0
- off/functionals/core_correction.py +84 -0
- off/functionals/exchange_correlation.py +174 -0
- off/functionals/external.py +49 -0
- off/functionals/functional.py +129 -0
- off/functionals/hartree.py +62 -0
- off/functionals/kinetic.py +87 -0
- off/main.py +172 -0
- off/ode_solver/__init__.py +32 -0
- off/ode_solver/eqx_ode.py +76 -0
- off/plot_binding_csv.py +63 -0
- off/plot_pes_ema.py +259 -0
- off/plot_pes_mpl.py +280 -0
- off/promolecular/__init__.py +27 -0
- off/promolecular/promolecular_dist.py +465 -0
- off/quadrature.py +261 -0
- off/quadrature_scan.py +188 -0
- off/scan_pes.py +133 -0
- off/test_fwd_rev.py +290 -0
- off/train/__init__.py +44 -0
- off/train/loop.py +228 -0
- off/train/loss.py +149 -0
- off/train/utils.py +38 -0
- off/utils.py +618 -0
- off-0.1.0.dist-info/METADATA +154 -0
- off-0.1.0.dist-info/RECORD +37 -0
- off-0.1.0.dist-info/WHEEL +5 -0
- off-0.1.0.dist-info/entry_points.txt +3 -0
- off-0.1.0.dist-info/licenses/LICENSE +21 -0
- off-0.1.0.dist-info/top_level.txt +1 -0
off/quadrature.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
import glob
|
|
2
|
+
import json
|
|
3
|
+
import re
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import jax.random as jrnd
|
|
10
|
+
import equinox as eqx
|
|
11
|
+
from pyscf import gto, dft
|
|
12
|
+
|
|
13
|
+
jax.config.update("jax_enable_x64", True)
|
|
14
|
+
|
|
15
|
+
from .flow.equiv_flows import CNF
|
|
16
|
+
from .ode_solver.eqx_ode import fwd_ode, rev_ode
|
|
17
|
+
from .utils import one_hot_encode, coordinates, get_solver
|
|
18
|
+
from .promolecular.promolecular_dist import ProMolecularDensity, AtomDBDistribution
|
|
19
|
+
from .train.loss import build_energy_functional
|
|
20
|
+
|
|
21
|
+
AA_TO_BOHR = 1.8897259886
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# ── model / prior loading ─────────────────────────────────────────────────────
|
|
25
|
+
def last_checkpoint(results_dir):
|
|
26
|
+
"""(path, epoch) of the highest-epoch checkpoint in results_dir/Checkpoints/."""
|
|
27
|
+
ckpts = glob.glob(str(Path(results_dir) / "Checkpoints" / "checkpoint_*.eqx"))
|
|
28
|
+
if not ckpts:
|
|
29
|
+
raise FileNotFoundError(f"No checkpoints in {results_dir}/Checkpoints/")
|
|
30
|
+
ckpts.sort(key=lambda p: int(re.search(r'checkpoint_(\d+)\.eqx', p).group(1)))
|
|
31
|
+
last = ckpts[-1]
|
|
32
|
+
return last, int(re.search(r'checkpoint_(\d+)\.eqx', last).group(1))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def load_model(results_dir, p):
|
|
36
|
+
"""Rebuild the CNF for job_params `p` and load its last checkpoint."""
|
|
37
|
+
Ne, atoms, z, coords = coordinates(p['mol_name'], p['bond_length'])
|
|
38
|
+
model = CNF(din=3, dim=p['hidden_layer'], mu=coords,
|
|
39
|
+
one_hot=one_hot_encode(z), key=jrnd.PRNGKey(0))
|
|
40
|
+
ckpt, epoch = last_checkpoint(results_dir)
|
|
41
|
+
model = eqx.tree_deserialise_leaves(ckpt, model)
|
|
42
|
+
return model, get_solver(p['solver']), Ne, atoms, z, coords, epoch
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def build_prior(p, z, coords, Ne):
|
|
46
|
+
"""Rebuild the base distribution used at training time (must match it)."""
|
|
47
|
+
if p.get('prior') == 'db_sir':
|
|
48
|
+
from atomdb import make_promolecule
|
|
49
|
+
db_prior = make_promolecule(atnums=z, coords=coords, dataset="hci")
|
|
50
|
+
return AtomDBDistribution(db_prior=db_prior, z=z, coords=coords, Ne=Ne)
|
|
51
|
+
return ProMolecularDensity(z.ravel(), coords)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# ── grid construction (PySCF) ─────────────────────────────────────────────────
|
|
55
|
+
def _grids_from_mol(mol, level):
|
|
56
|
+
"""Build a PySCF Becke grid for `mol`; return (coords, weights) in Bohr."""
|
|
57
|
+
grid = dft.gen_grid.Grids(mol)
|
|
58
|
+
grid.level = level
|
|
59
|
+
grid.build()
|
|
60
|
+
return (jnp.asarray(grid.coords, dtype=jnp.float64),
|
|
61
|
+
jnp.asarray(grid.weights, dtype=jnp.float64))
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def build_grid(atoms, coords, Ne, grid_level=3, basis="6-31G(d,p)", unit="B"):
|
|
65
|
+
"""PySCF molecular quadrature grid. `coords` are interpreted in `unit`
|
|
66
|
+
('B'/'Bohr' or 'Angstrom'); the returned grid coords/weights are in Bohr."""
|
|
67
|
+
atom_str = "; ".join(f"{a} {c[0]:.10f} {c[1]:.10f} {c[2]:.10f}"
|
|
68
|
+
for a, c in zip(atoms, np.asarray(coords)))
|
|
69
|
+
mol = gto.M(atom=atom_str, basis=basis, unit=unit, verbose=0, spin=int(Ne) % 2)
|
|
70
|
+
return _grids_from_mol(mol, grid_level)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def get_grid(geom, level=3, *, units="angstrom", basis="6-31G(d,p)", spin=0):
|
|
74
|
+
"""User-facing quadrature grid.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
geom : str
|
|
79
|
+
Geometry in PySCF's ``atom=`` format — e.g. ``"H 0 0 0; H 0 0 0.74"``
|
|
80
|
+
or a multi-line XYZ-style block.
|
|
81
|
+
level : int
|
|
82
|
+
PySCF grid level (the "grid size");
|
|
83
|
+
units : {'angstrom', 'bohr'}
|
|
84
|
+
Units the geometry is given in (PySCF's default is angstrom).
|
|
85
|
+
basis, spin :
|
|
86
|
+
Forwarded to ``pyscf.gto.M``. The basis only sets the atom-centred
|
|
87
|
+
grid partitioning; it does not affect the flow density.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
(weights, coords) :
|
|
92
|
+
Note the order, to match the listing ``w_grid, x_grid = get_grid(...)``;
|
|
93
|
+
the internal :func:`build_grid` returns the opposite ``(coords, weights)``.
|
|
94
|
+
"""
|
|
95
|
+
unit = "Bohr" if str(units).lower().startswith("b") else "Angstrom"
|
|
96
|
+
mol = gto.M(atom=geom, basis=basis, unit=unit, spin=spin, verbose=0)
|
|
97
|
+
coords, weights = _grids_from_mol(mol, level)
|
|
98
|
+
return weights, coords
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
getGrid = get_grid
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def rho_on_grid(model, solver, prior, grid_coords, chunk=256):
|
|
105
|
+
"""Evaluate (positions, ρ_φ, score = ∇log ρ_φ) at the grid points."""
|
|
106
|
+
x_l, rho_l, sc_l = [], [], []
|
|
107
|
+
for i in range(0, grid_coords.shape[0], chunk):
|
|
108
|
+
xc = grid_coords[i:i+chunk]; n = xc.shape[0]
|
|
109
|
+
st_r = jnp.concatenate([xc, jnp.zeros((n, 1)), jnp.zeros((n, 3))], axis=1)
|
|
110
|
+
zb, _ = rev_ode(model, st_r, solver)
|
|
111
|
+
lp0 = prior.log_prob(zb)
|
|
112
|
+
sc0 = prior.score(zb)
|
|
113
|
+
xt, lpt, sct = fwd_ode(model, jnp.concatenate([zb, lp0, sc0], axis=1), solver)
|
|
114
|
+
x_l.append(np.array(xt))
|
|
115
|
+
rho_l.append(np.array(jnp.exp(lpt)).ravel())
|
|
116
|
+
sc_l.append(np.array(sct))
|
|
117
|
+
return np.concatenate(x_l), np.concatenate(rho_l), np.concatenate(sc_l)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def quadrature_energy(functional, x_np, rho_np, sc_np, grid_coords, grid_weights,
|
|
121
|
+
mol_dict, Ne, chunk=256):
|
|
122
|
+
"""Integrate every energy term on the grid.
|
|
123
|
+
|
|
124
|
+
The local terms use ``functional``'s component functionals; the Hartree term
|
|
125
|
+
is the grid double sum (true 1/r), not the functional's MC pairwise estimator.
|
|
126
|
+
"""
|
|
127
|
+
w = np.array(grid_weights)
|
|
128
|
+
G = rho_np.shape[0]
|
|
129
|
+
rc = rho_np[:, None]
|
|
130
|
+
measure = jnp.asarray(w * rho_np)
|
|
131
|
+
|
|
132
|
+
def _args(sl):
|
|
133
|
+
return (jnp.array(rc[sl]), jnp.array(sc_np[sl]), jnp.array(x_np[sl]),
|
|
134
|
+
Ne, mol_dict, None)
|
|
135
|
+
|
|
136
|
+
def local(func): # ∫ f(...)·ρ dr via the shared _integrate
|
|
137
|
+
out = np.zeros(G)
|
|
138
|
+
for i in range(0, G, chunk):
|
|
139
|
+
sl = slice(i, min(i + chunk, G))
|
|
140
|
+
out[sl] = np.array(func(*_args(sl))).ravel()
|
|
141
|
+
return float(functional._integrate(jnp.asarray(out), measure))
|
|
142
|
+
|
|
143
|
+
T = local(functional.kinetic)
|
|
144
|
+
E_X = local(functional.exchange)
|
|
145
|
+
E_C = local(functional.correlation) if functional.correlation is not None else 0.0
|
|
146
|
+
V_N = local(functional.external)
|
|
147
|
+
E_CC = local(functional.core_correction) if functional.core_correction is not None else 0.0
|
|
148
|
+
|
|
149
|
+
# Hartree — grid double sum
|
|
150
|
+
gc = np.array(grid_coords)
|
|
151
|
+
vc = np.zeros(G)
|
|
152
|
+
for i in range(0, G, chunk):
|
|
153
|
+
xi = gc[i:i+chunk]
|
|
154
|
+
r2 = np.sum((gc[None, :, :] - xi[:, None, :]) ** 2, axis=-1)
|
|
155
|
+
vc[i:i+chunk] = np.dot(1. / np.sqrt(np.where(r2 == 0., np.inf, r2)), w * rho_np)
|
|
156
|
+
V_H = float(0.5 * Ne ** 2 * functional._integrate(jnp.asarray(vc), measure))
|
|
157
|
+
|
|
158
|
+
# Nuclear repulsion
|
|
159
|
+
cn = np.array(mol_dict['coords']); zn = np.array(mol_dict['z']).ravel()
|
|
160
|
+
E_NN = sum(zn[I] * zn[J] / float(np.linalg.norm(cn[I] - cn[J]))
|
|
161
|
+
for I in range(len(cn)) for J in range(I + 1, len(cn)))
|
|
162
|
+
|
|
163
|
+
return dict(T=T, V_N=V_N, V_H=V_H, E_X=E_X, E_C=E_C, E_CC=E_CC, E_NN=E_NN,
|
|
164
|
+
E_total=T + V_N + V_H + E_X + E_C + E_CC + E_NN)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# ── high-level entry points ───────────────────────────────────────────────────
|
|
168
|
+
def grid_energy(model, prior, solver, coords, z, atoms, Ne, functional, *,
|
|
169
|
+
grid_level=3, units="bohr", basis="6-31G(d,p)", chunk=256):
|
|
170
|
+
"""Build the grid, evaluate ρ_φ, and integrate all energy terms.
|
|
171
|
+
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
model, prior, solver : the trained CNF, its base distribution, ODE solver.
|
|
175
|
+
coords, z, atoms, Ne : molecular geometry / charges / electron count.
|
|
176
|
+
functional : an EnergyFunctional (e.g. from build_energy_functional).
|
|
177
|
+
grid_level : PySCF grid level (the "grid size").
|
|
178
|
+
units : 'bohr' or 'angstrom' — how `coords` are given; the flow
|
|
179
|
+
works in Bohr, so 'angstrom' inputs are converted.
|
|
180
|
+
Returns the energy dict plus 'Ne_integral'.
|
|
181
|
+
"""
|
|
182
|
+
coords = np.asarray(coords, dtype=float)
|
|
183
|
+
unit = "Bohr" if str(units).lower().startswith("b") else "Angstrom"
|
|
184
|
+
coords_bohr = coords if unit == "Bohr" else coords * AA_TO_BOHR
|
|
185
|
+
gc, gw = build_grid(atoms, coords, Ne, grid_level=grid_level, basis=basis, unit=unit)
|
|
186
|
+
x_np, rho_np, sc_np = rho_on_grid(model, solver, prior, gc, chunk=chunk)
|
|
187
|
+
mol_dict = {'coords': jnp.asarray(coords_bohr, dtype=jnp.float64), 'z': jnp.asarray(z)}
|
|
188
|
+
en = quadrature_energy(functional, x_np, rho_np, sc_np, gc, gw, mol_dict, Ne, chunk=chunk)
|
|
189
|
+
en['Ne_integral'] = float(np.dot(np.array(gw), Ne * rho_np))
|
|
190
|
+
return en
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def grid_energy_from_checkpoint(results_dir, *, grid_level=3, basis="6-31G(d,p)",
|
|
194
|
+
units="bohr", chunk=256, recompute=False, cache=True):
|
|
195
|
+
"""One call from a trained run directory: read job_params.json (functional +
|
|
196
|
+
geometry), load the last checkpoint, and integrate. The geometry comes from
|
|
197
|
+
``coordinates`` (always Bohr), so no units flag is needed here.
|
|
198
|
+
|
|
199
|
+
The result is cached in ``results_dir/energy_summary.json`` — pass
|
|
200
|
+
``recompute=True`` to ignore an existing cache, or ``cache=False`` to skip
|
|
201
|
+
reading/writing it.
|
|
202
|
+
"""
|
|
203
|
+
results_dir = Path(results_dir)
|
|
204
|
+
summary_path = results_dir / "energy_summary.json"
|
|
205
|
+
if cache and not recompute and summary_path.exists():
|
|
206
|
+
with open(summary_path) as f:
|
|
207
|
+
return json.load(f)
|
|
208
|
+
|
|
209
|
+
with open(results_dir / "job_params.json") as f:
|
|
210
|
+
p = json.load(f)
|
|
211
|
+
model, solver, Ne, atoms, z, coords, epoch = load_model(results_dir, p)
|
|
212
|
+
prior = build_prior(p, z, coords, Ne)
|
|
213
|
+
functional = build_energy_functional(
|
|
214
|
+
kinetic_name=p['kinetic'], lam=p['lam'], exchange_name=p['exchange'],
|
|
215
|
+
correlation_name=p['correlation'], hartree_name=p['hartree'],
|
|
216
|
+
external_name=p['external'], core_correction_name=p['core_correction'],
|
|
217
|
+
)
|
|
218
|
+
en = grid_energy(model, prior, solver, coords, z, atoms, Ne, functional,
|
|
219
|
+
grid_level=grid_level, units=units, basis=basis, chunk=chunk)
|
|
220
|
+
en.update(epoch=epoch, mol_name=p['mol_name'], bond_length=p['bond_length'])
|
|
221
|
+
if cache:
|
|
222
|
+
with open(summary_path, "w") as f:
|
|
223
|
+
json.dump(en, f, indent=4)
|
|
224
|
+
return en
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _print_energy(results_dir, en):
|
|
228
|
+
print(f"\n{results_dir}")
|
|
229
|
+
print(f" mol={en.get('mol_name')} R={en.get('bond_length')} epoch={en.get('epoch')}")
|
|
230
|
+
print(" " + "-" * 36)
|
|
231
|
+
for k in ("T", "V_N", "V_H", "E_X", "E_C", "E_CC", "E_NN"):
|
|
232
|
+
print(f" {k:8s} = {en[k]:+.6f} Ha")
|
|
233
|
+
print(" " + "-" * 36)
|
|
234
|
+
print(f" {'E_total':8s} = {en['E_total']:+.6f} Ha")
|
|
235
|
+
print(f" {'N_e':8s} = {en['Ne_integral']:.4f} (∫ρ, should be Ne)")
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def main():
|
|
239
|
+
import argparse
|
|
240
|
+
ap = argparse.ArgumentParser(
|
|
241
|
+
description="Grid (quadrature) energy of a trained OFF run directory.")
|
|
242
|
+
ap.add_argument("results_dir", nargs="+",
|
|
243
|
+
help="bl_* run dir(s) with job_params.json and Checkpoints/ "
|
|
244
|
+
"(shell globs like Results/H2/<method>/bl_* are fine)")
|
|
245
|
+
ap.add_argument("--grid_level", type=int, default=1, help="PySCF grid level")
|
|
246
|
+
ap.add_argument("--bs", type=int, default=256, help="grid chunk size")
|
|
247
|
+
ap.add_argument("--basis", type=str, default="6-31G(d,p)",
|
|
248
|
+
help="PySCF basis (sets grid partitioning only)")
|
|
249
|
+
ap.add_argument("--recompute", action="store_true",
|
|
250
|
+
help="ignore cached energy_summary.json and recompute")
|
|
251
|
+
args = ap.parse_args()
|
|
252
|
+
|
|
253
|
+
for rd in args.results_dir:
|
|
254
|
+
en = grid_energy_from_checkpoint(
|
|
255
|
+
rd, grid_level=args.grid_level, basis=args.basis,
|
|
256
|
+
chunk=args.bs, recompute=args.recompute)
|
|
257
|
+
_print_energy(rd, en)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
if __name__ == "__main__":
|
|
261
|
+
main()
|
off/quadrature_scan.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Grid-quadrature total energy for every bond length of a molecule.
|
|
3
|
+
|
|
4
|
+
Thin CLI around ``of_flows/quadrature.py``: it walks every method directory and
|
|
5
|
+
every bl_* subdirectory under Results/{mol}/, calls
|
|
6
|
+
``quadrature.grid_energy_from_checkpoint`` on each (which builds the PySCF grid,
|
|
7
|
+
evaluates ρ_φ via the flow, and integrates all energy terms), and writes one CSV
|
|
8
|
+
per molecule. For molecules it also integrates the constituent single atoms
|
|
9
|
+
under the same method tag and reports the binding energy.
|
|
10
|
+
|
|
11
|
+
Directory layout assumed (same as main.py):
|
|
12
|
+
Results/{mol}/{method}/bl_X.XXXX/
|
|
13
|
+
Checkpoints/checkpoint_*.eqx
|
|
14
|
+
job_params.json
|
|
15
|
+
Results/{atom}/{method}/bl_0.0000/ (binding reference)
|
|
16
|
+
Results/ is located next to this script, so it runs from anywhere.
|
|
17
|
+
|
|
18
|
+
Usage
|
|
19
|
+
-----
|
|
20
|
+
python quadrature_scan.py --H2
|
|
21
|
+
python quadrature_scan.py --H2 --N2
|
|
22
|
+
python quadrature_scan.py --mol H2 H10
|
|
23
|
+
python quadrature_scan.py --H10 --recompute
|
|
24
|
+
|
|
25
|
+
Output (one CSV per molecule, under Results/{mol}/):
|
|
26
|
+
Results/{mol}/quadrature_{mol}.csv
|
|
27
|
+
columns: method, R_bohr, epoch, E_total, T, V_N, V_H, E_X, E_C, E_CC,
|
|
28
|
+
E_NN, Ne_int, E_atoms, dE_bind_Ha
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
import sys, os
|
|
32
|
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
33
|
+
|
|
34
|
+
import gc
|
|
35
|
+
import re
|
|
36
|
+
import argparse
|
|
37
|
+
from pathlib import Path
|
|
38
|
+
|
|
39
|
+
import jax
|
|
40
|
+
import pandas as pd
|
|
41
|
+
|
|
42
|
+
from quadrature import grid_energy_from_checkpoint
|
|
43
|
+
|
|
44
|
+
_SCRIPT_DIR = Path(__file__).resolve().parent
|
|
45
|
+
|
|
46
|
+
KNOWN_MOLS = ["H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne",
|
|
47
|
+
"H2", "N2", "O2", "F2", "HF", "CO", "LiH", "H10"]
|
|
48
|
+
SINGLE_ATOMS = {"H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne"}
|
|
49
|
+
|
|
50
|
+
# ── CLI ───────────────────────────────────────────────────────────────────────
|
|
51
|
+
parser = argparse.ArgumentParser(
|
|
52
|
+
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
53
|
+
allow_abbrev=False) # so --H is not treated as a prefix of --H2 / --H10
|
|
54
|
+
for _m in KNOWN_MOLS:
|
|
55
|
+
parser.add_argument(f"--{_m}", action="store_true", help=f"Scan molecule {_m}")
|
|
56
|
+
parser.add_argument("--mol", type=str, nargs="+", default=[], metavar="NAME",
|
|
57
|
+
help="Molecule name(s) to scan (alternative to the flags)")
|
|
58
|
+
parser.add_argument("--results_root", type=str, default=None,
|
|
59
|
+
help="Override Results root (default: <script_dir>/Results)")
|
|
60
|
+
parser.add_argument("--bs", type=int, default=256, help="Grid chunk size")
|
|
61
|
+
parser.add_argument("--grid_level", type=int, default=3, help="PySCF grid level")
|
|
62
|
+
parser.add_argument("--recompute", action="store_true",
|
|
63
|
+
help="Re-run grid integration even if energy_summary.json is cached")
|
|
64
|
+
parser.add_argument("--out", type=str, default=None,
|
|
65
|
+
help="Output CSV path (default: Results/{mol}/quadrature_{mol}.csv)")
|
|
66
|
+
args = parser.parse_args()
|
|
67
|
+
|
|
68
|
+
selected = list(args.mol) + [m for m in KNOWN_MOLS if getattr(args, m)]
|
|
69
|
+
selected = list(dict.fromkeys(selected))
|
|
70
|
+
if not selected:
|
|
71
|
+
parser.error("No molecule selected. Use a flag (e.g. --H2) or --mol H2 [...].")
|
|
72
|
+
|
|
73
|
+
root = Path(args.results_root).resolve() if args.results_root else (_SCRIPT_DIR / "Results")
|
|
74
|
+
print(f"Results root : {root}\n")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def constituents(mol: str) -> dict:
|
|
78
|
+
"""{element: count} from a formula, e.g. N2->{N:2}, HF->{H:1,F:1}, H10->{H:10}."""
|
|
79
|
+
out = {}
|
|
80
|
+
for el, n in re.findall(r"([A-Z][a-z]?)(\d*)", mol):
|
|
81
|
+
if el:
|
|
82
|
+
out[el] = out.get(el, 0) + (int(n) if n else 1)
|
|
83
|
+
return out
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def atom_reference(method_name: str, mol: str):
|
|
87
|
+
"""Grid energy reference Σ_atoms count·E(atom) under the same method tag.
|
|
88
|
+
Returns (E_atoms, {element: E_atom}) or (None, None) if any atom is missing."""
|
|
89
|
+
total = 0.0
|
|
90
|
+
detail = {}
|
|
91
|
+
for el, n in constituents(mol).items():
|
|
92
|
+
adir = root / el / method_name / "bl_0.0000"
|
|
93
|
+
if not (adir / "job_params.json").exists():
|
|
94
|
+
print(f" atom reference: {el} not found at {adir} — binding skipped")
|
|
95
|
+
return None, None
|
|
96
|
+
try:
|
|
97
|
+
data = grid_energy_from_checkpoint(
|
|
98
|
+
adir, grid_level=args.grid_level, chunk=args.bs, recompute=args.recompute)
|
|
99
|
+
except Exception as e:
|
|
100
|
+
print(f" atom reference: {el} FAILED — {e}")
|
|
101
|
+
return None, None
|
|
102
|
+
detail[el] = data['E_total']
|
|
103
|
+
total += n * data['E_total']
|
|
104
|
+
return total, detail
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def scan_molecule(mol: str):
|
|
108
|
+
mol_dir = root / mol
|
|
109
|
+
if not mol_dir.is_dir():
|
|
110
|
+
print(f"[{mol}] SKIP — {mol_dir} not found\n")
|
|
111
|
+
return
|
|
112
|
+
|
|
113
|
+
is_atom = mol in SINGLE_ATOMS
|
|
114
|
+
rows = []
|
|
115
|
+
for method_dir in sorted(d for d in mol_dir.iterdir() if d.is_dir()):
|
|
116
|
+
bl_dirs = sorted(method_dir.glob("bl_*"),
|
|
117
|
+
key=lambda d: float(d.name.split("_")[1]))
|
|
118
|
+
if not bl_dirs:
|
|
119
|
+
continue
|
|
120
|
+
print(f"[{mol}] method: {method_dir.name} ({len(bl_dirs)} bond lengths)")
|
|
121
|
+
|
|
122
|
+
# Single-atom reference for the binding energy (same method tag).
|
|
123
|
+
E_atoms = None
|
|
124
|
+
if not is_atom:
|
|
125
|
+
E_atoms, detail = atom_reference(method_dir.name, mol)
|
|
126
|
+
if E_atoms is not None:
|
|
127
|
+
ref = " ".join(f"{n}*E({el})={detail[el]:+.6f}"
|
|
128
|
+
for el, n in constituents(mol).items())
|
|
129
|
+
print(f" atom reference (grid): {ref} -> Σ = {E_atoms:+.6f} Ha")
|
|
130
|
+
|
|
131
|
+
for bl_dir in bl_dirs:
|
|
132
|
+
if not (bl_dir / "job_params.json").exists():
|
|
133
|
+
print(f" {bl_dir.name}: missing job_params.json — skipping")
|
|
134
|
+
continue
|
|
135
|
+
try:
|
|
136
|
+
data = grid_energy_from_checkpoint(
|
|
137
|
+
bl_dir, grid_level=args.grid_level, chunk=args.bs,
|
|
138
|
+
recompute=args.recompute)
|
|
139
|
+
except Exception as e:
|
|
140
|
+
print(f" {bl_dir.name}: FAILED — {e}")
|
|
141
|
+
continue
|
|
142
|
+
row = {
|
|
143
|
+
"method": method_dir.name,
|
|
144
|
+
"R_bohr": data['bond_length'],
|
|
145
|
+
"epoch": data.get('epoch', '?'),
|
|
146
|
+
"E_total": data['E_total'],
|
|
147
|
+
"T": data['T'],
|
|
148
|
+
"V_N": data['V_N'],
|
|
149
|
+
"V_H": data['V_H'],
|
|
150
|
+
"E_X": data['E_X'],
|
|
151
|
+
"E_C": data.get('E_C', 0.0),
|
|
152
|
+
"E_CC": data.get('E_CC', 0.0),
|
|
153
|
+
"E_NN": data['E_NN'],
|
|
154
|
+
"Ne_int": data['Ne_integral'],
|
|
155
|
+
}
|
|
156
|
+
if E_atoms is not None:
|
|
157
|
+
row["E_atoms"] = E_atoms
|
|
158
|
+
row["dE_bind_Ha"] = E_atoms - data['E_total'] # ΔE = ΣE(atom) - E(mol)
|
|
159
|
+
rows.append(row)
|
|
160
|
+
msg = (f" R={data['bond_length']:.4f} Bohr epoch={data.get('epoch','?'):>6}"
|
|
161
|
+
f" E_total={data['E_total']:+.6f} Ha")
|
|
162
|
+
if E_atoms is not None:
|
|
163
|
+
msg += f" ΔE={E_atoms - data['E_total']:+.6f} Ha"
|
|
164
|
+
print(msg)
|
|
165
|
+
jax.clear_caches()
|
|
166
|
+
gc.collect()
|
|
167
|
+
print()
|
|
168
|
+
|
|
169
|
+
if not rows:
|
|
170
|
+
print(f"[{mol}] nothing to write (no checkpoints found)\n")
|
|
171
|
+
return
|
|
172
|
+
|
|
173
|
+
df = (pd.DataFrame(rows)
|
|
174
|
+
.sort_values(["method", "R_bohr"])
|
|
175
|
+
.reset_index(drop=True))
|
|
176
|
+
out_path = (Path(args.out).resolve() if args.out
|
|
177
|
+
else mol_dir / f"quadrature_{mol}.csv")
|
|
178
|
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
179
|
+
df.to_csv(out_path, index=False, float_format="%.8f")
|
|
180
|
+
|
|
181
|
+
print("=" * 96)
|
|
182
|
+
print(df.to_string(index=False))
|
|
183
|
+
print("=" * 96)
|
|
184
|
+
print(f"[{mol}] saved → {out_path}\n")
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
for mol in selected:
|
|
188
|
+
scan_molecule(mol)
|
off/scan_pes.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Potential Energy Surface scan over a set of bond-length result directories.
|
|
3
|
+
|
|
4
|
+
Thin CLI around ``of_flows/quadrature.py``: grid-integrates every bl_* directory
|
|
5
|
+
under --scan_dir (via ``grid_energy_from_checkpoint``) and, optionally, an atom
|
|
6
|
+
reference for the binding energy, then writes pes.csv and a plot.
|
|
7
|
+
|
|
8
|
+
Usage
|
|
9
|
+
-----
|
|
10
|
+
python scan_pes.py \
|
|
11
|
+
--scan_dir Results/H2/<method> \
|
|
12
|
+
--atom_dir Results/H/<method>/bl_0.0000
|
|
13
|
+
|
|
14
|
+
Outputs (written inside --scan_dir):
|
|
15
|
+
pes.csv — R, E_total, T, V_N, V_H, E_X, E_NN, E_bind, D_e (Ha / eV)
|
|
16
|
+
pes.png / .svg — PES curve (E_total and D_e vs R)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import sys, os
|
|
20
|
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
21
|
+
|
|
22
|
+
import argparse
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
|
|
25
|
+
import matplotlib
|
|
26
|
+
matplotlib.use("Agg")
|
|
27
|
+
import matplotlib.pyplot as plt
|
|
28
|
+
import pandas as pd
|
|
29
|
+
|
|
30
|
+
from quadrature import grid_energy_from_checkpoint
|
|
31
|
+
|
|
32
|
+
# ── CLI ───────────────────────────────────────────────────────────────────────
|
|
33
|
+
parser = argparse.ArgumentParser()
|
|
34
|
+
parser.add_argument("--scan_dir", type=str, required=True,
|
|
35
|
+
help="Method directory containing bl_X.XXXX subdirectories")
|
|
36
|
+
parser.add_argument("--atom_dir", type=str, default=None,
|
|
37
|
+
help="bl_0.0000 directory for the atom (binding-energy reference)")
|
|
38
|
+
parser.add_argument("--bs", type=int, default=256, help="Grid chunk size")
|
|
39
|
+
parser.add_argument("--grid_level",type=int, default=3, help="PySCF grid level")
|
|
40
|
+
parser.add_argument("--recompute", action="store_true",
|
|
41
|
+
help="Re-run integration even if energy_summary.json already exists")
|
|
42
|
+
args = parser.parse_args()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def analyse(results_dir):
|
|
46
|
+
return grid_energy_from_checkpoint(
|
|
47
|
+
Path(results_dir).resolve(), grid_level=args.grid_level,
|
|
48
|
+
chunk=args.bs, recompute=args.recompute)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# ── Scan over all bl_* directories ───────────────────────────────────────────
|
|
52
|
+
scan_dir = Path(args.scan_dir).resolve()
|
|
53
|
+
bl_dirs = sorted(scan_dir.glob("bl_*"),
|
|
54
|
+
key=lambda d: float(d.name.split("_")[1]))
|
|
55
|
+
if not bl_dirs:
|
|
56
|
+
raise FileNotFoundError(f"No bl_* directories found in {scan_dir}")
|
|
57
|
+
|
|
58
|
+
print(f"\nFound {len(bl_dirs)} bond-length directories in:\n {scan_dir}\n")
|
|
59
|
+
|
|
60
|
+
# Optional atom reference (binding uses 2*E_atom — homonuclear diatomic)
|
|
61
|
+
E_atom = None
|
|
62
|
+
if args.atom_dir is not None:
|
|
63
|
+
print("atom reference:")
|
|
64
|
+
E_atom = analyse(args.atom_dir)['E_total']
|
|
65
|
+
print(f" E(atom) = {E_atom:+.6f} Ha\n")
|
|
66
|
+
|
|
67
|
+
# Main scan
|
|
68
|
+
rows = []
|
|
69
|
+
for bl_dir in bl_dirs:
|
|
70
|
+
if not (bl_dir / "job_params.json").exists():
|
|
71
|
+
print(f" {bl_dir.name}: missing job_params.json — skipping")
|
|
72
|
+
continue
|
|
73
|
+
data = analyse(bl_dir)
|
|
74
|
+
R = data['bond_length']
|
|
75
|
+
row = {'R_bohr': R,
|
|
76
|
+
'epoch': data.get('epoch', '?'),
|
|
77
|
+
'E_total': data['E_total'],
|
|
78
|
+
'T': data['T'],
|
|
79
|
+
'V_N': data['V_N'],
|
|
80
|
+
'V_H': data['V_H'],
|
|
81
|
+
'E_X': data['E_X'],
|
|
82
|
+
'E_C': data.get('E_C', 0.0),
|
|
83
|
+
'E_NN': data['E_NN'],
|
|
84
|
+
'Ne_int': data['Ne_integral']}
|
|
85
|
+
if E_atom is not None:
|
|
86
|
+
E_bind = data['E_total'] - 2.0 * E_atom
|
|
87
|
+
row['E_bind_Ha'] = E_bind
|
|
88
|
+
row['D_e_eV'] = -E_bind * 27.2114
|
|
89
|
+
rows.append(row)
|
|
90
|
+
tag = f" R={R:.4f} Bohr epoch={row['epoch']:>6} E={data['E_total']:+.6f} Ha"
|
|
91
|
+
if E_atom is not None:
|
|
92
|
+
tag += f" E_bind={row['E_bind_Ha']:+.6f} Ha"
|
|
93
|
+
print(tag)
|
|
94
|
+
|
|
95
|
+
# ── Save CSV ─────────────────────────────────────────────────────────────────
|
|
96
|
+
df = pd.DataFrame(rows).sort_values('R_bohr').reset_index(drop=True)
|
|
97
|
+
csv_path = scan_dir / "pes.csv"
|
|
98
|
+
df.to_csv(csv_path, index=False, float_format='%.8f')
|
|
99
|
+
print(f"\nPES data saved → {csv_path}")
|
|
100
|
+
print(df.to_string(index=False))
|
|
101
|
+
|
|
102
|
+
# ── Plot ──────────────────────────────────────────────────────────────────────
|
|
103
|
+
fig, axes = plt.subplots(1, 2 if E_atom is not None else 1,
|
|
104
|
+
figsize=(11 if E_atom is not None else 5, 4))
|
|
105
|
+
if E_atom is None:
|
|
106
|
+
axes = [axes]
|
|
107
|
+
|
|
108
|
+
R_vals = df['R_bohr'].values
|
|
109
|
+
max_epoch = df['epoch'].max()
|
|
110
|
+
complete = df['epoch'] == max_epoch # True if run finished
|
|
111
|
+
|
|
112
|
+
for ax, y_col, ylabel, title, color in [
|
|
113
|
+
(axes[0], 'E_total', 'Energy [Ha]', 'Potential Energy Surface', 'tab:blue'),
|
|
114
|
+
*( [(axes[1], 'D_e_eV', 'D_e [eV]', 'Dissociation Energy', 'tab:orange')]
|
|
115
|
+
if E_atom is not None else [] ),
|
|
116
|
+
]:
|
|
117
|
+
y = df[y_col].values
|
|
118
|
+
ax.plot(R_vals[complete], y[complete], 'o-', color=color, label=f'epoch={max_epoch}')
|
|
119
|
+
ax.plot(R_vals[~complete], y[~complete], '^', color=color, alpha=0.5,
|
|
120
|
+
label='incomplete', markerfacecolor='none')
|
|
121
|
+
if y_col == 'D_e_eV':
|
|
122
|
+
ax.axhline(0, color='k', linewidth=0.8, linestyle='--')
|
|
123
|
+
ax.set_xlabel("R [Bohr]")
|
|
124
|
+
ax.set_ylabel(ylabel)
|
|
125
|
+
ax.set_title(title)
|
|
126
|
+
ax.legend(fontsize=8)
|
|
127
|
+
ax.grid(True, alpha=0.3)
|
|
128
|
+
|
|
129
|
+
fig.suptitle(scan_dir.parent.name.split("/")[-1], fontsize=9)
|
|
130
|
+
fig.tight_layout()
|
|
131
|
+
fig.savefig(scan_dir / "pes.svg", transparent=True)
|
|
132
|
+
fig.savefig(scan_dir / "pes.png", dpi=150)
|
|
133
|
+
print(f"PES plot saved → {scan_dir}/pes.png")
|