off 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- off/__init__.py +23 -0
- off/atom_energies.py +151 -0
- off/config/_config.py +108 -0
- off/dft_distrax/__init__.py +27 -0
- off/dft_distrax/dft_distrax.py +216 -0
- off/flow/__init__.py +29 -0
- off/flow/equiv_flows.py +99 -0
- off/functionals/__init__.py +35 -0
- off/functionals/core_correction.py +84 -0
- off/functionals/exchange_correlation.py +174 -0
- off/functionals/external.py +49 -0
- off/functionals/functional.py +129 -0
- off/functionals/hartree.py +62 -0
- off/functionals/kinetic.py +87 -0
- off/main.py +172 -0
- off/ode_solver/__init__.py +32 -0
- off/ode_solver/eqx_ode.py +76 -0
- off/plot_binding_csv.py +63 -0
- off/plot_pes_ema.py +259 -0
- off/plot_pes_mpl.py +280 -0
- off/promolecular/__init__.py +27 -0
- off/promolecular/promolecular_dist.py +465 -0
- off/quadrature.py +261 -0
- off/quadrature_scan.py +188 -0
- off/scan_pes.py +133 -0
- off/test_fwd_rev.py +290 -0
- off/train/__init__.py +44 -0
- off/train/loop.py +228 -0
- off/train/loss.py +149 -0
- off/train/utils.py +38 -0
- off/utils.py +618 -0
- off-0.1.0.dist-info/METADATA +154 -0
- off-0.1.0.dist-info/RECORD +37 -0
- off-0.1.0.dist-info/WHEEL +5 -0
- off-0.1.0.dist-info/entry_points.txt +3 -0
- off-0.1.0.dist-info/licenses/LICENSE +21 -0
- off-0.1.0.dist-info/top_level.txt +1 -0
off/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""OFF — orbital-free DFT with continuous normalizing flows."""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.1.0"
|
|
4
|
+
|
|
5
|
+
from .functionals.functional import (
|
|
6
|
+
Functional,
|
|
7
|
+
CompositeFunctional,
|
|
8
|
+
EnergyFunctional,
|
|
9
|
+
FunctionalInputs,
|
|
10
|
+
unit_coefficient,
|
|
11
|
+
)
|
|
12
|
+
from .functionals import (
|
|
13
|
+
tf, weizsacker, tf_weizsacker,
|
|
14
|
+
lda, b88, vwn, pw92, lda_b88,
|
|
15
|
+
NuclearPotential, CoulombPotential, CoulombPotential_,
|
|
16
|
+
KatoCondition, HutcheonCuspCondition,
|
|
17
|
+
)
|
|
18
|
+
from .quadrature import (
|
|
19
|
+
getGrid, get_grid, build_grid,
|
|
20
|
+
grid_energy, grid_energy_from_checkpoint,
|
|
21
|
+
)
|
|
22
|
+
from .train.loss import build_energy_functional, create_loss_function
|
|
23
|
+
from .train.loop import training
|
off/atom_energies.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Grid-integrated total energy vs EMA energy, side by side, for a set of atoms.
|
|
3
|
+
|
|
4
|
+
For each atom it:
|
|
5
|
+
1. grid-integrates the LAST checkpoint via ``quadrature.grid_energy_from_checkpoint``
|
|
6
|
+
-> E_grid (= E_total; for a single atom E_NN = 0, so this is the electronic
|
|
7
|
+
total energy).
|
|
8
|
+
2. averages the last --window rows of training_metrics_ema.csv (E + CC) -> E_ema.
|
|
9
|
+
3. writes both numbers side by side to a CSV (+ prints a table).
|
|
10
|
+
|
|
11
|
+
Because E_NN = 0 for an isolated atom, E_grid and E_ema are the *same* physical
|
|
12
|
+
quantity (total atomic energy); the only difference is grid quadrature vs the
|
|
13
|
+
Monte-Carlo EMA estimate during training.
|
|
14
|
+
|
|
15
|
+
Directory layout assumed (same as main.py):
|
|
16
|
+
{results_root}/{atom}/{method}/bl_0.0000/
|
|
17
|
+
Checkpoints/checkpoint_*.eqx
|
|
18
|
+
training_metrics_ema.csv
|
|
19
|
+
job_params.json
|
|
20
|
+
|
|
21
|
+
Usage
|
|
22
|
+
-----
|
|
23
|
+
python atom_energies.py --method tf_w_lam0.2_none_lda_none_dopri8_db_sir_sched_MIX
|
|
24
|
+
python atom_energies.py --method <tag> --atoms B Be C F H He Li N Ne O \
|
|
25
|
+
--window 1000 --recompute --results_root /path/to/Results
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
import sys, os
|
|
29
|
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
30
|
+
|
|
31
|
+
import argparse
|
|
32
|
+
from pathlib import Path
|
|
33
|
+
|
|
34
|
+
import pandas as pd
|
|
35
|
+
|
|
36
|
+
from quadrature import grid_energy_from_checkpoint
|
|
37
|
+
|
|
38
|
+
# atomic number == electron count for a neutral atom
|
|
39
|
+
Z_TABLE = {"H": 1, "He": 2, "Li": 3, "Be": 4, "B": 5,
|
|
40
|
+
"C": 6, "N": 7, "O": 8, "F": 9, "Ne": 10}
|
|
41
|
+
|
|
42
|
+
# ── CLI ───────────────────────────────────────────────────────────────────────
|
|
43
|
+
parser = argparse.ArgumentParser(
|
|
44
|
+
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
|
45
|
+
parser.add_argument("--method", type=str, required=True,
|
|
46
|
+
help="Method directory name, e.g. "
|
|
47
|
+
"tf_w_lam0.2_none_lda_none_dopri8_db_sir_sched_MIX")
|
|
48
|
+
parser.add_argument("--atoms", type=str, nargs="+",
|
|
49
|
+
default=["B", "Be", "C", "F", "H", "He", "Li", "N", "Ne", "O"],
|
|
50
|
+
help="Atoms to process (default: the full set).")
|
|
51
|
+
parser.add_argument("--results_root", type=str, default="Results",
|
|
52
|
+
help="Root dir holding {atom}/{method}/bl_0.0000 (default: Results)")
|
|
53
|
+
parser.add_argument("--window", type=int, default=1000,
|
|
54
|
+
help="Average the last N rows of training_metrics_ema.csv (default: 1000)")
|
|
55
|
+
parser.add_argument("--bs", type=int, default=256, help="Grid chunk size")
|
|
56
|
+
parser.add_argument("--grid_level", type=int, default=3, help="PySCF grid level")
|
|
57
|
+
parser.add_argument("--recompute", action="store_true",
|
|
58
|
+
help="Re-run grid integration even if energy_summary.json is cached")
|
|
59
|
+
parser.add_argument("--out", type=str, default="atom_energies.csv",
|
|
60
|
+
help="Output CSV path (default: atom_energies.csv)")
|
|
61
|
+
args = parser.parse_args()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def read_last_ema(bl_dir: Path, window: int):
|
|
65
|
+
"""Mean of the last `window` rows of training_metrics_ema.csv.
|
|
66
|
+
E = E + CC (no nuclear repulsion; for an atom this is the total energy)."""
|
|
67
|
+
csv = bl_dir / "training_metrics_ema.csv"
|
|
68
|
+
if not csv.exists():
|
|
69
|
+
return None, None
|
|
70
|
+
try:
|
|
71
|
+
df = pd.read_csv(csv)
|
|
72
|
+
except pd.errors.EmptyDataError:
|
|
73
|
+
return None, None
|
|
74
|
+
if df.empty:
|
|
75
|
+
return None, None
|
|
76
|
+
tail = df.tail(window)
|
|
77
|
+
E = float(tail["E"].mean())
|
|
78
|
+
if "CC" in tail.columns:
|
|
79
|
+
E += float(tail["CC"].mean())
|
|
80
|
+
epoch = int(df.iloc[-1]["epoch"])
|
|
81
|
+
return E, epoch
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# ── main loop ─────────────────────────────────────────────────────────────────
|
|
85
|
+
root = Path(args.results_root).resolve()
|
|
86
|
+
print(f"Results root : {root}")
|
|
87
|
+
print(f"Method : {args.method}")
|
|
88
|
+
print(f"EMA window : last {args.window} rows\n")
|
|
89
|
+
|
|
90
|
+
rows = []
|
|
91
|
+
for atom in args.atoms:
|
|
92
|
+
atom_dir = root / atom / args.method / "bl_0.0000"
|
|
93
|
+
print(f"[{atom}] {atom_dir}")
|
|
94
|
+
|
|
95
|
+
if not (atom_dir / "job_params.json").exists():
|
|
96
|
+
print(" SKIP — no job_params.json (directory missing?)\n")
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
# grid integration of the last checkpoint (via the OFF quadrature module)
|
|
100
|
+
try:
|
|
101
|
+
data = grid_energy_from_checkpoint(
|
|
102
|
+
atom_dir, grid_level=args.grid_level, chunk=args.bs, recompute=args.recompute)
|
|
103
|
+
E_grid = data["E_total"]
|
|
104
|
+
grid_epoch = data["epoch"]
|
|
105
|
+
Ne_int = data["Ne_integral"]
|
|
106
|
+
except Exception as e:
|
|
107
|
+
print(f" grid integration FAILED: {e}")
|
|
108
|
+
E_grid = grid_epoch = Ne_int = None
|
|
109
|
+
|
|
110
|
+
# EMA mean of last `window` rows
|
|
111
|
+
E_ema, ema_epoch = read_last_ema(atom_dir, args.window)
|
|
112
|
+
if E_ema is None:
|
|
113
|
+
print(" no training_metrics_ema.csv")
|
|
114
|
+
|
|
115
|
+
diff = (E_grid - E_ema) if (E_grid is not None and E_ema is not None) else None
|
|
116
|
+
rows.append({
|
|
117
|
+
"atom": atom,
|
|
118
|
+
"Ne": Z_TABLE.get(atom),
|
|
119
|
+
"E_grid_Ha": E_grid,
|
|
120
|
+
"grid_epoch": grid_epoch,
|
|
121
|
+
"E_ema_Ha": E_ema,
|
|
122
|
+
"ema_epoch": ema_epoch,
|
|
123
|
+
"diff_Ha": diff,
|
|
124
|
+
"Ne_int": Ne_int,
|
|
125
|
+
})
|
|
126
|
+
if E_grid is not None and E_ema is not None:
|
|
127
|
+
print(f" E_grid={E_grid:+.6f} E_ema={E_ema:+.6f} Δ={diff:+.6f} Ha")
|
|
128
|
+
print()
|
|
129
|
+
|
|
130
|
+
if not rows:
|
|
131
|
+
raise RuntimeError("No atoms processed — check --results_root / --method.")
|
|
132
|
+
|
|
133
|
+
df = pd.DataFrame(rows)
|
|
134
|
+
|
|
135
|
+
# ── print + save ──────────────────────────────────────────────────────────────
|
|
136
|
+
print("=" * 80)
|
|
137
|
+
print(f"{'atom':>4} {'Ne':>3} {'E_grid [Ha]':>15} {'E_ema [Ha]':>15} "
|
|
138
|
+
f"{'Δ(grid-ema)':>13} {'∫ρ':>8}")
|
|
139
|
+
print("-" * 80)
|
|
140
|
+
for _, r in df.iterrows():
|
|
141
|
+
eg = f"{r['E_grid_Ha']:+15.6f}" if pd.notna(r['E_grid_Ha']) else f"{'—':>15}"
|
|
142
|
+
em = f"{r['E_ema_Ha']:+15.6f}" if pd.notna(r['E_ema_Ha']) else f"{'—':>15}"
|
|
143
|
+
dd = f"{r['diff_Ha']:+13.6f}" if pd.notna(r['diff_Ha']) else f"{'—':>13}"
|
|
144
|
+
ni = f"{r['Ne_int']:8.4f}" if pd.notna(r['Ne_int']) else f"{'—':>8}"
|
|
145
|
+
ne = f"{int(r['Ne'])}" if pd.notna(r['Ne']) else "?"
|
|
146
|
+
print(f"{r['atom']:>4} {ne:>3} {eg} {em} {dd} {ni}")
|
|
147
|
+
print("=" * 80)
|
|
148
|
+
|
|
149
|
+
out_path = Path(args.out).resolve()
|
|
150
|
+
df.to_csv(out_path, index=False, float_format="%.8f")
|
|
151
|
+
print(f"\nSaved → {out_path}")
|
off/config/_config.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Global configuration module for storing runtime parameters and directories.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
class Config:
|
|
6
|
+
"""Global configuration class to store all runtime parameters."""
|
|
7
|
+
|
|
8
|
+
# Model parameters
|
|
9
|
+
mol_name = None
|
|
10
|
+
epochs = None
|
|
11
|
+
bs = None
|
|
12
|
+
hl = None
|
|
13
|
+
lr = None
|
|
14
|
+
prior = None
|
|
15
|
+
|
|
16
|
+
# Functionals
|
|
17
|
+
kin = None
|
|
18
|
+
nuc = None
|
|
19
|
+
hart = None
|
|
20
|
+
x = None
|
|
21
|
+
c = None
|
|
22
|
+
cc = None
|
|
23
|
+
|
|
24
|
+
# Training settings
|
|
25
|
+
sched = None
|
|
26
|
+
solver = None
|
|
27
|
+
ckpt_freq = None
|
|
28
|
+
|
|
29
|
+
# Directories (set during runtime)
|
|
30
|
+
results_dir = None
|
|
31
|
+
ckpt_dir = None
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def from_args(cls, args):
|
|
35
|
+
"""
|
|
36
|
+
Initialize configuration from argparse arguments.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
args: argparse.Namespace object containing parsed arguments
|
|
40
|
+
"""
|
|
41
|
+
# Model parameters
|
|
42
|
+
cls.mol_name = args.mol_name
|
|
43
|
+
cls.epochs = args.epochs
|
|
44
|
+
cls.bs = args.bs
|
|
45
|
+
cls.hl = args.hl
|
|
46
|
+
cls.lr = args.lr
|
|
47
|
+
cls.prior = args.prior
|
|
48
|
+
|
|
49
|
+
# Functionals
|
|
50
|
+
cls.kin = args.kin
|
|
51
|
+
cls.nuc = args.nuc
|
|
52
|
+
cls.hart = args.hart
|
|
53
|
+
cls.x = args.x
|
|
54
|
+
cls.c = args.c
|
|
55
|
+
cls.cc = args.cc
|
|
56
|
+
|
|
57
|
+
# Training settings
|
|
58
|
+
cls.sched = args.sched
|
|
59
|
+
cls.solver = args.solver
|
|
60
|
+
cls.ckpt_freq = args.ckpt_freq
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def set_directories(cls, results_dir, ckpt_dir):
|
|
64
|
+
"""
|
|
65
|
+
Set the results and checkpoint directories.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
results_dir: Path to results directory
|
|
69
|
+
ckpt_dir: Path to checkpoint directory
|
|
70
|
+
"""
|
|
71
|
+
cls.results_dir = results_dir
|
|
72
|
+
cls.ckpt_dir = ckpt_dir
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def get_model_params(cls):
|
|
76
|
+
"""Return dictionary of model parameters."""
|
|
77
|
+
return {
|
|
78
|
+
'mol_name': cls.mol_name,
|
|
79
|
+
'epochs': cls.epochs,
|
|
80
|
+
'bs': cls.bs,
|
|
81
|
+
'hl': cls.hl,
|
|
82
|
+
'lr': cls.lr,
|
|
83
|
+
'prior': cls.prior
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def get_functionals(cls):
|
|
88
|
+
"""Return dictionary of functional parameters."""
|
|
89
|
+
return {
|
|
90
|
+
'kin': cls.kin,
|
|
91
|
+
'nuc': cls.nuc,
|
|
92
|
+
'hart': cls.hart,
|
|
93
|
+
'x': cls.x,
|
|
94
|
+
'c': cls.c,
|
|
95
|
+
'cc': cls.cc
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def __repr__(cls):
|
|
100
|
+
"""String representation of configuration."""
|
|
101
|
+
return (
|
|
102
|
+
f"Config(\n"
|
|
103
|
+
f" Model: mol_name={cls.mol_name}, epochs={cls.epochs}, "
|
|
104
|
+
f"bs={cls.bs}, hl={cls.hl}, lr={cls.lr}, prior={cls.prior}\n"
|
|
105
|
+
f" Functionals: kin={cls.kin}, x={cls.x}, c={cls.c}, cc={cls.cc}\n"
|
|
106
|
+
f" Directories: results={cls.results_dir}, ckpt={cls.ckpt_dir}\n"
|
|
107
|
+
f")"
|
|
108
|
+
)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# MIT License
|
|
2
|
+
|
|
3
|
+
# Copyright (c) 2025 AlexandreDeCamargo
|
|
4
|
+
|
|
5
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
# in the Software without restriction, including without limitation the rights
|
|
8
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
# furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
# copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
# SOFTWARE.
|
|
22
|
+
|
|
23
|
+
__version__ = "0.1.0"
|
|
24
|
+
|
|
25
|
+
from .dft_distrax import (
|
|
26
|
+
DFTDistribution
|
|
27
|
+
)
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from jaxtyping import Array, Float
|
|
9
|
+
|
|
10
|
+
from pyscf import gto, dft
|
|
11
|
+
from pyscf.dft import numint
|
|
12
|
+
from pyscf.data.nist import BOHR
|
|
13
|
+
|
|
14
|
+
import distrax
|
|
15
|
+
from distrax import MultivariateNormalDiag, Categorical
|
|
16
|
+
from distrax._src.distributions.distribution import Array
|
|
17
|
+
|
|
18
|
+
Dtype = Any
|
|
19
|
+
|
|
20
|
+
class MixGaussian(distrax.Distribution):
|
|
21
|
+
def __init__(self, loc: Array , scale_diag: Array , probs: Array):
|
|
22
|
+
r"""
|
|
23
|
+
Creates a distribution with a mixture of Gaussian components.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
loc : Array
|
|
28
|
+
Molecular coordinates.
|
|
29
|
+
scale_diag : Array
|
|
30
|
+
Sigma matrix.
|
|
31
|
+
probs : Array
|
|
32
|
+
Mixture probabilities.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
self.loc = loc
|
|
36
|
+
self.scale_diag = scale_diag
|
|
37
|
+
self.probs = probs
|
|
38
|
+
self.mixture_dist = Categorical(probs=probs)
|
|
39
|
+
self.components_dist = MultivariateNormalDiag(loc=self.loc,scale_diag=self.scale_diag)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@jax.jit
|
|
43
|
+
def prob(self, value: Array) -> jax.Array:
|
|
44
|
+
"""
|
|
45
|
+
Calculates the probability of an event.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
value : Array
|
|
50
|
+
An event.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
jax.Array
|
|
55
|
+
The probability of the event.
|
|
56
|
+
"""
|
|
57
|
+
log_px_components_dist = self.components_dist.log_prob(value).T
|
|
58
|
+
px_components_dist = jnp.exp(log_px_components_dist)
|
|
59
|
+
px = px_components_dist@self.probs[:,None]
|
|
60
|
+
return px
|
|
61
|
+
|
|
62
|
+
@jax.jit
|
|
63
|
+
def log_prob(self, value: Array) -> jax.Array:
|
|
64
|
+
"""
|
|
65
|
+
Calculates the log probability of an event.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
value : Array
|
|
70
|
+
An event.
|
|
71
|
+
|
|
72
|
+
Returns
|
|
73
|
+
-------
|
|
74
|
+
jax.Array
|
|
75
|
+
The log probability of the event.
|
|
76
|
+
"""
|
|
77
|
+
return jnp.log(self.prob(value))
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _sample_n(self, key: jax.random.PRNGKey, n: int) -> jax.Array:
|
|
81
|
+
"""
|
|
82
|
+
Returns 'n' samples.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
key : PRNGKey
|
|
87
|
+
Random key.
|
|
88
|
+
n : int
|
|
89
|
+
Number of samples to generate.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
jax.Array
|
|
94
|
+
An array of 'n' samples.
|
|
95
|
+
"""
|
|
96
|
+
_, key_mixt, key_comp = jax.random.split(key,3)
|
|
97
|
+
samples_mixt = self.mixture_dist._sample_n(key_mixt,n)
|
|
98
|
+
samples_mixt_one_hot = jax.nn.one_hot(samples_mixt,self.probs.shape[-1])
|
|
99
|
+
samples_comp = self.components_dist.sample(seed=key_comp, sample_shape=n)
|
|
100
|
+
samples_comp = jnp.squeeze(samples_comp,axis=-2)
|
|
101
|
+
|
|
102
|
+
samples = jnp.einsum('ijl,ij->il',samples_comp,samples_mixt_one_hot)
|
|
103
|
+
return samples
|
|
104
|
+
|
|
105
|
+
def event_shape(self):
|
|
106
|
+
pass
|
|
107
|
+
#6-31G(d,p)
|
|
108
|
+
@jax.jit
|
|
109
|
+
def score(self,values):
|
|
110
|
+
return jax.vmap(jax.grad(lambda x:
|
|
111
|
+
self.log_prob(x).sum()))(values)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class DFTDistribution(distrax.Distribution):
|
|
115
|
+
|
|
116
|
+
def __init__(self, atoms: Any, geometry: Any, basis_set: str = '6-31G(d,p)', exc: str = 'b3lyp', dtype_: Dtype = jnp.float32):
|
|
117
|
+
|
|
118
|
+
self.atoms = atoms
|
|
119
|
+
self.geometry = geometry
|
|
120
|
+
self.basis_set = basis_set
|
|
121
|
+
self.exc = exc
|
|
122
|
+
self.dtype_ = dtype_
|
|
123
|
+
|
|
124
|
+
self._grid_level = 5 # change this for larger molecules
|
|
125
|
+
self.mol = self._mol()
|
|
126
|
+
self.grids = dft.gen_grid.Grids(self.mol)
|
|
127
|
+
self.grids.level = self._grid_level
|
|
128
|
+
self.grids.build()
|
|
129
|
+
self.Ne = self.mol.tot_electrons()
|
|
130
|
+
self.dft, self.rdm1 = self._dft()
|
|
131
|
+
|
|
132
|
+
self.coords = jnp.array(self.grids.coords)
|
|
133
|
+
self.weights = jnp.array(self.grids.weights)
|
|
134
|
+
|
|
135
|
+
def get_molecule(self):
|
|
136
|
+
m_ = ""
|
|
137
|
+
for a, xi in zip(self.atoms, self.geometry):
|
|
138
|
+
print(a, xi)
|
|
139
|
+
mi_ = f'{a} '
|
|
140
|
+
mxi_ = ""
|
|
141
|
+
for xii in xi:
|
|
142
|
+
mxi_ += str(xii) + " "
|
|
143
|
+
mi_ += mxi_ + '\n'
|
|
144
|
+
m_ += mi_
|
|
145
|
+
return m_
|
|
146
|
+
|
|
147
|
+
def _mol(self):
|
|
148
|
+
atoms = self.get_molecule()
|
|
149
|
+
#mol = gto.M(atom=atoms, basis=self.basis_set,
|
|
150
|
+
# unit='B',spin=1) # , symmetry = True)
|
|
151
|
+
spin_dict = {
|
|
152
|
+
'H': 1, 'He': 0,
|
|
153
|
+
'Li': 1, 'Be': 0, 'B': 1, 'C': 2, 'N': 3, 'O': 2, 'F': 1, 'Ne': 0
|
|
154
|
+
}
|
|
155
|
+
# If single atom, use spin_dict; else, compute total spin (more complex)
|
|
156
|
+
if len(self.atoms) == 1:
|
|
157
|
+
spin = spin_dict.get(self.atoms[0], 0) # Default to 0 if not found
|
|
158
|
+
mol = gto.M(atom=atoms, basis=self.basis_set, unit='B', spin=spin)
|
|
159
|
+
else:
|
|
160
|
+
#atoms = self.get_molecule()
|
|
161
|
+
mol = gto.M(atom=atoms, basis=self.basis_set,
|
|
162
|
+
unit='B')
|
|
163
|
+
return mol
|
|
164
|
+
|
|
165
|
+
def _dft(self):
|
|
166
|
+
|
|
167
|
+
mf_hf = dft.RKS(self.mol)
|
|
168
|
+
LDA_X = 1.
|
|
169
|
+
B88_X = 1.
|
|
170
|
+
VWN_C = 1.
|
|
171
|
+
|
|
172
|
+
mf_hf.xc = f'{LDA_X:} * LDA + {B88_X:} * B88, {VWN_C:} * VWN'
|
|
173
|
+
# mf_hf.xc = f'{LDA_X:} * LDA'
|
|
174
|
+
mf_hf = mf_hf.newton() # second-order algortihm
|
|
175
|
+
mf_hf.kernel()
|
|
176
|
+
dm = mf_hf.make_rdm1()
|
|
177
|
+
return mf_hf, dm
|
|
178
|
+
|
|
179
|
+
@partial(jax.custom_vjp, nondiff_argnums=(0,))
|
|
180
|
+
def prob(self, value):
|
|
181
|
+
coords = np.array(value)
|
|
182
|
+
ao_value = numint.eval_ao(self.mol, coords, deriv=1)
|
|
183
|
+
rho_and_grho = numint.eval_rho(
|
|
184
|
+
self.mol, ao_value, self.rdm1, xctype='GGA')
|
|
185
|
+
rho = jnp.asarray(rho_and_grho[0], dtype=self.dtype_) # /self.Ne
|
|
186
|
+
return rho[:, None] # includes batch dimension
|
|
187
|
+
|
|
188
|
+
def prob_fwd(self, value):
|
|
189
|
+
coords = value
|
|
190
|
+
ao_value = numint.eval_ao(self.mol, coords, deriv=1)
|
|
191
|
+
rho_and_grho = numint.eval_rho(
|
|
192
|
+
self.mol, ao_value, self.rdm1, xctype='GGA')
|
|
193
|
+
rho = jnp.array(rho_and_grho[0], dtype=self.dtype_)/self.Ne
|
|
194
|
+
drho_dx = jnp.array(
|
|
195
|
+
rho_and_grho[1:, :].T, dtype=self.dtype_)/self.Ne
|
|
196
|
+
return rho[:, None], (drho_dx)
|
|
197
|
+
|
|
198
|
+
def prob_bwd(self, res, g):
|
|
199
|
+
drho_dx = res
|
|
200
|
+
return (drho_dx*g,)
|
|
201
|
+
# return (jnp.matmul(g.T, drho_dx),)
|
|
202
|
+
|
|
203
|
+
prob.defvjp(prob_fwd, prob_bwd)
|
|
204
|
+
|
|
205
|
+
def log_prob(self, value: Any) -> Array:
|
|
206
|
+
pass
|
|
207
|
+
|
|
208
|
+
def _sample_n(self, key, n):
|
|
209
|
+
pass
|
|
210
|
+
|
|
211
|
+
def event_shape(self):
|
|
212
|
+
pass
|
|
213
|
+
|
|
214
|
+
def _sample_n_and_log_prob(self, key, n):
|
|
215
|
+
pass
|
|
216
|
+
|
off/flow/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# MIT License
|
|
2
|
+
|
|
3
|
+
# Copyright (c) 2025 AlexandreDeCamargo
|
|
4
|
+
|
|
5
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
# in the Software without restriction, including without limitation the rights
|
|
8
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
# furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
# copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
# SOFTWARE.
|
|
22
|
+
|
|
23
|
+
__version__ = "0.1.0"
|
|
24
|
+
|
|
25
|
+
from .equiv_flows import (
|
|
26
|
+
_Flow,
|
|
27
|
+
Radial_MLP,
|
|
28
|
+
CNF
|
|
29
|
+
)
|
off/flow/equiv_flows.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
from jax import lax, numpy as jnp
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
from jaxtyping import Array, Float, Int, Scalar
|
|
5
|
+
|
|
6
|
+
class _Flow(eqx.Module):
|
|
7
|
+
linear_in: eqx.nn.Linear
|
|
8
|
+
linear_out: eqx.nn.Linear
|
|
9
|
+
blocks: list[eqx.nn.Linear]
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
din: Int[Scalar, ""],
|
|
14
|
+
dim: Int[Scalar, ""],
|
|
15
|
+
key: jax.random.PRNGKey
|
|
16
|
+
):
|
|
17
|
+
keys = jax.random.split(key, 5)
|
|
18
|
+
|
|
19
|
+
self.linear_in = eqx.nn.Linear(din + 1, dim, key=keys[0])
|
|
20
|
+
self.blocks = [eqx.nn.Linear(dim, dim, key=k) for k in jax.random.split(keys[1], 3)]
|
|
21
|
+
self.linear_out = eqx.nn.Linear(dim, din, key=keys[2])
|
|
22
|
+
|
|
23
|
+
def __call__(
|
|
24
|
+
self,
|
|
25
|
+
t,
|
|
26
|
+
xi_norm,
|
|
27
|
+
zi_one_hot
|
|
28
|
+
):
|
|
29
|
+
t = jnp.reshape(t, (1,))
|
|
30
|
+
xi_norm = jnp.reshape(xi_norm, (1,))
|
|
31
|
+
x = jnp.hstack((t, xi_norm, zi_one_hot))
|
|
32
|
+
x = self.linear_in(x)
|
|
33
|
+
x = jnp.tanh(x)
|
|
34
|
+
for block in self.blocks:
|
|
35
|
+
x = block(x)
|
|
36
|
+
x = jnp.tanh(x)
|
|
37
|
+
x = self.linear_out(x)
|
|
38
|
+
return x
|
|
39
|
+
|
|
40
|
+
class Radial_MLP(eqx.Module):
|
|
41
|
+
xyz_nuclei : Array
|
|
42
|
+
z_one_hot : Array
|
|
43
|
+
flow: _Flow
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
dim: int,
|
|
48
|
+
key,
|
|
49
|
+
xyz_nuclei,
|
|
50
|
+
z_one_hot
|
|
51
|
+
):
|
|
52
|
+
din_flow = 1 + z_one_hot.shape[-1]
|
|
53
|
+
self.xyz_nuclei = xyz_nuclei[:,None,:]
|
|
54
|
+
self.z_one_hot = z_one_hot
|
|
55
|
+
self.flow = _Flow(din_flow, dim, key=key)
|
|
56
|
+
|
|
57
|
+
def __call__(self, states, t):
|
|
58
|
+
vmap_radial_block = jax.vmap(self.flow, in_axes=(None, 0, 0))
|
|
59
|
+
z = lax.expand_dims(states, dimensions=(0,)) - self.xyz_nuclei
|
|
60
|
+
z_norm = jnp.linalg.norm(z, axis=-1)
|
|
61
|
+
x = vmap_radial_block(t, z_norm, self.z_one_hot)
|
|
62
|
+
x = jnp.einsum('ijk,ij->k', z, x)
|
|
63
|
+
return x
|
|
64
|
+
|
|
65
|
+
class CNF(eqx.Module):
|
|
66
|
+
flow: Radial_MLP
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
din: Int[Scalar, ""],
|
|
71
|
+
dim: Int[Scalar, ""],
|
|
72
|
+
mu: any,
|
|
73
|
+
one_hot:any,
|
|
74
|
+
key: jax.random.PRNGKey
|
|
75
|
+
):
|
|
76
|
+
self.flow = Radial_MLP(dim, key, mu, one_hot)
|
|
77
|
+
|
|
78
|
+
def __call__(self, states, t):
|
|
79
|
+
data_dim = 3 #Hardcoded for 3 dimensions
|
|
80
|
+
|
|
81
|
+
@jax.jit
|
|
82
|
+
def _f_ode(self, states, t):
|
|
83
|
+
x, log_px, score = states[:data_dim], states[data_dim:data_dim+1], states[data_dim+1:]
|
|
84
|
+
jac = jax.jacrev(self.flow)(x, t)
|
|
85
|
+
dtrJ = -1. * jnp.trace(jac)
|
|
86
|
+
dz = self.flow(x, t)
|
|
87
|
+
return dz, dtrJ
|
|
88
|
+
|
|
89
|
+
@jax.jit
|
|
90
|
+
def f_ode(self,states,t):
|
|
91
|
+
state, score = states[:-data_dim], states[-data_dim:]
|
|
92
|
+
dx_and_dlopz, _f_vjp = jax.vjp(
|
|
93
|
+
lambda state: _f_ode(self,state, t), state)
|
|
94
|
+
dx, dlopz = dx_and_dlopz
|
|
95
|
+
(vjp_all,) = _f_vjp((score, -1.))
|
|
96
|
+
score_vjp, grad_div = vjp_all[:-1], vjp_all[-1]
|
|
97
|
+
dscore = -score_vjp + grad_div
|
|
98
|
+
return jnp.append(jnp.append(dx, dlopz), dscore)
|
|
99
|
+
return f_ode(self, states, t)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# MIT License
|
|
2
|
+
|
|
3
|
+
# Copyright (c) 2025 AlexandreDeCamargo
|
|
4
|
+
|
|
5
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
# in the Software without restriction, including without limitation the rights
|
|
8
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
# furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
# copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
# SOFTWARE.
|
|
22
|
+
|
|
23
|
+
__version__ = "0.1.0"
|
|
24
|
+
|
|
25
|
+
from .functional import (
|
|
26
|
+
Functional,
|
|
27
|
+
CompositeFunctional,
|
|
28
|
+
EnergyFunctional,
|
|
29
|
+
FunctionalInputs,
|
|
30
|
+
)
|
|
31
|
+
from .kinetic import tf, weizsacker, tf_weizsacker
|
|
32
|
+
from .exchange_correlation import lda, b88, vwn, pw92, lda_b88
|
|
33
|
+
from .external import NuclearPotential
|
|
34
|
+
from .hartree import CoulombPotential, CoulombPotential_
|
|
35
|
+
from .core_correction import KatoCondition, HutcheonCuspCondition
|