off 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
off/main.py ADDED
@@ -0,0 +1,172 @@
1
+ import argparse
2
+ import json
3
+ from fractions import Fraction
4
+ from pathlib import Path
5
+ from .train.loop import training
6
+ from .config._config import Config
7
+
8
+
9
+ def _lam(value: str) -> float:
10
+ """Accept λ as a fraction ('1/9', '1/5') or plain float ('0.2', '2.0')."""
11
+ try:
12
+ return float(Fraction(value))
13
+ except (ValueError, ZeroDivisionError):
14
+ raise argparse.ArgumentTypeError(
15
+ f"Invalid λ value '{value}'. Use a fraction (1/9, 1/5) or float (0.111, 2.0)."
16
+ )
17
+
18
+ SINGLE_ATOMS = {'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne'}
19
+
20
+ def _method_tag(args) -> str:
21
+ """Encode functional/solver choices into a compact directory name."""
22
+ kin_tag = args.kin
23
+ if args.kin in ('w', 'tf_w'):
24
+ kin_tag += f"_lam{args.lam:.6g}"
25
+
26
+ tag = f"{kin_tag}_{args.cc}_{args.x}_{args.c}_{args.solver}_{args.prior}"
27
+ if args.sched.lower() not in ['c', 'const']:
28
+ tag += f"_sched_{args.sched}"
29
+ if args.hart.lower() != 'coulomb':
30
+ tag += f"_hart_{args.hart}"
31
+ return tag.lower()
32
+
33
+
34
+ def setup_directories(args):
35
+ """Create and return directory paths for results, checkpoints, and figures.
36
+
37
+ Layout: Results/{mol}/{method}/bl_{bond_length}/
38
+ - Single atoms always use bl_0.0000 (bond length has no meaning).
39
+ - Diatomics/polyatomics use the supplied --bond_length value.
40
+ This makes bond-length scans trivial:
41
+ glob('Results/H2/{method}/bl_*/')
42
+ """
43
+ bl = 0.0 if args.mol_name in SINGLE_ATOMS else args.bond_length
44
+ results_dir = f"Results/{args.mol_name}/{_method_tag(args)}/bl_{bl:.2f}"
45
+ ckpt_dir = f"{results_dir}/Checkpoints"
46
+
47
+ for directory in [results_dir, ckpt_dir]:
48
+ Path(directory).mkdir(parents=True, exist_ok=True)
49
+
50
+ return results_dir, ckpt_dir
51
+
52
+
53
+ def save_job_params(results_dir, args):
54
+ """Save training parameters to JSON file."""
55
+ job_params = {
56
+ 'model': 'cnf',
57
+ 'mol_name': args.mol_name,
58
+ 'bond_length': args.bond_length,
59
+ 'epochs': args.epochs,
60
+ 'batch_size': args.bs,
61
+ 'hidden_layer': args.hl,
62
+ 'lr': args.lr,
63
+ 'kinetic': args.kin,
64
+ 'lam': args.lam,
65
+ 'external': args.nuc,
66
+ 'hartree': args.hart,
67
+ 'exchange': args.x,
68
+ 'correlation': args.c,
69
+ 'core_correction': args.cc,
70
+ 'scheduler': args.sched,
71
+ 'solver': args.solver,
72
+ 'prior': args.prior,
73
+ }
74
+
75
+ with open(f"{results_dir}/job_params.json", "w") as outfile:
76
+ json.dump(job_params, outfile, indent=4)
77
+
78
+ return job_params
79
+
80
+
81
+ def main():
82
+ parser = argparse.ArgumentParser()
83
+ # Model parameters
84
+ parser.add_argument("--mol_name", type=str, default='H',
85
+ help="Molecule name")
86
+ parser.add_argument("--bond_length", type=float, default=4.4,
87
+ help="Bond length for the molecule (Bohr)")
88
+ parser.add_argument("--epochs", type=int, default=500,
89
+ help="Number of training epochs")
90
+ parser.add_argument("--bs", type=int, default=512,
91
+ help="Batch size")
92
+ parser.add_argument("--hl", type=int, default=64,
93
+ help="Hidden layer size")
94
+ parser.add_argument("--lr", type=float, default=3e-4,
95
+ help="Learning rate")
96
+ parser.add_argument("--prior", type=str, default='promolecular',
97
+ choices=['promolecular', 'db_sir'],
98
+ help="Prior distribution type")
99
+
100
+ # Functionals
101
+ parser.add_argument("--kin", type=str, default='tf_w',
102
+ choices=['tf', 'w', 'tf_w'],
103
+ help="Kinetic energy functional")
104
+ parser.add_argument("--lam", type=_lam, default=1/5,
105
+ help="Weizsäcker prefactor λ in TF-λW: fraction or float ")
106
+ parser.add_argument("--nuc", type=str, default='np',
107
+ help="Nuclear potential functional")
108
+ parser.add_argument("--hart", type=str, default='coulomb',
109
+ help="Hartree energy functional")
110
+ parser.add_argument("--x", type=str, default='lda',
111
+ choices=['lda', 'b88_x', 'lda_b88_x'],
112
+ help="Exchange energy functional")
113
+ parser.add_argument("--c", type=str, default='none',
114
+ choices=['vwn_c', 'pw92_c', 'none'],
115
+ help="Correlation energy functional")
116
+ parser.add_argument("--cc", type=str, default='none',
117
+ choices=['kato', 'hutcheon', 'none'],
118
+ help="Core correction functional")
119
+
120
+ # Training settings
121
+ parser.add_argument("--sched", type=str, default='mix',
122
+ help="Learning rate scheduler type")
123
+ parser.add_argument("--solver", type=str, default='dopri8',
124
+ choices=['dopri5', 'tsit5', 'dopri8'],
125
+ help="ODE solver")
126
+ parser.add_argument("--ckpt_freq", type=int, default=15,
127
+ help="Checkpoint saving frequency (epochs)")
128
+
129
+ args = parser.parse_args()
130
+
131
+ Config.from_args(args)
132
+
133
+ # Setup directories
134
+ results_dir, ckpt_dir = setup_directories(args)
135
+ Config.set_directories(results_dir, ckpt_dir)
136
+
137
+ # Save parameters
138
+ job_params = save_job_params(results_dir, args)
139
+ print(f"Starting training with parameters:")
140
+ print(json.dumps(job_params, indent=2))
141
+ print(f"\nResults will be saved to: {results_dir}")
142
+
143
+ # Run training
144
+ shared = dict(
145
+ mol_name=args.mol_name,
146
+ bond_length=args.bond_length,
147
+ tw_kin=args.kin,
148
+ lam=args.lam,
149
+ n_pot=args.nuc,
150
+ h_pot=args.hart,
151
+ x_pot=args.x,
152
+ c_pot=args.c,
153
+ cc_pot=args.cc,
154
+ batch_size=args.bs,
155
+ hidden_layer=args.hl,
156
+ epochs=args.epochs,
157
+ lr=args.lr,
158
+ scheduler_type=args.sched,
159
+ prior_type=args.prior,
160
+ checkpoint_dir=ckpt_dir,
161
+ checkpoint_freq=args.ckpt_freq,
162
+ )
163
+
164
+ model, df, df_ema = training(**shared, solver_type=args.solver)
165
+
166
+ print(f"\nTraining complete!")
167
+ print(f"Results saved to: {results_dir}")
168
+ print(f"Final energy (EMA): {df_ema['E'].iloc[-1]:.6f}")
169
+
170
+
171
+ if __name__ == "__main__":
172
+ main()
@@ -0,0 +1,32 @@
1
+ # MIT License
2
+
3
+ # Copyright (c) 2025 AlexandreDeCamargo
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ __version__ = "0.1.0"
24
+
25
+ from .eqx_ode import (
26
+ fwd_ode,
27
+ rev_ode,
28
+ )
29
+
30
+ from ..flow.equiv_flows import (
31
+ CNF,
32
+ )
@@ -0,0 +1,76 @@
1
+ from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import functools
5
+
6
+
7
+ @functools.partial(jax.vmap, in_axes=(None,0,0), out_axes=0)
8
+ def forward(model,x,t):
9
+ return model(x,t)
10
+
11
+
12
+
13
+ def fwd_ode(flow_model,x_and_logpx,solver):
14
+ t0 = 0.
15
+ t1 = 1.
16
+ dt0 = t1 - t0
17
+
18
+
19
+ vector_field = lambda t,x,args: forward(flow_model,x,t*jnp.ones((x.shape[0],1)))
20
+ term = ODETerm(vector_field)
21
+ solver = solver
22
+ saveat = SaveAt(ts=jnp.array([0.,1.]))
23
+ #Set a dict to change the rtol and atol
24
+ stepsize_controller=PIDController(rtol=1e-8, atol=1e-8)
25
+
26
+ sol = diffeqsolve(term, solver, t0, t1, dt0, x_and_logpx,
27
+ stepsize_controller=stepsize_controller,
28
+ saveat=saveat)
29
+ data_dim = 3
30
+ z_t, logp_diff_t, score_t = sol.ys[:, :,:data_dim],sol.ys[:, :, data_dim:data_dim+1],sol.ys[:, :, data_dim+1:]
31
+ z_t1, logp_diff_t1, score_t1 = z_t[-1], logp_diff_t[-1], score_t[-1]
32
+
33
+ return z_t1, logp_diff_t1, score_t1
34
+
35
+ # def rev_ode(flow_model, z_and_logpz, solver):
36
+
37
+ # t0 = 0.
38
+ # t1 = 1.
39
+ # dt0 = t1 - t0
40
+ # vector_field = lambda t,x,args: forward(flow_model,x,t*jnp.ones((x.shape[0],1)))
41
+ # term = ODETerm(vector_field)
42
+ # solver = solver
43
+ # saveat = SaveAt(ts=jnp.array([1., 0.]))
44
+ # stepsize_controller = PIDController(rtol=1e-8, atol=1e-8)
45
+
46
+ # sol = diffeqsolve(term, solver, t1, t0, -dt0, z_and_logpz,
47
+ # stepsize_controller=stepsize_controller,
48
+ # saveat=saveat)
49
+ # data_dim = 3
50
+ # z_t, logp_diff_t, score_diff_t = sol.ys[:, :, :data_dim], sol.ys[:, :, data_dim:data_dim+1], sol.ys[:, :, data_dim+1:]
51
+ # z_t0, logp_diff_t0, score_diff_t0 = z_t[-1], logp_diff_t[-1], score_diff_t[-1]
52
+
53
+ # return z_t0, logp_diff_t0, score_diff_t0
54
+
55
+
56
+ def rev_ode(flow_model, z_and_logpz, solver):
57
+ t0 = 0.
58
+ t1 = 1.
59
+ dt0 = t1 - t0
60
+
61
+ vector_field = lambda t,x,args: forward(flow_model,x,t*jnp.ones((x.shape[0],1)))
62
+ term = ODETerm(vector_field)
63
+ solver = solver
64
+ saveat = SaveAt(ts=jnp.array([1., 0.]))
65
+ stepsize_controller = PIDController(rtol=1e-8, atol=1e-8)
66
+
67
+ sol = diffeqsolve(term, solver, t1, t0, -dt0, z_and_logpz,
68
+ stepsize_controller=stepsize_controller,
69
+ saveat=saveat)
70
+ data_dim = 3
71
+ # z_t, logp_diff_t, _ = sol.ys[:-1, :, :data_dim], sol.ys[:-1, :, data_dim:data_dim+1], sol.ys[:, :, data_dim+1:]
72
+ # z_t0, logp_diff_t0 = sol.ys[:-1, :, :data_dim], sol.ys[:-1, :, data_dim:data_dim+1]
73
+ # return sol.ys
74
+ z_t, logp_diff_t, score_diff_t = sol.ys[:, :, :data_dim], sol.ys[:, :, data_dim:data_dim+1], sol.ys[:, :, data_dim+1:]
75
+ z_t0, logp_diff_t0, score_diff_t0 = z_t[-1], logp_diff_t[-1], score_diff_t[-1]
76
+ return z_t0, logp_diff_t0
@@ -0,0 +1,63 @@
1
+ """
2
+ Plot a binding_{mol}.csv: PES (left) and ΔE binding (right).
3
+ MC = blue line + markers
4
+ grid = orange dots
5
+
6
+ Usage:
7
+ python plot_binding_csv.py Results/N2/binding_N2.csv
8
+ python plot_binding_csv.py Results/N2/binding_N2.csv --out n2.png
9
+ """
10
+
11
+ import argparse
12
+ from pathlib import Path
13
+
14
+ import pandas as pd
15
+ import matplotlib
16
+ matplotlib.use("Agg")
17
+ import matplotlib.pyplot as plt
18
+
19
+ ap = argparse.ArgumentParser()
20
+ ap.add_argument("csv", help="path to binding_{mol}.csv")
21
+ ap.add_argument("--out", default=None, help="output image (default: <csv>.png)")
22
+ args = ap.parse_args()
23
+
24
+ df = pd.read_csv(args.csv)
25
+ fig, (axL, axR) = plt.subplots(1, 2, figsize=(14, 6))
26
+
27
+ for method, g in df.groupby("method"):
28
+ g = g.sort_values("R_bohr")
29
+
30
+ # ── left: total energy (PES) ──────────────────────────────────────────────
31
+ axL.plot(g.R_bohr, g.E_AB_mc, "-o", color="tab:blue", lw=2, ms=7, label="MC")
32
+ axL.scatter(g.R_bohr, g.E_AB_grid, color="orange", marker="o", s=70,
33
+ edgecolors="black", linewidths=0.5, zorder=5, label="grid")
34
+ if "E_atoms_grid" in g.columns: # dissociation limit (grid)
35
+ axL.axhline(g.E_atoms_grid.iloc[0], color="orange", ls=":", lw=1.2,
36
+ alpha=0.9, label=r"2·E(atom) grid")
37
+
38
+ # ── right: ΔE = E(A) + E(B) - E(AB) ───────────────────────────────────────
39
+ axR.plot(g.R_bohr, g.dE_mc_Ha, "-o", color="tab:blue", lw=2, ms=7, label="MC")
40
+ axR.scatter(g.R_bohr, g.dE_grid_Ha, color="orange", marker="o", s=70,
41
+ edgecolors="black", linewidths=0.5, zorder=5, label="grid")
42
+
43
+ axL.set_xlabel("R [Bohr]")
44
+ axL.set_ylabel(r"E[$\rho$] + V$_{NN}$(R) [a.u.]")
45
+ axL.set_title("PES")
46
+ axL.grid(alpha=0.3)
47
+ axL.legend(fontsize=8)
48
+
49
+ axR.axhline(0, color="k", lw=0.8, ls="--")
50
+ axR.set_xlabel("R [Bohr]")
51
+ axR.set_ylabel(r"$\Delta$E = E(A) + E(B) - E(AB) [a.u.]")
52
+ axR.set_title("Binding energy")
53
+ axR.grid(alpha=0.3)
54
+ axR.legend(fontsize=8)
55
+
56
+ fig.suptitle(Path(args.csv).stem)
57
+ fig.tight_layout()
58
+
59
+ out = Path(args.out) if args.out else Path(args.csv).with_suffix(".png")
60
+ fig.savefig(out, dpi=150)
61
+ fig.savefig(out.with_suffix(".svg"))
62
+ print("saved →", out)
63
+ print("saved →", out.with_suffix(".svg"))
off/plot_pes_ema.py ADDED
@@ -0,0 +1,259 @@
1
+ """
2
+ Quick PES plot from EMA training logs — no grid integration needed.
3
+
4
+ Reads training_metrics_ema.csv from each bl_* directory and uses the
5
+ last epoch's EMA energy (E + CC) as E_total.
6
+
7
+ Usage
8
+ -----
9
+ # PES only:
10
+ python plot_pes_ema.py \
11
+ --scan_dir Results/H2/tf_w_lam0.2_hutcheon_lda_none_dopri8_promolecular_sched_MIX
12
+
13
+ # With binding energy (needs H atom dir):
14
+ python plot_pes_ema.py \
15
+ --scan_dir Results/N2/tf_w_lam0.2_none_lda_none_dopri8_promolecular_sched_MIX_hart_COULOMB_ALLPAIRS \
16
+ --atom_dir Results/N/tf_w_lam0.2_none_lda_none_dopri8_promolecular_sched_MIX_hart_COULOMB_ALLPAIRS/bl_0.0000
17
+
18
+ # Read R=8.0 and R=9.0 at epoch 20000, the rest at their last epoch:
19
+ python plot_pes_ema.py --scan_dir Results/N2/... --epoch_at 8.0:20000 9.0:20000
20
+ """
21
+
22
+ import argparse
23
+ from pathlib import Path
24
+
25
+ import matplotlib.pyplot as plt
26
+ import numpy as np
27
+ import pandas as pd
28
+
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument("--scan_dir", type=str, required=True,
31
+ help="Method directory containing bl_X.XXXX subdirectories")
32
+ parser.add_argument("--atom_dir", type=str, default=None,
33
+ help="bl_0.0000 directory for the H atom (binding energy reference)")
34
+ parser.add_argument("--pes_csv", type=str, default=None,
35
+ help="pes.csv from scan_pes.py to overlay as grid-integration "
36
+ "points. If omitted, looks for pes.csv inside --scan_dir.")
37
+ parser.add_argument("--avg_window", type=int, default=1,
38
+ help="Average the last N rows of training_metrics_ema.csv "
39
+ "instead of taking just the last value (default: 500).")
40
+ parser.add_argument("--bls", type=float, nargs="+", default=None,
41
+ help="Only include these bond lengths, e.g. --bls 2.0 3.0 4.0 9.0. "
42
+ "If omitted, include all bl_* directories found.")
43
+ parser.add_argument("--epoch_at", type=str, nargs="+", default=None, metavar="R:EPOCH",
44
+ help="Per-bond-length epoch override, e.g. --epoch_at 8.0:20000 9.0:20000. "
45
+ "Those bond lengths use the EMA as of that epoch; all others use "
46
+ "their last epoch.")
47
+ args = parser.parse_args()
48
+
49
+ scan_dir = Path(args.scan_dir).resolve()
50
+
51
+ # Parse --epoch_at "R:EPOCH" pairs into {round(R,4): epoch}
52
+ EPOCH_OVERRIDE = {}
53
+ if args.epoch_at:
54
+ for pair in args.epoch_at:
55
+ if ":" not in pair:
56
+ parser.error(f"--epoch_at expects R:EPOCH pairs, got '{pair}'")
57
+ r_str, e_str = pair.split(":", 1)
58
+ EPOCH_OVERRIDE[round(float(r_str), 4)] = int(e_str)
59
+ print(f"Epoch overrides: {EPOCH_OVERRIDE}")
60
+
61
+
62
+ def read_last_ema(bl_dir: Path, window: int = 500, at_epoch: int = None):
63
+ """Return (E_electronic, epoch) averaged over the last `window` rows of
64
+ training_metrics_ema.csv. Epoch returned is the final one (window is just
65
+ smoothing the EMA noise).
66
+ E_electronic = E + CC (does NOT include nuclear repulsion E_NN).
67
+ If `at_epoch` is given, the log is first truncated to rows with
68
+ epoch <= at_epoch, so the value is read *as of* that epoch.
69
+ """
70
+ csv = bl_dir / "training_metrics_ema.csv"
71
+ if not csv.exists():
72
+ return None, None
73
+ df = pd.read_csv(csv)
74
+ if df.empty:
75
+ return None, None
76
+ if at_epoch is not None:
77
+ df = df[df["epoch"] <= at_epoch] # read the EMA as of this epoch
78
+ if df.empty:
79
+ return None, None
80
+ tail = df.tail(window)
81
+ E_elec = float(tail["E"].mean())
82
+ if "CC" in tail.columns:
83
+ E_elec += float(tail["CC"].mean())
84
+ epoch = int(df.iloc[-1]["epoch"])
85
+ return E_elec, epoch
86
+
87
+
88
+ def e_nn(mol_name: str, R: float) -> float:
89
+ """Nuclear-nuclear repulsion energy [Ha] for a diatomic at bond length R [Bohr]."""
90
+ Z = {"H": 1, "He": 2, "Li": 3, "Be": 4, "B": 5, "C": 6,
91
+ "N": 7, "O": 8, "F": 9, "Ne": 10}
92
+ # homonuclear diatomics: mol_name = element symbol × 2 (e.g. "H2", "N2")
93
+ elem = mol_name.rstrip("0123456789")
94
+ if mol_name in ("HF",):
95
+ za, zb = Z["H"], Z["F"]
96
+ elif mol_name == "CO":
97
+ za, zb = Z["C"], Z["O"]
98
+ elif mol_name == "NO":
99
+ za, zb = Z["N"], Z["O"]
100
+ else:
101
+ za = zb = Z.get(elem, 1)
102
+ return za * zb / R if R > 0 else 0.0
103
+
104
+
105
+ # ── H atom reference ──────────────────────────────────────────────────────────
106
+ E_atom = None
107
+ if args.atom_dir is not None:
108
+ E_atom, ep_atom = read_last_ema(Path(args.atom_dir).resolve(), window=args.avg_window)
109
+ if E_atom is not None:
110
+ print(f"E(atom) = {E_atom:+.6f} Ha (epoch {ep_atom})")
111
+ else:
112
+ print(f"WARNING: could not read atom EMA from {args.atom_dir}")
113
+ E_atom = None
114
+
115
+ # ── Scan over bl_* directories ────────────────────────────────────────────────
116
+ bl_dirs = sorted(scan_dir.glob("bl_*"),
117
+ key=lambda d: float(d.name.split("_")[1]))
118
+
119
+ if not bl_dirs:
120
+ raise FileNotFoundError(f"No bl_* directories found in {scan_dir}")
121
+
122
+ if args.bls is not None:
123
+ keep = {round(bl, 4) for bl in args.bls}
124
+ bl_dirs = [d for d in bl_dirs if round(float(d.name.split("_")[1]), 4) in keep]
125
+ if not bl_dirs:
126
+ raise FileNotFoundError(
127
+ f"None of --bls {args.bls} match any bl_X.XXXX in {scan_dir}")
128
+ print(f"Filtered to {len(bl_dirs)} requested bond lengths: "
129
+ f"{[d.name for d in bl_dirs]}")
130
+
131
+ import json, re
132
+
133
+ # detect molecule name from first bl_* job_params.json
134
+ mol_name = "H2"
135
+ for d in bl_dirs:
136
+ jp = d / "job_params.json"
137
+ if jp.exists():
138
+ mol_name = json.load(open(jp))["mol_name"]
139
+ break
140
+ print(f"Molecule: {mol_name}")
141
+
142
+ # Parse mol_name for LaTeX labels: A_n -> atom_sym='A', mol_latex='\mathrm{A}_n'
143
+ _m = re.fullmatch(r"([A-Z][a-z]?)(\d+)?", mol_name)
144
+ if _m and _m.group(2):
145
+ atom_sym = _m.group(1)
146
+ mol_latex = rf"\mathrm{{{atom_sym}}}_{{{_m.group(2)}}}"
147
+ else:
148
+ atom_sym = mol_name
149
+ mol_latex = rf"\mathrm{{{mol_name}}}"
150
+ if mol_name in ("HF", "CO", "NO"):
151
+ print(f"WARNING: {mol_name} is heteronuclear — binding uses 2*E_atom, "
152
+ f"which assumes homonuclear dissociation and will be physically wrong")
153
+
154
+ rows = []
155
+ for bl_dir in bl_dirs:
156
+ R = float(bl_dir.name.split("_")[1])
157
+ if R == 0.0:
158
+ continue # skip atom directory if present
159
+ at = EPOCH_OVERRIDE.get(round(R, 4)) # None unless this R is overridden
160
+ E_elec, epoch = read_last_ema(bl_dir, window=args.avg_window, at_epoch=at)
161
+ if E_elec is None:
162
+ print(f" bl={R:.4f}: no EMA csv — skipping")
163
+ continue
164
+ E_NN = e_nn(mol_name, R)
165
+ E_total = E_elec + E_NN # add nuclear repulsion
166
+ row = {"R": R, "E_total": E_total, "epoch": epoch}
167
+ if E_atom is not None:
168
+ row["bind_Ha"] = E_total - 2.0 * E_atom # E(H2) - 2E(H)
169
+ rows.append(row)
170
+ tag = f" R={R:.4f} epoch={epoch:>6} E={E_total:+.6f} Ha (E_NN={E_NN:+.4f})"
171
+ if E_atom is not None:
172
+ tag += f" bind={row['bind_Ha']:+.6f} Ha"
173
+ print(tag)
174
+
175
+ if not rows:
176
+ raise RuntimeError("No data found — check scan_dir.")
177
+
178
+ df = pd.DataFrame(rows).sort_values("R").reset_index(drop=True)
179
+ max_epoch = df["epoch"].max()
180
+ complete = df["epoch"] == max_epoch
181
+
182
+ # ── Optional: grid-integration results from scan_pes.py ──────────────────────
183
+ pes_path = Path(args.pes_csv).resolve() if args.pes_csv else scan_dir / "pes.csv"
184
+ pes_df = None
185
+ if pes_path.exists():
186
+ pes_df = pd.read_csv(pes_path).sort_values("R_bohr").reset_index(drop=True)
187
+ if args.bls is not None:
188
+ keep = {round(bl, 4) for bl in args.bls}
189
+ pes_df = pes_df[pes_df["R_bohr"].round(4).isin(keep)].reset_index(drop=True)
190
+ print(f"\nGrid overlay: {pes_path} ({len(pes_df)} points)")
191
+ else:
192
+ print(f"\nNo pes.csv at {pes_path} — plotting EMA only")
193
+
194
+ # ── Side-by-side binding-energy comparison (EMA vs grid) ─────────────────────
195
+ if pes_df is not None and E_atom is not None and "E_bind_Ha" in pes_df.columns:
196
+ cmp = df[["R", "bind_Ha"]].merge(
197
+ pes_df[["R_bohr", "E_bind_Ha"]].rename(columns={"R_bohr": "R",
198
+ "E_bind_Ha": "grid_bind_Ha"}),
199
+ on="R", how="outer").sort_values("R").reset_index(drop=True)
200
+ cmp["delta_Ha"] = cmp["bind_Ha"] - cmp["grid_bind_Ha"]
201
+ print("\n=== Binding energy: EMA vs grid (Ha) ===")
202
+ print(f"{'R':>8} {'EMA':>12} {'Grid':>12} {'Δ(EMA-Grid)':>14}")
203
+ for _, r in cmp.iterrows():
204
+ ema_s = f"{r['bind_Ha']:+12.6f}" if pd.notna(r['bind_Ha']) else f"{'—':>12}"
205
+ grid_s = f"{r['grid_bind_Ha']:+12.6f}" if pd.notna(r['grid_bind_Ha']) else f"{'—':>12}"
206
+ d_s = f"{r['delta_Ha']:+14.6f}" if pd.notna(r['delta_Ha']) else f"{'—':>14}"
207
+ print(f"{r['R']:8.4f} {ema_s} {grid_s} {d_s}")
208
+
209
+ # ── Plot ──────────────────────────────────────────────────────────────────────
210
+ n_panels = 2 if E_atom is not None else 1
211
+ fig, axes = plt.subplots(1, n_panels, figsize=(5 * n_panels + 1, 4))
212
+ if n_panels == 1:
213
+ axes = [axes]
214
+
215
+ R_vals = df["R"].values
216
+
217
+ # Panel 1: PES (E_total)
218
+ ax = axes[0]
219
+ ax.plot(R_vals[complete], df["E_total"].values[complete],
220
+ "o-", color="tab:blue", lw=1.8, label=f"EMA (epoch={max_epoch})")
221
+ ax.plot(R_vals[~complete], df["E_total"].values[~complete],
222
+ "^", color="tab:blue", alpha=0.5, markerfacecolor="none",
223
+ markersize=7, label="incomplete")
224
+ if pes_df is not None:
225
+ ax.plot(pes_df["R_bohr"], pes_df["E_total"], "o", color="gold",
226
+ markersize=6, markeredgecolor="black", markeredgewidth=0.4,
227
+ linestyle="none", label="scan_pes (grid)", zorder=5)
228
+ ax.set_xlabel("R [Bohr]")
229
+ ax.set_ylabel("E [Ha]")
230
+ ax.set_title("Potential Energy Surface")
231
+ ax.legend(fontsize=8)
232
+ ax.grid(True, alpha=0.25)
233
+
234
+ # Panel 2: binding energy E(mol) - 2 E(atom)
235
+ if E_atom is not None:
236
+ ax2 = axes[1]
237
+ ax2.plot(R_vals[complete], df["bind_Ha"].values[complete],
238
+ "o-", color="tab:orange", lw=1.8, label=f"EMA (epoch={max_epoch})")
239
+ ax2.plot(R_vals[~complete], df["bind_Ha"].values[~complete],
240
+ "^", color="tab:orange", alpha=0.5, markerfacecolor="none", markersize=7)
241
+ if pes_df is not None and "E_bind_Ha" in pes_df.columns:
242
+ # scan_pes already stores E_bind = E(mol) - 2E(atom) — plot directly
243
+ ax2.plot(pes_df["R_bohr"], pes_df["E_bind_Ha"], "o", color="gold",
244
+ markersize=6, markeredgecolor="black", markeredgewidth=0.4,
245
+ linestyle="none", label="scan_pes (grid)", zorder=5)
246
+ ax2.axhline(0, color="k", lw=0.8, ls="--", alpha=0.5)
247
+ ax2.set_xlabel("R [Bohr]")
248
+ ax2.set_ylabel(rf"$E({mol_latex}) - 2E(\mathrm{{{atom_sym}}})$ [Ha]")
249
+ ax2.set_title("Binding Energy")
250
+ ax2.legend(fontsize=8)
251
+ ax2.grid(True, alpha=0.25)
252
+
253
+ fig.suptitle(scan_dir.name, fontsize=8)
254
+ fig.tight_layout()
255
+
256
+ out = scan_dir / "pes_ema.png"
257
+ fig.savefig(out, dpi=150)
258
+ fig.savefig(out.with_suffix(".svg"), transparent=True)
259
+ print(f"\nSaved → {out}")