pyautoencoder 1.1.0__tar.gz → 1.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/PKG-INFO +7 -3
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/README.md +6 -2
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/api/models.rst +11 -1
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/architecture.rst +1 -1
- pyautoencoder-1.1.2/pyautoencoder/_version.py +24 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/loss/base.py +9 -7
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/variational/__init__.py +2 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/variational/stochastic_layers.py +12 -8
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/variational/vae.py +208 -1
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder.egg-info/PKG-INFO +7 -3
- pyautoencoder-1.1.2/test/loss/test_base_loss.py +347 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/test/variational/test_stochastic_layers.py +140 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/test/variational/test_vae.py +286 -0
- pyautoencoder-1.1.0/pyautoencoder/_version.py +0 -34
- pyautoencoder-1.1.0/test/loss/test_base_loss.py +0 -170
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/.github/workflows/ci.yml +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/.github/workflows/publish.yml +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/.gitignore +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/LICENSE +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/assets/logo_nobackground.png +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/Makefile +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/make.bat +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/requirements.txt +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/api/base.rst +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/api/index.rst +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/api/losses.rst +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/api/stochastic.rst +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/conf.py +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/examples/ae_mnist.rst +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/examples/vae_mnist.rst +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/examples.rst +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/getting_started.rst +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/index.rst +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/examples/mnist_ae.py +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/examples/mnist_vae_kingma2013.py +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/__init__.py +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/_base/__init__.py +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/_base/base.py +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/experimental/__init__.py +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/experimental/benchmark_datasets/__init__.py +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/experimental/benchmark_datasets/disentanglement.py +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/experimental/hypernetworks.py +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/loss/__init__.py +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/vanilla/__init__.py +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/vanilla/autoencoder.py +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder.egg-info/SOURCES.txt +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder.egg-info/dependency_links.txt +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder.egg-info/requires.txt +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder.egg-info/top_level.txt +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyproject.toml +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/readthedocs.yaml +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/setup.cfg +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/test/_base/test_base.py +0 -0
- {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/test/vanilla/test_autoencoder.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyautoencoder
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.2
|
|
4
4
|
Summary: A Python package offering implementations of state-of-the-art autoencoder architectures in PyTorch.
|
|
5
5
|
Author: Andrea Pollastro
|
|
6
6
|
License: MIT
|
|
@@ -67,6 +67,7 @@ PyAutoencoder is designed to offer **simple and easy access to autoencoder frame
|
|
|
67
67
|
**Currently implemented**:
|
|
68
68
|
- Autoencoder (AE)
|
|
69
69
|
- Variational Autoencoder (VAE)
|
|
70
|
+
- Adaptive Group Variational Autoencoder (Ada-GVAE)
|
|
70
71
|
|
|
71
72
|
---
|
|
72
73
|
|
|
@@ -123,8 +124,8 @@ for x in dataloader:
|
|
|
123
124
|
loss_results.objective.backward() # negative ELBO
|
|
124
125
|
optimizer.step()
|
|
125
126
|
# optional: log components
|
|
126
|
-
log_likelihood = loss_results.
|
|
127
|
-
kl_divergence = loss_results.
|
|
127
|
+
log_likelihood = loss_results.diagnostics["log_likelihood"]
|
|
128
|
+
kl_divergence = loss_results.diagnostics["kl_divergence"]
|
|
128
129
|
```
|
|
129
130
|
|
|
130
131
|
## Examples
|
|
@@ -156,3 +157,6 @@ If you use this package in academic work, please cite:
|
|
|
156
157
|
publisher={Elsevier}
|
|
157
158
|
}
|
|
158
159
|
```
|
|
160
|
+
## Acknowledgments
|
|
161
|
+
|
|
162
|
+
This work was funded by the PNRR MUR project PE0000013-FAIR (CUP: E63C25000630006).
|
|
@@ -39,6 +39,7 @@ PyAutoencoder is designed to offer **simple and easy access to autoencoder frame
|
|
|
39
39
|
**Currently implemented**:
|
|
40
40
|
- Autoencoder (AE)
|
|
41
41
|
- Variational Autoencoder (VAE)
|
|
42
|
+
- Adaptive Group Variational Autoencoder (Ada-GVAE)
|
|
42
43
|
|
|
43
44
|
---
|
|
44
45
|
|
|
@@ -95,8 +96,8 @@ for x in dataloader:
|
|
|
95
96
|
loss_results.objective.backward() # negative ELBO
|
|
96
97
|
optimizer.step()
|
|
97
98
|
# optional: log components
|
|
98
|
-
log_likelihood = loss_results.
|
|
99
|
-
kl_divergence = loss_results.
|
|
99
|
+
log_likelihood = loss_results.diagnostics["log_likelihood"]
|
|
100
|
+
kl_divergence = loss_results.diagnostics["kl_divergence"]
|
|
100
101
|
```
|
|
101
102
|
|
|
102
103
|
## Examples
|
|
@@ -128,3 +129,6 @@ If you use this package in academic work, please cite:
|
|
|
128
129
|
publisher={Elsevier}
|
|
129
130
|
}
|
|
130
131
|
```
|
|
132
|
+
## Acknowledgments
|
|
133
|
+
|
|
134
|
+
This work was funded by the PNRR MUR project PE0000013-FAIR (CUP: E63C25000630006).
|
|
@@ -90,4 +90,14 @@ Variational Autoencoder
|
|
|
90
90
|
|
|
91
91
|
.. autoclass:: pyautoencoder.variational.VAEOutput
|
|
92
92
|
:members:
|
|
93
|
-
:no-index:
|
|
93
|
+
:no-index:
|
|
94
|
+
|
|
95
|
+
Adaptive Group Variational Autoencoder
|
|
96
|
+
---------------------------------------
|
|
97
|
+
|
|
98
|
+
.. autoclass:: pyautoencoder.variational.AdaGVAE
|
|
99
|
+
:members:
|
|
100
|
+
:undoc-members:
|
|
101
|
+
:show-inheritance:
|
|
102
|
+
:exclude-members: build, _encode, _decode
|
|
103
|
+
:special-members: __init__
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# file generated by vcs-versioning
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"__version__",
|
|
7
|
+
"__version_tuple__",
|
|
8
|
+
"version",
|
|
9
|
+
"version_tuple",
|
|
10
|
+
"__commit_id__",
|
|
11
|
+
"commit_id",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
version: str
|
|
15
|
+
__version__: str
|
|
16
|
+
__version_tuple__: tuple[int | str, ...]
|
|
17
|
+
version_tuple: tuple[int | str, ...]
|
|
18
|
+
commit_id: str | None
|
|
19
|
+
__commit_id__: str | None
|
|
20
|
+
|
|
21
|
+
__version__ = version = '1.1.2'
|
|
22
|
+
__version_tuple__ = version_tuple = (1, 1, 2)
|
|
23
|
+
|
|
24
|
+
__commit_id__ = commit_id = 'g2d2766837'
|
|
@@ -82,7 +82,7 @@ def log_likelihood(x: torch.Tensor,
|
|
|
82
82
|
.. math::
|
|
83
83
|
|
|
84
84
|
\log p(x \mid \hat{x}) =
|
|
85
|
-
-\tfrac{1}{2}
|
|
85
|
+
-\tfrac{1}{2} (x - \hat{x})^2.
|
|
86
86
|
|
|
87
87
|
The output has the same shape as ``x``. Summing over feature dimensions
|
|
88
88
|
gives per-sample log-likelihoods.
|
|
@@ -134,8 +134,7 @@ def log_likelihood(x: torch.Tensor,
|
|
|
134
134
|
|
|
135
135
|
elif likelihood == LikelihoodType.GAUSSIAN:
|
|
136
136
|
squared_error = (x_hat - x) ** 2
|
|
137
|
-
|
|
138
|
-
return -0.5 * (squared_error + log_2pi)
|
|
137
|
+
return -0.5 * squared_error
|
|
139
138
|
|
|
140
139
|
else:
|
|
141
140
|
raise ValueError(f"Unsupported likelihood: {likelihood}")
|
|
@@ -144,8 +143,8 @@ def kl_divergence_diag_gaussian(
|
|
|
144
143
|
mu_q: torch.Tensor,
|
|
145
144
|
log_var_q: torch.Tensor,
|
|
146
145
|
mu_p: Optional[torch.Tensor] = None,
|
|
147
|
-
log_var_p: Optional[torch.Tensor] = None
|
|
148
|
-
) -> torch.Tensor:
|
|
146
|
+
log_var_p: Optional[torch.Tensor] = None,
|
|
147
|
+
reduce_sum: bool = True) -> torch.Tensor:
|
|
149
148
|
r"""Compute the KL divergence :math:`\mathrm{KL}(q(z \mid x) \,\|\, p(z))`
|
|
150
149
|
between two diagonal Gaussian distributions.
|
|
151
150
|
|
|
@@ -175,6 +174,8 @@ def kl_divergence_diag_gaussian(
|
|
|
175
174
|
Mean of the second distribution ``[B, D_z]``. Defaults to 0.
|
|
176
175
|
log_var_p : torch.Tensor, optional
|
|
177
176
|
Log-variance of the second distribution ``[B, D_z]``. Defaults to 0.
|
|
177
|
+
reduce_sum: bool, optional
|
|
178
|
+
Sum over the dimensions. Default to True
|
|
178
179
|
|
|
179
180
|
Returns
|
|
180
181
|
-------
|
|
@@ -194,5 +195,6 @@ def kl_divergence_diag_gaussian(
|
|
|
194
195
|
|
|
195
196
|
term1 = log_var_p - log_var_q
|
|
196
197
|
term2 = (var_q + (mu_q - mu_p).pow(2)) / var_p
|
|
197
|
-
|
|
198
|
-
|
|
198
|
+
if reduce_sum:
|
|
199
|
+
return 0.5 * torch.sum(term1 + term2 - 1, dim=-1)
|
|
200
|
+
return 0.5 * (term1 + term2 - 1)
|
|
@@ -118,20 +118,24 @@ class FullyFactorizedGaussian(nn.Module):
|
|
|
118
118
|
if S < 1:
|
|
119
119
|
raise ValueError("S must be >= 1.")
|
|
120
120
|
|
|
121
|
-
mu = self.mu(x) # type: ignore
|
|
122
|
-
log_var = self.log_var(x) # type: ignore
|
|
121
|
+
mu = self.mu(x) # type: ignore # [B, Dz]
|
|
122
|
+
log_var = self.log_var(x) # type: ignore # [B, Dz]
|
|
123
123
|
|
|
124
124
|
if self.training:
|
|
125
|
-
|
|
126
|
-
mu_e = mu.unsqueeze(1).expand(-1, S, -1) # [B, S, Dz]
|
|
127
|
-
std_e = std.unsqueeze(1).expand(-1, S, -1) # [B, S, Dz]
|
|
128
|
-
eps = torch.randn_like(std_e)
|
|
129
|
-
z = mu_e + std_e * eps # [B, S, Dz]
|
|
125
|
+
z = self.reparametrize(mu=mu, log_var=log_var, S=S) # [B, S, Dz]
|
|
130
126
|
else:
|
|
131
|
-
z = mu.unsqueeze(1).expand(-1, S, -1)
|
|
127
|
+
z = mu.unsqueeze(1).expand(-1, S, -1) # [B, S, Dz]
|
|
132
128
|
|
|
133
129
|
return z, mu, log_var
|
|
134
130
|
|
|
131
|
+
def reparametrize(self, mu: torch.Tensor, log_var: torch.Tensor, S: int = 1):
|
|
132
|
+
std = torch.exp(0.5 * log_var) # [B, Dz]
|
|
133
|
+
mu_e = mu.unsqueeze(1).expand(-1, S, -1) # [B, S, Dz]
|
|
134
|
+
std_e = std.unsqueeze(1).expand(-1, S, -1) # [B, S, Dz]
|
|
135
|
+
eps = torch.randn_like(std_e)
|
|
136
|
+
z = mu_e + std_e * eps # [B, S, Dz]
|
|
137
|
+
return z
|
|
138
|
+
|
|
135
139
|
@property
|
|
136
140
|
def built(self) -> bool:
|
|
137
141
|
"""Whether the module has been successfully built.
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Union,
|
|
4
|
+
from typing import Union, Tuple
|
|
5
5
|
|
|
6
6
|
from ..loss.base import (
|
|
7
7
|
LikelihoodType,
|
|
@@ -318,3 +318,210 @@ class VAE(BaseAutoencoder):
|
|
|
318
318
|
'kl_divergence': kl_q_p.mean().item(),
|
|
319
319
|
}
|
|
320
320
|
)
|
|
321
|
+
|
|
322
|
+
class AdaGVAE(VAE):
|
|
323
|
+
r"""Adaptive Group Variational Autoencoder (Ada-GVAE), from Locatello et al. (2020).
|
|
324
|
+
|
|
325
|
+
This class extends the VAE class and enables feature disentanglement in the latent space.
|
|
326
|
+
For inference, use the .encode() and .decode() methods, as the forward method expects pairs of images,
|
|
327
|
+
following the formulation introduced by Locatello et al.
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
def __init__(
|
|
331
|
+
self,
|
|
332
|
+
encoder: nn.Module,
|
|
333
|
+
decoder: nn.Module,
|
|
334
|
+
latent_dim: int,
|
|
335
|
+
):
|
|
336
|
+
"""Construct an AdaGVAE from an encoder, decoder, and latent size.
|
|
337
|
+
|
|
338
|
+
Notes
|
|
339
|
+
-----
|
|
340
|
+
The encoder and decoder are identical to those in a standard VAE. The adaptive
|
|
341
|
+
grouping mechanism is applied during the encoding step when processing paired inputs.
|
|
342
|
+
|
|
343
|
+
Parameters
|
|
344
|
+
----------
|
|
345
|
+
encoder : nn.Module
|
|
346
|
+
Maps input ``x`` to a feature vector ``f(x)`` with shape ``[B, F]``.
|
|
347
|
+
decoder : nn.Module
|
|
348
|
+
Maps latent samples ``z`` to reconstructions ``x_hat``.
|
|
349
|
+
latent_dim : int
|
|
350
|
+
Dimensionality ``D_z`` of the latent space.
|
|
351
|
+
"""
|
|
352
|
+
super().__init__(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
|
|
353
|
+
|
|
354
|
+
# --- training-time hooks required by BaseAutoencoder ---
|
|
355
|
+
def _encode_pair(self, x1: torch.Tensor, x2: torch.Tensor, S: int = 1) -> Tuple[VAEEncodeOutput, VAEEncodeOutput]:
|
|
356
|
+
r"""Encode a pair of inputs with adaptive posterior alignment.
|
|
357
|
+
|
|
358
|
+
As described in Locatello et al., this method:
|
|
359
|
+
|
|
360
|
+
1. Encodes both inputs independently to obtain posterior parameters ``(mu1, log_var1)`` and ``(mu2, log_var2)``.
|
|
361
|
+
2. Computes element-wise KL divergence between the two posteriors: ``KL(q1||q2) → [B, D_z]``.
|
|
362
|
+
3. Computes a per-sample threshold ``tau`` based on KL divergences.
|
|
363
|
+
4. For each dimension, selects aligned (shared) or independent posteriors:
|
|
364
|
+
- If ``KL(q1_d||q2_d) < tau``: uses average distribution ``q_tilde``.
|
|
365
|
+
- If ``KL(q1_d||q2_d) ≥ tau``: uses original independent distribution.
|
|
366
|
+
5. Samples from the resulting (mixed) posteriors.
|
|
367
|
+
|
|
368
|
+
Parameters
|
|
369
|
+
----------
|
|
370
|
+
x1 : torch.Tensor
|
|
371
|
+
First input batch of shape ``[B, ...]``.
|
|
372
|
+
x2 : torch.Tensor
|
|
373
|
+
Second input batch of shape ``[B, ...]``.
|
|
374
|
+
S : int
|
|
375
|
+
Number of latent samples per input.
|
|
376
|
+
|
|
377
|
+
Returns
|
|
378
|
+
-------
|
|
379
|
+
Tuple[VAEEncodeOutput, VAEEncodeOutput]
|
|
380
|
+
A pair of ``VAEEncodeOutput`` objects, each containing:
|
|
381
|
+
|
|
382
|
+
- ``z`` of shape ``[B, S, D_z]``: samples from the adapted posteriors.
|
|
383
|
+
- ``mu`` of shape ``[B, D_z]``: the (adapted) means.
|
|
384
|
+
- ``log_var`` of shape ``[B, D_z]``: the (adapted) log-variances.
|
|
385
|
+
|
|
386
|
+
Notes
|
|
387
|
+
-----
|
|
388
|
+
The thresholding mechanism promotes learning of shared latent factors
|
|
389
|
+
while allowing independent variation for high-divergence dimensions.
|
|
390
|
+
This encourages disentanglement and structured representations.
|
|
391
|
+
"""
|
|
392
|
+
_, mu1, log_var1 = self.sampling_layer(self.encoder(x1))
|
|
393
|
+
_, mu2, log_var2 = self.sampling_layer(self.encoder(x2))
|
|
394
|
+
|
|
395
|
+
# KL(q1||q2) -> [B, latents]
|
|
396
|
+
kl_q1_q2 = kl_divergence_diag_gaussian(mu1, log_var1, mu2, log_var2, reduce_sum=False)
|
|
397
|
+
|
|
398
|
+
# Computing threshold tau
|
|
399
|
+
max_delta = torch.max(kl_q1_q2, dim=1, keepdim=True)[0]
|
|
400
|
+
min_delta = torch.min(kl_q1_q2, dim=1, keepdim=True)[0]
|
|
401
|
+
tau = 0.5 * (max_delta + min_delta)
|
|
402
|
+
|
|
403
|
+
# Computing q_tilde1 and q_tilde2
|
|
404
|
+
mu_mean = 0.5*(mu1 + mu2)
|
|
405
|
+
var_mean = 0.5*(torch.exp(log_var1) + torch.exp(log_var2))
|
|
406
|
+
log_var_mean = torch.log(var_mean)
|
|
407
|
+
|
|
408
|
+
mask = kl_q1_q2 < tau
|
|
409
|
+
mu_tilde1 = torch.where(mask, mu_mean, mu1)
|
|
410
|
+
mu_tilde2 = torch.where(mask, mu_mean, mu2)
|
|
411
|
+
log_var_tilde1 = torch.where(mask, log_var_mean, log_var1)
|
|
412
|
+
log_var_tilde2 = torch.where(mask, log_var_mean, log_var2)
|
|
413
|
+
|
|
414
|
+
z1 = self.sampling_layer.reparametrize(mu=mu_tilde1, log_var=log_var_tilde1, S=S)
|
|
415
|
+
z2 = self.sampling_layer.reparametrize(mu=mu_tilde2, log_var=log_var_tilde2, S=S)
|
|
416
|
+
|
|
417
|
+
return VAEEncodeOutput(z=z1, mu=mu_tilde1, log_var=log_var_tilde1), \
|
|
418
|
+
VAEEncodeOutput(z=z2, mu=mu_tilde2, log_var=log_var_tilde2)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def forward(self, x1: torch.Tensor, x2: torch.Tensor, S: int = 1) -> Tuple[VAEOutput, VAEOutput]:
|
|
422
|
+
"""Full AdaGVAE forward pass: encode pairs with adaptive grouping, sample, and decode.
|
|
423
|
+
|
|
424
|
+
Parameters
|
|
425
|
+
----------
|
|
426
|
+
x1 : torch.Tensor
|
|
427
|
+
First input batch of shape ``[B, ...]``.
|
|
428
|
+
x2 : torch.Tensor
|
|
429
|
+
Second input batch of shape ``[B, ...]``.
|
|
430
|
+
S : int
|
|
431
|
+
Number of latent samples for Monte Carlo estimates.
|
|
432
|
+
|
|
433
|
+
Returns
|
|
434
|
+
-------
|
|
435
|
+
Tuple[VAEOutput, VAEOutput]
|
|
436
|
+
A pair of VAE outputs, each containing:
|
|
437
|
+
|
|
438
|
+
- ``x_hat``: reconstructions from the adapted latent samples.
|
|
439
|
+
- ``z``: latent samples from the adapted posteriors.
|
|
440
|
+
- ``mu``: (adapted) posterior means.
|
|
441
|
+
- ``log_var``: (adapted) posterior log-variances.
|
|
442
|
+
"""
|
|
443
|
+
x1_enc, x2_enc = self._encode_pair(x1, x2, S=S)
|
|
444
|
+
x1_dec = self._decode(x1_enc.z)
|
|
445
|
+
x2_dec = self._decode(x2_enc.z)
|
|
446
|
+
return VAEOutput(x_hat=x1_dec.x_hat, z=x1_enc.z, mu=x1_enc.mu, log_var=x1_enc.log_var), \
|
|
447
|
+
VAEOutput(x_hat=x2_dec.x_hat, z=x2_enc.z, mu=x2_enc.mu, log_var=x2_enc.log_var)
|
|
448
|
+
|
|
449
|
+
def compute_loss(self,
|
|
450
|
+
x1: torch.Tensor,
|
|
451
|
+
x1_vae_output: VAEOutput,
|
|
452
|
+
x2: torch.Tensor,
|
|
453
|
+
x2_vae_output: VAEOutput,
|
|
454
|
+
beta: float = 1,
|
|
455
|
+
likelihood: Union[str, LikelihoodType] = LikelihoodType.GAUSSIAN) -> LossResult:
|
|
456
|
+
r"""Compute the combined ELBO for a pair of inputs with adaptive posteriors.
|
|
457
|
+
|
|
458
|
+
This method computes the sum of the standard VAE ELBOs for both inputs:
|
|
459
|
+
|
|
460
|
+
.. math::
|
|
461
|
+
|
|
462
|
+
\mathcal{L}(x_1, x_2; \beta)
|
|
463
|
+
= \left[ \mathbb{E}_{q(\hat{z} \mid x_1)}[\log p(x_1 \mid \hat{z})]
|
|
464
|
+
\;-\; \beta \, \mathrm{KL}(q(\hat{z} \mid x_1) \,\|\, p(\hat{z})) \right]
|
|
465
|
+
+ \left[ \mathbb{E}_{q(\hat{z} \mid x_2)}[\log p(x_2 \mid \hat{z})]
|
|
466
|
+
\;-\; \beta \, \mathrm{KL}(q(\hat{z} \mid x_2) \,\|\, p(\hat{z})) \right].
|
|
467
|
+
|
|
468
|
+
The key difference from standard VAE is that the posteriors :math:`q(\hat{z} | x_1)` and
|
|
469
|
+
:math:`q(\hat{z} | x_2)` are obtained from the adaptive grouping mechanism, which can
|
|
470
|
+
share dimensions based on KL divergence thresholds.
|
|
471
|
+
|
|
472
|
+
Parameters
|
|
473
|
+
----------
|
|
474
|
+
x1 : torch.Tensor
|
|
475
|
+
First input batch of shape ``[B, ...]``.
|
|
476
|
+
x1_vae_output : VAEOutput
|
|
477
|
+
Output from the forward pass for ``x1``. Expected fields:
|
|
478
|
+
|
|
479
|
+
- ``x_hat`` (torch.Tensor): Reconstructions, shape ``[B, ...]`` or ``[B, S, ...]``.
|
|
480
|
+
- ``mu`` (torch.Tensor): (Adapted) posterior mean, shape ``[B, D_z]``.
|
|
481
|
+
- ``log_var`` (torch.Tensor): (Adapted) posterior log-variance, shape ``[B, D_z]``.
|
|
482
|
+
|
|
483
|
+
x2 : torch.Tensor
|
|
484
|
+
Second input batch of shape ``[B, ...]``.
|
|
485
|
+
x2_vae_output : VAEOutput
|
|
486
|
+
Output from the forward pass for ``x2``. Same structure as ``x1_vae_output``.
|
|
487
|
+
likelihood : Union[str, LikelihoodType], optional
|
|
488
|
+
Likelihood model for the reconstruction term.
|
|
489
|
+
Can be 'gaussian' or 'bernoulli'. Defaults to Gaussian.
|
|
490
|
+
beta : float, optional
|
|
491
|
+
Weighting factor for the KL term (beta-VAE).
|
|
492
|
+
``beta = 1`` yields the standard objective. Defaults to 1.
|
|
493
|
+
|
|
494
|
+
Returns
|
|
495
|
+
-------
|
|
496
|
+
LossResult
|
|
497
|
+
Result containing:
|
|
498
|
+
|
|
499
|
+
* **objective** – Sum of negative ELBOs for both inputs (scalar).
|
|
500
|
+
* **diagnostics** – Dictionary with:
|
|
501
|
+
|
|
502
|
+
- ``"elbo"``: Sum of ELBOs for both inputs.
|
|
503
|
+
- ``"log_likelihood_x1"``: Mean reconstruction term for ``x1``.
|
|
504
|
+
- ``"log_likelihood_x2"``: Mean reconstruction term for ``x2``.
|
|
505
|
+
- ``"kl_divergence_x1"``: Mean KL divergence for ``x1``'s posterior.
|
|
506
|
+
- ``"kl_divergence_x2"``: Mean KL divergence for ``x2``'s posterior.
|
|
507
|
+
|
|
508
|
+
Notes
|
|
509
|
+
-----
|
|
510
|
+
- All diagnostics are **batch means** (per-sample losses averaged over ``B``).
|
|
511
|
+
- Gradients flow through both decoders; neither input is detached.
|
|
512
|
+
- The adaptive grouping introduces implicit structure learning through
|
|
513
|
+
the selective sharing of posterior dimensions.
|
|
514
|
+
"""
|
|
515
|
+
x1_loss_info = super().compute_loss(x=x1, vae_output=x1_vae_output, beta=beta, likelihood=likelihood)
|
|
516
|
+
x2_loss_info = super().compute_loss(x=x2, vae_output=x2_vae_output, beta=beta, likelihood=likelihood)
|
|
517
|
+
|
|
518
|
+
return LossResult(
|
|
519
|
+
objective = x1_loss_info.objective + x2_loss_info.objective,
|
|
520
|
+
diagnostics = {
|
|
521
|
+
'elbo': x1_loss_info.diagnostics['elbo'] + x2_loss_info.diagnostics['elbo'],
|
|
522
|
+
'log_likelihood_x1': x1_loss_info.diagnostics['log_likelihood'],
|
|
523
|
+
'log_likelihood_x2': x2_loss_info.diagnostics['log_likelihood'],
|
|
524
|
+
'kl_divergence_x1': x1_loss_info.diagnostics['kl_divergence'],
|
|
525
|
+
'kl_divergence_x2': x2_loss_info.diagnostics['kl_divergence'],
|
|
526
|
+
}
|
|
527
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pyautoencoder
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.2
|
|
4
4
|
Summary: A Python package offering implementations of state-of-the-art autoencoder architectures in PyTorch.
|
|
5
5
|
Author: Andrea Pollastro
|
|
6
6
|
License: MIT
|
|
@@ -67,6 +67,7 @@ PyAutoencoder is designed to offer **simple and easy access to autoencoder frame
|
|
|
67
67
|
**Currently implemented**:
|
|
68
68
|
- Autoencoder (AE)
|
|
69
69
|
- Variational Autoencoder (VAE)
|
|
70
|
+
- Adaptive Group Variational Autoencoder (Ada-GVAE)
|
|
70
71
|
|
|
71
72
|
---
|
|
72
73
|
|
|
@@ -123,8 +124,8 @@ for x in dataloader:
|
|
|
123
124
|
loss_results.objective.backward() # negative ELBO
|
|
124
125
|
optimizer.step()
|
|
125
126
|
# optional: log components
|
|
126
|
-
log_likelihood = loss_results.
|
|
127
|
-
kl_divergence = loss_results.
|
|
127
|
+
log_likelihood = loss_results.diagnostics["log_likelihood"]
|
|
128
|
+
kl_divergence = loss_results.diagnostics["kl_divergence"]
|
|
128
129
|
```
|
|
129
130
|
|
|
130
131
|
## Examples
|
|
@@ -156,3 +157,6 @@ If you use this package in academic work, please cite:
|
|
|
156
157
|
publisher={Elsevier}
|
|
157
158
|
}
|
|
158
159
|
```
|
|
160
|
+
## Acknowledgments
|
|
161
|
+
|
|
162
|
+
This work was funded by the PNRR MUR project PE0000013-FAIR (CUP: E63C25000630006).
|