off 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,465 @@
1
+ from typing import Optional
2
+ import distrax
3
+ import equinox as eqx
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from jax import lax
7
+ from distrax import MultivariateNormalDiag, Categorical
8
+ from jaxtyping import Array, Float, Int, Scalar
9
+ from ..dft_distrax import DFTDistribution
10
+ from pyscf import gto, dft
11
+ import jax.random as jrnd
12
+
13
+ AAtoBohr = 1.8897259886
14
+
15
+ class ProMolecularDensity(distrax.Distribution):
16
+ r"""
17
+ Creates a distribution for a molecule with a mixture of Gaussian components.
18
+
19
+ Attributes
20
+ ----------
21
+ z : Int[Scalar, ""]
22
+ Atomic numbers of the atoms in the molecule.
23
+ dim : Int[Scalar, ""]
24
+ Dimension of the system, default is 3 dimensions.
25
+ loc : Float[Array, "z dim"]
26
+ Molecular coordinates.
27
+ scale_diag : Optional[Array], optional
28
+ Sigma matrix, by default None
29
+ units : str, optional
30
+ Interatomic unit distance, by default 'Bohr'
31
+
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ z: Int[Scalar, ""],
37
+ loc: Float[Array, "z dim"],
38
+ dim: Int[Scalar, ""] = 3,
39
+ scale_diag: Optional[Array]=None,
40
+ units: str = 'Bohr',
41
+ ):
42
+ self.dim = dim
43
+ self.loc = lax.expand_dims(loc, dimensions=(1,))
44
+ self.units = units
45
+
46
+ if scale_diag is None:
47
+ #sigma = 0.5 / jnp.asarray(z, dtype=loc.dtype).ravel() # shape (n_atoms,)
48
+ #scale_diag = sigma[:, None] * jnp.ones(self.dim, dtype=loc.dtype)
49
+ #self.scale_diag = lax.expand_dims(scale_diag, dimensions= (1,))
50
+ self.scale_diag = jnp.ones_like(self.loc)
51
+ else:
52
+ self.scale_diag = lax.expand_dims(scale_diag, dimensions=(1,))
53
+
54
+ if self.units.lower() == 'aa' or self.units.lower() == 'angstrom':
55
+ self.loc = self.loc*AAtoBohr
56
+ self.scale_diag = self.scale_diag*AAtoBohr
57
+
58
+ self.logits = z
59
+ self.probs = z/jnp.linalg.norm(z, ord=1)
60
+ self.mixture_dist = Categorical(probs=self.probs)
61
+ self.mixture_probs = self.mixture_dist.probs
62
+ self.components_dist = MultivariateNormalDiag(
63
+ loc=self.loc, scale_diag=self.scale_diag)
64
+
65
+ @jax.jit
66
+ def prob(self, value):
67
+ log_px_components_dist = self.components_dist.log_prob(value).T
68
+ px_components_dist = jnp.exp(log_px_components_dist)
69
+ px = px_components_dist@self.mixture_probs[:, None]
70
+ return px
71
+
72
+ @jax.jit
73
+ def log_prob(self, value):
74
+ return jnp.log(self.prob(value))
75
+
76
+ def _sample_n(self, key, n):
77
+ _, key_mixt, key_comp = jax.random.split(key, 3)
78
+ samples_mixt = self.mixture_dist._sample_n(key_mixt, n)
79
+ samples_mixt_one_hot = jax.nn.one_hot(
80
+ samples_mixt, self.mixture_probs.shape[-1])
81
+
82
+ samples_comp = self.components_dist.sample(
83
+ seed=key_comp, sample_shape=n)
84
+ samples_comp = jnp.squeeze(samples_comp, axis=-2)
85
+
86
+ samples = jnp.einsum('ijl,ij->il', samples_comp, samples_mixt_one_hot)
87
+ return samples
88
+
89
+ def event_shape(self):
90
+ pass
91
+
92
+ @jax.jit
93
+ def score(self, values):
94
+ return jax.vmap(jax.grad(lambda x:
95
+ self.log_prob(x).sum()))(values)
96
+
97
+
98
+ # class RadialDensityDistribution:
99
+ # """Distribution based on radial density sampling."""
100
+
101
+ # def __init__(self, db_prior, z, coords, grid_range=(-3.01, 3.0), n_grid_points=1000):
102
+ # self.db_prior = db_prior
103
+ # self.z = z
104
+ # self.coords = coords
105
+ # self.n_grid_points = n_grid_points
106
+
107
+ # x = jnp.linspace(-3, 3, 100)
108
+ # y = jnp.linspace(-3, 3, 100)
109
+
110
+ # X, Y = jnp.meshgrid(x, y)
111
+ # Z = jnp.zeros_like(X)
112
+
113
+ # points = jnp.array([X.flatten(), Y.flatten(), Z.flatten()]).T
114
+ # # Pre-compute the radial grid
115
+ # self.rad_grid = jnp.linspace(grid_range[0], grid_range[1], n_grid_points)
116
+
117
+ # # Compute density along z-axis
118
+ # grid_points = jnp.array([
119
+ # jnp.zeros_like(self.rad_grid),
120
+ # jnp.zeros_like(self.rad_grid),
121
+ # self.rad_grid
122
+ # ]).T
123
+
124
+ # self.promol_dens = self.db_prior.density(grid_points)
125
+ # # self.p = self.promol_dens
126
+ # self.p = self.promol_dens /jnp.sum(self.promol_dens)
127
+
128
+ # # Pre-compute gradient
129
+ # self.promol_grad = self.db_prior.gradient(grid_points)
130
+
131
+ # # Compute score (gradient / density)
132
+ # self.promol_score = self.promol_grad / self.promol_dens.reshape(-1, 1)
133
+
134
+ # # Store last sampled indices for efficient lookup
135
+ # self._last_indices = None
136
+
137
+ # def sample(self, seed, sample_shape):
138
+ # """Sample positions from the radial density."""
139
+ # if isinstance(sample_shape, int):
140
+ # n_samples = sample_shape
141
+ # else:
142
+ # n_samples = sample_shape[0] if len(sample_shape) > 0 else 1
143
+
144
+ # # Sample indices according to density - shape (n_samples, 3)
145
+ # sampled_indices = jax.random.choice(
146
+ # seed,
147
+ # a=len(self.rad_grid),
148
+ # shape=(n_samples, 3),
149
+ # replace=True,
150
+ # p=self.p
151
+ # )
152
+
153
+ # # Store indices for later lookup
154
+ # self._last_indices = sampled_indices
155
+
156
+ # # Map indices to radial values - shape (n_samples, 3)
157
+ # samples = self.rad_grid[sampled_indices]
158
+ # return samples
159
+
160
+ # def log_prob(self, value):
161
+ # """Compute log probability of samples."""
162
+ # # Find which grid index each value corresponds to
163
+ # # Since samples come from rad_grid, find closest match
164
+ # indices = jnp.argmin(jnp.abs(value[:, 2:3] - self.rad_grid[None, :]), axis=1)
165
+
166
+ # # Get densities for those indices
167
+ # densities = self.p[indices]
168
+ # log_probs = jnp.log(densities)
169
+
170
+ # return log_probs[:, None] # Shape (batch_size, 1)
171
+
172
+ # def score(self, values):
173
+ # """Compute score (gradient of log probability)."""
174
+ # # Find which grid index each value corresponds to
175
+ # indices = jnp.argmin(jnp.abs(values[:, 2:3] - self.rad_grid[None, :]), axis=1)
176
+
177
+ # # Get pre-computed scores for those indices
178
+ # scores = self.promol_score[indices]
179
+
180
+ # return scores # Shape (batch_size, 3)
181
+
182
+
183
+ # class DFTGridDistribution:
184
+ # """Distribution based on PySCF grid with atomdb density."""
185
+
186
+ # def __init__(self, db_prior, atoms, coords, basis='6-31G(d,p)', grid_level=3):
187
+ # self.db_prior = db_prior
188
+ # self.atoms = atoms
189
+ # self.coords = coords
190
+
191
+ # # Build atom string for PySCF (in Bohr units)
192
+ # atom_string = ""
193
+ # for atom, coord in zip(atoms, coords):
194
+ # atom_string += f"{atom} {coord[0]:.6f} {coord[1]:.6f} {coord[2]:.6f}; "
195
+
196
+ # # Create PySCF molecule and grid
197
+ # mol = gto.M(atom=atom_string,
198
+ # basis=basis,
199
+ # unit='B',
200
+ # spin=0)
201
+
202
+
203
+ # pyscfgrid = dft.gen_grid.Grids(mol)
204
+ # pyscfgrid.level = grid_level
205
+ # pyscfgrid.build()
206
+
207
+
208
+ # self.grid_coords = jnp.array(pyscfgrid.coords)
209
+ # self.grid_weights = jnp.array(pyscfgrid.weights)
210
+
211
+ # self.promol_dens = self.db_prior.density(self.grid_coords)
212
+ # self.promol_grad = self.db_prior.gradient(self.grid_coords)
213
+ # self.promol_score = self.promol_grad / self.promol_dens.reshape(-1, 1)
214
+
215
+
216
+ # # ADD THIS: Normalize density for sampling
217
+ # # weighted_dens = self.promol_dens * self.grid_weights
218
+ # # self.p = weighted_dens / jnp.sum(weighted_dens)
219
+ # self.p = self.promol_dens /jnp.sum(self.promol_dens)
220
+
221
+
222
+ # def sample(self, seed, sample_shape):
223
+ # """Sample grid points weighted by density."""
224
+ # if isinstance(sample_shape, int):
225
+ # n_samples = sample_shape
226
+ # else:
227
+ # n_samples = sample_shape[0] if len(sample_shape) > 0 else 1
228
+
229
+ # indices = jax.random.choice(
230
+ # seed,
231
+ # a=len(self.grid_coords),
232
+ # shape=(n_samples,),
233
+ # # replace=True,
234
+ # p=self.p
235
+ # )
236
+
237
+ # return self.grid_coords[indices]
238
+ # def log_prob(self, value):
239
+ # """Compute log probability at points."""
240
+ # distances = jnp.linalg.norm(
241
+ # value[:, None, :] - self.grid_coords[None, :, :],
242
+ # axis=2
243
+ # )
244
+ # indices = jnp.argmin(distances, axis=1)
245
+
246
+ # probs = self.promol_dens[indices]
247
+ # log_probs = jnp.log(probs)
248
+
249
+ # return log_probs[:, None]
250
+
251
+ # def score(self, values):
252
+ # """Compute score at points."""
253
+ # distances = jnp.linalg.norm(
254
+ # values[:, None, :] - self.grid_coords[None, :, :],
255
+ # axis=2
256
+ # )
257
+ # indices = jnp.argmin(distances, axis=1)
258
+
259
+ # return self.promol_score[indices]
260
+
261
+ class AtomDBDistribution:
262
+ """Distribution based on atomdb density, with **direct sampling** via
263
+ per-atom inverse-CDF tables built at init time.
264
+
265
+ No SIR / importance reweighting required: samples come exactly from the
266
+ promolecular sum of atomic Slater densities.
267
+ """
268
+
269
+ def __init__(self, db_prior, z, coords, Ne,
270
+ n_radial: int = 4096, r_max: float = 20.0):
271
+ self.db_prior = db_prior
272
+ self.z = jnp.asarray(z).ravel()
273
+ self.coords = jnp.asarray(coords, dtype=jnp.float64)
274
+ self.Ne = Ne
275
+ self.n_atoms = int(self.coords.shape[0])
276
+ self._build_per_atom_invcdf(n_radial, r_max)
277
+
278
+ # ── Inverse-CDF table construction (numpy/scipy, one-shot at init) ───────
279
+ def _build_per_atom_invcdf(self, n_radial: int, r_max: float):
280
+ import numpy as np
281
+ from scipy.integrate import cumulative_trapezoid
282
+
283
+ r_grid = np.linspace(1e-6, r_max, n_radial) # avoid r=0 singularity
284
+ u_grid = np.linspace(0.0, 1.0, n_radial)
285
+
286
+ inv_cdf_tables = []
287
+ atom_weights = []
288
+ for atom_species in self.db_prior.atoms:
289
+ dens_spline = atom_species.dens_func() # DensitySpline
290
+ rho = np.asarray(dens_spline(r_grid))
291
+ rho = np.maximum(rho, 0.0)
292
+
293
+ integrand = 4.0 * np.pi * r_grid**2 * rho # radial probability mass
294
+ cdf = np.concatenate([[0.0], cumulative_trapezoid(integrand, r_grid)])
295
+ cdf = cdf[:n_radial]
296
+
297
+ n_electrons_atom = float(cdf[-1])
298
+ atom_weights.append(n_electrons_atom)
299
+ if n_electrons_atom <= 0.0:
300
+ # degenerate (shouldn't happen for real atoms); make uniform
301
+ r_at_u = np.copy(r_grid)
302
+ else:
303
+ cdf_norm = cdf / n_electrons_atom
304
+ # invert: for each u in u_grid, find r such that F(r) = u
305
+ r_at_u = np.interp(u_grid, cdf_norm, r_grid)
306
+ inv_cdf_tables.append(r_at_u)
307
+
308
+ self.inv_cdf_tables = jnp.asarray(np.stack(inv_cdf_tables),
309
+ dtype=jnp.float64) # (n_atoms, n_radial)
310
+ self.u_grid = jnp.asarray(u_grid, dtype=jnp.float64) # (n_radial,)
311
+ self.atom_weights = jnp.asarray(atom_weights, dtype=jnp.float64)
312
+ self.atom_probs = self.atom_weights / self.atom_weights.sum()
313
+
314
+ # ── Direct sampling ──────────────────────────────────────────────────────
315
+ def _sample_n(self, key, n: int):
316
+ k_atom, k_r, k_dir = jrnd.split(key, 3)
317
+
318
+ # 1. Pick an atom for each sample, weighted by per-atom electron count
319
+ atom_idx = jrnd.categorical(k_atom, jnp.log(self.atom_probs), shape=(n,))
320
+
321
+ # 2. Uniform u for inverse-CDF lookup
322
+ u = jrnd.uniform(k_r, shape=(n,))
323
+
324
+ # 3. Invert per-sample using the chosen atom's table
325
+ def _invert_one(idx, u_val):
326
+ return jnp.interp(u_val, self.u_grid, self.inv_cdf_tables[idx])
327
+ r = jax.vmap(_invert_one)(atom_idx, u) # (n,)
328
+
329
+ # 4. Isotropic direction on the sphere
330
+ d = jrnd.normal(k_dir, shape=(n, 3))
331
+ d = d / jnp.linalg.norm(d, axis=1, keepdims=True)
332
+
333
+ # 5. Position = atom_center + r * direction
334
+ centers = self.coords[atom_idx] # (n, 3)
335
+ return centers + r[:, None] * d
336
+
337
+ def sample(self, seed, sample_shape):
338
+ if isinstance(sample_shape, int):
339
+ n = sample_shape
340
+ elif len(sample_shape) == 0:
341
+ n = 1
342
+ else:
343
+ n = int(sample_shape[0])
344
+ return self._sample_n(seed, n)
345
+
346
+ # ── log_prob / score ─────────────────────────────────────────────────────
347
+ def log_prob(self, value):
348
+ density = self.db_prior.density(value)
349
+ normalized_density = density / self.Ne
350
+ # Clip floor to avoid log(0) -> -inf for samples that landed beyond
351
+ # the AtomDB spline's support (rare; happens for u→1 tails).
352
+ log_probs = jnp.log(jnp.maximum(normalized_density, 1e-30))
353
+ #log_probs = jnp.log(normalized_density)
354
+ return log_probs[:, None]
355
+
356
+ def prob(self, value):
357
+ density = self.db_prior.density(value)
358
+ normalized_density = density / self.Ne
359
+ return normalized_density
360
+
361
+ def score(self, values):
362
+ # score = ∇log p(x) = ∇log(ρ_total/Ne) = ∇ρ_total/ρ_total
363
+ # (the constant 1/Ne factor cancels in the gradient — do NOT divide ρ by Ne here)
364
+ density = self.db_prior.density(values)
365
+ gradient = self.db_prior.gradient(values)
366
+ score = gradient / jnp.maximum(density.reshape(-1, 1), 1e-30)
367
+ score = jnp.nan_to_num(score, nan=0.0, posinf=0.0, neginf=0.0)
368
+ return score
369
+
370
+ class SIRDistribution:
371
+ """
372
+ Sampling Importance Resampling (SIR) distribution.
373
+ Uses a base distribution as proposal and reweights to target distribution.
374
+ """
375
+
376
+ def __init__(
377
+ self,
378
+ base_distribution, # ProMolecularDensity (Gaussian mixture)
379
+ target_distribution, # AtomDBDistribution (atomdb)
380
+ oversampling_factor: int = 10
381
+ ):
382
+ """
383
+ Parameters
384
+ ----------
385
+ base_distribution : ProMolecularDensity
386
+ Base distribution (Gaussian mixture) - easy to sample from
387
+ target_distribution : DFTGridDistribution
388
+ Target distribution (atomdb) - what we actually want
389
+ oversampling_factor : int
390
+ How many proposal samples per final sample (higher = better quality)
391
+ """
392
+ self.base_distribution = base_distribution
393
+ self.target_distribution = target_distribution
394
+ self.oversampling_factor = oversampling_factor
395
+
396
+ def sample(self, seed, sample_shape):
397
+ """
398
+ Sample using SIR: sample from base, reweight, resample.
399
+ """
400
+ if isinstance(sample_shape, int):
401
+ n_final_samples = sample_shape
402
+ else:
403
+ n_final_samples = sample_shape[0] if len(sample_shape) > 0 else 1
404
+
405
+ key_sample, key_resample = jrnd.split(seed)
406
+
407
+ # Generate oversampled proposals from base distribution
408
+ n_proposal_samples = self.oversampling_factor * n_final_samples
409
+
410
+ proposal_samples = self.base_distribution.sample(
411
+ seed=key_sample,
412
+ sample_shape=n_proposal_samples
413
+ )
414
+
415
+ # Compute importance weights
416
+ density_base = self.base_distribution.prob(proposal_samples) # Shape: (n_proposal_samples,) or (n_proposal_samples, 1)
417
+ density_target = self.target_distribution.prob(proposal_samples)
418
+ log_density_target = self.target_distribution.log_prob(proposal_samples) # Shape: (n_proposal_samples, 1)
419
+
420
+ # Flatten to 1D if needed
421
+ # print((density_base.ravel()).shape)
422
+ # print((density_target.ravel()).shape)
423
+ # assert 0
424
+ density_base = density_base.ravel() # Shape: (n_proposal_samples,)
425
+ # density_target = density_target.squeeze()
426
+
427
+ # log_density_target = log_density_target.squeeze() # Shape: (n_proposal_samples,)
428
+ # density_target = jnp.exp(log_density_target) # Shape: (n_proposal_samples,)
429
+
430
+ # Importance weights: w = p_target / p_base
431
+ importance_weights = density_target / density_base
432
+ importance_weights = jnp.nan_to_num(importance_weights, nan=0.0, posinf=0.0, neginf=0.0)
433
+ importance_weights = jnp.maximum(importance_weights, 0.0)
434
+ importance_weights = importance_weights / jnp.sum(importance_weights)
435
+
436
+
437
+ # Resample according to weights
438
+ resampled_indices = jrnd.choice(
439
+ key_resample,
440
+ a=n_proposal_samples,
441
+ shape=(n_final_samples,),
442
+ p=importance_weights,
443
+ replace=True
444
+ )
445
+
446
+ # Step 4: Calculate effective sample size
447
+ effective_sample_size = 1.0 / jnp.sum(importance_weights**2)
448
+ # Step 5: Resample according to importance weights
449
+ if effective_sample_size < n_final_samples:
450
+ print(f"Warning: Low effective sample size ({effective_sample_size:.1f}) for {n_final_samples} final samples")
451
+
452
+ final_samples = proposal_samples[resampled_indices]
453
+ return final_samples
454
+
455
+ def log_prob(self, value):
456
+ """
457
+ Log probability uses the target distribution.
458
+ """
459
+ return self.target_distribution.log_prob(value)
460
+
461
+ def score(self, values):
462
+ """
463
+ Score uses the target distribution.
464
+ """
465
+ return self.target_distribution.score(values)