off 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
off/plot_pes_mpl.py ADDED
@@ -0,0 +1,280 @@
1
+ """
2
+ PES (left) and binding-energy curve (right), side by side — matplotlib.
3
+
4
+ MC (line + o) : per bond length, mean of training_metrics_ema.csv over an
5
+ epoch range (E + CC) + E_NN -> E_total.
6
+ Integration (pentagons) : E_total from energy_summary.json (quadrature_scan.py).
7
+
8
+ Right panel: binding energy [a.u.] = E_total(R) - Σ_atoms E(atom), where the atomic
9
+ references come from atom_energies.csv (column E_ema_Ha for the MC line, E_grid_Ha for
10
+ the integration points), summed over the molecule's constituent atoms.
11
+
12
+ Usage
13
+ -----
14
+ python plot_pes_mpl.py --H2
15
+ python plot_pes_mpl.py --N2 --epoch_min 9000 --epoch_max 10000
16
+ python plot_pes_mpl.py --H2 --atom_csv Results/atom_energies.csv
17
+ python plot_pes_mpl.py --H2 --results_root /scratch/al3x/MyRuns --out /tmp/h2.png
18
+
19
+ Output: static PNG + SVG under Results/{mol}/ (no HTML, no extra dependencies).
20
+ Run quadrature_scan.py first so each bl_* has an energy_summary.json.
21
+ """
22
+
23
+ import argparse
24
+ import json
25
+ import re
26
+ from pathlib import Path
27
+
28
+ import pandas as pd
29
+ import matplotlib
30
+ matplotlib.use("Agg") # headless: works on the cluster, no display needed
31
+ import matplotlib.pyplot as plt
32
+
33
+ _SCRIPT_DIR = Path(__file__).resolve().parent
34
+ KNOWN_MOLS = ["H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne",
35
+ "H2", "N2", "O2", "F2", "HF", "CO", "LiH", "H10"]
36
+ Z_TABLE = {"H": 1, "He": 2, "Li": 3, "Be": 4, "B": 5,
37
+ "C": 6, "N": 7, "O": 8, "F": 9, "Ne": 10}
38
+
39
+ # ── CLI ───────────────────────────────────────────────────────────────────────
40
+ parser = argparse.ArgumentParser(
41
+ description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter,
42
+ allow_abbrev=False) # so --H is not a prefix of --H2 / --H10
43
+ for _m in KNOWN_MOLS:
44
+ parser.add_argument(f"--{_m}", action="store_true", help=f"Plot molecule {_m}")
45
+ parser.add_argument("--mol", nargs="+", default=[], metavar="NAME",
46
+ help="Molecule name(s) (alternative to the flags)")
47
+ parser.add_argument("--method", default=None,
48
+ help="Restrict to one method directory (default: all methods)")
49
+ parser.add_argument("--epoch_min", type=int, default=None,
50
+ help="Average EMA rows with epoch >= this")
51
+ parser.add_argument("--epoch_max", type=int, default=None,
52
+ help="Average EMA rows with epoch <= this")
53
+ parser.add_argument("--window", type=int, default=1000,
54
+ help="If no epoch range given: average the last N EMA rows")
55
+ parser.add_argument("--results_root", default=None,
56
+ help="Override Results root (default: <script_dir>/Results)")
57
+ parser.add_argument("--atom_csv", default=None,
58
+ help="Path to atom_energies.csv (default: search mol dir / root / CWD)")
59
+ parser.add_argument("--out", default=None,
60
+ help="Output image path (default: Results/{mol}/pes_{mol}.png)")
61
+ args = parser.parse_args()
62
+
63
+ selected = list(args.mol) + [m for m in KNOWN_MOLS if getattr(args, m)]
64
+ selected = list(dict.fromkeys(selected))
65
+ if not selected:
66
+ parser.error("Pick a molecule, e.g. --H2 or --mol H2")
67
+
68
+ root = Path(args.results_root).resolve() if args.results_root else (_SCRIPT_DIR / "Results")
69
+
70
+
71
+ # ── helpers ───────────────────────────────────────────────────────────────────
72
+ def ema_mean(df: pd.DataFrame):
73
+ """Electronic energy: mean of (E + CC) over the chosen epoch range / window."""
74
+ if args.epoch_min is not None or args.epoch_max is not None:
75
+ lo = args.epoch_min if args.epoch_min is not None else df["epoch"].min()
76
+ hi = args.epoch_max if args.epoch_max is not None else df["epoch"].max()
77
+ sel = df[(df["epoch"] >= lo) & (df["epoch"] <= hi)]
78
+ else:
79
+ sel = df.tail(args.window)
80
+ if sel.empty:
81
+ return None
82
+ e = float(sel["E"].mean())
83
+ if "CC" in sel.columns:
84
+ e += float(sel["CC"].mean())
85
+ return e
86
+
87
+
88
+ def enn_fallback(mol: str, R: float):
89
+ """E_NN for simple geometries — used only if energy_summary.json is missing."""
90
+ m = re.fullmatch(r"([A-Z][a-z]?)2", mol) # homonuclear diatomic
91
+ if m and m.group(1) in Z_TABLE:
92
+ z = Z_TABLE[m.group(1)]
93
+ return z * z / R
94
+ pairs = {"HF": ("H", "F"), "CO": ("C", "O"), "LiH": ("Li", "H")}
95
+ if mol in pairs:
96
+ a, b = pairs[mol]
97
+ return Z_TABLE[a] * Z_TABLE[b] / R
98
+ m = re.fullmatch(r"H(\d+)", mol) # linear equal-spaced Hn chain
99
+ if m:
100
+ n = int(m.group(1))
101
+ return sum(1.0 / (abs(i - j) * R) for i in range(n) for j in range(i + 1, n))
102
+ return None
103
+
104
+
105
+ def constituents(mol: str) -> dict:
106
+ """{element: count} from a formula, e.g. H2->{H:2}, HF->{H:1,F:1}, H10->{H:10}."""
107
+ out = {}
108
+ for el, n in re.findall(r"([A-Z][a-z]?)(\d*)", mol):
109
+ if el:
110
+ out[el] = out.get(el, 0) + (int(n) if n else 1)
111
+ return out
112
+
113
+
114
+ def load_atom_csv(mol_dir: Path):
115
+ """Find and load atom_energies.csv (indexed by atom symbol)."""
116
+ cands = []
117
+ if args.atom_csv:
118
+ cands.append(Path(args.atom_csv))
119
+ cands += [mol_dir / "atom_energies.csv",
120
+ root / "atom_energies.csv",
121
+ Path("atom_energies.csv")]
122
+ for p in cands:
123
+ if p.exists():
124
+ return pd.read_csv(p).set_index("atom")
125
+ return None
126
+
127
+
128
+ def atom_ref(atom_df, mol: str, col: str):
129
+ """Σ_atoms count * E(atom) from the given column; None if any atom is missing."""
130
+ if atom_df is None:
131
+ return None
132
+ tot = 0.0
133
+ for el, n in constituents(mol).items():
134
+ if el not in atom_df.index:
135
+ return None
136
+ tot += n * float(atom_df.loc[el, col])
137
+ return tot
138
+
139
+
140
+ def gather(method_dir: Path, mol: str) -> pd.DataFrame:
141
+ rows = []
142
+ for bl in sorted(method_dir.glob("bl_*"), key=lambda d: float(d.name.split("_")[1])):
143
+ R = float(bl.name.split("_")[1])
144
+ grid_E = e_nn = epoch = None
145
+
146
+ es = bl / "energy_summary.json"
147
+ if es.exists():
148
+ with open(es) as f:
149
+ d = json.load(f)
150
+ grid_E = d.get("E_total")
151
+ e_nn = d.get("E_NN")
152
+ epoch = d.get("epoch")
153
+ if e_nn is None:
154
+ e_nn = enn_fallback(mol, R)
155
+
156
+ ema_E = None
157
+ ec = bl / "training_metrics_ema.csv"
158
+ if ec.exists() and e_nn is not None:
159
+ try:
160
+ df = pd.read_csv(ec)
161
+ except Exception:
162
+ df = None
163
+ if df is not None and not df.empty and "E" in df.columns:
164
+ em = ema_mean(df)
165
+ if em is not None:
166
+ ema_E = em + e_nn
167
+
168
+ rows.append(dict(R=R, grid_E=grid_E, ema_E=ema_E, epoch=epoch))
169
+ return pd.DataFrame(rows).sort_values("R").reset_index(drop=True)
170
+
171
+
172
+ def short(method: str) -> str:
173
+ """Readable legend label: strip the parts common to every run."""
174
+ return (method.replace("tf_w_lam0.2_", "")
175
+ .replace("_dopri8", "")
176
+ .replace("_sched_MIX", ""))
177
+
178
+
179
+ def plot_mol(mol: str):
180
+ mdir = root / mol
181
+ if not mdir.is_dir():
182
+ print(f"[{mol}] not found at {mdir}")
183
+ return
184
+
185
+ if args.method:
186
+ methods = [mdir / args.method]
187
+ else:
188
+ methods = sorted(d for d in mdir.iterdir() if d.is_dir())
189
+
190
+ # Atomic references for the binding-energy panel (same CSV for every method).
191
+ atom_df = load_atom_csv(mdir)
192
+ ref_ema = atom_ref(atom_df, mol, "E_ema_Ha")
193
+ ref_grid = atom_ref(atom_df, mol, "E_grid_Ha")
194
+ if ref_ema is None and ref_grid is None:
195
+ print(f"[{mol}] atom_energies.csv not found / missing atoms — binding panel empty")
196
+
197
+ colors = plt.cm.tab10.colors
198
+ fig, (axL, axR) = plt.subplots(1, 2, figsize=(14, 6))
199
+ any_data = False
200
+ bind_rows = []
201
+
202
+ for i, md in enumerate(methods):
203
+ if not md.is_dir():
204
+ print(f"[{mol}] method dir not found: {md.name}")
205
+ continue
206
+ df = gather(md, mol)
207
+ if df.empty:
208
+ continue
209
+ c = colors[i % len(colors)]
210
+
211
+ ema = df.dropna(subset=["ema_E"])
212
+ grid = df.dropna(subset=["grid_E"])
213
+
214
+ # ── left: PES ─────────────────────────────────────────────────────────
215
+ if not ema.empty:
216
+ any_data = True
217
+ axL.plot(ema["R"], ema["ema_E"], "-o", color=c, lw=2, ms=12, label="MC")
218
+ if not grid.empty:
219
+ any_data = True
220
+ axL.scatter(grid["R"], grid["grid_E"], color="orange", marker="p",
221
+ s=100, edgecolors="none", linewidths=0.6, zorder=5,
222
+ label="Integration")
223
+
224
+ # ── right: ΔE = E(A) + E(B) - E(AB) = Σ E(atom) - E_total (>0 = bound)
225
+ if not ema.empty and ref_ema is not None:
226
+ axR.plot(ema["R"], ref_ema - ema["ema_E"], "-o", color=c, lw=2, ms=12,
227
+ label="MC")
228
+ if not grid.empty and ref_grid is not None:
229
+ axR.scatter(grid["R"], ref_grid - grid["grid_E"], color="orange", marker="p",
230
+ s=100, edgecolors="none", linewidths=0.6, zorder=5,
231
+ label="Integration")
232
+
233
+ # binding values: ΔE = E(A) + E(B) - E(AB) = Σ E(atom) - E_total (>0 = bound)
234
+ for _, r in df.iterrows():
235
+ row = {"method": md.name, "R_bohr": float(r["R"]), "epoch": r["epoch"]}
236
+ if pd.notna(r["ema_E"]) and ref_ema is not None:
237
+ row["E_atoms_mc"] = ref_ema
238
+ row["E_AB_mc"] = float(r["ema_E"])
239
+ row["dE_mc_Ha"] = ref_ema - float(r["ema_E"])
240
+ if pd.notna(r["grid_E"]) and ref_grid is not None:
241
+ row["E_atoms_grid"] = ref_grid
242
+ row["E_AB_grid"] = float(r["grid_E"])
243
+ row["dE_grid_Ha"] = ref_grid - float(r["grid_E"])
244
+ bind_rows.append(row)
245
+
246
+ if not any_data:
247
+ print(f"[{mol}] no data — run quadrature_scan.py first "
248
+ f"(creates energy_summary.json in each bl_*).")
249
+ plt.close(fig)
250
+ return
251
+
252
+ axL.set_xlabel("R [Bohr]")
253
+ axL.set_ylabel(r"E[$\rho$] + V$_{NN}$(R) [a.u.]")
254
+ axL.legend(fontsize=9)
255
+
256
+ axR.axhline(0, color="k", lw=0.8, ls="--")
257
+ axR.set_xlabel("R [Bohr]")
258
+ axR.set_ylabel(r"$\Delta$E = E(A) + E(B) - E(AB) [a.u.]")
259
+ axR.legend(fontsize=9)
260
+
261
+ fig.tight_layout()
262
+
263
+ out = Path(args.out).resolve() if args.out else mdir / f"pes_{mol}.png"
264
+ out.parent.mkdir(parents=True, exist_ok=True)
265
+ fig.savefig(out, dpi=150)
266
+ fig.savefig(out.with_suffix(".svg"))
267
+ plt.close(fig)
268
+ print(f"[{mol}] saved → {out}")
269
+ print(f"[{mol}] saved → {out.with_suffix('.svg')}")
270
+
271
+ if bind_rows:
272
+ bdf = (pd.DataFrame(bind_rows)
273
+ .sort_values(["method", "R_bohr"]).reset_index(drop=True))
274
+ bcsv = out.with_name(f"binding_{mol}.csv")
275
+ bdf.to_csv(bcsv, index=False, float_format="%.8f")
276
+ print(f"[{mol}] saved → {bcsv}")
277
+
278
+
279
+ for mol in selected:
280
+ plot_mol(mol)
@@ -0,0 +1,27 @@
1
+ # MIT License
2
+
3
+ # Copyright (c) 2025 AlexandreDeCamargo
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ __version__ = "0.1.0"
24
+
25
+ from .promolecular_dist import (
26
+ ProMolecularDensity
27
+ )