off 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
off/__init__.py ADDED
@@ -0,0 +1,23 @@
1
+ """OFF — orbital-free DFT with continuous normalizing flows."""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ from .functionals.functional import (
6
+ Functional,
7
+ CompositeFunctional,
8
+ EnergyFunctional,
9
+ FunctionalInputs,
10
+ unit_coefficient,
11
+ )
12
+ from .functionals import (
13
+ tf, weizsacker, tf_weizsacker,
14
+ lda, b88, vwn, pw92, lda_b88,
15
+ NuclearPotential, CoulombPotential, CoulombPotential_,
16
+ KatoCondition, HutcheonCuspCondition,
17
+ )
18
+ from .quadrature import (
19
+ getGrid, get_grid, build_grid,
20
+ grid_energy, grid_energy_from_checkpoint,
21
+ )
22
+ from .train.loss import build_energy_functional, create_loss_function
23
+ from .train.loop import training
off/atom_energies.py ADDED
@@ -0,0 +1,151 @@
1
+ """
2
+ Grid-integrated total energy vs EMA energy, side by side, for a set of atoms.
3
+
4
+ For each atom it:
5
+ 1. grid-integrates the LAST checkpoint via ``quadrature.grid_energy_from_checkpoint``
6
+ -> E_grid (= E_total; for a single atom E_NN = 0, so this is the electronic
7
+ total energy).
8
+ 2. averages the last --window rows of training_metrics_ema.csv (E + CC) -> E_ema.
9
+ 3. writes both numbers side by side to a CSV (+ prints a table).
10
+
11
+ Because E_NN = 0 for an isolated atom, E_grid and E_ema are the *same* physical
12
+ quantity (total atomic energy); the only difference is grid quadrature vs the
13
+ Monte-Carlo EMA estimate during training.
14
+
15
+ Directory layout assumed (same as main.py):
16
+ {results_root}/{atom}/{method}/bl_0.0000/
17
+ Checkpoints/checkpoint_*.eqx
18
+ training_metrics_ema.csv
19
+ job_params.json
20
+
21
+ Usage
22
+ -----
23
+ python atom_energies.py --method tf_w_lam0.2_none_lda_none_dopri8_db_sir_sched_MIX
24
+ python atom_energies.py --method <tag> --atoms B Be C F H He Li N Ne O \
25
+ --window 1000 --recompute --results_root /path/to/Results
26
+ """
27
+
28
+ import sys, os
29
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
30
+
31
+ import argparse
32
+ from pathlib import Path
33
+
34
+ import pandas as pd
35
+
36
+ from quadrature import grid_energy_from_checkpoint
37
+
38
+ # atomic number == electron count for a neutral atom
39
+ Z_TABLE = {"H": 1, "He": 2, "Li": 3, "Be": 4, "B": 5,
40
+ "C": 6, "N": 7, "O": 8, "F": 9, "Ne": 10}
41
+
42
+ # ── CLI ───────────────────────────────────────────────────────────────────────
43
+ parser = argparse.ArgumentParser(
44
+ description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
45
+ parser.add_argument("--method", type=str, required=True,
46
+ help="Method directory name, e.g. "
47
+ "tf_w_lam0.2_none_lda_none_dopri8_db_sir_sched_MIX")
48
+ parser.add_argument("--atoms", type=str, nargs="+",
49
+ default=["B", "Be", "C", "F", "H", "He", "Li", "N", "Ne", "O"],
50
+ help="Atoms to process (default: the full set).")
51
+ parser.add_argument("--results_root", type=str, default="Results",
52
+ help="Root dir holding {atom}/{method}/bl_0.0000 (default: Results)")
53
+ parser.add_argument("--window", type=int, default=1000,
54
+ help="Average the last N rows of training_metrics_ema.csv (default: 1000)")
55
+ parser.add_argument("--bs", type=int, default=256, help="Grid chunk size")
56
+ parser.add_argument("--grid_level", type=int, default=3, help="PySCF grid level")
57
+ parser.add_argument("--recompute", action="store_true",
58
+ help="Re-run grid integration even if energy_summary.json is cached")
59
+ parser.add_argument("--out", type=str, default="atom_energies.csv",
60
+ help="Output CSV path (default: atom_energies.csv)")
61
+ args = parser.parse_args()
62
+
63
+
64
+ def read_last_ema(bl_dir: Path, window: int):
65
+ """Mean of the last `window` rows of training_metrics_ema.csv.
66
+ E = E + CC (no nuclear repulsion; for an atom this is the total energy)."""
67
+ csv = bl_dir / "training_metrics_ema.csv"
68
+ if not csv.exists():
69
+ return None, None
70
+ try:
71
+ df = pd.read_csv(csv)
72
+ except pd.errors.EmptyDataError:
73
+ return None, None
74
+ if df.empty:
75
+ return None, None
76
+ tail = df.tail(window)
77
+ E = float(tail["E"].mean())
78
+ if "CC" in tail.columns:
79
+ E += float(tail["CC"].mean())
80
+ epoch = int(df.iloc[-1]["epoch"])
81
+ return E, epoch
82
+
83
+
84
+ # ── main loop ─────────────────────────────────────────────────────────────────
85
+ root = Path(args.results_root).resolve()
86
+ print(f"Results root : {root}")
87
+ print(f"Method : {args.method}")
88
+ print(f"EMA window : last {args.window} rows\n")
89
+
90
+ rows = []
91
+ for atom in args.atoms:
92
+ atom_dir = root / atom / args.method / "bl_0.0000"
93
+ print(f"[{atom}] {atom_dir}")
94
+
95
+ if not (atom_dir / "job_params.json").exists():
96
+ print(" SKIP — no job_params.json (directory missing?)\n")
97
+ continue
98
+
99
+ # grid integration of the last checkpoint (via the OFF quadrature module)
100
+ try:
101
+ data = grid_energy_from_checkpoint(
102
+ atom_dir, grid_level=args.grid_level, chunk=args.bs, recompute=args.recompute)
103
+ E_grid = data["E_total"]
104
+ grid_epoch = data["epoch"]
105
+ Ne_int = data["Ne_integral"]
106
+ except Exception as e:
107
+ print(f" grid integration FAILED: {e}")
108
+ E_grid = grid_epoch = Ne_int = None
109
+
110
+ # EMA mean of last `window` rows
111
+ E_ema, ema_epoch = read_last_ema(atom_dir, args.window)
112
+ if E_ema is None:
113
+ print(" no training_metrics_ema.csv")
114
+
115
+ diff = (E_grid - E_ema) if (E_grid is not None and E_ema is not None) else None
116
+ rows.append({
117
+ "atom": atom,
118
+ "Ne": Z_TABLE.get(atom),
119
+ "E_grid_Ha": E_grid,
120
+ "grid_epoch": grid_epoch,
121
+ "E_ema_Ha": E_ema,
122
+ "ema_epoch": ema_epoch,
123
+ "diff_Ha": diff,
124
+ "Ne_int": Ne_int,
125
+ })
126
+ if E_grid is not None and E_ema is not None:
127
+ print(f" E_grid={E_grid:+.6f} E_ema={E_ema:+.6f} Δ={diff:+.6f} Ha")
128
+ print()
129
+
130
+ if not rows:
131
+ raise RuntimeError("No atoms processed — check --results_root / --method.")
132
+
133
+ df = pd.DataFrame(rows)
134
+
135
+ # ── print + save ──────────────────────────────────────────────────────────────
136
+ print("=" * 80)
137
+ print(f"{'atom':>4} {'Ne':>3} {'E_grid [Ha]':>15} {'E_ema [Ha]':>15} "
138
+ f"{'Δ(grid-ema)':>13} {'∫ρ':>8}")
139
+ print("-" * 80)
140
+ for _, r in df.iterrows():
141
+ eg = f"{r['E_grid_Ha']:+15.6f}" if pd.notna(r['E_grid_Ha']) else f"{'—':>15}"
142
+ em = f"{r['E_ema_Ha']:+15.6f}" if pd.notna(r['E_ema_Ha']) else f"{'—':>15}"
143
+ dd = f"{r['diff_Ha']:+13.6f}" if pd.notna(r['diff_Ha']) else f"{'—':>13}"
144
+ ni = f"{r['Ne_int']:8.4f}" if pd.notna(r['Ne_int']) else f"{'—':>8}"
145
+ ne = f"{int(r['Ne'])}" if pd.notna(r['Ne']) else "?"
146
+ print(f"{r['atom']:>4} {ne:>3} {eg} {em} {dd} {ni}")
147
+ print("=" * 80)
148
+
149
+ out_path = Path(args.out).resolve()
150
+ df.to_csv(out_path, index=False, float_format="%.8f")
151
+ print(f"\nSaved → {out_path}")
off/config/_config.py ADDED
@@ -0,0 +1,108 @@
1
+ """
2
+ Global configuration module for storing runtime parameters and directories.
3
+ """
4
+
5
+ class Config:
6
+ """Global configuration class to store all runtime parameters."""
7
+
8
+ # Model parameters
9
+ mol_name = None
10
+ epochs = None
11
+ bs = None
12
+ hl = None
13
+ lr = None
14
+ prior = None
15
+
16
+ # Functionals
17
+ kin = None
18
+ nuc = None
19
+ hart = None
20
+ x = None
21
+ c = None
22
+ cc = None
23
+
24
+ # Training settings
25
+ sched = None
26
+ solver = None
27
+ ckpt_freq = None
28
+
29
+ # Directories (set during runtime)
30
+ results_dir = None
31
+ ckpt_dir = None
32
+
33
+ @classmethod
34
+ def from_args(cls, args):
35
+ """
36
+ Initialize configuration from argparse arguments.
37
+
38
+ Args:
39
+ args: argparse.Namespace object containing parsed arguments
40
+ """
41
+ # Model parameters
42
+ cls.mol_name = args.mol_name
43
+ cls.epochs = args.epochs
44
+ cls.bs = args.bs
45
+ cls.hl = args.hl
46
+ cls.lr = args.lr
47
+ cls.prior = args.prior
48
+
49
+ # Functionals
50
+ cls.kin = args.kin
51
+ cls.nuc = args.nuc
52
+ cls.hart = args.hart
53
+ cls.x = args.x
54
+ cls.c = args.c
55
+ cls.cc = args.cc
56
+
57
+ # Training settings
58
+ cls.sched = args.sched
59
+ cls.solver = args.solver
60
+ cls.ckpt_freq = args.ckpt_freq
61
+
62
+ @classmethod
63
+ def set_directories(cls, results_dir, ckpt_dir):
64
+ """
65
+ Set the results and checkpoint directories.
66
+
67
+ Args:
68
+ results_dir: Path to results directory
69
+ ckpt_dir: Path to checkpoint directory
70
+ """
71
+ cls.results_dir = results_dir
72
+ cls.ckpt_dir = ckpt_dir
73
+
74
+ @classmethod
75
+ def get_model_params(cls):
76
+ """Return dictionary of model parameters."""
77
+ return {
78
+ 'mol_name': cls.mol_name,
79
+ 'epochs': cls.epochs,
80
+ 'bs': cls.bs,
81
+ 'hl': cls.hl,
82
+ 'lr': cls.lr,
83
+ 'prior': cls.prior
84
+ }
85
+
86
+ @classmethod
87
+ def get_functionals(cls):
88
+ """Return dictionary of functional parameters."""
89
+ return {
90
+ 'kin': cls.kin,
91
+ 'nuc': cls.nuc,
92
+ 'hart': cls.hart,
93
+ 'x': cls.x,
94
+ 'c': cls.c,
95
+ 'cc': cls.cc
96
+ }
97
+
98
+ @classmethod
99
+ def __repr__(cls):
100
+ """String representation of configuration."""
101
+ return (
102
+ f"Config(\n"
103
+ f" Model: mol_name={cls.mol_name}, epochs={cls.epochs}, "
104
+ f"bs={cls.bs}, hl={cls.hl}, lr={cls.lr}, prior={cls.prior}\n"
105
+ f" Functionals: kin={cls.kin}, x={cls.x}, c={cls.c}, cc={cls.cc}\n"
106
+ f" Directories: results={cls.results_dir}, ckpt={cls.ckpt_dir}\n"
107
+ f")"
108
+ )
@@ -0,0 +1,27 @@
1
+ # MIT License
2
+
3
+ # Copyright (c) 2025 AlexandreDeCamargo
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ __version__ = "0.1.0"
24
+
25
+ from .dft_distrax import (
26
+ DFTDistribution
27
+ )
@@ -0,0 +1,216 @@
1
+ from functools import partial
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from jaxtyping import Array, Float
9
+
10
+ from pyscf import gto, dft
11
+ from pyscf.dft import numint
12
+ from pyscf.data.nist import BOHR
13
+
14
+ import distrax
15
+ from distrax import MultivariateNormalDiag, Categorical
16
+ from distrax._src.distributions.distribution import Array
17
+
18
+ Dtype = Any
19
+
20
+ class MixGaussian(distrax.Distribution):
21
+ def __init__(self, loc: Array , scale_diag: Array , probs: Array):
22
+ r"""
23
+ Creates a distribution with a mixture of Gaussian components.
24
+
25
+ Parameters
26
+ ----------
27
+ loc : Array
28
+ Molecular coordinates.
29
+ scale_diag : Array
30
+ Sigma matrix.
31
+ probs : Array
32
+ Mixture probabilities.
33
+ """
34
+
35
+ self.loc = loc
36
+ self.scale_diag = scale_diag
37
+ self.probs = probs
38
+ self.mixture_dist = Categorical(probs=probs)
39
+ self.components_dist = MultivariateNormalDiag(loc=self.loc,scale_diag=self.scale_diag)
40
+
41
+
42
+ @jax.jit
43
+ def prob(self, value: Array) -> jax.Array:
44
+ """
45
+ Calculates the probability of an event.
46
+
47
+ Parameters
48
+ ----------
49
+ value : Array
50
+ An event.
51
+
52
+ Returns
53
+ -------
54
+ jax.Array
55
+ The probability of the event.
56
+ """
57
+ log_px_components_dist = self.components_dist.log_prob(value).T
58
+ px_components_dist = jnp.exp(log_px_components_dist)
59
+ px = px_components_dist@self.probs[:,None]
60
+ return px
61
+
62
+ @jax.jit
63
+ def log_prob(self, value: Array) -> jax.Array:
64
+ """
65
+ Calculates the log probability of an event.
66
+
67
+ Parameters
68
+ ----------
69
+ value : Array
70
+ An event.
71
+
72
+ Returns
73
+ -------
74
+ jax.Array
75
+ The log probability of the event.
76
+ """
77
+ return jnp.log(self.prob(value))
78
+
79
+
80
+ def _sample_n(self, key: jax.random.PRNGKey, n: int) -> jax.Array:
81
+ """
82
+ Returns 'n' samples.
83
+
84
+ Parameters
85
+ ----------
86
+ key : PRNGKey
87
+ Random key.
88
+ n : int
89
+ Number of samples to generate.
90
+
91
+ Returns
92
+ -------
93
+ jax.Array
94
+ An array of 'n' samples.
95
+ """
96
+ _, key_mixt, key_comp = jax.random.split(key,3)
97
+ samples_mixt = self.mixture_dist._sample_n(key_mixt,n)
98
+ samples_mixt_one_hot = jax.nn.one_hot(samples_mixt,self.probs.shape[-1])
99
+ samples_comp = self.components_dist.sample(seed=key_comp, sample_shape=n)
100
+ samples_comp = jnp.squeeze(samples_comp,axis=-2)
101
+
102
+ samples = jnp.einsum('ijl,ij->il',samples_comp,samples_mixt_one_hot)
103
+ return samples
104
+
105
+ def event_shape(self):
106
+ pass
107
+ #6-31G(d,p)
108
+ @jax.jit
109
+ def score(self,values):
110
+ return jax.vmap(jax.grad(lambda x:
111
+ self.log_prob(x).sum()))(values)
112
+
113
+
114
+ class DFTDistribution(distrax.Distribution):
115
+
116
+ def __init__(self, atoms: Any, geometry: Any, basis_set: str = '6-31G(d,p)', exc: str = 'b3lyp', dtype_: Dtype = jnp.float32):
117
+
118
+ self.atoms = atoms
119
+ self.geometry = geometry
120
+ self.basis_set = basis_set
121
+ self.exc = exc
122
+ self.dtype_ = dtype_
123
+
124
+ self._grid_level = 5 # change this for larger molecules
125
+ self.mol = self._mol()
126
+ self.grids = dft.gen_grid.Grids(self.mol)
127
+ self.grids.level = self._grid_level
128
+ self.grids.build()
129
+ self.Ne = self.mol.tot_electrons()
130
+ self.dft, self.rdm1 = self._dft()
131
+
132
+ self.coords = jnp.array(self.grids.coords)
133
+ self.weights = jnp.array(self.grids.weights)
134
+
135
+ def get_molecule(self):
136
+ m_ = ""
137
+ for a, xi in zip(self.atoms, self.geometry):
138
+ print(a, xi)
139
+ mi_ = f'{a} '
140
+ mxi_ = ""
141
+ for xii in xi:
142
+ mxi_ += str(xii) + " "
143
+ mi_ += mxi_ + '\n'
144
+ m_ += mi_
145
+ return m_
146
+
147
+ def _mol(self):
148
+ atoms = self.get_molecule()
149
+ #mol = gto.M(atom=atoms, basis=self.basis_set,
150
+ # unit='B',spin=1) # , symmetry = True)
151
+ spin_dict = {
152
+ 'H': 1, 'He': 0,
153
+ 'Li': 1, 'Be': 0, 'B': 1, 'C': 2, 'N': 3, 'O': 2, 'F': 1, 'Ne': 0
154
+ }
155
+ # If single atom, use spin_dict; else, compute total spin (more complex)
156
+ if len(self.atoms) == 1:
157
+ spin = spin_dict.get(self.atoms[0], 0) # Default to 0 if not found
158
+ mol = gto.M(atom=atoms, basis=self.basis_set, unit='B', spin=spin)
159
+ else:
160
+ #atoms = self.get_molecule()
161
+ mol = gto.M(atom=atoms, basis=self.basis_set,
162
+ unit='B')
163
+ return mol
164
+
165
+ def _dft(self):
166
+
167
+ mf_hf = dft.RKS(self.mol)
168
+ LDA_X = 1.
169
+ B88_X = 1.
170
+ VWN_C = 1.
171
+
172
+ mf_hf.xc = f'{LDA_X:} * LDA + {B88_X:} * B88, {VWN_C:} * VWN'
173
+ # mf_hf.xc = f'{LDA_X:} * LDA'
174
+ mf_hf = mf_hf.newton() # second-order algortihm
175
+ mf_hf.kernel()
176
+ dm = mf_hf.make_rdm1()
177
+ return mf_hf, dm
178
+
179
+ @partial(jax.custom_vjp, nondiff_argnums=(0,))
180
+ def prob(self, value):
181
+ coords = np.array(value)
182
+ ao_value = numint.eval_ao(self.mol, coords, deriv=1)
183
+ rho_and_grho = numint.eval_rho(
184
+ self.mol, ao_value, self.rdm1, xctype='GGA')
185
+ rho = jnp.asarray(rho_and_grho[0], dtype=self.dtype_) # /self.Ne
186
+ return rho[:, None] # includes batch dimension
187
+
188
+ def prob_fwd(self, value):
189
+ coords = value
190
+ ao_value = numint.eval_ao(self.mol, coords, deriv=1)
191
+ rho_and_grho = numint.eval_rho(
192
+ self.mol, ao_value, self.rdm1, xctype='GGA')
193
+ rho = jnp.array(rho_and_grho[0], dtype=self.dtype_)/self.Ne
194
+ drho_dx = jnp.array(
195
+ rho_and_grho[1:, :].T, dtype=self.dtype_)/self.Ne
196
+ return rho[:, None], (drho_dx)
197
+
198
+ def prob_bwd(self, res, g):
199
+ drho_dx = res
200
+ return (drho_dx*g,)
201
+ # return (jnp.matmul(g.T, drho_dx),)
202
+
203
+ prob.defvjp(prob_fwd, prob_bwd)
204
+
205
+ def log_prob(self, value: Any) -> Array:
206
+ pass
207
+
208
+ def _sample_n(self, key, n):
209
+ pass
210
+
211
+ def event_shape(self):
212
+ pass
213
+
214
+ def _sample_n_and_log_prob(self, key, n):
215
+ pass
216
+
off/flow/__init__.py ADDED
@@ -0,0 +1,29 @@
1
+ # MIT License
2
+
3
+ # Copyright (c) 2025 AlexandreDeCamargo
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ __version__ = "0.1.0"
24
+
25
+ from .equiv_flows import (
26
+ _Flow,
27
+ Radial_MLP,
28
+ CNF
29
+ )
@@ -0,0 +1,99 @@
1
+ import jax
2
+ from jax import lax, numpy as jnp
3
+ import equinox as eqx
4
+ from jaxtyping import Array, Float, Int, Scalar
5
+
6
+ class _Flow(eqx.Module):
7
+ linear_in: eqx.nn.Linear
8
+ linear_out: eqx.nn.Linear
9
+ blocks: list[eqx.nn.Linear]
10
+
11
+ def __init__(
12
+ self,
13
+ din: Int[Scalar, ""],
14
+ dim: Int[Scalar, ""],
15
+ key: jax.random.PRNGKey
16
+ ):
17
+ keys = jax.random.split(key, 5)
18
+
19
+ self.linear_in = eqx.nn.Linear(din + 1, dim, key=keys[0])
20
+ self.blocks = [eqx.nn.Linear(dim, dim, key=k) for k in jax.random.split(keys[1], 3)]
21
+ self.linear_out = eqx.nn.Linear(dim, din, key=keys[2])
22
+
23
+ def __call__(
24
+ self,
25
+ t,
26
+ xi_norm,
27
+ zi_one_hot
28
+ ):
29
+ t = jnp.reshape(t, (1,))
30
+ xi_norm = jnp.reshape(xi_norm, (1,))
31
+ x = jnp.hstack((t, xi_norm, zi_one_hot))
32
+ x = self.linear_in(x)
33
+ x = jnp.tanh(x)
34
+ for block in self.blocks:
35
+ x = block(x)
36
+ x = jnp.tanh(x)
37
+ x = self.linear_out(x)
38
+ return x
39
+
40
+ class Radial_MLP(eqx.Module):
41
+ xyz_nuclei : Array
42
+ z_one_hot : Array
43
+ flow: _Flow
44
+
45
+ def __init__(
46
+ self,
47
+ dim: int,
48
+ key,
49
+ xyz_nuclei,
50
+ z_one_hot
51
+ ):
52
+ din_flow = 1 + z_one_hot.shape[-1]
53
+ self.xyz_nuclei = xyz_nuclei[:,None,:]
54
+ self.z_one_hot = z_one_hot
55
+ self.flow = _Flow(din_flow, dim, key=key)
56
+
57
+ def __call__(self, states, t):
58
+ vmap_radial_block = jax.vmap(self.flow, in_axes=(None, 0, 0))
59
+ z = lax.expand_dims(states, dimensions=(0,)) - self.xyz_nuclei
60
+ z_norm = jnp.linalg.norm(z, axis=-1)
61
+ x = vmap_radial_block(t, z_norm, self.z_one_hot)
62
+ x = jnp.einsum('ijk,ij->k', z, x)
63
+ return x
64
+
65
+ class CNF(eqx.Module):
66
+ flow: Radial_MLP
67
+
68
+ def __init__(
69
+ self,
70
+ din: Int[Scalar, ""],
71
+ dim: Int[Scalar, ""],
72
+ mu: any,
73
+ one_hot:any,
74
+ key: jax.random.PRNGKey
75
+ ):
76
+ self.flow = Radial_MLP(dim, key, mu, one_hot)
77
+
78
+ def __call__(self, states, t):
79
+ data_dim = 3 #Hardcoded for 3 dimensions
80
+
81
+ @jax.jit
82
+ def _f_ode(self, states, t):
83
+ x, log_px, score = states[:data_dim], states[data_dim:data_dim+1], states[data_dim+1:]
84
+ jac = jax.jacrev(self.flow)(x, t)
85
+ dtrJ = -1. * jnp.trace(jac)
86
+ dz = self.flow(x, t)
87
+ return dz, dtrJ
88
+
89
+ @jax.jit
90
+ def f_ode(self,states,t):
91
+ state, score = states[:-data_dim], states[-data_dim:]
92
+ dx_and_dlopz, _f_vjp = jax.vjp(
93
+ lambda state: _f_ode(self,state, t), state)
94
+ dx, dlopz = dx_and_dlopz
95
+ (vjp_all,) = _f_vjp((score, -1.))
96
+ score_vjp, grad_div = vjp_all[:-1], vjp_all[-1]
97
+ dscore = -score_vjp + grad_div
98
+ return jnp.append(jnp.append(dx, dlopz), dscore)
99
+ return f_ode(self, states, t)
@@ -0,0 +1,35 @@
1
+ # MIT License
2
+
3
+ # Copyright (c) 2025 AlexandreDeCamargo
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ __version__ = "0.1.0"
24
+
25
+ from .functional import (
26
+ Functional,
27
+ CompositeFunctional,
28
+ EnergyFunctional,
29
+ FunctionalInputs,
30
+ )
31
+ from .kinetic import tf, weizsacker, tf_weizsacker
32
+ from .exchange_correlation import lda, b88, vwn, pw92, lda_b88
33
+ from .external import NuclearPotential
34
+ from .hartree import CoulombPotential, CoulombPotential_
35
+ from .core_correction import KatoCondition, HutcheonCuspCondition