boltzmann-generators 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- boltzmann_generators/__init__.py +18 -0
- boltzmann_generators/analysis.py +36 -0
- boltzmann_generators/base/__init__.py +6 -0
- boltzmann_generators/base/density.py +36 -0
- boltzmann_generators/base/energy.py +20 -0
- boltzmann_generators/energies/__init__.py +5 -0
- boltzmann_generators/energies/dipeptide.py +81 -0
- boltzmann_generators/energies/double_well.py +56 -0
- boltzmann_generators/energies/muller.py +65 -0
- boltzmann_generators/extensions/__init__.py +1 -0
- boltzmann_generators/extensions/equivariant.py +7 -0
- boltzmann_generators/extensions/molecular.py +7 -0
- boltzmann_generators/extensions/transferable.py +7 -0
- boltzmann_generators/flows/__init__.py +20 -0
- boltzmann_generators/flows/base.py +68 -0
- boltzmann_generators/flows/cnf.py +266 -0
- boltzmann_generators/flows/coupling.py +89 -0
- boltzmann_generators/flows/periodic.py +27 -0
- boltzmann_generators/flows/realnvp.py +79 -0
- boltzmann_generators/io.py +34 -0
- boltzmann_generators/losses.py +70 -0
- boltzmann_generators/mcmc.py +68 -0
- boltzmann_generators/py.typed +0 -0
- boltzmann_generators/sampling.py +54 -0
- boltzmann_generators/services/__init__.py +7 -0
- boltzmann_generators/services/analysis.py +65 -0
- boltzmann_generators/services/checkpoint.py +43 -0
- boltzmann_generators/services/sampling.py +74 -0
- boltzmann_generators/train.py +40 -0
- boltzmann_generators/training/__init__.py +19 -0
- boltzmann_generators/training/loss_strategies.py +113 -0
- boltzmann_generators/training/trainer.py +94 -0
- boltzmann_generators-0.2.1.dist-info/METADATA +146 -0
- boltzmann_generators-0.2.1.dist-info/RECORD +36 -0
- boltzmann_generators-0.2.1.dist-info/WHEEL +4 -0
- boltzmann_generators-0.2.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""Boltzmann Generators — from-scratch PyTorch implementation."""
|
|
2
|
+
|
|
3
|
+
from . import analysis, base, energies, flows, io, losses, mcmc, sampling, services, train, training
|
|
4
|
+
|
|
5
|
+
__version__ = "0.2.1"
|
|
6
|
+
__all__ = [
|
|
7
|
+
"analysis",
|
|
8
|
+
"base",
|
|
9
|
+
"energies",
|
|
10
|
+
"flows",
|
|
11
|
+
"io",
|
|
12
|
+
"losses",
|
|
13
|
+
"mcmc",
|
|
14
|
+
"sampling",
|
|
15
|
+
"services",
|
|
16
|
+
"train",
|
|
17
|
+
"training",
|
|
18
|
+
]
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Analysis helpers for weighted and unweighted population estimates."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
from .services.analysis import AnalysisSuite
|
|
10
|
+
|
|
11
|
+
RegionFn = Callable[[Tensor], Tensor]
|
|
12
|
+
|
|
13
|
+
__all__ = ["AnalysisSuite", "RegionFn", "basin_populations", "rectangular_region"]
|
|
14
|
+
|
|
15
|
+
_suite = AnalysisSuite()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def basin_populations(
|
|
19
|
+
x: Tensor,
|
|
20
|
+
region_fns: dict[str, RegionFn],
|
|
21
|
+
*,
|
|
22
|
+
log_w: Tensor | None = None,
|
|
23
|
+
) -> dict[str, float]:
|
|
24
|
+
"""Compute basin populations from point assignments or importance weights."""
|
|
25
|
+
return _suite.basin_populations(x, region_fns, log_w=log_w)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def rectangular_region(
|
|
29
|
+
*,
|
|
30
|
+
x_min: float,
|
|
31
|
+
x_max: float,
|
|
32
|
+
y_min: float,
|
|
33
|
+
y_max: float,
|
|
34
|
+
) -> RegionFn:
|
|
35
|
+
"""Create a rectangular region predicate over 2D coordinates."""
|
|
36
|
+
return AnalysisSuite.rectangular_region(x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Abstract base class for Boltzmann generator density models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Protocol, runtime_checkable
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import Tensor, nn
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@runtime_checkable
|
|
13
|
+
class DensityModel(Protocol):
|
|
14
|
+
"""Structural typing contract for density models used in losses and sampling."""
|
|
15
|
+
|
|
16
|
+
def sample(self, n: int, device: torch.device | str = "cpu") -> tuple[Tensor, Tensor]:
|
|
17
|
+
"""Draw ``n`` samples; return ``(x, log_q(x))``."""
|
|
18
|
+
|
|
19
|
+
def log_prob(self, x: Tensor) -> Tensor:
|
|
20
|
+
"""Log-density ``log q(x)`` for batch ``x``."""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BaseDensityModel(nn.Module, ABC):
|
|
24
|
+
"""PyTorch module implementing a tractable approximate Boltzmann density."""
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def sample(self, n: int, device: torch.device | str = "cpu") -> tuple[Tensor, Tensor]:
|
|
28
|
+
"""Draw ``n`` samples in data space; return ``(x, log_q(x))``."""
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def log_prob(self, x: Tensor) -> Tensor:
|
|
32
|
+
"""Log-density ``log q(x)`` for batch ``x``."""
|
|
33
|
+
|
|
34
|
+
def nll(self, x: Tensor) -> Tensor:
|
|
35
|
+
"""Negative log-likelihood (forward KL up to constant). Mean over batch."""
|
|
36
|
+
return -self.log_prob(x).mean()
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Abstract base class for benchmark energy functions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EnergyModel(ABC):
|
|
11
|
+
"""Reduced energy u(x) = U(x) / (kT) in dimension ``dim``."""
|
|
12
|
+
|
|
13
|
+
dim: int
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def energy(self, x: Tensor) -> Tensor:
|
|
17
|
+
"""Evaluate reduced energy on batch ``x`` of shape ``(..., dim)``."""
|
|
18
|
+
|
|
19
|
+
def __call__(self, x: Tensor) -> Tensor:
|
|
20
|
+
return self.energy(x)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Synthetic Ramachandran-like potential — a stand-in for alanine dipeptide.
|
|
2
|
+
|
|
3
|
+
Real alanine dipeptide has a ~50-atom Cartesian description and a CHARMM/AMBER
|
|
4
|
+
force field. Implementing that here would pull in OpenMM or a full from-scratch
|
|
5
|
+
MM energy, both of which are out of scope.
|
|
6
|
+
|
|
7
|
+
Instead we model the *free-energy surface* in dihedral space (phi, psi) directly,
|
|
8
|
+
using a sum of Gaussian wells at the canonical alanine dipeptide minima. This
|
|
9
|
+
gives a 2D periodic potential with the right qualitative structure (alpha_R,
|
|
10
|
+
alpha_L, beta/PPII, C5/C7eq) so we can demonstrate BG/CFM training on a
|
|
11
|
+
molecular-flavored target without molecular machinery.
|
|
12
|
+
|
|
13
|
+
Coordinates: (phi, psi) in degrees, periodic in [-180, 180].
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch import Tensor
|
|
20
|
+
|
|
21
|
+
from ..base.energy import EnergyModel
|
|
22
|
+
|
|
23
|
+
# Approximate alanine dipeptide minima (degrees), depths in kT
|
|
24
|
+
_MINIMA = torch.tensor(
|
|
25
|
+
[
|
|
26
|
+
[-65.0, -40.0], # alpha_R (deepest)
|
|
27
|
+
[-150.0, 155.0], # C5 / beta
|
|
28
|
+
[-80.0, 80.0], # PPII / C7eq
|
|
29
|
+
[65.0, 40.0], # alpha_L (shallow)
|
|
30
|
+
]
|
|
31
|
+
)
|
|
32
|
+
_DEPTHS = torch.tensor([6.0, 5.0, 4.5, 2.5]) # in kT
|
|
33
|
+
_WIDTHS = torch.tensor([22.0, 28.0, 30.0, 30.0]) # degrees
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _wrap_deg(d: Tensor) -> Tensor:
|
|
37
|
+
"""Wrap angle differences to (-180, 180]."""
|
|
38
|
+
return (d + 180.0) % 360.0 - 180.0
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class RamachandranDipeptide(EnergyModel):
|
|
42
|
+
"""Synthetic 2D dipeptide free-energy surface in (phi, psi) degrees.
|
|
43
|
+
|
|
44
|
+
u(phi, psi) = -log sum_k exp(-d_k(phi, psi) / w_k^2 + log depth_k) + const
|
|
45
|
+
|
|
46
|
+
Each well is a Gaussian in periodic-angle distance. Total potential is a
|
|
47
|
+
smooth log-sum-exp combination so derivatives are well-defined.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self) -> None:
|
|
51
|
+
self.dim = 2
|
|
52
|
+
|
|
53
|
+
def __call__(self, x: Tensor) -> Tensor:
|
|
54
|
+
return self.energy(x)
|
|
55
|
+
|
|
56
|
+
def energy(self, x: Tensor) -> Tensor:
|
|
57
|
+
"""x: (..., 2) in degrees. Returns reduced energy (kT units)."""
|
|
58
|
+
device = x.device
|
|
59
|
+
minima = _MINIMA.to(device)
|
|
60
|
+
depths = _DEPTHS.to(device)
|
|
61
|
+
widths = _WIDTHS.to(device)
|
|
62
|
+
# Periodic squared distance to each minimum
|
|
63
|
+
dphi = _wrap_deg(x[..., 0:1] - minima[:, 0]) # (..., K)
|
|
64
|
+
dpsi = _wrap_deg(x[..., 1:2] - minima[:, 1])
|
|
65
|
+
d2 = (dphi.pow(2) + dpsi.pow(2)) / widths.pow(2)
|
|
66
|
+
# log-sum-exp combination: u = -log sum exp(depth - d2)
|
|
67
|
+
logits = depths - d2
|
|
68
|
+
u = -torch.logsumexp(logits, dim=-1)
|
|
69
|
+
return u
|
|
70
|
+
|
|
71
|
+
def grid(self, n: int = 200) -> tuple[Tensor, Tensor, Tensor]:
|
|
72
|
+
xs = torch.linspace(-180, 180, n)
|
|
73
|
+
ys = torch.linspace(-180, 180, n)
|
|
74
|
+
gx, gy = torch.meshgrid(xs, ys, indexing="xy")
|
|
75
|
+
grid = torch.stack([gx.flatten(), gy.flatten()], dim=-1)
|
|
76
|
+
u = self.energy(grid).reshape(n, n)
|
|
77
|
+
return gx, gy, u
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def minima(self) -> Tensor:
|
|
81
|
+
return _MINIMA.clone()
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Double-well potentials in 1D and 2D.
|
|
2
|
+
|
|
3
|
+
All energies are returned as reduced (unitless) energies u(x) = U(x)/(kT).
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
|
|
11
|
+
from ..base.energy import EnergyModel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DoubleWell2D(EnergyModel):
|
|
15
|
+
"""u(x, y) = a*(x^2 - 1)^2 + 0.5/sigma^2 * y^2.
|
|
16
|
+
|
|
17
|
+
Two minima at (±1, 0). The y direction is a harmonic well of width sigma.
|
|
18
|
+
Parameter `a` controls the barrier height: barrier ≈ a (in kT).
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, a: float = 4.0, sigma_y: float = 0.5) -> None:
|
|
22
|
+
self.a = a
|
|
23
|
+
self.sigma_y = sigma_y
|
|
24
|
+
self.dim = 2
|
|
25
|
+
|
|
26
|
+
def __call__(self, x: Tensor) -> Tensor:
|
|
27
|
+
return self.energy(x)
|
|
28
|
+
|
|
29
|
+
def energy(self, x: Tensor) -> Tensor:
|
|
30
|
+
assert x.shape[-1] == 2
|
|
31
|
+
xx = x[..., 0]
|
|
32
|
+
yy = x[..., 1]
|
|
33
|
+
return self.a * (xx.pow(2) - 1.0).pow(2) + 0.5 * (yy / self.sigma_y).pow(2)
|
|
34
|
+
|
|
35
|
+
def grid(self, n: int = 200, span: float = 2.5) -> tuple[Tensor, Tensor, Tensor]:
|
|
36
|
+
xs = torch.linspace(-span, span, n)
|
|
37
|
+
ys = torch.linspace(-span, span, n)
|
|
38
|
+
gx, gy = torch.meshgrid(xs, ys, indexing="xy")
|
|
39
|
+
grid = torch.stack([gx.flatten(), gy.flatten()], dim=-1)
|
|
40
|
+
u = self.energy(grid).reshape(n, n)
|
|
41
|
+
return gx, gy, u
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DoubleWell1D(EnergyModel):
|
|
45
|
+
"""u(x) = a*(x^2 - 1)^2. Minima at ±1, barrier height = a (in kT)."""
|
|
46
|
+
|
|
47
|
+
def __init__(self, a: float = 4.0) -> None:
|
|
48
|
+
self.a = a
|
|
49
|
+
self.dim = 1
|
|
50
|
+
|
|
51
|
+
def __call__(self, x: Tensor) -> Tensor:
|
|
52
|
+
return self.energy(x)
|
|
53
|
+
|
|
54
|
+
def energy(self, x: Tensor) -> Tensor:
|
|
55
|
+
xx = x[..., 0] if x.ndim > 1 else x
|
|
56
|
+
return self.a * (xx.pow(2) - 1.0).pow(2)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""Müller-Brown potential.
|
|
2
|
+
|
|
3
|
+
Standard 2D benchmark for rare-event sampling. Three minima with two saddle
|
|
4
|
+
points connecting them. Parameters from Müller & Brown (1979).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
|
|
12
|
+
from ..base.energy import EnergyModel
|
|
13
|
+
|
|
14
|
+
_A = torch.tensor([-200.0, -100.0, -170.0, 15.0])
|
|
15
|
+
_a = torch.tensor([-1.0, -1.0, -6.5, 0.7])
|
|
16
|
+
_b = torch.tensor([0.0, 0.0, 11.0, 0.6])
|
|
17
|
+
_c = torch.tensor([-10.0, -10.0, -6.5, 0.7])
|
|
18
|
+
_x0 = torch.tensor([1.0, 0.0, -0.5, -1.0])
|
|
19
|
+
_y0 = torch.tensor([0.0, 0.5, 1.5, 1.0])
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MullerBrown(EnergyModel):
|
|
23
|
+
"""Reduced energy u(x) = U(x) / scale, with scale tuning barrier heights.
|
|
24
|
+
|
|
25
|
+
Native U has barriers ~100-200 (arbitrary units). For BG training we need
|
|
26
|
+
barriers of moderate height (a few kT), so we rescale by `scale`. Default
|
|
27
|
+
scale=20 yields barriers ~5-10 kT, reasonable for training.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, scale: float = 20.0) -> None:
|
|
31
|
+
self.scale = scale
|
|
32
|
+
self.dim = 2
|
|
33
|
+
|
|
34
|
+
def __call__(self, x: Tensor) -> Tensor:
|
|
35
|
+
return self.energy(x)
|
|
36
|
+
|
|
37
|
+
def energy(self, x: Tensor) -> Tensor:
|
|
38
|
+
assert x.shape[-1] == 2
|
|
39
|
+
xx = x[..., 0:1] # (..., 1)
|
|
40
|
+
yy = x[..., 1:2]
|
|
41
|
+
device = x.device
|
|
42
|
+
A = _A.to(device)
|
|
43
|
+
a = _a.to(device)
|
|
44
|
+
b = _b.to(device)
|
|
45
|
+
c = _c.to(device)
|
|
46
|
+
x0 = _x0.to(device)
|
|
47
|
+
y0 = _y0.to(device)
|
|
48
|
+
dx = xx - x0
|
|
49
|
+
dy = yy - y0
|
|
50
|
+
terms = A * torch.exp(a * dx.pow(2) + b * dx * dy + c * dy.pow(2))
|
|
51
|
+
U = terms.sum(dim=-1)
|
|
52
|
+
return U / self.scale
|
|
53
|
+
|
|
54
|
+
def grid(
|
|
55
|
+
self,
|
|
56
|
+
n: int = 200,
|
|
57
|
+
x_span: tuple[float, float] = (-1.7, 1.2),
|
|
58
|
+
y_span: tuple[float, float] = (-0.4, 2.1),
|
|
59
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
60
|
+
xs = torch.linspace(*x_span, n)
|
|
61
|
+
ys = torch.linspace(*y_span, n)
|
|
62
|
+
gx, gy = torch.meshgrid(xs, ys, indexing="xy")
|
|
63
|
+
grid = torch.stack([gx.flatten(), gy.flatten()], dim=-1)
|
|
64
|
+
u = self.energy(grid).reshape(n, n)
|
|
65
|
+
return gx, gy, u
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Extension points for future transferable and molecular BG capabilities."""
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from .base import Flow, FlowModel, GaussianPrior
|
|
2
|
+
from .cnf import CNFFlowModel, CNFModel, VelocityField
|
|
3
|
+
from .coupling import AffineCoupling
|
|
4
|
+
from .periodic import PeriodicEmbedding, periodic_inverse
|
|
5
|
+
from .realnvp import RealNVP, alternating_mask, halves_mask
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"Flow",
|
|
9
|
+
"FlowModel",
|
|
10
|
+
"GaussianPrior",
|
|
11
|
+
"AffineCoupling",
|
|
12
|
+
"RealNVP",
|
|
13
|
+
"alternating_mask",
|
|
14
|
+
"halves_mask",
|
|
15
|
+
"CNFModel",
|
|
16
|
+
"CNFFlowModel",
|
|
17
|
+
"VelocityField",
|
|
18
|
+
"PeriodicEmbedding",
|
|
19
|
+
"periodic_inverse",
|
|
20
|
+
]
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Base classes for normalizing flows.
|
|
2
|
+
|
|
3
|
+
Convention used everywhere in this package:
|
|
4
|
+
|
|
5
|
+
- `forward(z)` maps prior space → data space (sampling direction).
|
|
6
|
+
Returns `(x, log_det)` where `log_det = log|det df/dz|`.
|
|
7
|
+
- `inverse(x)` maps data space → prior space (density direction).
|
|
8
|
+
Returns `(z, log_det)` where `log_det = log|det df^-1/dx| = -log|det df/dz|`.
|
|
9
|
+
|
|
10
|
+
With this convention:
|
|
11
|
+
log p_X(x) = log p_Z(z) + log|det df^-1/dx| # the inverse log-det
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from torch import Tensor, nn
|
|
21
|
+
|
|
22
|
+
from ..base.density import BaseDensityModel
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Flow(nn.Module, ABC):
|
|
26
|
+
"""Invertible transformation z <-> x with tractable log-determinant."""
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def forward(self, z: Tensor) -> tuple[Tensor, Tensor]: ...
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def inverse(self, x: Tensor) -> tuple[Tensor, Tensor]: ...
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class GaussianPrior(nn.Module):
|
|
36
|
+
"""Standard normal prior N(0, I) of given dimension."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, dim: int) -> None:
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.dim = dim
|
|
41
|
+
self.register_buffer("_log_norm", torch.tensor(0.5 * dim * math.log(2 * math.pi)))
|
|
42
|
+
|
|
43
|
+
def sample(self, n: int, device: torch.device | str = "cpu") -> Tensor:
|
|
44
|
+
return torch.randn(n, self.dim, device=device)
|
|
45
|
+
|
|
46
|
+
def log_prob(self, z: Tensor) -> Tensor:
|
|
47
|
+
return -0.5 * z.pow(2).sum(dim=-1) - self._log_norm
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class FlowModel(BaseDensityModel):
|
|
51
|
+
"""Flow stack + prior. Provides sample, log_prob, and forward KL loss."""
|
|
52
|
+
|
|
53
|
+
def __init__(self, prior: GaussianPrior, flow: Flow) -> None:
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.prior = prior
|
|
56
|
+
self.flow = flow
|
|
57
|
+
|
|
58
|
+
def sample(self, n: int, device: torch.device | str = "cpu") -> tuple[Tensor, Tensor]:
|
|
59
|
+
"""Draw n samples in data space. Returns (x, log_prob_x)."""
|
|
60
|
+
z = self.prior.sample(n, device=device)
|
|
61
|
+
log_pz = self.prior.log_prob(z)
|
|
62
|
+
x, log_det_fwd = self.flow.forward(z)
|
|
63
|
+
log_px = log_pz - log_det_fwd
|
|
64
|
+
return x, log_px
|
|
65
|
+
|
|
66
|
+
def log_prob(self, x: Tensor) -> Tensor:
|
|
67
|
+
z, log_det_inv = self.flow.inverse(x)
|
|
68
|
+
return self.prior.log_prob(z) + log_det_inv
|