off 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- off/__init__.py +23 -0
- off/atom_energies.py +151 -0
- off/config/_config.py +108 -0
- off/dft_distrax/__init__.py +27 -0
- off/dft_distrax/dft_distrax.py +216 -0
- off/flow/__init__.py +29 -0
- off/flow/equiv_flows.py +99 -0
- off/functionals/__init__.py +35 -0
- off/functionals/core_correction.py +84 -0
- off/functionals/exchange_correlation.py +174 -0
- off/functionals/external.py +49 -0
- off/functionals/functional.py +129 -0
- off/functionals/hartree.py +62 -0
- off/functionals/kinetic.py +87 -0
- off/main.py +172 -0
- off/ode_solver/__init__.py +32 -0
- off/ode_solver/eqx_ode.py +76 -0
- off/plot_binding_csv.py +63 -0
- off/plot_pes_ema.py +259 -0
- off/plot_pes_mpl.py +280 -0
- off/promolecular/__init__.py +27 -0
- off/promolecular/promolecular_dist.py +465 -0
- off/quadrature.py +261 -0
- off/quadrature_scan.py +188 -0
- off/scan_pes.py +133 -0
- off/test_fwd_rev.py +290 -0
- off/train/__init__.py +44 -0
- off/train/loop.py +228 -0
- off/train/loss.py +149 -0
- off/train/utils.py +38 -0
- off/utils.py +618 -0
- off-0.1.0.dist-info/METADATA +154 -0
- off-0.1.0.dist-info/RECORD +37 -0
- off-0.1.0.dist-info/WHEEL +5 -0
- off-0.1.0.dist-info/entry_points.txt +3 -0
- off-0.1.0.dist-info/licenses/LICENSE +21 -0
- off-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,465 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import distrax
|
|
3
|
+
import equinox as eqx
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from jax import lax
|
|
7
|
+
from distrax import MultivariateNormalDiag, Categorical
|
|
8
|
+
from jaxtyping import Array, Float, Int, Scalar
|
|
9
|
+
from ..dft_distrax import DFTDistribution
|
|
10
|
+
from pyscf import gto, dft
|
|
11
|
+
import jax.random as jrnd
|
|
12
|
+
|
|
13
|
+
AAtoBohr = 1.8897259886
|
|
14
|
+
|
|
15
|
+
class ProMolecularDensity(distrax.Distribution):
|
|
16
|
+
r"""
|
|
17
|
+
Creates a distribution for a molecule with a mixture of Gaussian components.
|
|
18
|
+
|
|
19
|
+
Attributes
|
|
20
|
+
----------
|
|
21
|
+
z : Int[Scalar, ""]
|
|
22
|
+
Atomic numbers of the atoms in the molecule.
|
|
23
|
+
dim : Int[Scalar, ""]
|
|
24
|
+
Dimension of the system, default is 3 dimensions.
|
|
25
|
+
loc : Float[Array, "z dim"]
|
|
26
|
+
Molecular coordinates.
|
|
27
|
+
scale_diag : Optional[Array], optional
|
|
28
|
+
Sigma matrix, by default None
|
|
29
|
+
units : str, optional
|
|
30
|
+
Interatomic unit distance, by default 'Bohr'
|
|
31
|
+
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
z: Int[Scalar, ""],
|
|
37
|
+
loc: Float[Array, "z dim"],
|
|
38
|
+
dim: Int[Scalar, ""] = 3,
|
|
39
|
+
scale_diag: Optional[Array]=None,
|
|
40
|
+
units: str = 'Bohr',
|
|
41
|
+
):
|
|
42
|
+
self.dim = dim
|
|
43
|
+
self.loc = lax.expand_dims(loc, dimensions=(1,))
|
|
44
|
+
self.units = units
|
|
45
|
+
|
|
46
|
+
if scale_diag is None:
|
|
47
|
+
#sigma = 0.5 / jnp.asarray(z, dtype=loc.dtype).ravel() # shape (n_atoms,)
|
|
48
|
+
#scale_diag = sigma[:, None] * jnp.ones(self.dim, dtype=loc.dtype)
|
|
49
|
+
#self.scale_diag = lax.expand_dims(scale_diag, dimensions= (1,))
|
|
50
|
+
self.scale_diag = jnp.ones_like(self.loc)
|
|
51
|
+
else:
|
|
52
|
+
self.scale_diag = lax.expand_dims(scale_diag, dimensions=(1,))
|
|
53
|
+
|
|
54
|
+
if self.units.lower() == 'aa' or self.units.lower() == 'angstrom':
|
|
55
|
+
self.loc = self.loc*AAtoBohr
|
|
56
|
+
self.scale_diag = self.scale_diag*AAtoBohr
|
|
57
|
+
|
|
58
|
+
self.logits = z
|
|
59
|
+
self.probs = z/jnp.linalg.norm(z, ord=1)
|
|
60
|
+
self.mixture_dist = Categorical(probs=self.probs)
|
|
61
|
+
self.mixture_probs = self.mixture_dist.probs
|
|
62
|
+
self.components_dist = MultivariateNormalDiag(
|
|
63
|
+
loc=self.loc, scale_diag=self.scale_diag)
|
|
64
|
+
|
|
65
|
+
@jax.jit
|
|
66
|
+
def prob(self, value):
|
|
67
|
+
log_px_components_dist = self.components_dist.log_prob(value).T
|
|
68
|
+
px_components_dist = jnp.exp(log_px_components_dist)
|
|
69
|
+
px = px_components_dist@self.mixture_probs[:, None]
|
|
70
|
+
return px
|
|
71
|
+
|
|
72
|
+
@jax.jit
|
|
73
|
+
def log_prob(self, value):
|
|
74
|
+
return jnp.log(self.prob(value))
|
|
75
|
+
|
|
76
|
+
def _sample_n(self, key, n):
|
|
77
|
+
_, key_mixt, key_comp = jax.random.split(key, 3)
|
|
78
|
+
samples_mixt = self.mixture_dist._sample_n(key_mixt, n)
|
|
79
|
+
samples_mixt_one_hot = jax.nn.one_hot(
|
|
80
|
+
samples_mixt, self.mixture_probs.shape[-1])
|
|
81
|
+
|
|
82
|
+
samples_comp = self.components_dist.sample(
|
|
83
|
+
seed=key_comp, sample_shape=n)
|
|
84
|
+
samples_comp = jnp.squeeze(samples_comp, axis=-2)
|
|
85
|
+
|
|
86
|
+
samples = jnp.einsum('ijl,ij->il', samples_comp, samples_mixt_one_hot)
|
|
87
|
+
return samples
|
|
88
|
+
|
|
89
|
+
def event_shape(self):
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
@jax.jit
|
|
93
|
+
def score(self, values):
|
|
94
|
+
return jax.vmap(jax.grad(lambda x:
|
|
95
|
+
self.log_prob(x).sum()))(values)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# class RadialDensityDistribution:
|
|
99
|
+
# """Distribution based on radial density sampling."""
|
|
100
|
+
|
|
101
|
+
# def __init__(self, db_prior, z, coords, grid_range=(-3.01, 3.0), n_grid_points=1000):
|
|
102
|
+
# self.db_prior = db_prior
|
|
103
|
+
# self.z = z
|
|
104
|
+
# self.coords = coords
|
|
105
|
+
# self.n_grid_points = n_grid_points
|
|
106
|
+
|
|
107
|
+
# x = jnp.linspace(-3, 3, 100)
|
|
108
|
+
# y = jnp.linspace(-3, 3, 100)
|
|
109
|
+
|
|
110
|
+
# X, Y = jnp.meshgrid(x, y)
|
|
111
|
+
# Z = jnp.zeros_like(X)
|
|
112
|
+
|
|
113
|
+
# points = jnp.array([X.flatten(), Y.flatten(), Z.flatten()]).T
|
|
114
|
+
# # Pre-compute the radial grid
|
|
115
|
+
# self.rad_grid = jnp.linspace(grid_range[0], grid_range[1], n_grid_points)
|
|
116
|
+
|
|
117
|
+
# # Compute density along z-axis
|
|
118
|
+
# grid_points = jnp.array([
|
|
119
|
+
# jnp.zeros_like(self.rad_grid),
|
|
120
|
+
# jnp.zeros_like(self.rad_grid),
|
|
121
|
+
# self.rad_grid
|
|
122
|
+
# ]).T
|
|
123
|
+
|
|
124
|
+
# self.promol_dens = self.db_prior.density(grid_points)
|
|
125
|
+
# # self.p = self.promol_dens
|
|
126
|
+
# self.p = self.promol_dens /jnp.sum(self.promol_dens)
|
|
127
|
+
|
|
128
|
+
# # Pre-compute gradient
|
|
129
|
+
# self.promol_grad = self.db_prior.gradient(grid_points)
|
|
130
|
+
|
|
131
|
+
# # Compute score (gradient / density)
|
|
132
|
+
# self.promol_score = self.promol_grad / self.promol_dens.reshape(-1, 1)
|
|
133
|
+
|
|
134
|
+
# # Store last sampled indices for efficient lookup
|
|
135
|
+
# self._last_indices = None
|
|
136
|
+
|
|
137
|
+
# def sample(self, seed, sample_shape):
|
|
138
|
+
# """Sample positions from the radial density."""
|
|
139
|
+
# if isinstance(sample_shape, int):
|
|
140
|
+
# n_samples = sample_shape
|
|
141
|
+
# else:
|
|
142
|
+
# n_samples = sample_shape[0] if len(sample_shape) > 0 else 1
|
|
143
|
+
|
|
144
|
+
# # Sample indices according to density - shape (n_samples, 3)
|
|
145
|
+
# sampled_indices = jax.random.choice(
|
|
146
|
+
# seed,
|
|
147
|
+
# a=len(self.rad_grid),
|
|
148
|
+
# shape=(n_samples, 3),
|
|
149
|
+
# replace=True,
|
|
150
|
+
# p=self.p
|
|
151
|
+
# )
|
|
152
|
+
|
|
153
|
+
# # Store indices for later lookup
|
|
154
|
+
# self._last_indices = sampled_indices
|
|
155
|
+
|
|
156
|
+
# # Map indices to radial values - shape (n_samples, 3)
|
|
157
|
+
# samples = self.rad_grid[sampled_indices]
|
|
158
|
+
# return samples
|
|
159
|
+
|
|
160
|
+
# def log_prob(self, value):
|
|
161
|
+
# """Compute log probability of samples."""
|
|
162
|
+
# # Find which grid index each value corresponds to
|
|
163
|
+
# # Since samples come from rad_grid, find closest match
|
|
164
|
+
# indices = jnp.argmin(jnp.abs(value[:, 2:3] - self.rad_grid[None, :]), axis=1)
|
|
165
|
+
|
|
166
|
+
# # Get densities for those indices
|
|
167
|
+
# densities = self.p[indices]
|
|
168
|
+
# log_probs = jnp.log(densities)
|
|
169
|
+
|
|
170
|
+
# return log_probs[:, None] # Shape (batch_size, 1)
|
|
171
|
+
|
|
172
|
+
# def score(self, values):
|
|
173
|
+
# """Compute score (gradient of log probability)."""
|
|
174
|
+
# # Find which grid index each value corresponds to
|
|
175
|
+
# indices = jnp.argmin(jnp.abs(values[:, 2:3] - self.rad_grid[None, :]), axis=1)
|
|
176
|
+
|
|
177
|
+
# # Get pre-computed scores for those indices
|
|
178
|
+
# scores = self.promol_score[indices]
|
|
179
|
+
|
|
180
|
+
# return scores # Shape (batch_size, 3)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
# class DFTGridDistribution:
|
|
184
|
+
# """Distribution based on PySCF grid with atomdb density."""
|
|
185
|
+
|
|
186
|
+
# def __init__(self, db_prior, atoms, coords, basis='6-31G(d,p)', grid_level=3):
|
|
187
|
+
# self.db_prior = db_prior
|
|
188
|
+
# self.atoms = atoms
|
|
189
|
+
# self.coords = coords
|
|
190
|
+
|
|
191
|
+
# # Build atom string for PySCF (in Bohr units)
|
|
192
|
+
# atom_string = ""
|
|
193
|
+
# for atom, coord in zip(atoms, coords):
|
|
194
|
+
# atom_string += f"{atom} {coord[0]:.6f} {coord[1]:.6f} {coord[2]:.6f}; "
|
|
195
|
+
|
|
196
|
+
# # Create PySCF molecule and grid
|
|
197
|
+
# mol = gto.M(atom=atom_string,
|
|
198
|
+
# basis=basis,
|
|
199
|
+
# unit='B',
|
|
200
|
+
# spin=0)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# pyscfgrid = dft.gen_grid.Grids(mol)
|
|
204
|
+
# pyscfgrid.level = grid_level
|
|
205
|
+
# pyscfgrid.build()
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
# self.grid_coords = jnp.array(pyscfgrid.coords)
|
|
209
|
+
# self.grid_weights = jnp.array(pyscfgrid.weights)
|
|
210
|
+
|
|
211
|
+
# self.promol_dens = self.db_prior.density(self.grid_coords)
|
|
212
|
+
# self.promol_grad = self.db_prior.gradient(self.grid_coords)
|
|
213
|
+
# self.promol_score = self.promol_grad / self.promol_dens.reshape(-1, 1)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# # ADD THIS: Normalize density for sampling
|
|
217
|
+
# # weighted_dens = self.promol_dens * self.grid_weights
|
|
218
|
+
# # self.p = weighted_dens / jnp.sum(weighted_dens)
|
|
219
|
+
# self.p = self.promol_dens /jnp.sum(self.promol_dens)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# def sample(self, seed, sample_shape):
|
|
223
|
+
# """Sample grid points weighted by density."""
|
|
224
|
+
# if isinstance(sample_shape, int):
|
|
225
|
+
# n_samples = sample_shape
|
|
226
|
+
# else:
|
|
227
|
+
# n_samples = sample_shape[0] if len(sample_shape) > 0 else 1
|
|
228
|
+
|
|
229
|
+
# indices = jax.random.choice(
|
|
230
|
+
# seed,
|
|
231
|
+
# a=len(self.grid_coords),
|
|
232
|
+
# shape=(n_samples,),
|
|
233
|
+
# # replace=True,
|
|
234
|
+
# p=self.p
|
|
235
|
+
# )
|
|
236
|
+
|
|
237
|
+
# return self.grid_coords[indices]
|
|
238
|
+
# def log_prob(self, value):
|
|
239
|
+
# """Compute log probability at points."""
|
|
240
|
+
# distances = jnp.linalg.norm(
|
|
241
|
+
# value[:, None, :] - self.grid_coords[None, :, :],
|
|
242
|
+
# axis=2
|
|
243
|
+
# )
|
|
244
|
+
# indices = jnp.argmin(distances, axis=1)
|
|
245
|
+
|
|
246
|
+
# probs = self.promol_dens[indices]
|
|
247
|
+
# log_probs = jnp.log(probs)
|
|
248
|
+
|
|
249
|
+
# return log_probs[:, None]
|
|
250
|
+
|
|
251
|
+
# def score(self, values):
|
|
252
|
+
# """Compute score at points."""
|
|
253
|
+
# distances = jnp.linalg.norm(
|
|
254
|
+
# values[:, None, :] - self.grid_coords[None, :, :],
|
|
255
|
+
# axis=2
|
|
256
|
+
# )
|
|
257
|
+
# indices = jnp.argmin(distances, axis=1)
|
|
258
|
+
|
|
259
|
+
# return self.promol_score[indices]
|
|
260
|
+
|
|
261
|
+
class AtomDBDistribution:
|
|
262
|
+
"""Distribution based on atomdb density, with **direct sampling** via
|
|
263
|
+
per-atom inverse-CDF tables built at init time.
|
|
264
|
+
|
|
265
|
+
No SIR / importance reweighting required: samples come exactly from the
|
|
266
|
+
promolecular sum of atomic Slater densities.
|
|
267
|
+
"""
|
|
268
|
+
|
|
269
|
+
def __init__(self, db_prior, z, coords, Ne,
|
|
270
|
+
n_radial: int = 4096, r_max: float = 20.0):
|
|
271
|
+
self.db_prior = db_prior
|
|
272
|
+
self.z = jnp.asarray(z).ravel()
|
|
273
|
+
self.coords = jnp.asarray(coords, dtype=jnp.float64)
|
|
274
|
+
self.Ne = Ne
|
|
275
|
+
self.n_atoms = int(self.coords.shape[0])
|
|
276
|
+
self._build_per_atom_invcdf(n_radial, r_max)
|
|
277
|
+
|
|
278
|
+
# ── Inverse-CDF table construction (numpy/scipy, one-shot at init) ───────
|
|
279
|
+
def _build_per_atom_invcdf(self, n_radial: int, r_max: float):
|
|
280
|
+
import numpy as np
|
|
281
|
+
from scipy.integrate import cumulative_trapezoid
|
|
282
|
+
|
|
283
|
+
r_grid = np.linspace(1e-6, r_max, n_radial) # avoid r=0 singularity
|
|
284
|
+
u_grid = np.linspace(0.0, 1.0, n_radial)
|
|
285
|
+
|
|
286
|
+
inv_cdf_tables = []
|
|
287
|
+
atom_weights = []
|
|
288
|
+
for atom_species in self.db_prior.atoms:
|
|
289
|
+
dens_spline = atom_species.dens_func() # DensitySpline
|
|
290
|
+
rho = np.asarray(dens_spline(r_grid))
|
|
291
|
+
rho = np.maximum(rho, 0.0)
|
|
292
|
+
|
|
293
|
+
integrand = 4.0 * np.pi * r_grid**2 * rho # radial probability mass
|
|
294
|
+
cdf = np.concatenate([[0.0], cumulative_trapezoid(integrand, r_grid)])
|
|
295
|
+
cdf = cdf[:n_radial]
|
|
296
|
+
|
|
297
|
+
n_electrons_atom = float(cdf[-1])
|
|
298
|
+
atom_weights.append(n_electrons_atom)
|
|
299
|
+
if n_electrons_atom <= 0.0:
|
|
300
|
+
# degenerate (shouldn't happen for real atoms); make uniform
|
|
301
|
+
r_at_u = np.copy(r_grid)
|
|
302
|
+
else:
|
|
303
|
+
cdf_norm = cdf / n_electrons_atom
|
|
304
|
+
# invert: for each u in u_grid, find r such that F(r) = u
|
|
305
|
+
r_at_u = np.interp(u_grid, cdf_norm, r_grid)
|
|
306
|
+
inv_cdf_tables.append(r_at_u)
|
|
307
|
+
|
|
308
|
+
self.inv_cdf_tables = jnp.asarray(np.stack(inv_cdf_tables),
|
|
309
|
+
dtype=jnp.float64) # (n_atoms, n_radial)
|
|
310
|
+
self.u_grid = jnp.asarray(u_grid, dtype=jnp.float64) # (n_radial,)
|
|
311
|
+
self.atom_weights = jnp.asarray(atom_weights, dtype=jnp.float64)
|
|
312
|
+
self.atom_probs = self.atom_weights / self.atom_weights.sum()
|
|
313
|
+
|
|
314
|
+
# ── Direct sampling ──────────────────────────────────────────────────────
|
|
315
|
+
def _sample_n(self, key, n: int):
|
|
316
|
+
k_atom, k_r, k_dir = jrnd.split(key, 3)
|
|
317
|
+
|
|
318
|
+
# 1. Pick an atom for each sample, weighted by per-atom electron count
|
|
319
|
+
atom_idx = jrnd.categorical(k_atom, jnp.log(self.atom_probs), shape=(n,))
|
|
320
|
+
|
|
321
|
+
# 2. Uniform u for inverse-CDF lookup
|
|
322
|
+
u = jrnd.uniform(k_r, shape=(n,))
|
|
323
|
+
|
|
324
|
+
# 3. Invert per-sample using the chosen atom's table
|
|
325
|
+
def _invert_one(idx, u_val):
|
|
326
|
+
return jnp.interp(u_val, self.u_grid, self.inv_cdf_tables[idx])
|
|
327
|
+
r = jax.vmap(_invert_one)(atom_idx, u) # (n,)
|
|
328
|
+
|
|
329
|
+
# 4. Isotropic direction on the sphere
|
|
330
|
+
d = jrnd.normal(k_dir, shape=(n, 3))
|
|
331
|
+
d = d / jnp.linalg.norm(d, axis=1, keepdims=True)
|
|
332
|
+
|
|
333
|
+
# 5. Position = atom_center + r * direction
|
|
334
|
+
centers = self.coords[atom_idx] # (n, 3)
|
|
335
|
+
return centers + r[:, None] * d
|
|
336
|
+
|
|
337
|
+
def sample(self, seed, sample_shape):
|
|
338
|
+
if isinstance(sample_shape, int):
|
|
339
|
+
n = sample_shape
|
|
340
|
+
elif len(sample_shape) == 0:
|
|
341
|
+
n = 1
|
|
342
|
+
else:
|
|
343
|
+
n = int(sample_shape[0])
|
|
344
|
+
return self._sample_n(seed, n)
|
|
345
|
+
|
|
346
|
+
# ── log_prob / score ─────────────────────────────────────────────────────
|
|
347
|
+
def log_prob(self, value):
|
|
348
|
+
density = self.db_prior.density(value)
|
|
349
|
+
normalized_density = density / self.Ne
|
|
350
|
+
# Clip floor to avoid log(0) -> -inf for samples that landed beyond
|
|
351
|
+
# the AtomDB spline's support (rare; happens for u→1 tails).
|
|
352
|
+
log_probs = jnp.log(jnp.maximum(normalized_density, 1e-30))
|
|
353
|
+
#log_probs = jnp.log(normalized_density)
|
|
354
|
+
return log_probs[:, None]
|
|
355
|
+
|
|
356
|
+
def prob(self, value):
|
|
357
|
+
density = self.db_prior.density(value)
|
|
358
|
+
normalized_density = density / self.Ne
|
|
359
|
+
return normalized_density
|
|
360
|
+
|
|
361
|
+
def score(self, values):
|
|
362
|
+
# score = ∇log p(x) = ∇log(ρ_total/Ne) = ∇ρ_total/ρ_total
|
|
363
|
+
# (the constant 1/Ne factor cancels in the gradient — do NOT divide ρ by Ne here)
|
|
364
|
+
density = self.db_prior.density(values)
|
|
365
|
+
gradient = self.db_prior.gradient(values)
|
|
366
|
+
score = gradient / jnp.maximum(density.reshape(-1, 1), 1e-30)
|
|
367
|
+
score = jnp.nan_to_num(score, nan=0.0, posinf=0.0, neginf=0.0)
|
|
368
|
+
return score
|
|
369
|
+
|
|
370
|
+
class SIRDistribution:
|
|
371
|
+
"""
|
|
372
|
+
Sampling Importance Resampling (SIR) distribution.
|
|
373
|
+
Uses a base distribution as proposal and reweights to target distribution.
|
|
374
|
+
"""
|
|
375
|
+
|
|
376
|
+
def __init__(
|
|
377
|
+
self,
|
|
378
|
+
base_distribution, # ProMolecularDensity (Gaussian mixture)
|
|
379
|
+
target_distribution, # AtomDBDistribution (atomdb)
|
|
380
|
+
oversampling_factor: int = 10
|
|
381
|
+
):
|
|
382
|
+
"""
|
|
383
|
+
Parameters
|
|
384
|
+
----------
|
|
385
|
+
base_distribution : ProMolecularDensity
|
|
386
|
+
Base distribution (Gaussian mixture) - easy to sample from
|
|
387
|
+
target_distribution : DFTGridDistribution
|
|
388
|
+
Target distribution (atomdb) - what we actually want
|
|
389
|
+
oversampling_factor : int
|
|
390
|
+
How many proposal samples per final sample (higher = better quality)
|
|
391
|
+
"""
|
|
392
|
+
self.base_distribution = base_distribution
|
|
393
|
+
self.target_distribution = target_distribution
|
|
394
|
+
self.oversampling_factor = oversampling_factor
|
|
395
|
+
|
|
396
|
+
def sample(self, seed, sample_shape):
|
|
397
|
+
"""
|
|
398
|
+
Sample using SIR: sample from base, reweight, resample.
|
|
399
|
+
"""
|
|
400
|
+
if isinstance(sample_shape, int):
|
|
401
|
+
n_final_samples = sample_shape
|
|
402
|
+
else:
|
|
403
|
+
n_final_samples = sample_shape[0] if len(sample_shape) > 0 else 1
|
|
404
|
+
|
|
405
|
+
key_sample, key_resample = jrnd.split(seed)
|
|
406
|
+
|
|
407
|
+
# Generate oversampled proposals from base distribution
|
|
408
|
+
n_proposal_samples = self.oversampling_factor * n_final_samples
|
|
409
|
+
|
|
410
|
+
proposal_samples = self.base_distribution.sample(
|
|
411
|
+
seed=key_sample,
|
|
412
|
+
sample_shape=n_proposal_samples
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
# Compute importance weights
|
|
416
|
+
density_base = self.base_distribution.prob(proposal_samples) # Shape: (n_proposal_samples,) or (n_proposal_samples, 1)
|
|
417
|
+
density_target = self.target_distribution.prob(proposal_samples)
|
|
418
|
+
log_density_target = self.target_distribution.log_prob(proposal_samples) # Shape: (n_proposal_samples, 1)
|
|
419
|
+
|
|
420
|
+
# Flatten to 1D if needed
|
|
421
|
+
# print((density_base.ravel()).shape)
|
|
422
|
+
# print((density_target.ravel()).shape)
|
|
423
|
+
# assert 0
|
|
424
|
+
density_base = density_base.ravel() # Shape: (n_proposal_samples,)
|
|
425
|
+
# density_target = density_target.squeeze()
|
|
426
|
+
|
|
427
|
+
# log_density_target = log_density_target.squeeze() # Shape: (n_proposal_samples,)
|
|
428
|
+
# density_target = jnp.exp(log_density_target) # Shape: (n_proposal_samples,)
|
|
429
|
+
|
|
430
|
+
# Importance weights: w = p_target / p_base
|
|
431
|
+
importance_weights = density_target / density_base
|
|
432
|
+
importance_weights = jnp.nan_to_num(importance_weights, nan=0.0, posinf=0.0, neginf=0.0)
|
|
433
|
+
importance_weights = jnp.maximum(importance_weights, 0.0)
|
|
434
|
+
importance_weights = importance_weights / jnp.sum(importance_weights)
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
# Resample according to weights
|
|
438
|
+
resampled_indices = jrnd.choice(
|
|
439
|
+
key_resample,
|
|
440
|
+
a=n_proposal_samples,
|
|
441
|
+
shape=(n_final_samples,),
|
|
442
|
+
p=importance_weights,
|
|
443
|
+
replace=True
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
# Step 4: Calculate effective sample size
|
|
447
|
+
effective_sample_size = 1.0 / jnp.sum(importance_weights**2)
|
|
448
|
+
# Step 5: Resample according to importance weights
|
|
449
|
+
if effective_sample_size < n_final_samples:
|
|
450
|
+
print(f"Warning: Low effective sample size ({effective_sample_size:.1f}) for {n_final_samples} final samples")
|
|
451
|
+
|
|
452
|
+
final_samples = proposal_samples[resampled_indices]
|
|
453
|
+
return final_samples
|
|
454
|
+
|
|
455
|
+
def log_prob(self, value):
|
|
456
|
+
"""
|
|
457
|
+
Log probability uses the target distribution.
|
|
458
|
+
"""
|
|
459
|
+
return self.target_distribution.log_prob(value)
|
|
460
|
+
|
|
461
|
+
def score(self, values):
|
|
462
|
+
"""
|
|
463
|
+
Score uses the target distribution.
|
|
464
|
+
"""
|
|
465
|
+
return self.target_distribution.score(values)
|