pyautoencoder 1.0.4__tar.gz → 1.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. pyautoencoder-1.0.6/PKG-INFO +162 -0
  2. pyautoencoder-1.0.6/README.md +131 -0
  3. pyautoencoder-1.0.6/pyautoencoder/__init__.py +15 -0
  4. pyautoencoder-1.0.6/pyautoencoder/loss/__init__.py +12 -0
  5. pyautoencoder-1.0.6/pyautoencoder/loss/base.py +67 -0
  6. pyautoencoder-1.0.6/pyautoencoder/loss/vae.py +88 -0
  7. pyautoencoder-1.0.6/pyautoencoder/loss/wrapper.py +213 -0
  8. pyautoencoder-1.0.6/pyautoencoder/models/__init__.py +4 -0
  9. pyautoencoder-1.0.6/pyautoencoder/models/autoencoder.py +90 -0
  10. pyautoencoder-1.0.6/pyautoencoder/models/base.py +152 -0
  11. pyautoencoder-1.0.6/pyautoencoder/models/variational/__init__.py +3 -0
  12. pyautoencoder-1.0.6/pyautoencoder/models/variational/stochastic_layers.py +70 -0
  13. pyautoencoder-1.0.6/pyautoencoder/models/variational/vae.py +142 -0
  14. pyautoencoder-1.0.6/pyautoencoder.egg-info/PKG-INFO +162 -0
  15. {pyautoencoder-1.0.4 → pyautoencoder-1.0.6}/pyautoencoder.egg-info/SOURCES.txt +5 -1
  16. {pyautoencoder-1.0.4 → pyautoencoder-1.0.6}/setup.py +1 -1
  17. pyautoencoder-1.0.4/PKG-INFO +0 -99
  18. pyautoencoder-1.0.4/README.md +0 -68
  19. pyautoencoder-1.0.4/pyautoencoder/__init__.py +0 -4
  20. pyautoencoder-1.0.4/pyautoencoder/loss.py +0 -92
  21. pyautoencoder-1.0.4/pyautoencoder/models/__init__.py +0 -4
  22. pyautoencoder-1.0.4/pyautoencoder/models/autoencoder.py +0 -39
  23. pyautoencoder-1.0.4/pyautoencoder/models/variational/__init__.py +0 -3
  24. pyautoencoder-1.0.4/pyautoencoder/models/variational/stochastic_layers.py +0 -33
  25. pyautoencoder-1.0.4/pyautoencoder/models/variational/vae.py +0 -65
  26. pyautoencoder-1.0.4/pyautoencoder.egg-info/PKG-INFO +0 -99
  27. {pyautoencoder-1.0.4 → pyautoencoder-1.0.6}/LICENSE +0 -0
  28. {pyautoencoder-1.0.4 → pyautoencoder-1.0.6}/pyautoencoder.egg-info/dependency_links.txt +0 -0
  29. {pyautoencoder-1.0.4 → pyautoencoder-1.0.6}/pyautoencoder.egg-info/requires.txt +0 -0
  30. {pyautoencoder-1.0.4 → pyautoencoder-1.0.6}/pyautoencoder.egg-info/top_level.txt +0 -0
  31. {pyautoencoder-1.0.4 → pyautoencoder-1.0.6}/setup.cfg +0 -0
@@ -0,0 +1,162 @@
1
+ Metadata-Version: 2.4
2
+ Name: pyautoencoder
3
+ Version: 1.0.6
4
+ Summary: A Python package offering implementations of state-of-the-art autoencoder architectures in PyTorch.
5
+ Home-page: https://github.com/andrea-pollastro/pyautoencoder
6
+ Author: Andrea Pollastro
7
+ License: MIT
8
+ Keywords: autoencoder,pytorch,deep learning,machine learning,representation learning,dimensionality reduction,generative models
9
+ Classifier: Operating System :: OS Independent
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python
14
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
15
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
16
+ Requires-Python: >=3.7
17
+ Description-Content-Type: text/markdown
18
+ License-File: LICENSE
19
+ Requires-Dist: torch>=2.0.0
20
+ Dynamic: author
21
+ Dynamic: classifier
22
+ Dynamic: description
23
+ Dynamic: description-content-type
24
+ Dynamic: home-page
25
+ Dynamic: keywords
26
+ Dynamic: license
27
+ Dynamic: license-file
28
+ Dynamic: requires-dist
29
+ Dynamic: requires-python
30
+ Dynamic: summary
31
+
32
+ # PyAutoencoder
33
+
34
+ A clean, modular PyTorch library for building and training autoencoders.
35
+
36
+ <!-- <p align="center">
37
+ <img src="assets/logo_nobackground.png" alt="pyautoencoder logo" width="220"/>
38
+ </p> -->
39
+
40
+ ![logo](https://raw.githubusercontent.com/andrea-pollastro/pyautoencoder/main/assets/logo_nobackground.png)
41
+
42
+ <p align="center">
43
+ <a href="https://pypi.org/project/pyautoencoder/"><img alt="PyPI" src="https://img.shields.io/pypi/v/pyautoencoder.svg"></a>
44
+ <a href="https://github.com/andrea-pollastro/pyautoencoder/blob/main/LICENSE"><img alt="License: MIT" src="https://img.shields.io/badge/License-MIT-blue.svg"></a>
45
+ <a href="https://github.com/andrea-pollastro/pyautoencoder/stargazers"><img alt="Stars" src="https://img.shields.io/github/stars/andrea-pollastro/pyautoencoder?style=social"></a>
46
+ </p>
47
+
48
+ ---
49
+
50
+ ## Highlights
51
+
52
+ PyAutoencoder is designed to offer **simple and easy access to autoencoder frameworks**. Here's what it offers:
53
+
54
+ - **Minimal, composable API**
55
+ You don't have to inherit from complicated base classes or learn a new training loop. Simply provide your own PyTorch nn.Module encoder and decoder, and plug them into the ready‑to‑use autoencoder wrappers. This makes it easy to experiment with different architectures (e.g. MLPs, CNNs) while reusing the same training pipeline.
56
+
57
+ - **Ready‑to‑use autoencoders**
58
+ The library ships with working implementations of autoencoders, each paired with their respective loss functions. You can start training in a few lines, without re‑implementing reconstruction likelihoods, KL divergence, or other boilerplate.
59
+
60
+ - **PyTorch compatibility**
61
+ The library is fully compatible with the PyTorch ecosystem, so models integrate naturally with modules, tensors, optimizers, and schedulers.
62
+
63
+ - **Lightweight, research‑oriented**
64
+ The library is intentionally minimal: no training loop frameworks, no heavy abstractions. This makes it well suited for research prototypes where you want control and transparency.
65
+
66
+ > **Status**: The project is in an early but usable stage. Contributions, issues, and feedback are highly encouraged!
67
+
68
+ **Currently implemented**:
69
+ - Autoencoder (AE)
70
+ - Variational Autoencoder (VAE)
71
+
72
+ ---
73
+
74
+ ## Installation
75
+
76
+ ```bash
77
+ pip install pyautoencoder
78
+ ```
79
+
80
+ Or install from source for development:
81
+
82
+ ```bash
83
+ git clone https://github.com/andrea-pollastro/pyautoencoder.git
84
+ cd pyautoencoder
85
+ pip install -e .
86
+ ```
87
+
88
+ ## Quick start
89
+
90
+ ```python
91
+ import torch
92
+ import torch.nn as nn
93
+ from pyautoencoder import VAE, VAELoss
94
+
95
+ # Define encoder/decoder
96
+ encoder = nn.Sequential(
97
+ nn.Linear(784, 512),
98
+ nn.ReLU(),
99
+ nn.Linear(512, 256)
100
+ )
101
+
102
+ decoder = nn.Sequential(
103
+ nn.Linear(256, 512),
104
+ nn.ReLU(),
105
+ nn.Linear(512, 784)
106
+ )
107
+
108
+ # Model
109
+ vae = VAE(encoder=encoder, decoder=decoder, latent_dim=32)
110
+
111
+ # Loss
112
+ criterion = VAELoss(beta=1.0, likelihood="gaussian")
113
+ optimizer = torch.optim.Adam(vae.parameters())
114
+ for x in dataloader:
115
+ optimizer.zero_grad()
116
+ out = vae(x)
117
+ losses = criterion(out, x)
118
+ losses.total.backward() # negative ELBO
119
+ optimizer.step()
120
+
121
+ # optional: log components
122
+ log_likelihood = losses.components["log_likelihood"]
123
+ kl_divergence = losses.components["kl_divergence"]
124
+ ```
125
+
126
+ ## Built‑in models
127
+
128
+ - **`AE`** — standard Autoencoder
129
+ ```python
130
+ from pyautoencoder import AE, AutoencoderLoss
131
+ ae = AE(encoder=encoder, decoder=decoder)
132
+ criterion = AutoencoderLoss(likelihood="gaussian") # or bernoulli
133
+ ```
134
+
135
+ - **`VAE`** — Variational Autoencoder
136
+ ```python
137
+ from pyautoencoder import VAE, VAELoss
138
+ vae = VAE(encoder=encoder, decoder=decoder, latent_dim=32)
139
+ criterion = VAELoss(beta=1.0, likelihood="gaussian") # or bernoulli
140
+ ```
141
+
142
+ ## Examples
143
+
144
+ See the [`examples/`](examples/) folder for runnable scripts showing example of usage.
145
+
146
+ ## License
147
+
148
+ This project is released under the **MIT License**. See [LICENSE](LICENSE).
149
+
150
+ ## Citation
151
+
152
+ If you use this package in academic work, please cite:
153
+
154
+ ```bibtex
155
+ @misc{pollastro2025pyautoencoder,
156
+ author = {Andrea Pollastro},
157
+ title = {pyautoencoder},
158
+ year = {2025},
159
+ howpublished = {GitHub repository},
160
+ url = {https://github.com/andrea-pollastro/pyautoencoder}
161
+ }
162
+ ```
@@ -0,0 +1,131 @@
1
+ # PyAutoencoder
2
+
3
+ A clean, modular PyTorch library for building and training autoencoders.
4
+
5
+ <!-- <p align="center">
6
+ <img src="assets/logo_nobackground.png" alt="pyautoencoder logo" width="220"/>
7
+ </p> -->
8
+
9
+ ![logo](https://raw.githubusercontent.com/andrea-pollastro/pyautoencoder/main/assets/logo_nobackground.png)
10
+
11
+ <p align="center">
12
+ <a href="https://pypi.org/project/pyautoencoder/"><img alt="PyPI" src="https://img.shields.io/pypi/v/pyautoencoder.svg"></a>
13
+ <a href="https://github.com/andrea-pollastro/pyautoencoder/blob/main/LICENSE"><img alt="License: MIT" src="https://img.shields.io/badge/License-MIT-blue.svg"></a>
14
+ <a href="https://github.com/andrea-pollastro/pyautoencoder/stargazers"><img alt="Stars" src="https://img.shields.io/github/stars/andrea-pollastro/pyautoencoder?style=social"></a>
15
+ </p>
16
+
17
+ ---
18
+
19
+ ## Highlights
20
+
21
+ PyAutoencoder is designed to offer **simple and easy access to autoencoder frameworks**. Here's what it offers:
22
+
23
+ - **Minimal, composable API**
24
+ You don't have to inherit from complicated base classes or learn a new training loop. Simply provide your own PyTorch nn.Module encoder and decoder, and plug them into the ready‑to‑use autoencoder wrappers. This makes it easy to experiment with different architectures (e.g. MLPs, CNNs) while reusing the same training pipeline.
25
+
26
+ - **Ready‑to‑use autoencoders**
27
+ The library ships with working implementations of autoencoders, each paired with their respective loss functions. You can start training in a few lines, without re‑implementing reconstruction likelihoods, KL divergence, or other boilerplate.
28
+
29
+ - **PyTorch compatibility**
30
+ The library is fully compatible with the PyTorch ecosystem, so models integrate naturally with modules, tensors, optimizers, and schedulers.
31
+
32
+ - **Lightweight, research‑oriented**
33
+ The library is intentionally minimal: no training loop frameworks, no heavy abstractions. This makes it well suited for research prototypes where you want control and transparency.
34
+
35
+ > **Status**: The project is in an early but usable stage. Contributions, issues, and feedback are highly encouraged!
36
+
37
+ **Currently implemented**:
38
+ - Autoencoder (AE)
39
+ - Variational Autoencoder (VAE)
40
+
41
+ ---
42
+
43
+ ## Installation
44
+
45
+ ```bash
46
+ pip install pyautoencoder
47
+ ```
48
+
49
+ Or install from source for development:
50
+
51
+ ```bash
52
+ git clone https://github.com/andrea-pollastro/pyautoencoder.git
53
+ cd pyautoencoder
54
+ pip install -e .
55
+ ```
56
+
57
+ ## Quick start
58
+
59
+ ```python
60
+ import torch
61
+ import torch.nn as nn
62
+ from pyautoencoder import VAE, VAELoss
63
+
64
+ # Define encoder/decoder
65
+ encoder = nn.Sequential(
66
+ nn.Linear(784, 512),
67
+ nn.ReLU(),
68
+ nn.Linear(512, 256)
69
+ )
70
+
71
+ decoder = nn.Sequential(
72
+ nn.Linear(256, 512),
73
+ nn.ReLU(),
74
+ nn.Linear(512, 784)
75
+ )
76
+
77
+ # Model
78
+ vae = VAE(encoder=encoder, decoder=decoder, latent_dim=32)
79
+
80
+ # Loss
81
+ criterion = VAELoss(beta=1.0, likelihood="gaussian")
82
+ optimizer = torch.optim.Adam(vae.parameters())
83
+ for x in dataloader:
84
+ optimizer.zero_grad()
85
+ out = vae(x)
86
+ losses = criterion(out, x)
87
+ losses.total.backward() # negative ELBO
88
+ optimizer.step()
89
+
90
+ # optional: log components
91
+ log_likelihood = losses.components["log_likelihood"]
92
+ kl_divergence = losses.components["kl_divergence"]
93
+ ```
94
+
95
+ ## Built‑in models
96
+
97
+ - **`AE`** — standard Autoencoder
98
+ ```python
99
+ from pyautoencoder import AE, AutoencoderLoss
100
+ ae = AE(encoder=encoder, decoder=decoder)
101
+ criterion = AutoencoderLoss(likelihood="gaussian") # or bernoulli
102
+ ```
103
+
104
+ - **`VAE`** — Variational Autoencoder
105
+ ```python
106
+ from pyautoencoder import VAE, VAELoss
107
+ vae = VAE(encoder=encoder, decoder=decoder, latent_dim=32)
108
+ criterion = VAELoss(beta=1.0, likelihood="gaussian") # or bernoulli
109
+ ```
110
+
111
+ ## Examples
112
+
113
+ See the [`examples/`](examples/) folder for runnable scripts showing example of usage.
114
+
115
+ ## License
116
+
117
+ This project is released under the **MIT License**. See [LICENSE](LICENSE).
118
+
119
+ ## Citation
120
+
121
+ If you use this package in academic work, please cite:
122
+
123
+ ```bibtex
124
+ @misc{pollastro2025pyautoencoder,
125
+ author = {Andrea Pollastro},
126
+ title = {pyautoencoder},
127
+ year = {2025},
128
+ howpublished = {GitHub repository},
129
+ url = {https://github.com/andrea-pollastro/pyautoencoder}
130
+ }
131
+ ```
@@ -0,0 +1,15 @@
1
+ """PyAutoencoder: A clean, modular PyTorch library for autoencoder models."""
2
+
3
+ from .models.autoencoder import AE
4
+ from .models.variational.vae import VAE
5
+ from .loss.wrapper import AELoss, VAELoss, LossComponents
6
+
7
+ __version__ = "0.1.0"
8
+
9
+ __all__ = [
10
+ 'AE',
11
+ 'VAE',
12
+ 'AELoss',
13
+ 'VAELoss',
14
+ 'LossComponents'
15
+ ]
@@ -0,0 +1,12 @@
1
+ """Loss functions and wrappers for autoencoders."""
2
+ from .wrapper import AELoss, VAELoss, LossComponents
3
+ from .base import log_likelihood
4
+ from .vae import kl_divergence_gaussian
5
+
6
+ __all__ = [
7
+ 'AELoss',
8
+ 'VAELoss',
9
+ 'LossComponents',
10
+ 'log_likelihood',
11
+ 'kl_divergence_gaussian'
12
+ ]
@@ -0,0 +1,67 @@
1
+ """Base loss functions for autoencoders."""
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from typing import Union
6
+ from enum import Enum
7
+
8
+ class LikelihoodType(Enum):
9
+ """Supported likelihood types for the decoder distribution p(x|z)."""
10
+ GAUSSIAN = 'gaussian'
11
+ BERNOULLI = 'bernoulli'
12
+
13
+ # Cache for log(2π) constants per (device, dtype)
14
+ _LOG2PI_CACHE = {}
15
+
16
+ def _get_log2pi(x: torch.Tensor) -> torch.Tensor:
17
+ """Return log(2π) cached for the given device/dtype."""
18
+ key = (x.device, x.dtype)
19
+ if key not in _LOG2PI_CACHE:
20
+ _LOG2PI_CACHE[key] = torch.tensor(2.0 * math.pi, device=x.device, dtype=x.dtype).log()
21
+ return _LOG2PI_CACHE[key]
22
+
23
+ def log_likelihood(x: torch.Tensor,
24
+ x_hat: torch.Tensor,
25
+ likelihood: Union[str, LikelihoodType] = LikelihoodType.GAUSSIAN) -> torch.Tensor:
26
+ """
27
+ Computes elementwise log-likelihood log p(x|x_hat) under different likelihood assumptions.
28
+
29
+ For continuous data:
30
+ Gaussian (σ² = 1):
31
+ log p(x|x_hat) = -0.5 * [ (x - x_hat)^2 + log(2π) ]
32
+ Each dimension contributes independently. To obtain per-sample log-likelihoods,
33
+ sum over feature dimensions.
34
+
35
+ For discrete data:
36
+ Bernoulli:
37
+ log p(x|x_hat) = x * log σ(x_hat) + (1 - x) * log(1 - σ(x_hat)),
38
+ where σ is the sigmoid function and x_hat are logits.
39
+
40
+ Args:
41
+ x (torch.Tensor): Ground truth tensor.
42
+ x_hat (torch.Tensor): Reconstructed tensor. For Bernoulli, values are logits.
43
+ likelihood (Union[str, LikelihoodType]): Choice of likelihood model. Defaults to Gaussian.
44
+
45
+ Returns:
46
+ torch.Tensor: Elementwise log-likelihood with the same shape as `x`.
47
+ For multi-dimensional inputs, reduce across feature dimensions
48
+ to obtain per-sample log-likelihoods.
49
+
50
+ Notes:
51
+ - Bernoulli case uses a numerically stable BCE implementation in log-space.
52
+ - Gaussian case assumes fixed unit variance (σ²=1) and includes the normalization constant.
53
+ - log(2π) is cached per (device, dtype) for efficiency.
54
+ """
55
+ if isinstance(likelihood, str):
56
+ likelihood = LikelihoodType(likelihood.lower())
57
+
58
+ if likelihood == LikelihoodType.BERNOULLI:
59
+ return -F.binary_cross_entropy_with_logits(x_hat, x, reduction='none')
60
+
61
+ elif likelihood == LikelihoodType.GAUSSIAN:
62
+ squared_error = (x_hat - x) ** 2
63
+ log_2pi = _get_log2pi(x)
64
+ return -0.5 * (squared_error + log_2pi)
65
+
66
+ else:
67
+ raise ValueError(f"Unsupported likelihood: {likelihood}")
@@ -0,0 +1,88 @@
1
+ """Loss functions for variational autoencoders with rigorous mathematical implementations."""
2
+ import torch
3
+ from typing import Union, NamedTuple
4
+
5
+ from .base import log_likelihood, LikelihoodType
6
+
7
+ class ELBOComponents(NamedTuple):
8
+ """Components of the ELBO computation."""
9
+ elbo: torch.Tensor # scalar: batch-mean ELBO (with grad)
10
+ log_likelihood: torch.Tensor # scalar: batch-mean E_q[log p(x|z)]
11
+ beta_kl_divergence: torch.Tensor # scalar: batch-mean β * KL(q||p)
12
+
13
+ def kl_divergence_gaussian(mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
14
+ """
15
+ Computes the KL divergence KL(q(z|x) || p(z)) between the approximate
16
+ posterior q(z|x) = N(μ, σ²) and the standard normal prior p(z) = N(0, I).
17
+
18
+ Args:
19
+ mu (torch.Tensor): Mean of q(z|x), shape [B, D_z].
20
+ log_var (torch.Tensor): Log-variance of q(z|x), shape [B, D_z].
21
+
22
+ Returns:
23
+ torch.Tensor: KL divergence per sample, shape [B].
24
+ Reduction over latent dimensions is performed inside,
25
+ but not over the batch.
26
+ """
27
+ return -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1)
28
+
29
+ def compute_ELBO(
30
+ x: torch.Tensor,
31
+ x_hat: torch.Tensor,
32
+ mu: torch.Tensor,
33
+ log_var: torch.Tensor,
34
+ likelihood: Union[str, LikelihoodType] = LikelihoodType.GAUSSIAN,
35
+ beta: float = 1.0,
36
+ ) -> ELBOComponents:
37
+ """
38
+ Computes the Evidence Lower Bound (ELBO) for a Variational Autoencoder
39
+ using the β-VAE formulation.
40
+
41
+ Args:
42
+ x (torch.Tensor): Ground truth inputs, shape [B, ...].
43
+ x_hat (torch.Tensor): Reconstructed samples, shape [B, ...] or [B, S, ...],
44
+ where S is the number of Monte Carlo samples from q(z|x).
45
+ mu (torch.Tensor): Mean of q(z|x), shape [B, D_z].
46
+ log_var (torch.Tensor): Log-variance of q(z|x), shape [B, D_z].
47
+ likelihood (Union[str, LikelihoodType]): Likelihood model for p(x|z).
48
+ beta (float): Weighting factor for the KL term (β-VAE).
49
+
50
+ Returns:
51
+ ELBOComponents: NamedTuple containing:
52
+ - elbo (torch.Tensor): Scalar, mean ELBO over the batch.
53
+ - log_likelihood (torch.Tensor): Scalar, mean reconstruction term
54
+ E_q[log p(x|z)] over the batch.
55
+ - beta_kl_divergence (torch.Tensor): Scalar, β * mean KL divergence over the batch.
56
+
57
+ Notes:
58
+ - If x_hat has no sample dimension, it is assumed to contain a single sample (S=1).
59
+ - log p(x|z) is computed using the `log_likelihood` function, which already
60
+ includes the Gaussian normalization constant for σ²=1 or the stable BCE for Bernoulli.
61
+ - All outputs are averaged over the batch for reporting and optimization.
62
+ """
63
+ # Ensure a sample dimension S exists -> [B, S, ...]
64
+ if x_hat.dim() == x.dim():
65
+ x_hat = x_hat.unsqueeze(1) # S = 1
66
+ B, S = x_hat.size(0), x_hat.size(1)
67
+
68
+ # log p(x|z): elementwise -> sum over features => [B, S]
69
+ log_px_z = log_likelihood(x.unsqueeze(1), x_hat, likelihood=likelihood)
70
+ log_px_z = log_px_z.reshape(B, S, -1).sum(-1)
71
+
72
+ # E_q[log p(x|z)] via Monte Carlo average across S: [B]
73
+ E_log_px_z = log_px_z.mean(dim=1)
74
+
75
+ # KL(q||p): [B]
76
+ kl_q_p = kl_divergence_gaussian(mu, log_var)
77
+
78
+ # ELBO per sample and batch means (retain grads)
79
+ elbo_per_sample = E_log_px_z - beta * kl_q_p # [B]
80
+ elbo = elbo_per_sample.mean() # scalar
81
+ E_log_px_z_mean = E_log_px_z.mean() # scalar
82
+ beta_kl_q_p_mean = beta * kl_q_p.mean() # scalar
83
+
84
+ return ELBOComponents(
85
+ elbo=elbo,
86
+ log_likelihood=E_log_px_z_mean,
87
+ beta_kl_divergence=beta_kl_q_p_mean,
88
+ )
@@ -0,0 +1,213 @@
1
+ """Loss wrappers for easy loss computation and tracking."""
2
+ from dataclasses import dataclass
3
+ from typing import Dict, Optional, Union
4
+ import math
5
+ import torch
6
+
7
+ from .base import LikelihoodType, log_likelihood
8
+ from .vae import compute_ELBO
9
+ from ..models.autoencoder import AEOutput
10
+ from ..models.variational.vae import VAEOutput
11
+
12
+ LN2 = math.log(2.0)
13
+ LOG_2PI = math.log(2.0 * math.pi) # for Gaussian σ²=1 diagnostics
14
+
15
+ @dataclass
16
+ class LossComponents:
17
+ """
18
+ Container for loss components with detailed metrics.
19
+
20
+ Args:
21
+ total (torch.Tensor): Scalar loss to optimize (already reduced over batch).
22
+ components (Dict[str, torch.Tensor]): Named scalar terms that compose the loss
23
+ (e.g., 'negative_log_likelihood', 'beta_kl_divergence').
24
+ metrics (Optional[Dict[str, torch.Tensor]]): Additional scalar diagnostics
25
+ (e.g., per-dimension metrics in nats/bits, KL per latent dimension).
26
+
27
+ Notes:
28
+ - All values are batch means unless specified otherwise.
29
+ - Metrics are intended for logging/monitoring and do not affect optimization directly.
30
+ """
31
+ total: torch.Tensor
32
+ components: Dict[str, torch.Tensor]
33
+ metrics: Optional[Dict[str, torch.Tensor]] = None
34
+
35
+ class BaseLoss:
36
+ """Base class for all loss functions."""
37
+ def __call__(self, *args, **kwargs) -> LossComponents:
38
+ raise NotImplementedError
39
+
40
+ class VAELoss(BaseLoss):
41
+ def __init__(
42
+ self,
43
+ beta: float = 1.0,
44
+ likelihood: Union[str, LikelihoodType] = LikelihoodType.GAUSSIAN,
45
+ ):
46
+ """
47
+ Loss function for Variational Autoencoders (β-VAE style).
48
+ The optimized loss is the *negative ELBO*.
49
+ Uses negative log-likelihood (NLL) of reconstructions:
50
+ - Gaussian (σ²=1): per-dim NLL = 0.5·[(x − x_hat)² + log(2π)].
51
+ - Bernoulli (logits): per-dim NLL = BCEWithLogits(x_hat, x).
52
+
53
+ Args:
54
+ beta (float): Weighting factor for the KL term (β-VAE).
55
+ likelihood (Union[str, LikelihoodType]): Likelihood model for p(x|z).
56
+ Supported: 'gaussian' (σ²=1) or 'bernoulli' (logits).
57
+ """
58
+ self.beta = beta
59
+ self.likelihood = likelihood
60
+
61
+ def __call__(self, x: torch.Tensor, model_output: VAEOutput) -> LossComponents:
62
+ """
63
+ Computes VAE loss components and size-normalized diagnostics.
64
+
65
+ Args:
66
+ x (torch.Tensor): Ground truth inputs, shape [B, ...].
67
+ model_output (VAEOutput): from the VAE forward pass, dataclass with:
68
+ - x_hat (torch.Tensor): Reconstructed samples, shape [B, S, ...],
69
+ where S is the number of Monte Carlo samples from q(z|x).
70
+ - z (torch.Tensor): Latent samples (unused here).
71
+ - mu (torch.Tensor): Mean of q(z|x), shape [B, D_z].
72
+ - log_var (torch.Tensor): Log-variance of q(z|x), shape [B, D_z].
73
+
74
+ Returns:
75
+ LossComponents: Named container with:
76
+ - total (torch.Tensor): Scalar, negative ELBO (to minimize).
77
+ - components (Dict[str, torch.Tensor]):
78
+ * 'negative_log_likelihood': Scalar, batch-mean -E_q[log p(x|z)] (nats).
79
+ * 'beta_kl_divergence': Scalar, batch-mean β * KL(q||p) (nats).
80
+ - metrics (Dict[str, torch.Tensor]):
81
+ * 'elbo': Scalar, batch-mean ELBO (nats).
82
+ * 'nll_per_dim_nats': Scalar, -E_q[log p(x|z)] / D_x (nats/dim).
83
+ * 'nll_per_dim_bits': Scalar, bits per dimension = nll_per_dim_nats / ln(2) (bits/dim).
84
+ * 'beta_kl_per_latent_dim_nats': Scalar, β * KL / D_z (nats per latent dim).
85
+ * 'beta_kl_per_latent_dim_bits': Scalar, beta_kl_per_latent_dim_nats / ln(2) (bits per latent dim).
86
+ * 'mse_per_dim' (optional): Scalar, derived from Gaussian σ²=1 identity.
87
+
88
+ Notes:
89
+ - Reductions follow: sum over feature dimensions → mean over MC samples (if any) → mean over batch.
90
+ - log p(x|z) is computed by `log_likelihood`:
91
+ * Gaussian (σ²=1): per-dim NLL = 0.5·MSE + 0.5·log(2π).
92
+ * Bernoulli (logits): per-dim NLL = BCEWithLogits.
93
+ - For Gaussian (σ²=1), 'mse_per_dim' is computed via:
94
+ MSE_per_dim = 2·NLL_per_dim − log(2π), clamped to be ≥ 0.
95
+ """
96
+ x_hat = model_output.x_hat
97
+ mu = model_output.mu
98
+ log_var = model_output.log_var
99
+
100
+ elbo_components = compute_ELBO(
101
+ x=x,
102
+ x_hat=x_hat,
103
+ mu=mu,
104
+ log_var=log_var,
105
+ likelihood=self.likelihood,
106
+ beta=self.beta,
107
+ )
108
+
109
+ D_x = x[0].numel()
110
+ D_z = mu.size(-1)
111
+
112
+ # Per-dimension / per-latent-dimension metrics
113
+ nll_per_dim_nats = -elbo_components.log_likelihood / D_x # nats/dim
114
+ nll_per_dim_bits = nll_per_dim_nats / LN2 # bits/dim
115
+
116
+ beta_kl_per_latent_dim_nats = elbo_components.beta_kl_divergence / D_z # nats/latent-dim
117
+ beta_kl_per_latent_dim_bits = beta_kl_per_latent_dim_nats / LN2 # bits/latent-dim
118
+
119
+ metrics: Dict[str, torch.Tensor] = {
120
+ 'elbo': elbo_components.elbo.detach().cpu(),
121
+ 'nll_per_dim_nats': nll_per_dim_nats.detach().cpu(),
122
+ 'nll_per_dim_bits': nll_per_dim_bits.detach().cpu(),
123
+ 'beta_kl_per_latent_dim_nats': beta_kl_per_latent_dim_nats.detach().cpu(),
124
+ 'beta_kl_per_latent_dim_bits': beta_kl_per_latent_dim_bits.detach().cpu(),
125
+ }
126
+
127
+ # Extra: derive MSE/dim for Gaussian(σ²=1)
128
+ if self.likelihood == LikelihoodType.GAUSSIAN:
129
+ # NLL_per_dim = 0.5*MSE_per_dim + 0.5*log(2π) ⇒ MSE_per_dim = 2*NLL_per_dim − log(2π)
130
+ mse_per_dim = torch.clamp(2.0 * nll_per_dim_nats - LOG_2PI, min=0.0)
131
+ metrics['mse_per_dim'] = mse_per_dim.detach().cpu()
132
+
133
+ return LossComponents(
134
+ total=-elbo_components.elbo, # minimize negative ELBO
135
+ components={
136
+ 'negative_log_likelihood': -elbo_components.log_likelihood,
137
+ 'beta_kl_divergence': elbo_components.beta_kl_divergence,
138
+ },
139
+ metrics=metrics,
140
+ )
141
+
142
+ class AELoss(BaseLoss):
143
+ def __init__(self, likelihood: Union[str, LikelihoodType] = LikelihoodType.GAUSSIAN):
144
+ """
145
+ Loss function for standard Autoencoders.
146
+
147
+ Uses negative log-likelihood (NLL) of reconstructions:
148
+ - Gaussian (σ²=1): per-dim NLL = 0.5·[(x − x_hat)² + log(2π)].
149
+ - Bernoulli (logits): per-dim NLL = BCEWithLogits(x_hat, x).
150
+
151
+ Args:
152
+ likelihood (Union[str, LikelihoodType]): Likelihood model for p(x|z).
153
+ Supported: 'gaussian' (σ²=1) or 'bernoulli' (logits).
154
+ """
155
+ self.likelihood = likelihood
156
+
157
+ def __call__(self, x: torch.Tensor, model_output: AEOutput) -> LossComponents:
158
+ """
159
+ Computes Autoencoder reconstruction loss and size-normalized diagnostics.
160
+
161
+ Args:
162
+ x (torch.Tensor): Ground truth inputs, shape [B, ...].
163
+ model_output (AEOutput): from the AE forward pass, dataclass containing:
164
+ - x_hat (torch.Tensor): Reconstructions, shape [B, ...].
165
+ - z (torch.Tensor): Latent samples (unused here).
166
+
167
+ Returns:
168
+ LossComponents: Named container with:
169
+ - total (torch.Tensor): Scalar, batch-mean reconstruction loss (NLL in nats).
170
+ - components (Dict[str, torch.Tensor]):
171
+ * 'negative_log_likelihood': Scalar, same as total.
172
+ - metrics (Dict[str, torch.Tensor]):
173
+ * 'nll_per_dim_nats': Scalar, NLL / D_x (nats/dim).
174
+ * 'nll_per_dim_bits': Scalar, bits per dimension = nll_per_dim_nats / ln(2) (bits/dim).
175
+ * 'mse_per_dim' (optional): Scalar, derived for Gaussian σ²=1.
176
+
177
+ Notes:
178
+ - Reductions follow: elementwise log-likelihood → sum over feature dimensions
179
+ → mean over batch.
180
+ - For Gaussian (σ²=1), 'mse_per_dim' is computed via:
181
+ MSE_per_dim = 2·(NLL_per_dim) − log(2π), clamped to be ≥ 0.
182
+ - Ensure inputs match the likelihood’s expected scale:
183
+ * Gaussian: continuous data (typically standardized).
184
+ * Bernoulli: targets in [0, 1], predictions given as logits.
185
+ """
186
+ x_hat = model_output.x_hat
187
+
188
+ B = x.size(0)
189
+ D_x = x[0].numel()
190
+
191
+ # Elementwise log-likelihood → per-sample sum → batch mean
192
+ ll_elem = log_likelihood(x, x_hat, likelihood=self.likelihood) # [B, ...]
193
+ ll_per_sample = ll_elem.reshape(B, -1).sum(-1) # [B]
194
+ nll = (-ll_per_sample).mean() # scalar NLL
195
+
196
+ # Per-dim diagnostics
197
+ nll_per_dim_nats = nll / D_x # nats/dim
198
+ nll_per_dim_bits = nll_per_dim_nats / LN2 # bits/dim
199
+
200
+ metrics: Dict[str, torch.Tensor] = {
201
+ 'nll_per_dim_nats': nll_per_dim_nats.detach().cpu(),
202
+ 'nll_per_dim_bits': nll_per_dim_bits.detach().cpu(),
203
+ }
204
+
205
+ if self.likelihood == LikelihoodType.GAUSSIAN:
206
+ mse_per_dim = torch.clamp(2.0 * nll_per_dim_nats - LOG_2PI, min=0.0)
207
+ metrics['mse_per_dim'] = mse_per_dim.detach().cpu()
208
+
209
+ return LossComponents(
210
+ total=nll,
211
+ components={'negative_log_likelihood': nll},
212
+ metrics=metrics,
213
+ )