boltzmann-generators 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. boltzmann_generators/__init__.py +18 -0
  2. boltzmann_generators/analysis.py +36 -0
  3. boltzmann_generators/base/__init__.py +6 -0
  4. boltzmann_generators/base/density.py +36 -0
  5. boltzmann_generators/base/energy.py +20 -0
  6. boltzmann_generators/energies/__init__.py +5 -0
  7. boltzmann_generators/energies/dipeptide.py +81 -0
  8. boltzmann_generators/energies/double_well.py +56 -0
  9. boltzmann_generators/energies/muller.py +65 -0
  10. boltzmann_generators/extensions/__init__.py +1 -0
  11. boltzmann_generators/extensions/equivariant.py +7 -0
  12. boltzmann_generators/extensions/molecular.py +7 -0
  13. boltzmann_generators/extensions/transferable.py +7 -0
  14. boltzmann_generators/flows/__init__.py +20 -0
  15. boltzmann_generators/flows/base.py +68 -0
  16. boltzmann_generators/flows/cnf.py +266 -0
  17. boltzmann_generators/flows/coupling.py +89 -0
  18. boltzmann_generators/flows/periodic.py +27 -0
  19. boltzmann_generators/flows/realnvp.py +79 -0
  20. boltzmann_generators/io.py +34 -0
  21. boltzmann_generators/losses.py +70 -0
  22. boltzmann_generators/mcmc.py +68 -0
  23. boltzmann_generators/py.typed +0 -0
  24. boltzmann_generators/sampling.py +54 -0
  25. boltzmann_generators/services/__init__.py +7 -0
  26. boltzmann_generators/services/analysis.py +65 -0
  27. boltzmann_generators/services/checkpoint.py +43 -0
  28. boltzmann_generators/services/sampling.py +74 -0
  29. boltzmann_generators/train.py +40 -0
  30. boltzmann_generators/training/__init__.py +19 -0
  31. boltzmann_generators/training/loss_strategies.py +113 -0
  32. boltzmann_generators/training/trainer.py +94 -0
  33. boltzmann_generators-0.2.1.dist-info/METADATA +146 -0
  34. boltzmann_generators-0.2.1.dist-info/RECORD +36 -0
  35. boltzmann_generators-0.2.1.dist-info/WHEEL +4 -0
  36. boltzmann_generators-0.2.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,18 @@
1
+ """Boltzmann Generators — from-scratch PyTorch implementation."""
2
+
3
+ from . import analysis, base, energies, flows, io, losses, mcmc, sampling, services, train, training
4
+
5
+ __version__ = "0.2.1"
6
+ __all__ = [
7
+ "analysis",
8
+ "base",
9
+ "energies",
10
+ "flows",
11
+ "io",
12
+ "losses",
13
+ "mcmc",
14
+ "sampling",
15
+ "services",
16
+ "train",
17
+ "training",
18
+ ]
@@ -0,0 +1,36 @@
1
+ """Analysis helpers for weighted and unweighted population estimates."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable
6
+
7
+ from torch import Tensor
8
+
9
+ from .services.analysis import AnalysisSuite
10
+
11
+ RegionFn = Callable[[Tensor], Tensor]
12
+
13
+ __all__ = ["AnalysisSuite", "RegionFn", "basin_populations", "rectangular_region"]
14
+
15
+ _suite = AnalysisSuite()
16
+
17
+
18
+ def basin_populations(
19
+ x: Tensor,
20
+ region_fns: dict[str, RegionFn],
21
+ *,
22
+ log_w: Tensor | None = None,
23
+ ) -> dict[str, float]:
24
+ """Compute basin populations from point assignments or importance weights."""
25
+ return _suite.basin_populations(x, region_fns, log_w=log_w)
26
+
27
+
28
+ def rectangular_region(
29
+ *,
30
+ x_min: float,
31
+ x_max: float,
32
+ y_min: float,
33
+ y_max: float,
34
+ ) -> RegionFn:
35
+ """Create a rectangular region predicate over 2D coordinates."""
36
+ return AnalysisSuite.rectangular_region(x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max)
@@ -0,0 +1,6 @@
1
+ """Core abstract base classes for energies and density models."""
2
+
3
+ from .density import BaseDensityModel, DensityModel
4
+ from .energy import EnergyModel
5
+
6
+ __all__ = ["BaseDensityModel", "DensityModel", "EnergyModel"]
@@ -0,0 +1,36 @@
1
+ """Abstract base class for Boltzmann generator density models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import Protocol, runtime_checkable
7
+
8
+ import torch
9
+ from torch import Tensor, nn
10
+
11
+
12
+ @runtime_checkable
13
+ class DensityModel(Protocol):
14
+ """Structural typing contract for density models used in losses and sampling."""
15
+
16
+ def sample(self, n: int, device: torch.device | str = "cpu") -> tuple[Tensor, Tensor]:
17
+ """Draw ``n`` samples; return ``(x, log_q(x))``."""
18
+
19
+ def log_prob(self, x: Tensor) -> Tensor:
20
+ """Log-density ``log q(x)`` for batch ``x``."""
21
+
22
+
23
+ class BaseDensityModel(nn.Module, ABC):
24
+ """PyTorch module implementing a tractable approximate Boltzmann density."""
25
+
26
+ @abstractmethod
27
+ def sample(self, n: int, device: torch.device | str = "cpu") -> tuple[Tensor, Tensor]:
28
+ """Draw ``n`` samples in data space; return ``(x, log_q(x))``."""
29
+
30
+ @abstractmethod
31
+ def log_prob(self, x: Tensor) -> Tensor:
32
+ """Log-density ``log q(x)`` for batch ``x``."""
33
+
34
+ def nll(self, x: Tensor) -> Tensor:
35
+ """Negative log-likelihood (forward KL up to constant). Mean over batch."""
36
+ return -self.log_prob(x).mean()
@@ -0,0 +1,20 @@
1
+ """Abstract base class for benchmark energy functions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+
7
+ from torch import Tensor
8
+
9
+
10
+ class EnergyModel(ABC):
11
+ """Reduced energy u(x) = U(x) / (kT) in dimension ``dim``."""
12
+
13
+ dim: int
14
+
15
+ @abstractmethod
16
+ def energy(self, x: Tensor) -> Tensor:
17
+ """Evaluate reduced energy on batch ``x`` of shape ``(..., dim)``."""
18
+
19
+ def __call__(self, x: Tensor) -> Tensor:
20
+ return self.energy(x)
@@ -0,0 +1,5 @@
1
+ from .dipeptide import RamachandranDipeptide
2
+ from .double_well import DoubleWell1D, DoubleWell2D
3
+ from .muller import MullerBrown
4
+
5
+ __all__ = ["DoubleWell1D", "DoubleWell2D", "MullerBrown", "RamachandranDipeptide"]
@@ -0,0 +1,81 @@
1
+ """Synthetic Ramachandran-like potential — a stand-in for alanine dipeptide.
2
+
3
+ Real alanine dipeptide has a ~50-atom Cartesian description and a CHARMM/AMBER
4
+ force field. Implementing that here would pull in OpenMM or a full from-scratch
5
+ MM energy, both of which are out of scope.
6
+
7
+ Instead we model the *free-energy surface* in dihedral space (phi, psi) directly,
8
+ using a sum of Gaussian wells at the canonical alanine dipeptide minima. This
9
+ gives a 2D periodic potential with the right qualitative structure (alpha_R,
10
+ alpha_L, beta/PPII, C5/C7eq) so we can demonstrate BG/CFM training on a
11
+ molecular-flavored target without molecular machinery.
12
+
13
+ Coordinates: (phi, psi) in degrees, periodic in [-180, 180].
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import torch
19
+ from torch import Tensor
20
+
21
+ from ..base.energy import EnergyModel
22
+
23
+ # Approximate alanine dipeptide minima (degrees), depths in kT
24
+ _MINIMA = torch.tensor(
25
+ [
26
+ [-65.0, -40.0], # alpha_R (deepest)
27
+ [-150.0, 155.0], # C5 / beta
28
+ [-80.0, 80.0], # PPII / C7eq
29
+ [65.0, 40.0], # alpha_L (shallow)
30
+ ]
31
+ )
32
+ _DEPTHS = torch.tensor([6.0, 5.0, 4.5, 2.5]) # in kT
33
+ _WIDTHS = torch.tensor([22.0, 28.0, 30.0, 30.0]) # degrees
34
+
35
+
36
+ def _wrap_deg(d: Tensor) -> Tensor:
37
+ """Wrap angle differences to (-180, 180]."""
38
+ return (d + 180.0) % 360.0 - 180.0
39
+
40
+
41
+ class RamachandranDipeptide(EnergyModel):
42
+ """Synthetic 2D dipeptide free-energy surface in (phi, psi) degrees.
43
+
44
+ u(phi, psi) = -log sum_k exp(-d_k(phi, psi) / w_k^2 + log depth_k) + const
45
+
46
+ Each well is a Gaussian in periodic-angle distance. Total potential is a
47
+ smooth log-sum-exp combination so derivatives are well-defined.
48
+ """
49
+
50
+ def __init__(self) -> None:
51
+ self.dim = 2
52
+
53
+ def __call__(self, x: Tensor) -> Tensor:
54
+ return self.energy(x)
55
+
56
+ def energy(self, x: Tensor) -> Tensor:
57
+ """x: (..., 2) in degrees. Returns reduced energy (kT units)."""
58
+ device = x.device
59
+ minima = _MINIMA.to(device)
60
+ depths = _DEPTHS.to(device)
61
+ widths = _WIDTHS.to(device)
62
+ # Periodic squared distance to each minimum
63
+ dphi = _wrap_deg(x[..., 0:1] - minima[:, 0]) # (..., K)
64
+ dpsi = _wrap_deg(x[..., 1:2] - minima[:, 1])
65
+ d2 = (dphi.pow(2) + dpsi.pow(2)) / widths.pow(2)
66
+ # log-sum-exp combination: u = -log sum exp(depth - d2)
67
+ logits = depths - d2
68
+ u = -torch.logsumexp(logits, dim=-1)
69
+ return u
70
+
71
+ def grid(self, n: int = 200) -> tuple[Tensor, Tensor, Tensor]:
72
+ xs = torch.linspace(-180, 180, n)
73
+ ys = torch.linspace(-180, 180, n)
74
+ gx, gy = torch.meshgrid(xs, ys, indexing="xy")
75
+ grid = torch.stack([gx.flatten(), gy.flatten()], dim=-1)
76
+ u = self.energy(grid).reshape(n, n)
77
+ return gx, gy, u
78
+
79
+ @property
80
+ def minima(self) -> Tensor:
81
+ return _MINIMA.clone()
@@ -0,0 +1,56 @@
1
+ """Double-well potentials in 1D and 2D.
2
+
3
+ All energies are returned as reduced (unitless) energies u(x) = U(x)/(kT).
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import torch
9
+ from torch import Tensor
10
+
11
+ from ..base.energy import EnergyModel
12
+
13
+
14
+ class DoubleWell2D(EnergyModel):
15
+ """u(x, y) = a*(x^2 - 1)^2 + 0.5/sigma^2 * y^2.
16
+
17
+ Two minima at (±1, 0). The y direction is a harmonic well of width sigma.
18
+ Parameter `a` controls the barrier height: barrier ≈ a (in kT).
19
+ """
20
+
21
+ def __init__(self, a: float = 4.0, sigma_y: float = 0.5) -> None:
22
+ self.a = a
23
+ self.sigma_y = sigma_y
24
+ self.dim = 2
25
+
26
+ def __call__(self, x: Tensor) -> Tensor:
27
+ return self.energy(x)
28
+
29
+ def energy(self, x: Tensor) -> Tensor:
30
+ assert x.shape[-1] == 2
31
+ xx = x[..., 0]
32
+ yy = x[..., 1]
33
+ return self.a * (xx.pow(2) - 1.0).pow(2) + 0.5 * (yy / self.sigma_y).pow(2)
34
+
35
+ def grid(self, n: int = 200, span: float = 2.5) -> tuple[Tensor, Tensor, Tensor]:
36
+ xs = torch.linspace(-span, span, n)
37
+ ys = torch.linspace(-span, span, n)
38
+ gx, gy = torch.meshgrid(xs, ys, indexing="xy")
39
+ grid = torch.stack([gx.flatten(), gy.flatten()], dim=-1)
40
+ u = self.energy(grid).reshape(n, n)
41
+ return gx, gy, u
42
+
43
+
44
+ class DoubleWell1D(EnergyModel):
45
+ """u(x) = a*(x^2 - 1)^2. Minima at ±1, barrier height = a (in kT)."""
46
+
47
+ def __init__(self, a: float = 4.0) -> None:
48
+ self.a = a
49
+ self.dim = 1
50
+
51
+ def __call__(self, x: Tensor) -> Tensor:
52
+ return self.energy(x)
53
+
54
+ def energy(self, x: Tensor) -> Tensor:
55
+ xx = x[..., 0] if x.ndim > 1 else x
56
+ return self.a * (xx.pow(2) - 1.0).pow(2)
@@ -0,0 +1,65 @@
1
+ """Müller-Brown potential.
2
+
3
+ Standard 2D benchmark for rare-event sampling. Three minima with two saddle
4
+ points connecting them. Parameters from Müller & Brown (1979).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+ from torch import Tensor
11
+
12
+ from ..base.energy import EnergyModel
13
+
14
+ _A = torch.tensor([-200.0, -100.0, -170.0, 15.0])
15
+ _a = torch.tensor([-1.0, -1.0, -6.5, 0.7])
16
+ _b = torch.tensor([0.0, 0.0, 11.0, 0.6])
17
+ _c = torch.tensor([-10.0, -10.0, -6.5, 0.7])
18
+ _x0 = torch.tensor([1.0, 0.0, -0.5, -1.0])
19
+ _y0 = torch.tensor([0.0, 0.5, 1.5, 1.0])
20
+
21
+
22
+ class MullerBrown(EnergyModel):
23
+ """Reduced energy u(x) = U(x) / scale, with scale tuning barrier heights.
24
+
25
+ Native U has barriers ~100-200 (arbitrary units). For BG training we need
26
+ barriers of moderate height (a few kT), so we rescale by `scale`. Default
27
+ scale=20 yields barriers ~5-10 kT, reasonable for training.
28
+ """
29
+
30
+ def __init__(self, scale: float = 20.0) -> None:
31
+ self.scale = scale
32
+ self.dim = 2
33
+
34
+ def __call__(self, x: Tensor) -> Tensor:
35
+ return self.energy(x)
36
+
37
+ def energy(self, x: Tensor) -> Tensor:
38
+ assert x.shape[-1] == 2
39
+ xx = x[..., 0:1] # (..., 1)
40
+ yy = x[..., 1:2]
41
+ device = x.device
42
+ A = _A.to(device)
43
+ a = _a.to(device)
44
+ b = _b.to(device)
45
+ c = _c.to(device)
46
+ x0 = _x0.to(device)
47
+ y0 = _y0.to(device)
48
+ dx = xx - x0
49
+ dy = yy - y0
50
+ terms = A * torch.exp(a * dx.pow(2) + b * dx * dy + c * dy.pow(2))
51
+ U = terms.sum(dim=-1)
52
+ return U / self.scale
53
+
54
+ def grid(
55
+ self,
56
+ n: int = 200,
57
+ x_span: tuple[float, float] = (-1.7, 1.2),
58
+ y_span: tuple[float, float] = (-0.4, 2.1),
59
+ ) -> tuple[Tensor, Tensor, Tensor]:
60
+ xs = torch.linspace(*x_span, n)
61
+ ys = torch.linspace(*y_span, n)
62
+ gx, gy = torch.meshgrid(xs, ys, indexing="xy")
63
+ grid = torch.stack([gx.flatten(), gy.flatten()], dim=-1)
64
+ u = self.energy(grid).reshape(n, n)
65
+ return gx, gy, u
@@ -0,0 +1 @@
1
+ """Extension points for future transferable and molecular BG capabilities."""
@@ -0,0 +1,7 @@
1
+ """Roadmap stub: E(3)-equivariant flow architectures.
2
+
3
+ Planned direction:
4
+ - EGNN-based velocity fields.
5
+ - SO(3)/E(3)-aware coupling transformations.
6
+ - Support for coordinates + atom/token conditioning.
7
+ """
@@ -0,0 +1,7 @@
1
+ """Roadmap stub: molecular-system interfaces.
2
+
3
+ Planned direction:
4
+ - OpenMM-backed energy wrappers.
5
+ - Internal-coordinate preprocessing (bond/angle/dihedral).
6
+ - Dataset adapters for trajectory sources (e.g., mdshare).
7
+ """
@@ -0,0 +1,7 @@
1
+ """Roadmap stub: transferable Boltzmann Generator components.
2
+
3
+ Planned direction:
4
+ - Tokenized chemistry conditioning.
5
+ - Shared backbone across molecules.
6
+ - Joint CFM + reweighting pipelines for zero-shot transfer.
7
+ """
@@ -0,0 +1,20 @@
1
+ from .base import Flow, FlowModel, GaussianPrior
2
+ from .cnf import CNFFlowModel, CNFModel, VelocityField
3
+ from .coupling import AffineCoupling
4
+ from .periodic import PeriodicEmbedding, periodic_inverse
5
+ from .realnvp import RealNVP, alternating_mask, halves_mask
6
+
7
+ __all__ = [
8
+ "Flow",
9
+ "FlowModel",
10
+ "GaussianPrior",
11
+ "AffineCoupling",
12
+ "RealNVP",
13
+ "alternating_mask",
14
+ "halves_mask",
15
+ "CNFModel",
16
+ "CNFFlowModel",
17
+ "VelocityField",
18
+ "PeriodicEmbedding",
19
+ "periodic_inverse",
20
+ ]
@@ -0,0 +1,68 @@
1
+ """Base classes for normalizing flows.
2
+
3
+ Convention used everywhere in this package:
4
+
5
+ - `forward(z)` maps prior space → data space (sampling direction).
6
+ Returns `(x, log_det)` where `log_det = log|det df/dz|`.
7
+ - `inverse(x)` maps data space → prior space (density direction).
8
+ Returns `(z, log_det)` where `log_det = log|det df^-1/dx| = -log|det df/dz|`.
9
+
10
+ With this convention:
11
+ log p_X(x) = log p_Z(z) + log|det df^-1/dx| # the inverse log-det
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import math
17
+ from abc import ABC, abstractmethod
18
+
19
+ import torch
20
+ from torch import Tensor, nn
21
+
22
+ from ..base.density import BaseDensityModel
23
+
24
+
25
+ class Flow(nn.Module, ABC):
26
+ """Invertible transformation z <-> x with tractable log-determinant."""
27
+
28
+ @abstractmethod
29
+ def forward(self, z: Tensor) -> tuple[Tensor, Tensor]: ...
30
+
31
+ @abstractmethod
32
+ def inverse(self, x: Tensor) -> tuple[Tensor, Tensor]: ...
33
+
34
+
35
+ class GaussianPrior(nn.Module):
36
+ """Standard normal prior N(0, I) of given dimension."""
37
+
38
+ def __init__(self, dim: int) -> None:
39
+ super().__init__()
40
+ self.dim = dim
41
+ self.register_buffer("_log_norm", torch.tensor(0.5 * dim * math.log(2 * math.pi)))
42
+
43
+ def sample(self, n: int, device: torch.device | str = "cpu") -> Tensor:
44
+ return torch.randn(n, self.dim, device=device)
45
+
46
+ def log_prob(self, z: Tensor) -> Tensor:
47
+ return -0.5 * z.pow(2).sum(dim=-1) - self._log_norm
48
+
49
+
50
+ class FlowModel(BaseDensityModel):
51
+ """Flow stack + prior. Provides sample, log_prob, and forward KL loss."""
52
+
53
+ def __init__(self, prior: GaussianPrior, flow: Flow) -> None:
54
+ super().__init__()
55
+ self.prior = prior
56
+ self.flow = flow
57
+
58
+ def sample(self, n: int, device: torch.device | str = "cpu") -> tuple[Tensor, Tensor]:
59
+ """Draw n samples in data space. Returns (x, log_prob_x)."""
60
+ z = self.prior.sample(n, device=device)
61
+ log_pz = self.prior.log_prob(z)
62
+ x, log_det_fwd = self.flow.forward(z)
63
+ log_px = log_pz - log_det_fwd
64
+ return x, log_px
65
+
66
+ def log_prob(self, x: Tensor) -> Tensor:
67
+ z, log_det_inv = self.flow.inverse(x)
68
+ return self.prior.log_prob(z) + log_det_inv