pyautoencoder 1.1.0__tar.gz → 1.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/PKG-INFO +7 -3
  2. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/README.md +6 -2
  3. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/api/models.rst +11 -1
  4. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/architecture.rst +1 -1
  5. pyautoencoder-1.1.2/pyautoencoder/_version.py +24 -0
  6. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/loss/base.py +9 -7
  7. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/variational/__init__.py +2 -0
  8. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/variational/stochastic_layers.py +12 -8
  9. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/variational/vae.py +208 -1
  10. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder.egg-info/PKG-INFO +7 -3
  11. pyautoencoder-1.1.2/test/loss/test_base_loss.py +347 -0
  12. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/test/variational/test_stochastic_layers.py +140 -0
  13. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/test/variational/test_vae.py +286 -0
  14. pyautoencoder-1.1.0/pyautoencoder/_version.py +0 -34
  15. pyautoencoder-1.1.0/test/loss/test_base_loss.py +0 -170
  16. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/.github/workflows/ci.yml +0 -0
  17. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/.github/workflows/publish.yml +0 -0
  18. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/.gitignore +0 -0
  19. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/LICENSE +0 -0
  20. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/assets/logo_nobackground.png +0 -0
  21. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/Makefile +0 -0
  22. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/make.bat +0 -0
  23. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/requirements.txt +0 -0
  24. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/api/base.rst +0 -0
  25. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/api/index.rst +0 -0
  26. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/api/losses.rst +0 -0
  27. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/api/stochastic.rst +0 -0
  28. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/conf.py +0 -0
  29. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/examples/ae_mnist.rst +0 -0
  30. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/examples/vae_mnist.rst +0 -0
  31. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/examples.rst +0 -0
  32. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/getting_started.rst +0 -0
  33. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/docs/source/index.rst +0 -0
  34. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/examples/mnist_ae.py +0 -0
  35. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/examples/mnist_vae_kingma2013.py +0 -0
  36. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/__init__.py +0 -0
  37. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/_base/__init__.py +0 -0
  38. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/_base/base.py +0 -0
  39. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/experimental/__init__.py +0 -0
  40. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/experimental/benchmark_datasets/__init__.py +0 -0
  41. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/experimental/benchmark_datasets/disentanglement.py +0 -0
  42. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/experimental/hypernetworks.py +0 -0
  43. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/loss/__init__.py +0 -0
  44. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/vanilla/__init__.py +0 -0
  45. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder/vanilla/autoencoder.py +0 -0
  46. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder.egg-info/SOURCES.txt +0 -0
  47. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder.egg-info/dependency_links.txt +0 -0
  48. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder.egg-info/requires.txt +0 -0
  49. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyautoencoder.egg-info/top_level.txt +0 -0
  50. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/pyproject.toml +0 -0
  51. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/readthedocs.yaml +0 -0
  52. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/setup.cfg +0 -0
  53. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/test/_base/test_base.py +0 -0
  54. {pyautoencoder-1.1.0 → pyautoencoder-1.1.2}/test/vanilla/test_autoencoder.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyautoencoder
3
- Version: 1.1.0
3
+ Version: 1.1.2
4
4
  Summary: A Python package offering implementations of state-of-the-art autoencoder architectures in PyTorch.
5
5
  Author: Andrea Pollastro
6
6
  License: MIT
@@ -67,6 +67,7 @@ PyAutoencoder is designed to offer **simple and easy access to autoencoder frame
67
67
  **Currently implemented**:
68
68
  - Autoencoder (AE)
69
69
  - Variational Autoencoder (VAE)
70
+ - Adaptive Group Variational Autoencoder (Ada-GVAE)
70
71
 
71
72
  ---
72
73
 
@@ -123,8 +124,8 @@ for x in dataloader:
123
124
  loss_results.objective.backward() # negative ELBO
124
125
  optimizer.step()
125
126
  # optional: log components
126
- log_likelihood = loss_results.components["log_likelihood"]
127
- kl_divergence = loss_results.components["kl_divergence"]
127
+ log_likelihood = loss_results.diagnostics["log_likelihood"]
128
+ kl_divergence = loss_results.diagnostics["kl_divergence"]
128
129
  ```
129
130
 
130
131
  ## Examples
@@ -156,3 +157,6 @@ If you use this package in academic work, please cite:
156
157
  publisher={Elsevier}
157
158
  }
158
159
  ```
160
+ ## Acknowledgments
161
+
162
+ This work was funded by the PNRR MUR project PE0000013-FAIR (CUP: E63C25000630006).
@@ -39,6 +39,7 @@ PyAutoencoder is designed to offer **simple and easy access to autoencoder frame
39
39
  **Currently implemented**:
40
40
  - Autoencoder (AE)
41
41
  - Variational Autoencoder (VAE)
42
+ - Adaptive Group Variational Autoencoder (Ada-GVAE)
42
43
 
43
44
  ---
44
45
 
@@ -95,8 +96,8 @@ for x in dataloader:
95
96
  loss_results.objective.backward() # negative ELBO
96
97
  optimizer.step()
97
98
  # optional: log components
98
- log_likelihood = loss_results.components["log_likelihood"]
99
- kl_divergence = loss_results.components["kl_divergence"]
99
+ log_likelihood = loss_results.diagnostics["log_likelihood"]
100
+ kl_divergence = loss_results.diagnostics["kl_divergence"]
100
101
  ```
101
102
 
102
103
  ## Examples
@@ -128,3 +129,6 @@ If you use this package in academic work, please cite:
128
129
  publisher={Elsevier}
129
130
  }
130
131
  ```
132
+ ## Acknowledgments
133
+
134
+ This work was funded by the PNRR MUR project PE0000013-FAIR (CUP: E63C25000630006).
@@ -90,4 +90,14 @@ Variational Autoencoder
90
90
 
91
91
  .. autoclass:: pyautoencoder.variational.VAEOutput
92
92
  :members:
93
- :no-index:
93
+ :no-index:
94
+
95
+ Adaptive Group Variational Autoencoder
96
+ ---------------------------------------
97
+
98
+ .. autoclass:: pyautoencoder.variational.AdaGVAE
99
+ :members:
100
+ :undoc-members:
101
+ :show-inheritance:
102
+ :exclude-members: build, _encode, _decode
103
+ :special-members: __init__
@@ -137,7 +137,7 @@ Supported likelihoods:
137
137
 
138
138
  .. math::
139
139
 
140
- \text{NLL} = \frac{1}{2}[(x-\hat{x})^2 + \log(2\pi)]
140
+ \text{NLL} = \frac{1}{2}(x-\hat{x})^2
141
141
 
142
142
  - **Bernoulli** – Discrete/binary data (logits)
143
143
 
@@ -0,0 +1,24 @@
1
+ # file generated by vcs-versioning
2
+ # don't change, don't track in version control
3
+ from __future__ import annotations
4
+
5
+ __all__ = [
6
+ "__version__",
7
+ "__version_tuple__",
8
+ "version",
9
+ "version_tuple",
10
+ "__commit_id__",
11
+ "commit_id",
12
+ ]
13
+
14
+ version: str
15
+ __version__: str
16
+ __version_tuple__: tuple[int | str, ...]
17
+ version_tuple: tuple[int | str, ...]
18
+ commit_id: str | None
19
+ __commit_id__: str | None
20
+
21
+ __version__ = version = '1.1.2'
22
+ __version_tuple__ = version_tuple = (1, 1, 2)
23
+
24
+ __commit_id__ = commit_id = 'g2d2766837'
@@ -82,7 +82,7 @@ def log_likelihood(x: torch.Tensor,
82
82
  .. math::
83
83
 
84
84
  \log p(x \mid \hat{x}) =
85
- -\tfrac{1}{2} \left[ (x - \hat{x})^2 + \log(2\pi) \right].
85
+ -\tfrac{1}{2} (x - \hat{x})^2.
86
86
 
87
87
  The output has the same shape as ``x``. Summing over feature dimensions
88
88
  gives per-sample log-likelihoods.
@@ -134,8 +134,7 @@ def log_likelihood(x: torch.Tensor,
134
134
 
135
135
  elif likelihood == LikelihoodType.GAUSSIAN:
136
136
  squared_error = (x_hat - x) ** 2
137
- log_2pi = _get_log2pi(x)
138
- return -0.5 * (squared_error + log_2pi)
137
+ return -0.5 * squared_error
139
138
 
140
139
  else:
141
140
  raise ValueError(f"Unsupported likelihood: {likelihood}")
@@ -144,8 +143,8 @@ def kl_divergence_diag_gaussian(
144
143
  mu_q: torch.Tensor,
145
144
  log_var_q: torch.Tensor,
146
145
  mu_p: Optional[torch.Tensor] = None,
147
- log_var_p: Optional[torch.Tensor] = None
148
- ) -> torch.Tensor:
146
+ log_var_p: Optional[torch.Tensor] = None,
147
+ reduce_sum: bool = True) -> torch.Tensor:
149
148
  r"""Compute the KL divergence :math:`\mathrm{KL}(q(z \mid x) \,\|\, p(z))`
150
149
  between two diagonal Gaussian distributions.
151
150
 
@@ -175,6 +174,8 @@ def kl_divergence_diag_gaussian(
175
174
  Mean of the second distribution ``[B, D_z]``. Defaults to 0.
176
175
  log_var_p : torch.Tensor, optional
177
176
  Log-variance of the second distribution ``[B, D_z]``. Defaults to 0.
177
+ reduce_sum: bool, optional
178
+ Sum over the dimensions. Default to True
178
179
 
179
180
  Returns
180
181
  -------
@@ -194,5 +195,6 @@ def kl_divergence_diag_gaussian(
194
195
 
195
196
  term1 = log_var_p - log_var_q
196
197
  term2 = (var_q + (mu_q - mu_p).pow(2)) / var_p
197
-
198
- return 0.5 * torch.sum(term1 + term2 - 1, dim=-1)
198
+ if reduce_sum:
199
+ return 0.5 * torch.sum(term1 + term2 - 1, dim=-1)
200
+ return 0.5 * (term1 + term2 - 1)
@@ -1,4 +1,5 @@
1
1
  from .vae import (
2
+ AdaGVAE,
2
3
  VAE,
3
4
  VAEDecodeOutput,
4
5
  VAEEncodeOutput,
@@ -6,6 +7,7 @@ from .vae import (
6
7
  )
7
8
 
8
9
  __all__ = [
10
+ 'AdaGVAE',
9
11
  'VAE',
10
12
  'VAEDecodeOutput',
11
13
  'VAEEncodeOutput',
@@ -118,20 +118,24 @@ class FullyFactorizedGaussian(nn.Module):
118
118
  if S < 1:
119
119
  raise ValueError("S must be >= 1.")
120
120
 
121
- mu = self.mu(x) # type: ignore # [B, Dz]
122
- log_var = self.log_var(x) # type: ignore # [B, Dz]
121
+ mu = self.mu(x) # type: ignore # [B, Dz]
122
+ log_var = self.log_var(x) # type: ignore # [B, Dz]
123
123
 
124
124
  if self.training:
125
- std = torch.exp(0.5 * log_var) # [B, Dz]
126
- mu_e = mu.unsqueeze(1).expand(-1, S, -1) # [B, S, Dz]
127
- std_e = std.unsqueeze(1).expand(-1, S, -1) # [B, S, Dz]
128
- eps = torch.randn_like(std_e)
129
- z = mu_e + std_e * eps # [B, S, Dz]
125
+ z = self.reparametrize(mu=mu, log_var=log_var, S=S) # [B, S, Dz]
130
126
  else:
131
- z = mu.unsqueeze(1).expand(-1, S, -1) # [B, S, Dz]
127
+ z = mu.unsqueeze(1).expand(-1, S, -1) # [B, S, Dz]
132
128
 
133
129
  return z, mu, log_var
134
130
 
131
+ def reparametrize(self, mu: torch.Tensor, log_var: torch.Tensor, S: int = 1):
132
+ std = torch.exp(0.5 * log_var) # [B, Dz]
133
+ mu_e = mu.unsqueeze(1).expand(-1, S, -1) # [B, S, Dz]
134
+ std_e = std.unsqueeze(1).expand(-1, S, -1) # [B, S, Dz]
135
+ eps = torch.randn_like(std_e)
136
+ z = mu_e + std_e * eps # [B, S, Dz]
137
+ return z
138
+
135
139
  @property
136
140
  def built(self) -> bool:
137
141
  """Whether the module has been successfully built.
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  from dataclasses import dataclass
4
- from typing import Union, Dict
4
+ from typing import Union, Tuple
5
5
 
6
6
  from ..loss.base import (
7
7
  LikelihoodType,
@@ -318,3 +318,210 @@ class VAE(BaseAutoencoder):
318
318
  'kl_divergence': kl_q_p.mean().item(),
319
319
  }
320
320
  )
321
+
322
+ class AdaGVAE(VAE):
323
+ r"""Adaptive Group Variational Autoencoder (Ada-GVAE), from Locatello et al. (2020).
324
+
325
+ This class extends the VAE class and enables feature disentanglement in the latent space.
326
+ For inference, use the .encode() and .decode() methods, as the forward method expects pairs of images,
327
+ following the formulation introduced by Locatello et al.
328
+ """
329
+
330
+ def __init__(
331
+ self,
332
+ encoder: nn.Module,
333
+ decoder: nn.Module,
334
+ latent_dim: int,
335
+ ):
336
+ """Construct an AdaGVAE from an encoder, decoder, and latent size.
337
+
338
+ Notes
339
+ -----
340
+ The encoder and decoder are identical to those in a standard VAE. The adaptive
341
+ grouping mechanism is applied during the encoding step when processing paired inputs.
342
+
343
+ Parameters
344
+ ----------
345
+ encoder : nn.Module
346
+ Maps input ``x`` to a feature vector ``f(x)`` with shape ``[B, F]``.
347
+ decoder : nn.Module
348
+ Maps latent samples ``z`` to reconstructions ``x_hat``.
349
+ latent_dim : int
350
+ Dimensionality ``D_z`` of the latent space.
351
+ """
352
+ super().__init__(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
353
+
354
+ # --- training-time hooks required by BaseAutoencoder ---
355
+ def _encode_pair(self, x1: torch.Tensor, x2: torch.Tensor, S: int = 1) -> Tuple[VAEEncodeOutput, VAEEncodeOutput]:
356
+ r"""Encode a pair of inputs with adaptive posterior alignment.
357
+
358
+ As described in Locatello et al., this method:
359
+
360
+ 1. Encodes both inputs independently to obtain posterior parameters ``(mu1, log_var1)`` and ``(mu2, log_var2)``.
361
+ 2. Computes element-wise KL divergence between the two posteriors: ``KL(q1||q2) → [B, D_z]``.
362
+ 3. Computes a per-sample threshold ``tau`` based on KL divergences.
363
+ 4. For each dimension, selects aligned (shared) or independent posteriors:
364
+ - If ``KL(q1_d||q2_d) < tau``: uses average distribution ``q_tilde``.
365
+ - If ``KL(q1_d||q2_d) ≥ tau``: uses original independent distribution.
366
+ 5. Samples from the resulting (mixed) posteriors.
367
+
368
+ Parameters
369
+ ----------
370
+ x1 : torch.Tensor
371
+ First input batch of shape ``[B, ...]``.
372
+ x2 : torch.Tensor
373
+ Second input batch of shape ``[B, ...]``.
374
+ S : int
375
+ Number of latent samples per input.
376
+
377
+ Returns
378
+ -------
379
+ Tuple[VAEEncodeOutput, VAEEncodeOutput]
380
+ A pair of ``VAEEncodeOutput`` objects, each containing:
381
+
382
+ - ``z`` of shape ``[B, S, D_z]``: samples from the adapted posteriors.
383
+ - ``mu`` of shape ``[B, D_z]``: the (adapted) means.
384
+ - ``log_var`` of shape ``[B, D_z]``: the (adapted) log-variances.
385
+
386
+ Notes
387
+ -----
388
+ The thresholding mechanism promotes learning of shared latent factors
389
+ while allowing independent variation for high-divergence dimensions.
390
+ This encourages disentanglement and structured representations.
391
+ """
392
+ _, mu1, log_var1 = self.sampling_layer(self.encoder(x1))
393
+ _, mu2, log_var2 = self.sampling_layer(self.encoder(x2))
394
+
395
+ # KL(q1||q2) -> [B, latents]
396
+ kl_q1_q2 = kl_divergence_diag_gaussian(mu1, log_var1, mu2, log_var2, reduce_sum=False)
397
+
398
+ # Computing threshold tau
399
+ max_delta = torch.max(kl_q1_q2, dim=1, keepdim=True)[0]
400
+ min_delta = torch.min(kl_q1_q2, dim=1, keepdim=True)[0]
401
+ tau = 0.5 * (max_delta + min_delta)
402
+
403
+ # Computing q_tilde1 and q_tilde2
404
+ mu_mean = 0.5*(mu1 + mu2)
405
+ var_mean = 0.5*(torch.exp(log_var1) + torch.exp(log_var2))
406
+ log_var_mean = torch.log(var_mean)
407
+
408
+ mask = kl_q1_q2 < tau
409
+ mu_tilde1 = torch.where(mask, mu_mean, mu1)
410
+ mu_tilde2 = torch.where(mask, mu_mean, mu2)
411
+ log_var_tilde1 = torch.where(mask, log_var_mean, log_var1)
412
+ log_var_tilde2 = torch.where(mask, log_var_mean, log_var2)
413
+
414
+ z1 = self.sampling_layer.reparametrize(mu=mu_tilde1, log_var=log_var_tilde1, S=S)
415
+ z2 = self.sampling_layer.reparametrize(mu=mu_tilde2, log_var=log_var_tilde2, S=S)
416
+
417
+ return VAEEncodeOutput(z=z1, mu=mu_tilde1, log_var=log_var_tilde1), \
418
+ VAEEncodeOutput(z=z2, mu=mu_tilde2, log_var=log_var_tilde2)
419
+
420
+
421
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor, S: int = 1) -> Tuple[VAEOutput, VAEOutput]:
422
+ """Full AdaGVAE forward pass: encode pairs with adaptive grouping, sample, and decode.
423
+
424
+ Parameters
425
+ ----------
426
+ x1 : torch.Tensor
427
+ First input batch of shape ``[B, ...]``.
428
+ x2 : torch.Tensor
429
+ Second input batch of shape ``[B, ...]``.
430
+ S : int
431
+ Number of latent samples for Monte Carlo estimates.
432
+
433
+ Returns
434
+ -------
435
+ Tuple[VAEOutput, VAEOutput]
436
+ A pair of VAE outputs, each containing:
437
+
438
+ - ``x_hat``: reconstructions from the adapted latent samples.
439
+ - ``z``: latent samples from the adapted posteriors.
440
+ - ``mu``: (adapted) posterior means.
441
+ - ``log_var``: (adapted) posterior log-variances.
442
+ """
443
+ x1_enc, x2_enc = self._encode_pair(x1, x2, S=S)
444
+ x1_dec = self._decode(x1_enc.z)
445
+ x2_dec = self._decode(x2_enc.z)
446
+ return VAEOutput(x_hat=x1_dec.x_hat, z=x1_enc.z, mu=x1_enc.mu, log_var=x1_enc.log_var), \
447
+ VAEOutput(x_hat=x2_dec.x_hat, z=x2_enc.z, mu=x2_enc.mu, log_var=x2_enc.log_var)
448
+
449
+ def compute_loss(self,
450
+ x1: torch.Tensor,
451
+ x1_vae_output: VAEOutput,
452
+ x2: torch.Tensor,
453
+ x2_vae_output: VAEOutput,
454
+ beta: float = 1,
455
+ likelihood: Union[str, LikelihoodType] = LikelihoodType.GAUSSIAN) -> LossResult:
456
+ r"""Compute the combined ELBO for a pair of inputs with adaptive posteriors.
457
+
458
+ This method computes the sum of the standard VAE ELBOs for both inputs:
459
+
460
+ .. math::
461
+
462
+ \mathcal{L}(x_1, x_2; \beta)
463
+ = \left[ \mathbb{E}_{q(\hat{z} \mid x_1)}[\log p(x_1 \mid \hat{z})]
464
+ \;-\; \beta \, \mathrm{KL}(q(\hat{z} \mid x_1) \,\|\, p(\hat{z})) \right]
465
+ + \left[ \mathbb{E}_{q(\hat{z} \mid x_2)}[\log p(x_2 \mid \hat{z})]
466
+ \;-\; \beta \, \mathrm{KL}(q(\hat{z} \mid x_2) \,\|\, p(\hat{z})) \right].
467
+
468
+ The key difference from standard VAE is that the posteriors :math:`q(\hat{z} | x_1)` and
469
+ :math:`q(\hat{z} | x_2)` are obtained from the adaptive grouping mechanism, which can
470
+ share dimensions based on KL divergence thresholds.
471
+
472
+ Parameters
473
+ ----------
474
+ x1 : torch.Tensor
475
+ First input batch of shape ``[B, ...]``.
476
+ x1_vae_output : VAEOutput
477
+ Output from the forward pass for ``x1``. Expected fields:
478
+
479
+ - ``x_hat`` (torch.Tensor): Reconstructions, shape ``[B, ...]`` or ``[B, S, ...]``.
480
+ - ``mu`` (torch.Tensor): (Adapted) posterior mean, shape ``[B, D_z]``.
481
+ - ``log_var`` (torch.Tensor): (Adapted) posterior log-variance, shape ``[B, D_z]``.
482
+
483
+ x2 : torch.Tensor
484
+ Second input batch of shape ``[B, ...]``.
485
+ x2_vae_output : VAEOutput
486
+ Output from the forward pass for ``x2``. Same structure as ``x1_vae_output``.
487
+ likelihood : Union[str, LikelihoodType], optional
488
+ Likelihood model for the reconstruction term.
489
+ Can be 'gaussian' or 'bernoulli'. Defaults to Gaussian.
490
+ beta : float, optional
491
+ Weighting factor for the KL term (beta-VAE).
492
+ ``beta = 1`` yields the standard objective. Defaults to 1.
493
+
494
+ Returns
495
+ -------
496
+ LossResult
497
+ Result containing:
498
+
499
+ * **objective** – Sum of negative ELBOs for both inputs (scalar).
500
+ * **diagnostics** – Dictionary with:
501
+
502
+ - ``"elbo"``: Sum of ELBOs for both inputs.
503
+ - ``"log_likelihood_x1"``: Mean reconstruction term for ``x1``.
504
+ - ``"log_likelihood_x2"``: Mean reconstruction term for ``x2``.
505
+ - ``"kl_divergence_x1"``: Mean KL divergence for ``x1``'s posterior.
506
+ - ``"kl_divergence_x2"``: Mean KL divergence for ``x2``'s posterior.
507
+
508
+ Notes
509
+ -----
510
+ - All diagnostics are **batch means** (per-sample losses averaged over ``B``).
511
+ - Gradients flow through both decoders; neither input is detached.
512
+ - The adaptive grouping introduces implicit structure learning through
513
+ the selective sharing of posterior dimensions.
514
+ """
515
+ x1_loss_info = super().compute_loss(x=x1, vae_output=x1_vae_output, beta=beta, likelihood=likelihood)
516
+ x2_loss_info = super().compute_loss(x=x2, vae_output=x2_vae_output, beta=beta, likelihood=likelihood)
517
+
518
+ return LossResult(
519
+ objective = x1_loss_info.objective + x2_loss_info.objective,
520
+ diagnostics = {
521
+ 'elbo': x1_loss_info.diagnostics['elbo'] + x2_loss_info.diagnostics['elbo'],
522
+ 'log_likelihood_x1': x1_loss_info.diagnostics['log_likelihood'],
523
+ 'log_likelihood_x2': x2_loss_info.diagnostics['log_likelihood'],
524
+ 'kl_divergence_x1': x1_loss_info.diagnostics['kl_divergence'],
525
+ 'kl_divergence_x2': x2_loss_info.diagnostics['kl_divergence'],
526
+ }
527
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyautoencoder
3
- Version: 1.1.0
3
+ Version: 1.1.2
4
4
  Summary: A Python package offering implementations of state-of-the-art autoencoder architectures in PyTorch.
5
5
  Author: Andrea Pollastro
6
6
  License: MIT
@@ -67,6 +67,7 @@ PyAutoencoder is designed to offer **simple and easy access to autoencoder frame
67
67
  **Currently implemented**:
68
68
  - Autoencoder (AE)
69
69
  - Variational Autoencoder (VAE)
70
+ - Adaptive Group Variational Autoencoder (Ada-GVAE)
70
71
 
71
72
  ---
72
73
 
@@ -123,8 +124,8 @@ for x in dataloader:
123
124
  loss_results.objective.backward() # negative ELBO
124
125
  optimizer.step()
125
126
  # optional: log components
126
- log_likelihood = loss_results.components["log_likelihood"]
127
- kl_divergence = loss_results.components["kl_divergence"]
127
+ log_likelihood = loss_results.diagnostics["log_likelihood"]
128
+ kl_divergence = loss_results.diagnostics["kl_divergence"]
128
129
  ```
129
130
 
130
131
  ## Examples
@@ -156,3 +157,6 @@ If you use this package in academic work, please cite:
156
157
  publisher={Elsevier}
157
158
  }
158
159
  ```
160
+ ## Acknowledgments
161
+
162
+ This work was funded by the PNRR MUR project PE0000013-FAIR (CUP: E63C25000630006).