off 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- off/__init__.py +23 -0
- off/atom_energies.py +151 -0
- off/config/_config.py +108 -0
- off/dft_distrax/__init__.py +27 -0
- off/dft_distrax/dft_distrax.py +216 -0
- off/flow/__init__.py +29 -0
- off/flow/equiv_flows.py +99 -0
- off/functionals/__init__.py +35 -0
- off/functionals/core_correction.py +84 -0
- off/functionals/exchange_correlation.py +174 -0
- off/functionals/external.py +49 -0
- off/functionals/functional.py +129 -0
- off/functionals/hartree.py +62 -0
- off/functionals/kinetic.py +87 -0
- off/main.py +172 -0
- off/ode_solver/__init__.py +32 -0
- off/ode_solver/eqx_ode.py +76 -0
- off/plot_binding_csv.py +63 -0
- off/plot_pes_ema.py +259 -0
- off/plot_pes_mpl.py +280 -0
- off/promolecular/__init__.py +27 -0
- off/promolecular/promolecular_dist.py +465 -0
- off/quadrature.py +261 -0
- off/quadrature_scan.py +188 -0
- off/scan_pes.py +133 -0
- off/test_fwd_rev.py +290 -0
- off/train/__init__.py +44 -0
- off/train/loop.py +228 -0
- off/train/loss.py +149 -0
- off/train/utils.py +38 -0
- off/utils.py +618 -0
- off-0.1.0.dist-info/METADATA +154 -0
- off-0.1.0.dist-info/RECORD +37 -0
- off-0.1.0.dist-info/WHEEL +5 -0
- off-0.1.0.dist-info/entry_points.txt +3 -0
- off-0.1.0.dist-info/licenses/LICENSE +21 -0
- off-0.1.0.dist-info/top_level.txt +1 -0
off/plot_pes_mpl.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PES (left) and binding-energy curve (right), side by side — matplotlib.
|
|
3
|
+
|
|
4
|
+
MC (line + o) : per bond length, mean of training_metrics_ema.csv over an
|
|
5
|
+
epoch range (E + CC) + E_NN -> E_total.
|
|
6
|
+
Integration (pentagons) : E_total from energy_summary.json (quadrature_scan.py).
|
|
7
|
+
|
|
8
|
+
Right panel: binding energy [a.u.] = E_total(R) - Σ_atoms E(atom), where the atomic
|
|
9
|
+
references come from atom_energies.csv (column E_ema_Ha for the MC line, E_grid_Ha for
|
|
10
|
+
the integration points), summed over the molecule's constituent atoms.
|
|
11
|
+
|
|
12
|
+
Usage
|
|
13
|
+
-----
|
|
14
|
+
python plot_pes_mpl.py --H2
|
|
15
|
+
python plot_pes_mpl.py --N2 --epoch_min 9000 --epoch_max 10000
|
|
16
|
+
python plot_pes_mpl.py --H2 --atom_csv Results/atom_energies.csv
|
|
17
|
+
python plot_pes_mpl.py --H2 --results_root /scratch/al3x/MyRuns --out /tmp/h2.png
|
|
18
|
+
|
|
19
|
+
Output: static PNG + SVG under Results/{mol}/ (no HTML, no extra dependencies).
|
|
20
|
+
Run quadrature_scan.py first so each bl_* has an energy_summary.json.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import argparse
|
|
24
|
+
import json
|
|
25
|
+
import re
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
|
|
28
|
+
import pandas as pd
|
|
29
|
+
import matplotlib
|
|
30
|
+
matplotlib.use("Agg") # headless: works on the cluster, no display needed
|
|
31
|
+
import matplotlib.pyplot as plt
|
|
32
|
+
|
|
33
|
+
_SCRIPT_DIR = Path(__file__).resolve().parent
|
|
34
|
+
KNOWN_MOLS = ["H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne",
|
|
35
|
+
"H2", "N2", "O2", "F2", "HF", "CO", "LiH", "H10"]
|
|
36
|
+
Z_TABLE = {"H": 1, "He": 2, "Li": 3, "Be": 4, "B": 5,
|
|
37
|
+
"C": 6, "N": 7, "O": 8, "F": 9, "Ne": 10}
|
|
38
|
+
|
|
39
|
+
# ── CLI ───────────────────────────────────────────────────────────────────────
|
|
40
|
+
parser = argparse.ArgumentParser(
|
|
41
|
+
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
42
|
+
allow_abbrev=False) # so --H is not a prefix of --H2 / --H10
|
|
43
|
+
for _m in KNOWN_MOLS:
|
|
44
|
+
parser.add_argument(f"--{_m}", action="store_true", help=f"Plot molecule {_m}")
|
|
45
|
+
parser.add_argument("--mol", nargs="+", default=[], metavar="NAME",
|
|
46
|
+
help="Molecule name(s) (alternative to the flags)")
|
|
47
|
+
parser.add_argument("--method", default=None,
|
|
48
|
+
help="Restrict to one method directory (default: all methods)")
|
|
49
|
+
parser.add_argument("--epoch_min", type=int, default=None,
|
|
50
|
+
help="Average EMA rows with epoch >= this")
|
|
51
|
+
parser.add_argument("--epoch_max", type=int, default=None,
|
|
52
|
+
help="Average EMA rows with epoch <= this")
|
|
53
|
+
parser.add_argument("--window", type=int, default=1000,
|
|
54
|
+
help="If no epoch range given: average the last N EMA rows")
|
|
55
|
+
parser.add_argument("--results_root", default=None,
|
|
56
|
+
help="Override Results root (default: <script_dir>/Results)")
|
|
57
|
+
parser.add_argument("--atom_csv", default=None,
|
|
58
|
+
help="Path to atom_energies.csv (default: search mol dir / root / CWD)")
|
|
59
|
+
parser.add_argument("--out", default=None,
|
|
60
|
+
help="Output image path (default: Results/{mol}/pes_{mol}.png)")
|
|
61
|
+
args = parser.parse_args()
|
|
62
|
+
|
|
63
|
+
selected = list(args.mol) + [m for m in KNOWN_MOLS if getattr(args, m)]
|
|
64
|
+
selected = list(dict.fromkeys(selected))
|
|
65
|
+
if not selected:
|
|
66
|
+
parser.error("Pick a molecule, e.g. --H2 or --mol H2")
|
|
67
|
+
|
|
68
|
+
root = Path(args.results_root).resolve() if args.results_root else (_SCRIPT_DIR / "Results")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# ── helpers ───────────────────────────────────────────────────────────────────
|
|
72
|
+
def ema_mean(df: pd.DataFrame):
|
|
73
|
+
"""Electronic energy: mean of (E + CC) over the chosen epoch range / window."""
|
|
74
|
+
if args.epoch_min is not None or args.epoch_max is not None:
|
|
75
|
+
lo = args.epoch_min if args.epoch_min is not None else df["epoch"].min()
|
|
76
|
+
hi = args.epoch_max if args.epoch_max is not None else df["epoch"].max()
|
|
77
|
+
sel = df[(df["epoch"] >= lo) & (df["epoch"] <= hi)]
|
|
78
|
+
else:
|
|
79
|
+
sel = df.tail(args.window)
|
|
80
|
+
if sel.empty:
|
|
81
|
+
return None
|
|
82
|
+
e = float(sel["E"].mean())
|
|
83
|
+
if "CC" in sel.columns:
|
|
84
|
+
e += float(sel["CC"].mean())
|
|
85
|
+
return e
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def enn_fallback(mol: str, R: float):
|
|
89
|
+
"""E_NN for simple geometries — used only if energy_summary.json is missing."""
|
|
90
|
+
m = re.fullmatch(r"([A-Z][a-z]?)2", mol) # homonuclear diatomic
|
|
91
|
+
if m and m.group(1) in Z_TABLE:
|
|
92
|
+
z = Z_TABLE[m.group(1)]
|
|
93
|
+
return z * z / R
|
|
94
|
+
pairs = {"HF": ("H", "F"), "CO": ("C", "O"), "LiH": ("Li", "H")}
|
|
95
|
+
if mol in pairs:
|
|
96
|
+
a, b = pairs[mol]
|
|
97
|
+
return Z_TABLE[a] * Z_TABLE[b] / R
|
|
98
|
+
m = re.fullmatch(r"H(\d+)", mol) # linear equal-spaced Hn chain
|
|
99
|
+
if m:
|
|
100
|
+
n = int(m.group(1))
|
|
101
|
+
return sum(1.0 / (abs(i - j) * R) for i in range(n) for j in range(i + 1, n))
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def constituents(mol: str) -> dict:
|
|
106
|
+
"""{element: count} from a formula, e.g. H2->{H:2}, HF->{H:1,F:1}, H10->{H:10}."""
|
|
107
|
+
out = {}
|
|
108
|
+
for el, n in re.findall(r"([A-Z][a-z]?)(\d*)", mol):
|
|
109
|
+
if el:
|
|
110
|
+
out[el] = out.get(el, 0) + (int(n) if n else 1)
|
|
111
|
+
return out
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def load_atom_csv(mol_dir: Path):
|
|
115
|
+
"""Find and load atom_energies.csv (indexed by atom symbol)."""
|
|
116
|
+
cands = []
|
|
117
|
+
if args.atom_csv:
|
|
118
|
+
cands.append(Path(args.atom_csv))
|
|
119
|
+
cands += [mol_dir / "atom_energies.csv",
|
|
120
|
+
root / "atom_energies.csv",
|
|
121
|
+
Path("atom_energies.csv")]
|
|
122
|
+
for p in cands:
|
|
123
|
+
if p.exists():
|
|
124
|
+
return pd.read_csv(p).set_index("atom")
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def atom_ref(atom_df, mol: str, col: str):
|
|
129
|
+
"""Σ_atoms count * E(atom) from the given column; None if any atom is missing."""
|
|
130
|
+
if atom_df is None:
|
|
131
|
+
return None
|
|
132
|
+
tot = 0.0
|
|
133
|
+
for el, n in constituents(mol).items():
|
|
134
|
+
if el not in atom_df.index:
|
|
135
|
+
return None
|
|
136
|
+
tot += n * float(atom_df.loc[el, col])
|
|
137
|
+
return tot
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def gather(method_dir: Path, mol: str) -> pd.DataFrame:
|
|
141
|
+
rows = []
|
|
142
|
+
for bl in sorted(method_dir.glob("bl_*"), key=lambda d: float(d.name.split("_")[1])):
|
|
143
|
+
R = float(bl.name.split("_")[1])
|
|
144
|
+
grid_E = e_nn = epoch = None
|
|
145
|
+
|
|
146
|
+
es = bl / "energy_summary.json"
|
|
147
|
+
if es.exists():
|
|
148
|
+
with open(es) as f:
|
|
149
|
+
d = json.load(f)
|
|
150
|
+
grid_E = d.get("E_total")
|
|
151
|
+
e_nn = d.get("E_NN")
|
|
152
|
+
epoch = d.get("epoch")
|
|
153
|
+
if e_nn is None:
|
|
154
|
+
e_nn = enn_fallback(mol, R)
|
|
155
|
+
|
|
156
|
+
ema_E = None
|
|
157
|
+
ec = bl / "training_metrics_ema.csv"
|
|
158
|
+
if ec.exists() and e_nn is not None:
|
|
159
|
+
try:
|
|
160
|
+
df = pd.read_csv(ec)
|
|
161
|
+
except Exception:
|
|
162
|
+
df = None
|
|
163
|
+
if df is not None and not df.empty and "E" in df.columns:
|
|
164
|
+
em = ema_mean(df)
|
|
165
|
+
if em is not None:
|
|
166
|
+
ema_E = em + e_nn
|
|
167
|
+
|
|
168
|
+
rows.append(dict(R=R, grid_E=grid_E, ema_E=ema_E, epoch=epoch))
|
|
169
|
+
return pd.DataFrame(rows).sort_values("R").reset_index(drop=True)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def short(method: str) -> str:
|
|
173
|
+
"""Readable legend label: strip the parts common to every run."""
|
|
174
|
+
return (method.replace("tf_w_lam0.2_", "")
|
|
175
|
+
.replace("_dopri8", "")
|
|
176
|
+
.replace("_sched_MIX", ""))
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def plot_mol(mol: str):
|
|
180
|
+
mdir = root / mol
|
|
181
|
+
if not mdir.is_dir():
|
|
182
|
+
print(f"[{mol}] not found at {mdir}")
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
if args.method:
|
|
186
|
+
methods = [mdir / args.method]
|
|
187
|
+
else:
|
|
188
|
+
methods = sorted(d for d in mdir.iterdir() if d.is_dir())
|
|
189
|
+
|
|
190
|
+
# Atomic references for the binding-energy panel (same CSV for every method).
|
|
191
|
+
atom_df = load_atom_csv(mdir)
|
|
192
|
+
ref_ema = atom_ref(atom_df, mol, "E_ema_Ha")
|
|
193
|
+
ref_grid = atom_ref(atom_df, mol, "E_grid_Ha")
|
|
194
|
+
if ref_ema is None and ref_grid is None:
|
|
195
|
+
print(f"[{mol}] atom_energies.csv not found / missing atoms — binding panel empty")
|
|
196
|
+
|
|
197
|
+
colors = plt.cm.tab10.colors
|
|
198
|
+
fig, (axL, axR) = plt.subplots(1, 2, figsize=(14, 6))
|
|
199
|
+
any_data = False
|
|
200
|
+
bind_rows = []
|
|
201
|
+
|
|
202
|
+
for i, md in enumerate(methods):
|
|
203
|
+
if not md.is_dir():
|
|
204
|
+
print(f"[{mol}] method dir not found: {md.name}")
|
|
205
|
+
continue
|
|
206
|
+
df = gather(md, mol)
|
|
207
|
+
if df.empty:
|
|
208
|
+
continue
|
|
209
|
+
c = colors[i % len(colors)]
|
|
210
|
+
|
|
211
|
+
ema = df.dropna(subset=["ema_E"])
|
|
212
|
+
grid = df.dropna(subset=["grid_E"])
|
|
213
|
+
|
|
214
|
+
# ── left: PES ─────────────────────────────────────────────────────────
|
|
215
|
+
if not ema.empty:
|
|
216
|
+
any_data = True
|
|
217
|
+
axL.plot(ema["R"], ema["ema_E"], "-o", color=c, lw=2, ms=12, label="MC")
|
|
218
|
+
if not grid.empty:
|
|
219
|
+
any_data = True
|
|
220
|
+
axL.scatter(grid["R"], grid["grid_E"], color="orange", marker="p",
|
|
221
|
+
s=100, edgecolors="none", linewidths=0.6, zorder=5,
|
|
222
|
+
label="Integration")
|
|
223
|
+
|
|
224
|
+
# ── right: ΔE = E(A) + E(B) - E(AB) = Σ E(atom) - E_total (>0 = bound)
|
|
225
|
+
if not ema.empty and ref_ema is not None:
|
|
226
|
+
axR.plot(ema["R"], ref_ema - ema["ema_E"], "-o", color=c, lw=2, ms=12,
|
|
227
|
+
label="MC")
|
|
228
|
+
if not grid.empty and ref_grid is not None:
|
|
229
|
+
axR.scatter(grid["R"], ref_grid - grid["grid_E"], color="orange", marker="p",
|
|
230
|
+
s=100, edgecolors="none", linewidths=0.6, zorder=5,
|
|
231
|
+
label="Integration")
|
|
232
|
+
|
|
233
|
+
# binding values: ΔE = E(A) + E(B) - E(AB) = Σ E(atom) - E_total (>0 = bound)
|
|
234
|
+
for _, r in df.iterrows():
|
|
235
|
+
row = {"method": md.name, "R_bohr": float(r["R"]), "epoch": r["epoch"]}
|
|
236
|
+
if pd.notna(r["ema_E"]) and ref_ema is not None:
|
|
237
|
+
row["E_atoms_mc"] = ref_ema
|
|
238
|
+
row["E_AB_mc"] = float(r["ema_E"])
|
|
239
|
+
row["dE_mc_Ha"] = ref_ema - float(r["ema_E"])
|
|
240
|
+
if pd.notna(r["grid_E"]) and ref_grid is not None:
|
|
241
|
+
row["E_atoms_grid"] = ref_grid
|
|
242
|
+
row["E_AB_grid"] = float(r["grid_E"])
|
|
243
|
+
row["dE_grid_Ha"] = ref_grid - float(r["grid_E"])
|
|
244
|
+
bind_rows.append(row)
|
|
245
|
+
|
|
246
|
+
if not any_data:
|
|
247
|
+
print(f"[{mol}] no data — run quadrature_scan.py first "
|
|
248
|
+
f"(creates energy_summary.json in each bl_*).")
|
|
249
|
+
plt.close(fig)
|
|
250
|
+
return
|
|
251
|
+
|
|
252
|
+
axL.set_xlabel("R [Bohr]")
|
|
253
|
+
axL.set_ylabel(r"E[$\rho$] + V$_{NN}$(R) [a.u.]")
|
|
254
|
+
axL.legend(fontsize=9)
|
|
255
|
+
|
|
256
|
+
axR.axhline(0, color="k", lw=0.8, ls="--")
|
|
257
|
+
axR.set_xlabel("R [Bohr]")
|
|
258
|
+
axR.set_ylabel(r"$\Delta$E = E(A) + E(B) - E(AB) [a.u.]")
|
|
259
|
+
axR.legend(fontsize=9)
|
|
260
|
+
|
|
261
|
+
fig.tight_layout()
|
|
262
|
+
|
|
263
|
+
out = Path(args.out).resolve() if args.out else mdir / f"pes_{mol}.png"
|
|
264
|
+
out.parent.mkdir(parents=True, exist_ok=True)
|
|
265
|
+
fig.savefig(out, dpi=150)
|
|
266
|
+
fig.savefig(out.with_suffix(".svg"))
|
|
267
|
+
plt.close(fig)
|
|
268
|
+
print(f"[{mol}] saved → {out}")
|
|
269
|
+
print(f"[{mol}] saved → {out.with_suffix('.svg')}")
|
|
270
|
+
|
|
271
|
+
if bind_rows:
|
|
272
|
+
bdf = (pd.DataFrame(bind_rows)
|
|
273
|
+
.sort_values(["method", "R_bohr"]).reset_index(drop=True))
|
|
274
|
+
bcsv = out.with_name(f"binding_{mol}.csv")
|
|
275
|
+
bdf.to_csv(bcsv, index=False, float_format="%.8f")
|
|
276
|
+
print(f"[{mol}] saved → {bcsv}")
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
for mol in selected:
|
|
280
|
+
plot_mol(mol)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# MIT License
|
|
2
|
+
|
|
3
|
+
# Copyright (c) 2025 AlexandreDeCamargo
|
|
4
|
+
|
|
5
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
# in the Software without restriction, including without limitation the rights
|
|
8
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
# furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
# copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
# SOFTWARE.
|
|
22
|
+
|
|
23
|
+
__version__ = "0.1.0"
|
|
24
|
+
|
|
25
|
+
from .promolecular_dist import (
|
|
26
|
+
ProMolecularDensity
|
|
27
|
+
)
|