off 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- off/__init__.py +23 -0
- off/atom_energies.py +151 -0
- off/config/_config.py +108 -0
- off/dft_distrax/__init__.py +27 -0
- off/dft_distrax/dft_distrax.py +216 -0
- off/flow/__init__.py +29 -0
- off/flow/equiv_flows.py +99 -0
- off/functionals/__init__.py +35 -0
- off/functionals/core_correction.py +84 -0
- off/functionals/exchange_correlation.py +174 -0
- off/functionals/external.py +49 -0
- off/functionals/functional.py +129 -0
- off/functionals/hartree.py +62 -0
- off/functionals/kinetic.py +87 -0
- off/main.py +172 -0
- off/ode_solver/__init__.py +32 -0
- off/ode_solver/eqx_ode.py +76 -0
- off/plot_binding_csv.py +63 -0
- off/plot_pes_ema.py +259 -0
- off/plot_pes_mpl.py +280 -0
- off/promolecular/__init__.py +27 -0
- off/promolecular/promolecular_dist.py +465 -0
- off/quadrature.py +261 -0
- off/quadrature_scan.py +188 -0
- off/scan_pes.py +133 -0
- off/test_fwd_rev.py +290 -0
- off/train/__init__.py +44 -0
- off/train/loop.py +228 -0
- off/train/loss.py +149 -0
- off/train/utils.py +38 -0
- off/utils.py +618 -0
- off-0.1.0.dist-info/METADATA +154 -0
- off-0.1.0.dist-info/RECORD +37 -0
- off-0.1.0.dist-info/WHEEL +5 -0
- off-0.1.0.dist-info/entry_points.txt +3 -0
- off-0.1.0.dist-info/licenses/LICENSE +21 -0
- off-0.1.0.dist-info/top_level.txt +1 -0
off/main.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
from fractions import Fraction
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from .train.loop import training
|
|
6
|
+
from .config._config import Config
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _lam(value: str) -> float:
|
|
10
|
+
"""Accept λ as a fraction ('1/9', '1/5') or plain float ('0.2', '2.0')."""
|
|
11
|
+
try:
|
|
12
|
+
return float(Fraction(value))
|
|
13
|
+
except (ValueError, ZeroDivisionError):
|
|
14
|
+
raise argparse.ArgumentTypeError(
|
|
15
|
+
f"Invalid λ value '{value}'. Use a fraction (1/9, 1/5) or float (0.111, 2.0)."
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
SINGLE_ATOMS = {'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne'}
|
|
19
|
+
|
|
20
|
+
def _method_tag(args) -> str:
|
|
21
|
+
"""Encode functional/solver choices into a compact directory name."""
|
|
22
|
+
kin_tag = args.kin
|
|
23
|
+
if args.kin in ('w', 'tf_w'):
|
|
24
|
+
kin_tag += f"_lam{args.lam:.6g}"
|
|
25
|
+
|
|
26
|
+
tag = f"{kin_tag}_{args.cc}_{args.x}_{args.c}_{args.solver}_{args.prior}"
|
|
27
|
+
if args.sched.lower() not in ['c', 'const']:
|
|
28
|
+
tag += f"_sched_{args.sched}"
|
|
29
|
+
if args.hart.lower() != 'coulomb':
|
|
30
|
+
tag += f"_hart_{args.hart}"
|
|
31
|
+
return tag.lower()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def setup_directories(args):
|
|
35
|
+
"""Create and return directory paths for results, checkpoints, and figures.
|
|
36
|
+
|
|
37
|
+
Layout: Results/{mol}/{method}/bl_{bond_length}/
|
|
38
|
+
- Single atoms always use bl_0.0000 (bond length has no meaning).
|
|
39
|
+
- Diatomics/polyatomics use the supplied --bond_length value.
|
|
40
|
+
This makes bond-length scans trivial:
|
|
41
|
+
glob('Results/H2/{method}/bl_*/')
|
|
42
|
+
"""
|
|
43
|
+
bl = 0.0 if args.mol_name in SINGLE_ATOMS else args.bond_length
|
|
44
|
+
results_dir = f"Results/{args.mol_name}/{_method_tag(args)}/bl_{bl:.2f}"
|
|
45
|
+
ckpt_dir = f"{results_dir}/Checkpoints"
|
|
46
|
+
|
|
47
|
+
for directory in [results_dir, ckpt_dir]:
|
|
48
|
+
Path(directory).mkdir(parents=True, exist_ok=True)
|
|
49
|
+
|
|
50
|
+
return results_dir, ckpt_dir
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def save_job_params(results_dir, args):
|
|
54
|
+
"""Save training parameters to JSON file."""
|
|
55
|
+
job_params = {
|
|
56
|
+
'model': 'cnf',
|
|
57
|
+
'mol_name': args.mol_name,
|
|
58
|
+
'bond_length': args.bond_length,
|
|
59
|
+
'epochs': args.epochs,
|
|
60
|
+
'batch_size': args.bs,
|
|
61
|
+
'hidden_layer': args.hl,
|
|
62
|
+
'lr': args.lr,
|
|
63
|
+
'kinetic': args.kin,
|
|
64
|
+
'lam': args.lam,
|
|
65
|
+
'external': args.nuc,
|
|
66
|
+
'hartree': args.hart,
|
|
67
|
+
'exchange': args.x,
|
|
68
|
+
'correlation': args.c,
|
|
69
|
+
'core_correction': args.cc,
|
|
70
|
+
'scheduler': args.sched,
|
|
71
|
+
'solver': args.solver,
|
|
72
|
+
'prior': args.prior,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
with open(f"{results_dir}/job_params.json", "w") as outfile:
|
|
76
|
+
json.dump(job_params, outfile, indent=4)
|
|
77
|
+
|
|
78
|
+
return job_params
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def main():
|
|
82
|
+
parser = argparse.ArgumentParser()
|
|
83
|
+
# Model parameters
|
|
84
|
+
parser.add_argument("--mol_name", type=str, default='H',
|
|
85
|
+
help="Molecule name")
|
|
86
|
+
parser.add_argument("--bond_length", type=float, default=4.4,
|
|
87
|
+
help="Bond length for the molecule (Bohr)")
|
|
88
|
+
parser.add_argument("--epochs", type=int, default=500,
|
|
89
|
+
help="Number of training epochs")
|
|
90
|
+
parser.add_argument("--bs", type=int, default=512,
|
|
91
|
+
help="Batch size")
|
|
92
|
+
parser.add_argument("--hl", type=int, default=64,
|
|
93
|
+
help="Hidden layer size")
|
|
94
|
+
parser.add_argument("--lr", type=float, default=3e-4,
|
|
95
|
+
help="Learning rate")
|
|
96
|
+
parser.add_argument("--prior", type=str, default='promolecular',
|
|
97
|
+
choices=['promolecular', 'db_sir'],
|
|
98
|
+
help="Prior distribution type")
|
|
99
|
+
|
|
100
|
+
# Functionals
|
|
101
|
+
parser.add_argument("--kin", type=str, default='tf_w',
|
|
102
|
+
choices=['tf', 'w', 'tf_w'],
|
|
103
|
+
help="Kinetic energy functional")
|
|
104
|
+
parser.add_argument("--lam", type=_lam, default=1/5,
|
|
105
|
+
help="Weizsäcker prefactor λ in TF-λW: fraction or float ")
|
|
106
|
+
parser.add_argument("--nuc", type=str, default='np',
|
|
107
|
+
help="Nuclear potential functional")
|
|
108
|
+
parser.add_argument("--hart", type=str, default='coulomb',
|
|
109
|
+
help="Hartree energy functional")
|
|
110
|
+
parser.add_argument("--x", type=str, default='lda',
|
|
111
|
+
choices=['lda', 'b88_x', 'lda_b88_x'],
|
|
112
|
+
help="Exchange energy functional")
|
|
113
|
+
parser.add_argument("--c", type=str, default='none',
|
|
114
|
+
choices=['vwn_c', 'pw92_c', 'none'],
|
|
115
|
+
help="Correlation energy functional")
|
|
116
|
+
parser.add_argument("--cc", type=str, default='none',
|
|
117
|
+
choices=['kato', 'hutcheon', 'none'],
|
|
118
|
+
help="Core correction functional")
|
|
119
|
+
|
|
120
|
+
# Training settings
|
|
121
|
+
parser.add_argument("--sched", type=str, default='mix',
|
|
122
|
+
help="Learning rate scheduler type")
|
|
123
|
+
parser.add_argument("--solver", type=str, default='dopri8',
|
|
124
|
+
choices=['dopri5', 'tsit5', 'dopri8'],
|
|
125
|
+
help="ODE solver")
|
|
126
|
+
parser.add_argument("--ckpt_freq", type=int, default=15,
|
|
127
|
+
help="Checkpoint saving frequency (epochs)")
|
|
128
|
+
|
|
129
|
+
args = parser.parse_args()
|
|
130
|
+
|
|
131
|
+
Config.from_args(args)
|
|
132
|
+
|
|
133
|
+
# Setup directories
|
|
134
|
+
results_dir, ckpt_dir = setup_directories(args)
|
|
135
|
+
Config.set_directories(results_dir, ckpt_dir)
|
|
136
|
+
|
|
137
|
+
# Save parameters
|
|
138
|
+
job_params = save_job_params(results_dir, args)
|
|
139
|
+
print(f"Starting training with parameters:")
|
|
140
|
+
print(json.dumps(job_params, indent=2))
|
|
141
|
+
print(f"\nResults will be saved to: {results_dir}")
|
|
142
|
+
|
|
143
|
+
# Run training
|
|
144
|
+
shared = dict(
|
|
145
|
+
mol_name=args.mol_name,
|
|
146
|
+
bond_length=args.bond_length,
|
|
147
|
+
tw_kin=args.kin,
|
|
148
|
+
lam=args.lam,
|
|
149
|
+
n_pot=args.nuc,
|
|
150
|
+
h_pot=args.hart,
|
|
151
|
+
x_pot=args.x,
|
|
152
|
+
c_pot=args.c,
|
|
153
|
+
cc_pot=args.cc,
|
|
154
|
+
batch_size=args.bs,
|
|
155
|
+
hidden_layer=args.hl,
|
|
156
|
+
epochs=args.epochs,
|
|
157
|
+
lr=args.lr,
|
|
158
|
+
scheduler_type=args.sched,
|
|
159
|
+
prior_type=args.prior,
|
|
160
|
+
checkpoint_dir=ckpt_dir,
|
|
161
|
+
checkpoint_freq=args.ckpt_freq,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
model, df, df_ema = training(**shared, solver_type=args.solver)
|
|
165
|
+
|
|
166
|
+
print(f"\nTraining complete!")
|
|
167
|
+
print(f"Results saved to: {results_dir}")
|
|
168
|
+
print(f"Final energy (EMA): {df_ema['E'].iloc[-1]:.6f}")
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
if __name__ == "__main__":
|
|
172
|
+
main()
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# MIT License
|
|
2
|
+
|
|
3
|
+
# Copyright (c) 2025 AlexandreDeCamargo
|
|
4
|
+
|
|
5
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
# in the Software without restriction, including without limitation the rights
|
|
8
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
# furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
# copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
# SOFTWARE.
|
|
22
|
+
|
|
23
|
+
__version__ = "0.1.0"
|
|
24
|
+
|
|
25
|
+
from .eqx_ode import (
|
|
26
|
+
fwd_ode,
|
|
27
|
+
rev_ode,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
from ..flow.equiv_flows import (
|
|
31
|
+
CNF,
|
|
32
|
+
)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import functools
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@functools.partial(jax.vmap, in_axes=(None,0,0), out_axes=0)
|
|
8
|
+
def forward(model,x,t):
|
|
9
|
+
return model(x,t)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def fwd_ode(flow_model,x_and_logpx,solver):
|
|
14
|
+
t0 = 0.
|
|
15
|
+
t1 = 1.
|
|
16
|
+
dt0 = t1 - t0
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
vector_field = lambda t,x,args: forward(flow_model,x,t*jnp.ones((x.shape[0],1)))
|
|
20
|
+
term = ODETerm(vector_field)
|
|
21
|
+
solver = solver
|
|
22
|
+
saveat = SaveAt(ts=jnp.array([0.,1.]))
|
|
23
|
+
#Set a dict to change the rtol and atol
|
|
24
|
+
stepsize_controller=PIDController(rtol=1e-8, atol=1e-8)
|
|
25
|
+
|
|
26
|
+
sol = diffeqsolve(term, solver, t0, t1, dt0, x_and_logpx,
|
|
27
|
+
stepsize_controller=stepsize_controller,
|
|
28
|
+
saveat=saveat)
|
|
29
|
+
data_dim = 3
|
|
30
|
+
z_t, logp_diff_t, score_t = sol.ys[:, :,:data_dim],sol.ys[:, :, data_dim:data_dim+1],sol.ys[:, :, data_dim+1:]
|
|
31
|
+
z_t1, logp_diff_t1, score_t1 = z_t[-1], logp_diff_t[-1], score_t[-1]
|
|
32
|
+
|
|
33
|
+
return z_t1, logp_diff_t1, score_t1
|
|
34
|
+
|
|
35
|
+
# def rev_ode(flow_model, z_and_logpz, solver):
|
|
36
|
+
|
|
37
|
+
# t0 = 0.
|
|
38
|
+
# t1 = 1.
|
|
39
|
+
# dt0 = t1 - t0
|
|
40
|
+
# vector_field = lambda t,x,args: forward(flow_model,x,t*jnp.ones((x.shape[0],1)))
|
|
41
|
+
# term = ODETerm(vector_field)
|
|
42
|
+
# solver = solver
|
|
43
|
+
# saveat = SaveAt(ts=jnp.array([1., 0.]))
|
|
44
|
+
# stepsize_controller = PIDController(rtol=1e-8, atol=1e-8)
|
|
45
|
+
|
|
46
|
+
# sol = diffeqsolve(term, solver, t1, t0, -dt0, z_and_logpz,
|
|
47
|
+
# stepsize_controller=stepsize_controller,
|
|
48
|
+
# saveat=saveat)
|
|
49
|
+
# data_dim = 3
|
|
50
|
+
# z_t, logp_diff_t, score_diff_t = sol.ys[:, :, :data_dim], sol.ys[:, :, data_dim:data_dim+1], sol.ys[:, :, data_dim+1:]
|
|
51
|
+
# z_t0, logp_diff_t0, score_diff_t0 = z_t[-1], logp_diff_t[-1], score_diff_t[-1]
|
|
52
|
+
|
|
53
|
+
# return z_t0, logp_diff_t0, score_diff_t0
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def rev_ode(flow_model, z_and_logpz, solver):
|
|
57
|
+
t0 = 0.
|
|
58
|
+
t1 = 1.
|
|
59
|
+
dt0 = t1 - t0
|
|
60
|
+
|
|
61
|
+
vector_field = lambda t,x,args: forward(flow_model,x,t*jnp.ones((x.shape[0],1)))
|
|
62
|
+
term = ODETerm(vector_field)
|
|
63
|
+
solver = solver
|
|
64
|
+
saveat = SaveAt(ts=jnp.array([1., 0.]))
|
|
65
|
+
stepsize_controller = PIDController(rtol=1e-8, atol=1e-8)
|
|
66
|
+
|
|
67
|
+
sol = diffeqsolve(term, solver, t1, t0, -dt0, z_and_logpz,
|
|
68
|
+
stepsize_controller=stepsize_controller,
|
|
69
|
+
saveat=saveat)
|
|
70
|
+
data_dim = 3
|
|
71
|
+
# z_t, logp_diff_t, _ = sol.ys[:-1, :, :data_dim], sol.ys[:-1, :, data_dim:data_dim+1], sol.ys[:, :, data_dim+1:]
|
|
72
|
+
# z_t0, logp_diff_t0 = sol.ys[:-1, :, :data_dim], sol.ys[:-1, :, data_dim:data_dim+1]
|
|
73
|
+
# return sol.ys
|
|
74
|
+
z_t, logp_diff_t, score_diff_t = sol.ys[:, :, :data_dim], sol.ys[:, :, data_dim:data_dim+1], sol.ys[:, :, data_dim+1:]
|
|
75
|
+
z_t0, logp_diff_t0, score_diff_t0 = z_t[-1], logp_diff_t[-1], score_diff_t[-1]
|
|
76
|
+
return z_t0, logp_diff_t0
|
off/plot_binding_csv.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Plot a binding_{mol}.csv: PES (left) and ΔE binding (right).
|
|
3
|
+
MC = blue line + markers
|
|
4
|
+
grid = orange dots
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
python plot_binding_csv.py Results/N2/binding_N2.csv
|
|
8
|
+
python plot_binding_csv.py Results/N2/binding_N2.csv --out n2.png
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import argparse
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
import pandas as pd
|
|
15
|
+
import matplotlib
|
|
16
|
+
matplotlib.use("Agg")
|
|
17
|
+
import matplotlib.pyplot as plt
|
|
18
|
+
|
|
19
|
+
ap = argparse.ArgumentParser()
|
|
20
|
+
ap.add_argument("csv", help="path to binding_{mol}.csv")
|
|
21
|
+
ap.add_argument("--out", default=None, help="output image (default: <csv>.png)")
|
|
22
|
+
args = ap.parse_args()
|
|
23
|
+
|
|
24
|
+
df = pd.read_csv(args.csv)
|
|
25
|
+
fig, (axL, axR) = plt.subplots(1, 2, figsize=(14, 6))
|
|
26
|
+
|
|
27
|
+
for method, g in df.groupby("method"):
|
|
28
|
+
g = g.sort_values("R_bohr")
|
|
29
|
+
|
|
30
|
+
# ── left: total energy (PES) ──────────────────────────────────────────────
|
|
31
|
+
axL.plot(g.R_bohr, g.E_AB_mc, "-o", color="tab:blue", lw=2, ms=7, label="MC")
|
|
32
|
+
axL.scatter(g.R_bohr, g.E_AB_grid, color="orange", marker="o", s=70,
|
|
33
|
+
edgecolors="black", linewidths=0.5, zorder=5, label="grid")
|
|
34
|
+
if "E_atoms_grid" in g.columns: # dissociation limit (grid)
|
|
35
|
+
axL.axhline(g.E_atoms_grid.iloc[0], color="orange", ls=":", lw=1.2,
|
|
36
|
+
alpha=0.9, label=r"2·E(atom) grid")
|
|
37
|
+
|
|
38
|
+
# ── right: ΔE = E(A) + E(B) - E(AB) ───────────────────────────────────────
|
|
39
|
+
axR.plot(g.R_bohr, g.dE_mc_Ha, "-o", color="tab:blue", lw=2, ms=7, label="MC")
|
|
40
|
+
axR.scatter(g.R_bohr, g.dE_grid_Ha, color="orange", marker="o", s=70,
|
|
41
|
+
edgecolors="black", linewidths=0.5, zorder=5, label="grid")
|
|
42
|
+
|
|
43
|
+
axL.set_xlabel("R [Bohr]")
|
|
44
|
+
axL.set_ylabel(r"E[$\rho$] + V$_{NN}$(R) [a.u.]")
|
|
45
|
+
axL.set_title("PES")
|
|
46
|
+
axL.grid(alpha=0.3)
|
|
47
|
+
axL.legend(fontsize=8)
|
|
48
|
+
|
|
49
|
+
axR.axhline(0, color="k", lw=0.8, ls="--")
|
|
50
|
+
axR.set_xlabel("R [Bohr]")
|
|
51
|
+
axR.set_ylabel(r"$\Delta$E = E(A) + E(B) - E(AB) [a.u.]")
|
|
52
|
+
axR.set_title("Binding energy")
|
|
53
|
+
axR.grid(alpha=0.3)
|
|
54
|
+
axR.legend(fontsize=8)
|
|
55
|
+
|
|
56
|
+
fig.suptitle(Path(args.csv).stem)
|
|
57
|
+
fig.tight_layout()
|
|
58
|
+
|
|
59
|
+
out = Path(args.out) if args.out else Path(args.csv).with_suffix(".png")
|
|
60
|
+
fig.savefig(out, dpi=150)
|
|
61
|
+
fig.savefig(out.with_suffix(".svg"))
|
|
62
|
+
print("saved →", out)
|
|
63
|
+
print("saved →", out.with_suffix(".svg"))
|
off/plot_pes_ema.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Quick PES plot from EMA training logs — no grid integration needed.
|
|
3
|
+
|
|
4
|
+
Reads training_metrics_ema.csv from each bl_* directory and uses the
|
|
5
|
+
last epoch's EMA energy (E + CC) as E_total.
|
|
6
|
+
|
|
7
|
+
Usage
|
|
8
|
+
-----
|
|
9
|
+
# PES only:
|
|
10
|
+
python plot_pes_ema.py \
|
|
11
|
+
--scan_dir Results/H2/tf_w_lam0.2_hutcheon_lda_none_dopri8_promolecular_sched_MIX
|
|
12
|
+
|
|
13
|
+
# With binding energy (needs H atom dir):
|
|
14
|
+
python plot_pes_ema.py \
|
|
15
|
+
--scan_dir Results/N2/tf_w_lam0.2_none_lda_none_dopri8_promolecular_sched_MIX_hart_COULOMB_ALLPAIRS \
|
|
16
|
+
--atom_dir Results/N/tf_w_lam0.2_none_lda_none_dopri8_promolecular_sched_MIX_hart_COULOMB_ALLPAIRS/bl_0.0000
|
|
17
|
+
|
|
18
|
+
# Read R=8.0 and R=9.0 at epoch 20000, the rest at their last epoch:
|
|
19
|
+
python plot_pes_ema.py --scan_dir Results/N2/... --epoch_at 8.0:20000 9.0:20000
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import argparse
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
|
|
25
|
+
import matplotlib.pyplot as plt
|
|
26
|
+
import numpy as np
|
|
27
|
+
import pandas as pd
|
|
28
|
+
|
|
29
|
+
parser = argparse.ArgumentParser()
|
|
30
|
+
parser.add_argument("--scan_dir", type=str, required=True,
|
|
31
|
+
help="Method directory containing bl_X.XXXX subdirectories")
|
|
32
|
+
parser.add_argument("--atom_dir", type=str, default=None,
|
|
33
|
+
help="bl_0.0000 directory for the H atom (binding energy reference)")
|
|
34
|
+
parser.add_argument("--pes_csv", type=str, default=None,
|
|
35
|
+
help="pes.csv from scan_pes.py to overlay as grid-integration "
|
|
36
|
+
"points. If omitted, looks for pes.csv inside --scan_dir.")
|
|
37
|
+
parser.add_argument("--avg_window", type=int, default=1,
|
|
38
|
+
help="Average the last N rows of training_metrics_ema.csv "
|
|
39
|
+
"instead of taking just the last value (default: 500).")
|
|
40
|
+
parser.add_argument("--bls", type=float, nargs="+", default=None,
|
|
41
|
+
help="Only include these bond lengths, e.g. --bls 2.0 3.0 4.0 9.0. "
|
|
42
|
+
"If omitted, include all bl_* directories found.")
|
|
43
|
+
parser.add_argument("--epoch_at", type=str, nargs="+", default=None, metavar="R:EPOCH",
|
|
44
|
+
help="Per-bond-length epoch override, e.g. --epoch_at 8.0:20000 9.0:20000. "
|
|
45
|
+
"Those bond lengths use the EMA as of that epoch; all others use "
|
|
46
|
+
"their last epoch.")
|
|
47
|
+
args = parser.parse_args()
|
|
48
|
+
|
|
49
|
+
scan_dir = Path(args.scan_dir).resolve()
|
|
50
|
+
|
|
51
|
+
# Parse --epoch_at "R:EPOCH" pairs into {round(R,4): epoch}
|
|
52
|
+
EPOCH_OVERRIDE = {}
|
|
53
|
+
if args.epoch_at:
|
|
54
|
+
for pair in args.epoch_at:
|
|
55
|
+
if ":" not in pair:
|
|
56
|
+
parser.error(f"--epoch_at expects R:EPOCH pairs, got '{pair}'")
|
|
57
|
+
r_str, e_str = pair.split(":", 1)
|
|
58
|
+
EPOCH_OVERRIDE[round(float(r_str), 4)] = int(e_str)
|
|
59
|
+
print(f"Epoch overrides: {EPOCH_OVERRIDE}")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def read_last_ema(bl_dir: Path, window: int = 500, at_epoch: int = None):
|
|
63
|
+
"""Return (E_electronic, epoch) averaged over the last `window` rows of
|
|
64
|
+
training_metrics_ema.csv. Epoch returned is the final one (window is just
|
|
65
|
+
smoothing the EMA noise).
|
|
66
|
+
E_electronic = E + CC (does NOT include nuclear repulsion E_NN).
|
|
67
|
+
If `at_epoch` is given, the log is first truncated to rows with
|
|
68
|
+
epoch <= at_epoch, so the value is read *as of* that epoch.
|
|
69
|
+
"""
|
|
70
|
+
csv = bl_dir / "training_metrics_ema.csv"
|
|
71
|
+
if not csv.exists():
|
|
72
|
+
return None, None
|
|
73
|
+
df = pd.read_csv(csv)
|
|
74
|
+
if df.empty:
|
|
75
|
+
return None, None
|
|
76
|
+
if at_epoch is not None:
|
|
77
|
+
df = df[df["epoch"] <= at_epoch] # read the EMA as of this epoch
|
|
78
|
+
if df.empty:
|
|
79
|
+
return None, None
|
|
80
|
+
tail = df.tail(window)
|
|
81
|
+
E_elec = float(tail["E"].mean())
|
|
82
|
+
if "CC" in tail.columns:
|
|
83
|
+
E_elec += float(tail["CC"].mean())
|
|
84
|
+
epoch = int(df.iloc[-1]["epoch"])
|
|
85
|
+
return E_elec, epoch
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def e_nn(mol_name: str, R: float) -> float:
|
|
89
|
+
"""Nuclear-nuclear repulsion energy [Ha] for a diatomic at bond length R [Bohr]."""
|
|
90
|
+
Z = {"H": 1, "He": 2, "Li": 3, "Be": 4, "B": 5, "C": 6,
|
|
91
|
+
"N": 7, "O": 8, "F": 9, "Ne": 10}
|
|
92
|
+
# homonuclear diatomics: mol_name = element symbol × 2 (e.g. "H2", "N2")
|
|
93
|
+
elem = mol_name.rstrip("0123456789")
|
|
94
|
+
if mol_name in ("HF",):
|
|
95
|
+
za, zb = Z["H"], Z["F"]
|
|
96
|
+
elif mol_name == "CO":
|
|
97
|
+
za, zb = Z["C"], Z["O"]
|
|
98
|
+
elif mol_name == "NO":
|
|
99
|
+
za, zb = Z["N"], Z["O"]
|
|
100
|
+
else:
|
|
101
|
+
za = zb = Z.get(elem, 1)
|
|
102
|
+
return za * zb / R if R > 0 else 0.0
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# ── H atom reference ──────────────────────────────────────────────────────────
|
|
106
|
+
E_atom = None
|
|
107
|
+
if args.atom_dir is not None:
|
|
108
|
+
E_atom, ep_atom = read_last_ema(Path(args.atom_dir).resolve(), window=args.avg_window)
|
|
109
|
+
if E_atom is not None:
|
|
110
|
+
print(f"E(atom) = {E_atom:+.6f} Ha (epoch {ep_atom})")
|
|
111
|
+
else:
|
|
112
|
+
print(f"WARNING: could not read atom EMA from {args.atom_dir}")
|
|
113
|
+
E_atom = None
|
|
114
|
+
|
|
115
|
+
# ── Scan over bl_* directories ────────────────────────────────────────────────
|
|
116
|
+
bl_dirs = sorted(scan_dir.glob("bl_*"),
|
|
117
|
+
key=lambda d: float(d.name.split("_")[1]))
|
|
118
|
+
|
|
119
|
+
if not bl_dirs:
|
|
120
|
+
raise FileNotFoundError(f"No bl_* directories found in {scan_dir}")
|
|
121
|
+
|
|
122
|
+
if args.bls is not None:
|
|
123
|
+
keep = {round(bl, 4) for bl in args.bls}
|
|
124
|
+
bl_dirs = [d for d in bl_dirs if round(float(d.name.split("_")[1]), 4) in keep]
|
|
125
|
+
if not bl_dirs:
|
|
126
|
+
raise FileNotFoundError(
|
|
127
|
+
f"None of --bls {args.bls} match any bl_X.XXXX in {scan_dir}")
|
|
128
|
+
print(f"Filtered to {len(bl_dirs)} requested bond lengths: "
|
|
129
|
+
f"{[d.name for d in bl_dirs]}")
|
|
130
|
+
|
|
131
|
+
import json, re
|
|
132
|
+
|
|
133
|
+
# detect molecule name from first bl_* job_params.json
|
|
134
|
+
mol_name = "H2"
|
|
135
|
+
for d in bl_dirs:
|
|
136
|
+
jp = d / "job_params.json"
|
|
137
|
+
if jp.exists():
|
|
138
|
+
mol_name = json.load(open(jp))["mol_name"]
|
|
139
|
+
break
|
|
140
|
+
print(f"Molecule: {mol_name}")
|
|
141
|
+
|
|
142
|
+
# Parse mol_name for LaTeX labels: A_n -> atom_sym='A', mol_latex='\mathrm{A}_n'
|
|
143
|
+
_m = re.fullmatch(r"([A-Z][a-z]?)(\d+)?", mol_name)
|
|
144
|
+
if _m and _m.group(2):
|
|
145
|
+
atom_sym = _m.group(1)
|
|
146
|
+
mol_latex = rf"\mathrm{{{atom_sym}}}_{{{_m.group(2)}}}"
|
|
147
|
+
else:
|
|
148
|
+
atom_sym = mol_name
|
|
149
|
+
mol_latex = rf"\mathrm{{{mol_name}}}"
|
|
150
|
+
if mol_name in ("HF", "CO", "NO"):
|
|
151
|
+
print(f"WARNING: {mol_name} is heteronuclear — binding uses 2*E_atom, "
|
|
152
|
+
f"which assumes homonuclear dissociation and will be physically wrong")
|
|
153
|
+
|
|
154
|
+
rows = []
|
|
155
|
+
for bl_dir in bl_dirs:
|
|
156
|
+
R = float(bl_dir.name.split("_")[1])
|
|
157
|
+
if R == 0.0:
|
|
158
|
+
continue # skip atom directory if present
|
|
159
|
+
at = EPOCH_OVERRIDE.get(round(R, 4)) # None unless this R is overridden
|
|
160
|
+
E_elec, epoch = read_last_ema(bl_dir, window=args.avg_window, at_epoch=at)
|
|
161
|
+
if E_elec is None:
|
|
162
|
+
print(f" bl={R:.4f}: no EMA csv — skipping")
|
|
163
|
+
continue
|
|
164
|
+
E_NN = e_nn(mol_name, R)
|
|
165
|
+
E_total = E_elec + E_NN # add nuclear repulsion
|
|
166
|
+
row = {"R": R, "E_total": E_total, "epoch": epoch}
|
|
167
|
+
if E_atom is not None:
|
|
168
|
+
row["bind_Ha"] = E_total - 2.0 * E_atom # E(H2) - 2E(H)
|
|
169
|
+
rows.append(row)
|
|
170
|
+
tag = f" R={R:.4f} epoch={epoch:>6} E={E_total:+.6f} Ha (E_NN={E_NN:+.4f})"
|
|
171
|
+
if E_atom is not None:
|
|
172
|
+
tag += f" bind={row['bind_Ha']:+.6f} Ha"
|
|
173
|
+
print(tag)
|
|
174
|
+
|
|
175
|
+
if not rows:
|
|
176
|
+
raise RuntimeError("No data found — check scan_dir.")
|
|
177
|
+
|
|
178
|
+
df = pd.DataFrame(rows).sort_values("R").reset_index(drop=True)
|
|
179
|
+
max_epoch = df["epoch"].max()
|
|
180
|
+
complete = df["epoch"] == max_epoch
|
|
181
|
+
|
|
182
|
+
# ── Optional: grid-integration results from scan_pes.py ──────────────────────
|
|
183
|
+
pes_path = Path(args.pes_csv).resolve() if args.pes_csv else scan_dir / "pes.csv"
|
|
184
|
+
pes_df = None
|
|
185
|
+
if pes_path.exists():
|
|
186
|
+
pes_df = pd.read_csv(pes_path).sort_values("R_bohr").reset_index(drop=True)
|
|
187
|
+
if args.bls is not None:
|
|
188
|
+
keep = {round(bl, 4) for bl in args.bls}
|
|
189
|
+
pes_df = pes_df[pes_df["R_bohr"].round(4).isin(keep)].reset_index(drop=True)
|
|
190
|
+
print(f"\nGrid overlay: {pes_path} ({len(pes_df)} points)")
|
|
191
|
+
else:
|
|
192
|
+
print(f"\nNo pes.csv at {pes_path} — plotting EMA only")
|
|
193
|
+
|
|
194
|
+
# ── Side-by-side binding-energy comparison (EMA vs grid) ─────────────────────
|
|
195
|
+
if pes_df is not None and E_atom is not None and "E_bind_Ha" in pes_df.columns:
|
|
196
|
+
cmp = df[["R", "bind_Ha"]].merge(
|
|
197
|
+
pes_df[["R_bohr", "E_bind_Ha"]].rename(columns={"R_bohr": "R",
|
|
198
|
+
"E_bind_Ha": "grid_bind_Ha"}),
|
|
199
|
+
on="R", how="outer").sort_values("R").reset_index(drop=True)
|
|
200
|
+
cmp["delta_Ha"] = cmp["bind_Ha"] - cmp["grid_bind_Ha"]
|
|
201
|
+
print("\n=== Binding energy: EMA vs grid (Ha) ===")
|
|
202
|
+
print(f"{'R':>8} {'EMA':>12} {'Grid':>12} {'Δ(EMA-Grid)':>14}")
|
|
203
|
+
for _, r in cmp.iterrows():
|
|
204
|
+
ema_s = f"{r['bind_Ha']:+12.6f}" if pd.notna(r['bind_Ha']) else f"{'—':>12}"
|
|
205
|
+
grid_s = f"{r['grid_bind_Ha']:+12.6f}" if pd.notna(r['grid_bind_Ha']) else f"{'—':>12}"
|
|
206
|
+
d_s = f"{r['delta_Ha']:+14.6f}" if pd.notna(r['delta_Ha']) else f"{'—':>14}"
|
|
207
|
+
print(f"{r['R']:8.4f} {ema_s} {grid_s} {d_s}")
|
|
208
|
+
|
|
209
|
+
# ── Plot ──────────────────────────────────────────────────────────────────────
|
|
210
|
+
n_panels = 2 if E_atom is not None else 1
|
|
211
|
+
fig, axes = plt.subplots(1, n_panels, figsize=(5 * n_panels + 1, 4))
|
|
212
|
+
if n_panels == 1:
|
|
213
|
+
axes = [axes]
|
|
214
|
+
|
|
215
|
+
R_vals = df["R"].values
|
|
216
|
+
|
|
217
|
+
# Panel 1: PES (E_total)
|
|
218
|
+
ax = axes[0]
|
|
219
|
+
ax.plot(R_vals[complete], df["E_total"].values[complete],
|
|
220
|
+
"o-", color="tab:blue", lw=1.8, label=f"EMA (epoch={max_epoch})")
|
|
221
|
+
ax.plot(R_vals[~complete], df["E_total"].values[~complete],
|
|
222
|
+
"^", color="tab:blue", alpha=0.5, markerfacecolor="none",
|
|
223
|
+
markersize=7, label="incomplete")
|
|
224
|
+
if pes_df is not None:
|
|
225
|
+
ax.plot(pes_df["R_bohr"], pes_df["E_total"], "o", color="gold",
|
|
226
|
+
markersize=6, markeredgecolor="black", markeredgewidth=0.4,
|
|
227
|
+
linestyle="none", label="scan_pes (grid)", zorder=5)
|
|
228
|
+
ax.set_xlabel("R [Bohr]")
|
|
229
|
+
ax.set_ylabel("E [Ha]")
|
|
230
|
+
ax.set_title("Potential Energy Surface")
|
|
231
|
+
ax.legend(fontsize=8)
|
|
232
|
+
ax.grid(True, alpha=0.25)
|
|
233
|
+
|
|
234
|
+
# Panel 2: binding energy E(mol) - 2 E(atom)
|
|
235
|
+
if E_atom is not None:
|
|
236
|
+
ax2 = axes[1]
|
|
237
|
+
ax2.plot(R_vals[complete], df["bind_Ha"].values[complete],
|
|
238
|
+
"o-", color="tab:orange", lw=1.8, label=f"EMA (epoch={max_epoch})")
|
|
239
|
+
ax2.plot(R_vals[~complete], df["bind_Ha"].values[~complete],
|
|
240
|
+
"^", color="tab:orange", alpha=0.5, markerfacecolor="none", markersize=7)
|
|
241
|
+
if pes_df is not None and "E_bind_Ha" in pes_df.columns:
|
|
242
|
+
# scan_pes already stores E_bind = E(mol) - 2E(atom) — plot directly
|
|
243
|
+
ax2.plot(pes_df["R_bohr"], pes_df["E_bind_Ha"], "o", color="gold",
|
|
244
|
+
markersize=6, markeredgecolor="black", markeredgewidth=0.4,
|
|
245
|
+
linestyle="none", label="scan_pes (grid)", zorder=5)
|
|
246
|
+
ax2.axhline(0, color="k", lw=0.8, ls="--", alpha=0.5)
|
|
247
|
+
ax2.set_xlabel("R [Bohr]")
|
|
248
|
+
ax2.set_ylabel(rf"$E({mol_latex}) - 2E(\mathrm{{{atom_sym}}})$ [Ha]")
|
|
249
|
+
ax2.set_title("Binding Energy")
|
|
250
|
+
ax2.legend(fontsize=8)
|
|
251
|
+
ax2.grid(True, alpha=0.25)
|
|
252
|
+
|
|
253
|
+
fig.suptitle(scan_dir.name, fontsize=8)
|
|
254
|
+
fig.tight_layout()
|
|
255
|
+
|
|
256
|
+
out = scan_dir / "pes_ema.png"
|
|
257
|
+
fig.savefig(out, dpi=150)
|
|
258
|
+
fig.savefig(out.with_suffix(".svg"), transparent=True)
|
|
259
|
+
print(f"\nSaved → {out}")
|