pyautoencoder 1.1.2__tar.gz → 1.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/PKG-INFO +1 -1
  2. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/_base/base.py +1 -1
  3. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/_version.py +3 -3
  4. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/loss/base.py +3 -31
  5. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/variational/stochastic_layers.py +23 -2
  6. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/variational/vae.py +15 -12
  7. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder.egg-info/PKG-INFO +1 -1
  8. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder.egg-info/SOURCES.txt +0 -4
  9. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/test/_base/test_base.py +0 -1
  10. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/test/loss/test_base_loss.py +7 -24
  11. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/test/vanilla/test_autoencoder.py +20 -14
  12. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/test/variational/test_stochastic_layers.py +19 -0
  13. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/test/variational/test_vae.py +115 -50
  14. pyautoencoder-1.1.2/pyautoencoder/experimental/__init__.py +0 -0
  15. pyautoencoder-1.1.2/pyautoencoder/experimental/benchmark_datasets/__init__.py +0 -0
  16. pyautoencoder-1.1.2/pyautoencoder/experimental/benchmark_datasets/disentanglement.py +0 -84
  17. pyautoencoder-1.1.2/pyautoencoder/experimental/hypernetworks.py +0 -261
  18. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/.github/workflows/ci.yml +0 -0
  19. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/.github/workflows/publish.yml +0 -0
  20. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/.gitignore +0 -0
  21. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/LICENSE +0 -0
  22. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/README.md +0 -0
  23. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/assets/logo_nobackground.png +0 -0
  24. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/Makefile +0 -0
  25. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/make.bat +0 -0
  26. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/requirements.txt +0 -0
  27. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/api/base.rst +0 -0
  28. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/api/index.rst +0 -0
  29. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/api/losses.rst +0 -0
  30. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/api/models.rst +0 -0
  31. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/api/stochastic.rst +0 -0
  32. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/architecture.rst +0 -0
  33. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/conf.py +0 -0
  34. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/examples/ae_mnist.rst +0 -0
  35. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/examples/vae_mnist.rst +0 -0
  36. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/examples.rst +0 -0
  37. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/getting_started.rst +0 -0
  38. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/index.rst +0 -0
  39. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/examples/mnist_ae.py +0 -0
  40. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/examples/mnist_vae_kingma2013.py +0 -0
  41. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/__init__.py +0 -0
  42. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/_base/__init__.py +0 -0
  43. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/loss/__init__.py +0 -0
  44. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/vanilla/__init__.py +0 -0
  45. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/vanilla/autoencoder.py +0 -0
  46. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/variational/__init__.py +0 -0
  47. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder.egg-info/dependency_links.txt +0 -0
  48. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder.egg-info/requires.txt +0 -0
  49. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder.egg-info/top_level.txt +0 -0
  50. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyproject.toml +0 -0
  51. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/readthedocs.yaml +0 -0
  52. {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyautoencoder
3
- Version: 1.1.2
3
+ Version: 1.1.4
4
4
  Summary: A Python package offering implementations of state-of-the-art autoencoder architectures in PyTorch.
5
5
  Author: Andrea Pollastro
6
6
  License: MIT
@@ -156,7 +156,7 @@ class BuildGuardMixin(ABC):
156
156
  @wraps(_orig_build)
157
157
  def _wrapped_build(self, *args: Any, **kwargs: Any) -> None:
158
158
 
159
- if getattr(self, "_built", True):
159
+ if getattr(self, "_built", False):
160
160
  return
161
161
 
162
162
  with torch.no_grad():
@@ -18,7 +18,7 @@ version_tuple: tuple[int | str, ...]
18
18
  commit_id: str | None
19
19
  __commit_id__: str | None
20
20
 
21
- __version__ = version = '1.1.2'
22
- __version_tuple__ = version_tuple = (1, 1, 2)
21
+ __version__ = version = '1.1.4'
22
+ __version_tuple__ = version_tuple = (1, 1, 4)
23
23
 
24
- __commit_id__ = commit_id = 'g2d2766837'
24
+ __commit_id__ = commit_id = 'g21e483730'
@@ -1,4 +1,3 @@
1
- import math
2
1
  import torch
3
2
  import torch.nn.functional as F
4
3
  from dataclasses import dataclass
@@ -20,9 +19,6 @@ class LikelihoodType(Enum):
20
19
  GAUSSIAN = 'gaussian'
21
20
  BERNOULLI = 'bernoulli'
22
21
 
23
- # Cache for log(2pi) constants per (device, dtype)
24
- _LOG2PI_CACHE = {}
25
-
26
22
  @dataclass(slots=True, repr=True)
27
23
  class LossResult:
28
24
  r"""Container for loss computation results with objective and diagnostics.
@@ -45,30 +41,6 @@ class LossResult:
45
41
  objective: torch.Tensor
46
42
  diagnostics: Dict[str, float]
47
43
 
48
- def _get_log2pi(x: torch.Tensor) -> torch.Tensor:
49
- r"""Return a cached value of :math:`\log(2\pi)` for the given device and dtype.
50
-
51
- This avoids repeatedly allocating the constant for different devices or
52
- precisions. A separate tensor is cached for each ``(device, dtype)`` pair.
53
-
54
- Parameters
55
- ----------
56
- x : torch.Tensor
57
- A tensor whose ``device`` and ``dtype`` determine which cached value is
58
- returned or created.
59
-
60
- Returns
61
- -------
62
- torch.Tensor
63
- A scalar tensor equal to :math:`\log(2\pi)` with the same device and
64
- dtype as ``x``.
65
- """
66
-
67
- key = (x.device, x.dtype)
68
- if key not in _LOG2PI_CACHE:
69
- _LOG2PI_CACHE[key] = torch.tensor(2.0 * math.pi, device=x.device, dtype=x.dtype).log()
70
- return _LOG2PI_CACHE[key]
71
-
72
44
  def log_likelihood(x: torch.Tensor,
73
45
  x_hat: torch.Tensor,
74
46
  likelihood: Union[str, LikelihoodType] = LikelihoodType.GAUSSIAN) -> torch.Tensor:
@@ -118,9 +90,9 @@ def log_likelihood(x: torch.Tensor,
118
90
 
119
91
  Notes
120
92
  -----
121
- - The Gaussian case includes the normalization constant
122
- :math:`\log(2\pi)`, cached per ``(device, dtype)`` with
123
- :func:`_get_log2pi`.
93
+ - The Gaussian case omits the normalization constant
94
+ :math:`-\tfrac{1}{2}\log(2\pi)`, which is constant with respect to
95
+ the model parameters and has no effect on optimization.
124
96
  - The Bernoulli case is fully numerically stable because it operates
125
97
  directly in log-space.
126
98
  """
@@ -62,7 +62,7 @@ class FullyFactorizedGaussian(nn.Module):
62
62
  raise TypeError("build(x) expects a torch.Tensor.")
63
63
  if input_sample.ndim != 2:
64
64
  raise ValueError(f"build(x): expected shape [B, F], got {tuple(input_sample.shape)}. Flatten upstream.")
65
- if input_sample.shape[1] <= 0:
65
+ if input_sample.shape[1] == 0:
66
66
  raise ValueError("build(x): F (feature dimension) must be > 0.")
67
67
 
68
68
  in_features = int(input_sample.shape[1])
@@ -128,7 +128,28 @@ class FullyFactorizedGaussian(nn.Module):
128
128
 
129
129
  return z, mu, log_var
130
130
 
131
- def reparametrize(self, mu: torch.Tensor, log_var: torch.Tensor, S: int = 1):
131
+ def reparametrize(self, mu: torch.Tensor, log_var: torch.Tensor, S: int = 1) -> torch.Tensor:
132
+ r"""Draw ``S`` latent samples via the reparameterization trick.
133
+
134
+ .. math::
135
+
136
+ z^{(s)} = \mu + \sigma \odot \epsilon^{(s)},
137
+ \qquad \epsilon^{(s)} \sim \mathcal{N}(0, I).
138
+
139
+ Parameters
140
+ ----------
141
+ mu : torch.Tensor
142
+ Mean of the posterior, shape ``[B, D_z]``.
143
+ log_var : torch.Tensor
144
+ Log-variance of the posterior, shape ``[B, D_z]``.
145
+ S : int, optional
146
+ Number of samples to draw. Defaults to ``1``.
147
+
148
+ Returns
149
+ -------
150
+ torch.Tensor
151
+ Sampled latent codes of shape ``[B, S, D_z]``.
152
+ """
132
153
  std = torch.exp(0.5 * log_var) # [B, Dz]
133
154
  mu_e = mu.unsqueeze(1).expand(-1, S, -1) # [B, S, Dz]
134
155
  std_e = std.unsqueeze(1).expand(-1, S, -1) # [B, S, Dz]
@@ -322,11 +322,14 @@ class VAE(BaseAutoencoder):
322
322
  class AdaGVAE(VAE):
323
323
  r"""Adaptive Group Variational Autoencoder (Ada-GVAE), from Locatello et al. (2020).
324
324
 
325
- This class extends the VAE class and enables feature disentanglement in the latent space.
326
- For inference, use the .encode() and .decode() methods, as the forward method expects pairs of images,
327
- following the formulation introduced by Locatello et al.
325
+ This class extends the VAE class and enables feature disentanglement in the latent space.
326
+ Use :meth:`forward_pair` for the paired training pass and :meth:`compute_pair_loss` for
327
+ its loss. The inherited :meth:`encode` / :meth:`decode` methods work normally for inference
328
+ on single inputs.
328
329
  """
329
330
 
331
+ _GUARDED = VAE._GUARDED | {"_encode_pair"}
332
+
330
333
  def __init__(
331
334
  self,
332
335
  encoder: nn.Module,
@@ -418,7 +421,7 @@ class AdaGVAE(VAE):
418
421
  VAEEncodeOutput(z=z2, mu=mu_tilde2, log_var=log_var_tilde2)
419
422
 
420
423
 
421
- def forward(self, x1: torch.Tensor, x2: torch.Tensor, S: int = 1) -> Tuple[VAEOutput, VAEOutput]:
424
+ def forward_pair(self, x1: torch.Tensor, x2: torch.Tensor, S: int = 1) -> Tuple[VAEOutput, VAEOutput]:
422
425
  """Full AdaGVAE forward pass: encode pairs with adaptive grouping, sample, and decode.
423
426
 
424
427
  Parameters
@@ -445,14 +448,14 @@ class AdaGVAE(VAE):
445
448
  x2_dec = self._decode(x2_enc.z)
446
449
  return VAEOutput(x_hat=x1_dec.x_hat, z=x1_enc.z, mu=x1_enc.mu, log_var=x1_enc.log_var), \
447
450
  VAEOutput(x_hat=x2_dec.x_hat, z=x2_enc.z, mu=x2_enc.mu, log_var=x2_enc.log_var)
448
-
449
- def compute_loss(self,
450
- x1: torch.Tensor,
451
- x1_vae_output: VAEOutput,
452
- x2: torch.Tensor,
453
- x2_vae_output: VAEOutput,
454
- beta: float = 1,
455
- likelihood: Union[str, LikelihoodType] = LikelihoodType.GAUSSIAN) -> LossResult:
451
+
452
+ def compute_pair_loss(self,
453
+ x1: torch.Tensor,
454
+ x1_vae_output: VAEOutput,
455
+ x2: torch.Tensor,
456
+ x2_vae_output: VAEOutput,
457
+ beta: float = 1,
458
+ likelihood: Union[str, LikelihoodType] = LikelihoodType.GAUSSIAN) -> LossResult:
456
459
  r"""Compute the combined ELBO for a pair of inputs with adaptive posteriors.
457
460
 
458
461
  This method computes the sum of the standard VAE ELBOs for both inputs:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pyautoencoder
3
- Version: 1.1.2
3
+ Version: 1.1.4
4
4
  Summary: A Python package offering implementations of state-of-the-art autoencoder architectures in PyTorch.
5
5
  Author: Andrea Pollastro
6
6
  License: MIT
@@ -32,10 +32,6 @@ pyautoencoder.egg-info/requires.txt
32
32
  pyautoencoder.egg-info/top_level.txt
33
33
  pyautoencoder/_base/__init__.py
34
34
  pyautoencoder/_base/base.py
35
- pyautoencoder/experimental/__init__.py
36
- pyautoencoder/experimental/hypernetworks.py
37
- pyautoencoder/experimental/benchmark_datasets/__init__.py
38
- pyautoencoder/experimental/benchmark_datasets/disentanglement.py
39
35
  pyautoencoder/loss/__init__.py
40
36
  pyautoencoder/loss/base.py
41
37
  pyautoencoder/vanilla/__init__.py
@@ -14,7 +14,6 @@ from pyautoencoder._base.base import (
14
14
  # ================= ModelOutput =================
15
15
 
16
16
  def test_model_output_repr_tensors_and_non_tensors():
17
- @torch.no_grad()
18
17
  @dataclass(slots=True, repr=False)
19
18
  class MyOutput(ModelOutput):
20
19
  logits: torch.Tensor
@@ -1,5 +1,4 @@
1
1
  import pytest
2
- import math
3
2
  import torch
4
3
  import torch.nn.functional as F
5
4
 
@@ -252,22 +251,22 @@ def test_kl_divergence_preserves_device():
252
251
  assert kl.device == mu_q.device
253
252
 
254
253
 
255
- def test_kl_divergence_symmetric_when_p_and_q_swapped_with_custom_prior():
256
- """Test asymmetry property: KL(q||p) != KL(p||q) in general."""
254
+ def test_kl_divergence_is_asymmetric():
255
+ """KL(q||p) != KL(p||q) in general — KL is not a symmetric distance."""
257
256
  B, Dz = 2, 3
257
+ torch.manual_seed(0)
258
258
  mu_q = torch.randn(B, Dz)
259
259
  log_var_q = torch.randn(B, Dz)
260
260
  mu_p = torch.randn(B, Dz)
261
261
  log_var_p = torch.randn(B, Dz)
262
-
262
+
263
263
  kl_q_p = kl_divergence_diag_gaussian(mu_q, log_var_q, mu_p, log_var_p)
264
264
  kl_p_q = kl_divergence_diag_gaussian(mu_p, log_var_p, mu_q, log_var_q)
265
-
266
- # KL divergence is asymmetric in general
267
- # (unless the distributions happen to be very similar)
268
- # Just check both are valid
265
+
269
266
  assert torch.isfinite(kl_q_p).all()
270
267
  assert torch.isfinite(kl_p_q).all()
268
+ # For random, distinct Gaussians this will virtually never be equal
269
+ assert not torch.allclose(kl_q_p, kl_p_q)
271
270
 
272
271
 
273
272
  def test_kl_divergence_backward_flows_gradients():
@@ -306,22 +305,6 @@ def test_kl_divergence_with_custom_prior_backward():
306
305
  assert log_var_p.grad is not None
307
306
 
308
307
 
309
- def test_kl_divergence_matches_pytorch_implementation():
310
- """Compare with a reference PyTorch implementation."""
311
- B, Dz = 3, 4
312
- mu_q = torch.randn(B, Dz)
313
- log_var_q = torch.randn(B, Dz)
314
-
315
- # Our implementation
316
- kl_ours = kl_divergence_diag_gaussian(mu_q, log_var_q)
317
-
318
- # Reference implementation (standard VAE KL)
319
- var_q = log_var_q.exp()
320
- kl_ref = 0.5 * torch.sum(-log_var_q + var_q + mu_q.pow(2) - 1, dim=-1)
321
-
322
- assert torch.allclose(kl_ours, kl_ref, atol=1e-5)
323
-
324
-
325
308
  def test_kl_divergence_large_batch():
326
309
  """Test with larger batch size."""
327
310
  B, Dz = 128, 32
@@ -247,8 +247,8 @@ def test_ae_compute_loss_gaussian_likelihood_returns_correct_type():
247
247
  assert isinstance(loss_result.diagnostics['log_likelihood'], float)
248
248
 
249
249
 
250
- def test_ae_compute_loss_gaussian_likelihood_is_nonnegative():
251
- """Test that NLL (objective) is non-negative for Gaussian likelihood."""
250
+ def test_ae_compute_loss_gaussian_nll_equals_half_mse():
251
+ """Test that the Gaussian NLL objective equals 0.5 * mean MSE."""
252
252
  batch_size = 5
253
253
  in_features = 6
254
254
  latent_features = 2
@@ -265,12 +265,15 @@ def test_ae_compute_loss_gaussian_likelihood_is_nonnegative():
265
265
 
266
266
  loss_result = ae.compute_loss(x, ae_output, likelihood='gaussian')
267
267
 
268
- # NLL should be non-negative (it's -log_likelihood)
269
- assert loss_result.objective.item() >= 0
268
+ # Gaussian NLL (without normalization constant) == 0.5 * per-sample MSE, batch-averaged
269
+ expected = 0.5 * ((ae_output.x_hat - x) ** 2).reshape(batch_size, -1).sum(-1).mean()
270
+ assert torch.allclose(loss_result.objective, expected, atol=1e-6)
270
271
 
271
272
 
272
- def test_ae_compute_loss_bernoulli_likelihood():
273
- """Test compute_loss with Bernoulli likelihood."""
273
+ def test_ae_compute_loss_bernoulli_nll_equals_bce():
274
+ """Bernoulli NLL equals sum-over-features BCE, batch-averaged."""
275
+ import torch.nn.functional as F
276
+
274
277
  batch_size = 4
275
278
  in_features = 8
276
279
  latent_features = 3
@@ -279,7 +282,7 @@ def test_ae_compute_loss_bernoulli_likelihood():
279
282
  decoder = SimpleDecoder(latent_features=latent_features, out_features=in_features)
280
283
  ae = AE(encoder=encoder, decoder=decoder)
281
284
 
282
- x = torch.sigmoid(torch.randn(batch_size, in_features)) # Bernoulli needs [0, 1]
285
+ x = torch.sigmoid(torch.randn(batch_size, in_features)) # targets in [0, 1]
283
286
  ae.build(x)
284
287
 
285
288
  torch.set_grad_enabled(True)
@@ -287,16 +290,18 @@ def test_ae_compute_loss_bernoulli_likelihood():
287
290
 
288
291
  loss_result = ae.compute_loss(x, ae_output, likelihood='bernoulli')
289
292
 
290
- # Check return structure
291
- from pyautoencoder.loss.base import LossResult
292
- assert isinstance(loss_result, LossResult)
293
+ # NLL = mean over batch of (sum over features of BCE)
294
+ expected = F.binary_cross_entropy_with_logits(
295
+ ae_output.x_hat, x, reduction='none'
296
+ ).reshape(batch_size, -1).sum(-1).mean()
297
+
293
298
  assert loss_result.objective.dim() == 0
299
+ assert torch.allclose(loss_result.objective, expected, atol=1e-6)
294
300
  assert 'log_likelihood' in loss_result.diagnostics
295
- assert loss_result.objective.item() >= 0
296
301
 
297
302
 
298
- def test_ae_compute_loss_backward_flows_through_x_hat():
299
- """Test that gradients flow properly through the loss."""
303
+ def test_ae_compute_loss_backward_flows_through_all_params():
304
+ """Test that gradients flow through both encoder and decoder."""
300
305
  batch_size = 2
301
306
  in_features = 4
302
307
  latent_features = 2
@@ -314,8 +319,9 @@ def test_ae_compute_loss_backward_flows_through_x_hat():
314
319
  loss_result = ae.compute_loss(x, ae_output)
315
320
  loss_result.objective.backward()
316
321
 
317
- # Check that decoder params have gradients
322
+ enc_grads = [p.grad for p in encoder.parameters() if p.requires_grad]
318
323
  dec_grads = [p.grad for p in decoder.parameters() if p.requires_grad]
324
+ assert any(g is not None and torch.any(g != 0) for g in enc_grads)
319
325
  assert any(g is not None and torch.any(g != 0) for g in dec_grads)
320
326
 
321
327
 
@@ -87,6 +87,23 @@ def test_ffg_build_can_be_called_twice_with_same_feature_dim():
87
87
  assert isinstance(head.mu, nn.Linear)
88
88
  assert head.mu.in_features == F
89
89
 
90
+
91
+ def test_ffg_build_replaces_layers_on_different_feature_dim():
92
+ """Rebuilding with a different F must replace mu and log_var layers."""
93
+ latent_dim = 3
94
+ head = FullyFactorizedGaussian(latent_dim=latent_dim)
95
+
96
+ head.build(torch.randn(2, 5))
97
+ assert head.in_features == 5
98
+ assert isinstance(head.mu, nn.Linear) and head.mu.in_features == 5
99
+ assert isinstance(head.log_var, nn.Linear) and head.log_var.in_features == 5
100
+
101
+ head.build(torch.randn(2, 8))
102
+ assert head.in_features == 8
103
+ assert isinstance(head.mu, nn.Linear) and head.mu.in_features == 8
104
+ assert isinstance(head.log_var, nn.Linear) and head.log_var.in_features == 8
105
+ assert head.built is True
106
+
90
107
  def test_ffg_forward_raises_if_not_built():
91
108
  head = FullyFactorizedGaussian(latent_dim=3)
92
109
  x = torch.randn(2, 5)
@@ -178,6 +195,8 @@ def test_ffg_eval_forward_respects_default_S_equals_1():
178
195
  z, mu, log_var = head(x) # S default = 1
179
196
 
180
197
  assert z.shape == (B, 1, Dz)
198
+ assert mu.shape == (B, Dz)
199
+ assert log_var.shape == (B, Dz)
181
200
  expected_z = mu.unsqueeze(1) # [B, 1, Dz]
182
201
  assert torch.allclose(z, expected_z)
183
202
 
@@ -323,6 +323,22 @@ def test_vae_output_repr_uses_modeloutput_smart_repr():
323
323
  assert f"shape={tuple(log_var.shape)}" in s
324
324
 
325
325
 
326
+ def test_vae_build_runs_encoder_under_no_grad():
327
+ """The build wrapper executes under torch.no_grad(); encoder must see grad disabled."""
328
+ B, in_features, feat_dim, latent_dim = 3, 5, 7, 4
329
+ x = torch.randn(B, in_features)
330
+
331
+ encoder = DummyEncoder(in_features=in_features, feat_dim=feat_dim)
332
+ decoder = DummyDecoder(latent_dim=latent_dim, out_features=in_features)
333
+ vae = VAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
334
+
335
+ torch.set_grad_enabled(True)
336
+ vae.build(x)
337
+
338
+ assert encoder.last_grad_enabled is False # encoder saw no_grad during build
339
+ assert torch.is_grad_enabled() is True # global state restored afterwards
340
+
341
+
326
342
  # ================= compute_loss =================
327
343
 
328
344
  def test_vae_compute_loss_gaussian_likelihood_returns_correct_type():
@@ -460,8 +476,8 @@ def test_vae_compute_loss_bernoulli_likelihood():
460
476
  assert 'kl_divergence' in loss_result.diagnostics
461
477
 
462
478
 
463
- def test_vae_compute_loss_multiple_samples():
464
- """Test compute_loss with S > 1 samples for Monte Carlo estimation."""
479
+ def test_vae_compute_loss_eval_mode_elbo_independent_of_S():
480
+ """In eval mode z = tiled mu, so ELBO must be identical for any S >= 1."""
465
481
  B, in_features, feat_dim, latent_dim = 2, 4, 6, 2
466
482
  x = torch.randn(B, in_features)
467
483
 
@@ -470,24 +486,19 @@ def test_vae_compute_loss_multiple_samples():
470
486
  vae = VAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
471
487
  vae.build(x)
472
488
 
473
- vae.train()
474
- torch.set_grad_enabled(True)
489
+ vae.eval()
490
+ torch.set_grad_enabled(False)
475
491
 
476
- # Forward with S=1
477
- vae_output_s1 = vae.forward(x, S=1)
478
- loss_s1 = vae.compute_loss(x, vae_output_s1)
492
+ out_s1 = vae.forward(x, S=1)
493
+ loss_s1 = vae.compute_loss(x, out_s1)
479
494
 
480
- # Forward with S=5 (more MC samples)
481
- vae_output_s5 = vae.forward(x, S=5)
482
- loss_s5 = vae.compute_loss(x, vae_output_s5)
495
+ out_s5 = vae.forward(x, S=5)
496
+ loss_s5 = vae.compute_loss(x, out_s5)
483
497
 
484
- # Both should produce valid LossResult
485
- assert isinstance(loss_s1.objective, torch.Tensor)
486
- assert isinstance(loss_s5.objective, torch.Tensor)
487
-
488
- # Shapes should match input batch size
489
- assert vae_output_s1.x_hat.shape[0] == B
490
- assert vae_output_s5.x_hat.shape[0] == B
498
+ # All S copies of z are identical (tiled mu), so the MC average is exact
499
+ assert torch.allclose(loss_s1.objective, loss_s5.objective, atol=1e-5)
500
+ assert abs(loss_s1.diagnostics['elbo'] - loss_s5.diagnostics['elbo']) < 1e-5
501
+ assert abs(loss_s1.diagnostics['log_likelihood'] - loss_s5.diagnostics['log_likelihood']) < 1e-5
491
502
 
492
503
 
493
504
  def test_vae_compute_loss_backward_flows_through_all_params():
@@ -538,6 +549,29 @@ def test_vae_compute_loss_batch_size_one():
538
549
  assert not torch.isinf(loss_result.objective)
539
550
 
540
551
 
552
+ def test_vae_compute_loss_diagnostics_elbo_consistency():
553
+ """elbo diagnostic must equal log_likelihood - kl_divergence."""
554
+ B, in_features, feat_dim, latent_dim, S = 3, 5, 7, 2, 2
555
+ x = torch.randn(B, in_features)
556
+
557
+ encoder = DummyEncoder(in_features=in_features, feat_dim=feat_dim)
558
+ decoder = DummyDecoder(latent_dim=latent_dim, out_features=in_features)
559
+ vae = VAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
560
+ vae.build(x)
561
+
562
+ vae.train()
563
+ torch.set_grad_enabled(True)
564
+
565
+ vae_output = vae.forward(x, S=S)
566
+ loss_result = vae.compute_loss(x, vae_output)
567
+
568
+ ll = loss_result.diagnostics['log_likelihood']
569
+ kl = loss_result.diagnostics['kl_divergence']
570
+ elbo = loss_result.diagnostics['elbo']
571
+
572
+ assert abs(elbo - (ll - kl)) < 1e-5
573
+
574
+
541
575
  def test_vae_compute_loss_with_different_likelihood_formats():
542
576
  """Test that compute_loss handles both string and LikelihoodType inputs."""
543
577
  B, in_features, feat_dim, latent_dim, S = 3, 5, 7, 2, 2
@@ -572,10 +606,10 @@ def test_adagvae_inherits_from_vae():
572
606
  """Test that AdaGVAE is a subclass of VAE."""
573
607
  from pyautoencoder.variational.vae import AdaGVAE
574
608
 
575
- B, in_features, latent_dim = 4, 6, 3
609
+ in_features, latent_dim = 6, 3
576
610
  encoder = DummyEncoder(in_features=in_features, feat_dim=10)
577
611
  decoder = DummyDecoder(latent_dim=latent_dim, out_features=in_features)
578
-
612
+
579
613
  adagvae = AdaGVAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
580
614
 
581
615
  assert isinstance(adagvae, VAE)
@@ -594,10 +628,10 @@ def test_adagvae_raises_before_build():
594
628
  x1 = torch.randn(B, in_features)
595
629
  x2 = torch.randn(B, in_features)
596
630
 
597
- with pytest.raises(NotBuiltError):
598
- adagvae.forward(x1, x2)
631
+ with pytest.raises(NotBuiltError, match="Model is not built"):
632
+ adagvae.forward_pair(x1, x2)
599
633
 
600
- with pytest.raises(NotBuiltError):
634
+ with pytest.raises(NotBuiltError, match="Model is not built"):
601
635
  adagvae._encode_pair(x1, x2)
602
636
 
603
637
 
@@ -618,7 +652,7 @@ def test_adagvae_forward_pair_shapes_and_types():
618
652
  torch.set_grad_enabled(True)
619
653
 
620
654
  # Forward returns tuple of two VAEOutput
621
- out1, out2 = adagvae.forward(x1, x2, S=S)
655
+ out1, out2 = adagvae.forward_pair(x1, x2, S=S)
622
656
 
623
657
  assert isinstance(out1, VAEOutput)
624
658
  assert isinstance(out2, VAEOutput)
@@ -682,8 +716,8 @@ def test_adagvae_compute_loss_returns_correct_structure():
682
716
  adagvae.train()
683
717
  torch.set_grad_enabled(True)
684
718
 
685
- out1, out2 = adagvae.forward(x1, x2, S=S)
686
- loss_result = adagvae.compute_loss(x1, out1, x2, out2)
719
+ out1, out2 = adagvae.forward_pair(x1, x2, S=S)
720
+ loss_result = adagvae.compute_pair_loss(x1, out1, x2, out2)
687
721
 
688
722
  # Check return type
689
723
  assert isinstance(loss_result, LossResult)
@@ -726,8 +760,8 @@ def test_adagvae_compute_loss_backward_flows():
726
760
  adagvae.train()
727
761
  torch.set_grad_enabled(True)
728
762
 
729
- out1, out2 = adagvae.forward(x1, x2, S=S)
730
- loss_result = adagvae.compute_loss(x1, out1, x2, out2)
763
+ out1, out2 = adagvae.forward_pair(x1, x2, S=S)
764
+ loss_result = adagvae.compute_pair_loss(x1, out1, x2, out2)
731
765
  loss_result.objective.backward()
732
766
 
733
767
  # Check gradients in all components
@@ -756,13 +790,13 @@ def test_adagvae_compute_loss_with_beta():
756
790
  adagvae.train()
757
791
  torch.set_grad_enabled(True)
758
792
 
759
- out1, out2 = adagvae.forward(x1, x2, S=S)
793
+ out1, out2 = adagvae.forward_pair(x1, x2, S=S)
760
794
 
761
795
  # Compute with beta=1
762
- loss_beta1 = adagvae.compute_loss(x1, out1, x2, out2, beta=1.0)
796
+ loss_beta1 = adagvae.compute_pair_loss(x1, out1, x2, out2, beta=1.0)
763
797
 
764
798
  # Compute with beta=0.5
765
- loss_beta05 = adagvae.compute_loss(x1, out1, x2, out2, beta=0.5)
799
+ loss_beta05 = adagvae.compute_pair_loss(x1, out1, x2, out2, beta=0.5)
766
800
 
767
801
  # ELBO should be different
768
802
  elbo_beta1 = loss_beta1.diagnostics['elbo']
@@ -772,33 +806,34 @@ def test_adagvae_compute_loss_with_beta():
772
806
  assert elbo_beta05 > elbo_beta1
773
807
 
774
808
 
775
- def test_adagvae_adaptive_grouping_aligns_similar_inputs():
776
- """Test that AdaGVAE adaptive grouping works with similar inputs."""
809
+ def test_adagvae_identical_inputs_produce_no_grouping():
810
+ """When x1 == x2, KL(q1||q2) = 0 everywhere, tau = 0, mask is all-False.
811
+ The adapted posteriors must equal the individual (unadapted) posteriors."""
777
812
  from pyautoencoder.variational.vae import AdaGVAE
778
-
779
- B, in_features, feat_dim, latent_dim, S = 3, 5, 7, 2, 1
780
-
781
- # Create similar inputs (nearly identical)
782
- x_base = torch.randn(B, in_features)
783
- x1 = x_base.clone()
784
- x2 = x_base + 0.01 * torch.randn_like(x_base) # Add small noise
813
+
814
+ B, in_features, feat_dim, latent_dim = 3, 5, 7, 4
815
+ x = torch.randn(B, in_features)
785
816
 
786
817
  encoder = DummyEncoder(in_features=in_features, feat_dim=feat_dim)
787
818
  decoder = DummyDecoder(latent_dim=latent_dim, out_features=in_features)
788
819
  adagvae = AdaGVAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
789
- adagvae.build(x1)
820
+ adagvae.build(x)
790
821
 
791
822
  adagvae.eval()
792
823
  torch.set_grad_enabled(False)
793
824
 
794
- out1, out2 = adagvae.forward(x1, x2, S=S)
825
+ # Identical inputs → mu1 == mu2, log_var1 == log_var2 → KL(q1||q2) = 0 everywhere
826
+ # → max_delta = min_delta = 0 → tau = 0 → mask = (0 < 0) = False
827
+ # → no grouping: adapted posteriors equal the original individual posteriors
828
+ enc1_pair, enc2_pair = adagvae._encode_pair(x, x.clone(), S=1)
795
829
 
796
- # For similar inputs, posteriors should be relatively close
797
- assert out1.mu.shape == (B, latent_dim)
798
- assert out2.mu.shape == (B, latent_dim)
799
-
800
- # The adaptive mechanism should produce outputs (shape check is the main test)
801
- assert out1.z.shape == out2.z.shape == (B, S, latent_dim)
830
+ # Reference: single-input encode
831
+ single_enc = adagvae._encode(x, S=1)
832
+
833
+ assert torch.allclose(enc1_pair.mu, single_enc.mu, atol=1e-6)
834
+ assert torch.allclose(enc1_pair.log_var, single_enc.log_var, atol=1e-6)
835
+ assert torch.allclose(enc2_pair.mu, single_enc.mu, atol=1e-6)
836
+ assert torch.allclose(enc2_pair.log_var, single_enc.log_var, atol=1e-6)
802
837
 
803
838
 
804
839
  def test_adagvae_encode_pair_with_different_s():
@@ -831,7 +866,7 @@ def test_adagvae_encode_pair_with_different_s():
831
866
  def test_adagvae_compute_loss_bernoulli_likelihood():
832
867
  """Test AdaGVAE compute_loss with Bernoulli likelihood."""
833
868
  from pyautoencoder.variational.vae import AdaGVAE
834
-
869
+
835
870
  B, in_features, feat_dim, latent_dim, S = 3, 5, 7, 2, 2
836
871
  x1 = torch.sigmoid(torch.randn(B, in_features))
837
872
  x2 = torch.sigmoid(torch.randn(B, in_features))
@@ -844,10 +879,40 @@ def test_adagvae_compute_loss_bernoulli_likelihood():
844
879
  adagvae.train()
845
880
  torch.set_grad_enabled(True)
846
881
 
847
- out1, out2 = adagvae.forward(x1, x2, S=S)
848
- loss_result = adagvae.compute_loss(x1, out1, x2, out2, likelihood='bernoulli')
882
+ out1, out2 = adagvae.forward_pair(x1, x2, S=S)
883
+ loss_result = adagvae.compute_pair_loss(x1, out1, x2, out2, likelihood='bernoulli')
849
884
 
850
885
  assert isinstance(loss_result.objective, torch.Tensor)
851
886
  assert loss_result.objective.dim() == 0
852
887
  assert 'elbo' in loss_result.diagnostics
853
888
 
889
+
890
+ def test_adagvae_compute_pair_loss_equals_sum_of_individual_losses():
891
+ """compute_pair_loss objective == sum of the two individual VAE ELBO losses."""
892
+ from pyautoencoder.variational.vae import AdaGVAE
893
+
894
+ B, in_features, feat_dim, latent_dim, S = 3, 5, 7, 2, 2
895
+ x1 = torch.randn(B, in_features)
896
+ x2 = torch.randn(B, in_features)
897
+
898
+ encoder = DummyEncoder(in_features=in_features, feat_dim=feat_dim)
899
+ decoder = DummyDecoder(latent_dim=latent_dim, out_features=in_features)
900
+ adagvae = AdaGVAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
901
+ adagvae.build(x1)
902
+
903
+ adagvae.train()
904
+ torch.set_grad_enabled(True)
905
+
906
+ out1, out2 = adagvae.forward_pair(x1, x2, S=S)
907
+ pair_loss = adagvae.compute_pair_loss(x1, out1, x2, out2)
908
+
909
+ # compute_pair_loss calls VAE.compute_loss twice and adds objectives
910
+ loss1 = adagvae.compute_loss(x1, out1)
911
+ loss2 = adagvae.compute_loss(x2, out2)
912
+
913
+ assert torch.allclose(pair_loss.objective, loss1.objective + loss2.objective, atol=1e-5)
914
+ assert abs(
915
+ pair_loss.diagnostics['elbo']
916
+ - (loss1.diagnostics['elbo'] + loss2.diagnostics['elbo'])
917
+ ) < 1e-5
918
+
@@ -1,84 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import torch
4
- from torch.utils.data import Dataset
5
- from pathlib import Path
6
- from typing import Optional, Union, Tuple, Callable
7
- import numpy as np
8
- import wget
9
-
10
- class DSprite(Dataset):
11
- """PyTorch dataset wrapper for the dSprites factorized shapes dataset.
12
-
13
- Source:
14
- Matthey et al., "dSprites: Disentanglement test sprites."
15
- Original files hosted by DeepMind on GitHub: https://github.com/google-deepmind/dsprites-dataset
16
-
17
- Files:
18
- - dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz
19
-
20
- Contents (loaded into memory on init):
21
- - X (torch.Tensor[int8]): Binary images, shape [N, 64, 64].
22
- - latents_values (torch.Tensor[float64]): Continuous latent values per sample,
23
- shape [N, 6].
24
- - latents_classes (torch.Tensor[int64]): Discrete latent indices per sample,
25
- shape [N, 6].
26
-
27
- Latent factor order (size):
28
- [color (1), shape (3), scale (6), orientation (40), posX (32), posY (32)]
29
-
30
- Notes:
31
- - Images are binary (0/1) stored as int8; most models will want them converted
32
- to float and possibly normalized. Provide a `transform` to handle this.
33
- - All arrays are fully loaded into CPU memory at construction for fast access.
34
- - Set `download=True` (default) to fetch the NPZ if missing at `root`.
35
- """
36
-
37
- _NPZ_URL = "https://github.com/deepmind/dsprites-dataset/raw/master/"
38
- _NPZ_FILENAME = 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'
39
-
40
- def __init__(self,
41
- root: Optional[Union[str, Path]] = None,
42
- transform: Optional[Callable] = None,
43
- download: bool = True):
44
- """Initialize the dSprites dataset.
45
-
46
- Args:
47
- root (str | pathlib.Path | None): Directory to store/find the NPZ file.
48
- Defaults to "./data/dSprites" when None.
49
- transform (Callable | None): Optional transform applied to each image.
50
- download (bool): If True and the dataset file is not present at `root`,
51
- it will be downloaded from the official GitHub URL.
52
- """
53
- # Assign default if no root is provided
54
- if root is None:
55
- root = Path('data') / 'dSprites'
56
- elif isinstance(root, str):
57
- root = Path(root)
58
-
59
- self.root = root
60
- self.root.mkdir(parents=True, exist_ok=True)
61
- self.filepath = self.root / DSprite._NPZ_FILENAME
62
-
63
- if download and not self.filepath.exists():
64
- url = DSprite._NPZ_URL + DSprite._NPZ_FILENAME
65
- print(f'Downloading dSprites from {url}')
66
- wget.download(url, out=str(self.filepath))
67
- print('\nDownload completed')
68
-
69
- data = np.load(self.filepath, allow_pickle=True)
70
- self.X = torch.as_tensor(data['imgs'], dtype=torch.int8).unsqueeze(1)
71
- self.latents_values = torch.as_tensor(data['latents_values'], dtype=torch.float64)
72
- self.latents_classes = torch.as_tensor(data['latents_classes'], dtype=torch.int64)
73
- self.transform = transform
74
-
75
- def __len__(self) -> int:
76
- return self.X.shape[0]
77
-
78
- def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
79
- img = self.X[idx]
80
- lv = self.latents_values[idx]
81
- lc = self.latents_classes[idx]
82
- if self.transform:
83
- img = self.transform(img)
84
- return img, lv, lc
@@ -1,261 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from typing import Iterable, Tuple, List, Dict, Optional, Any, Callable
4
- from dataclasses import fields, Field
5
- from .._base.base import BaseAutoencoder, BuildGuardMixin
6
-
7
- class HyperAE(BuildGuardMixin, nn.Module):
8
- """
9
- Hypernetwork wrapper around a base autoencoder.
10
-
11
- This module:
12
- - wraps a `BaseAutoencoder` and uses a hypernetwork (2 layered MLP with ReLU)
13
- to generate a subset of its parameters on a per-input basis
14
- - keeps the remaining parameters shared across the batch
15
- """
16
-
17
- _GUARDED = {"forward"}
18
-
19
- def __init__(
20
- self,
21
- base_ae: BaseAutoencoder,
22
- target_modules: Iterable[str] = ("encoder", "sampling_layer", "decoder"),
23
- hidden_dim: int = 256,
24
- ):
25
- """
26
- Initialize the HyperAE wrapper.
27
-
28
- Args:
29
- base_ae:
30
- The underlying autoencoder instance to be controlled by the hypernetwork.
31
- target_modules:
32
- Iterable of module name prefixes inside `base_ae` whose parameters
33
- will be generated by the hypernetwork (e.g. "encoder", "decoder").
34
- hidden_dim:
35
- Hidden dimension for the MLP hypernetwork.
36
- """
37
- super().__init__()
38
-
39
- self.base_ae = base_ae
40
- self.target_modules = tuple(target_modules)
41
- self.hidden_dim = hidden_dim
42
-
43
- # (param_name, original_shape, flat_start, flat_end)
44
- self._param_info: List[Tuple[str, torch.Size, int, int]] = []
45
-
46
- # Parameters that remain shared (usual trainable params).
47
- self._shared_param_dict: Dict[str, torch.Tensor] = {}
48
-
49
- # Total number of scalar parameters generated by the hypernetwork.
50
- self.total_generated_params: int = 0
51
-
52
- # Hypernetwork (built lazily in build()).
53
- self.hypernet: Optional[nn.Module] = None
54
-
55
- # Output type information inferred from a sample call in build().
56
- self._output_type: Optional[type] = None
57
- self._output_field_names: Optional[List[str]] = None
58
-
59
- # Cached ModelOutput field metadata (set in build()).
60
- self._output_fields: Optional[Tuple[Field, ...]] = None
61
-
62
- # Cached vmap metadata / function (set in build()).
63
- self._in_dims_params: Optional[Dict[str, Optional[int]]] = None
64
- self._vmapped_call: Optional[Callable[..., Any]] = None
65
-
66
-
67
- @torch.no_grad()
68
- def build(self, input_sample: torch.Tensor) -> None:
69
- """
70
- Build the hypernetwork and prepare parameter bookkeeping.
71
-
72
- This method:
73
- - builds the underlying `base_ae`
74
- - infers the output `ModelOutput` type and field names
75
- - splits `base_ae` parameters into generated vs shared sets
76
- - constructs the MLP hypernetwork that outputs all generated parameters
77
- Must be called once before using `forward` (enforced by `BuildGuardMixin`).
78
- """
79
- # Ensure the base autoencoder is built (and warmed up via its own build).
80
- self.base_ae.build(input_sample)
81
-
82
- # Inspect a sample output to record the output structure.
83
- sample_out = self.base_ae(input_sample)
84
- self._output_type = type(sample_out)
85
- self._output_fields = fields(sample_out)
86
- self._output_field_names = [f.name for f in self._output_fields]
87
-
88
- # Reset metadata containers.
89
- self._param_info.clear()
90
-
91
- # Walk over all base_ae parameters and mark them as generated or shared.
92
- flat_offset = 0
93
- shared_param_names = []
94
- for name, p in self.base_ae.named_parameters():
95
- if any(name.startswith(m + ".") for m in self.target_modules):
96
- # Parameters inside target modules are generated by the hypernet.
97
- numel = p.numel()
98
- self._param_info.append(
99
- (name, p.shape, flat_offset, flat_offset + numel)
100
- )
101
- flat_offset += numel
102
- # Generated parameters are "owned" by the hypernet, so we freeze them.
103
- p.requires_grad_(False)
104
- else:
105
- # Remaining parameters are shared and stay trainable as usual.
106
- shared_param_names.append(name)
107
- p.requires_grad_(True)
108
-
109
- self.total_generated_params = flat_offset
110
- if self.total_generated_params == 0:
111
- raise ValueError(
112
- f"Zero parameters defined by the hypernetwork. "
113
- f"Check target_modules: {self.target_modules}."
114
- )
115
-
116
- # Cache a dict of the shared parameters for quick access in forward().
117
- self._shared_param_dict = {
118
- name: p
119
- for name, p in self.base_ae.named_parameters()
120
- if name in shared_param_names
121
- }
122
-
123
- # Hypernetwork input dimension is the flattened per-sample input.
124
- in_dims = input_sample[0].numel()
125
- self.hypernet = nn.Sequential(
126
- nn.Flatten(1),
127
- nn.Linear(in_dims, self.hidden_dim),
128
- nn.ReLU(),
129
- nn.Linear(self.hidden_dim, self.hidden_dim),
130
- nn.ReLU(),
131
- nn.Linear(self.hidden_dim, self.total_generated_params),
132
- )
133
-
134
- # Precompute in_dims mapping for vmap: generated params vary on dim 0, shared are None.
135
- # Use the same key ordering as `params` in forward: shared first, then generated
136
- generated_params_names = [name for (name, _, _, _) in self._param_info]
137
-
138
- all_param_names = list(self._shared_param_dict.keys()) + generated_params_names
139
-
140
- self._in_dims_params = {
141
- name: (0 if name in generated_params_names else None)
142
- for name in all_param_names
143
- }
144
-
145
- # Prebuild the vmapped single-sample call.
146
- self._vmapped_call = torch.func.vmap(
147
- self._call_single,
148
- in_dims=(self._in_dims_params, 0),
149
- out_dims=0,
150
- randomness="different",
151
- )
152
-
153
- self._built = True
154
-
155
- def _generated_params(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
156
- """
157
- Run the hypernetwork and reshape its output into per-parameter tensors.
158
-
159
- Args:
160
- x:
161
- Input batch of shape (B, ...) used to condition the hypernetwork.
162
-
163
- Returns:
164
- A dictionary mapping parameter names (inside `target_modules`) to
165
- tensors of shape (B, *original_shape), i.e. one generated parameter
166
- tensor per sample in the batch.
167
- """
168
- B = x.size(0)
169
- flat = self.hypernet(x) # type: ignore -- (B, total_generated_params)
170
-
171
- gen: Dict[str, torch.Tensor] = {}
172
- for name, shape, start, end in self._param_info:
173
- slice_ = flat[:, start:end]
174
- gen[name] = slice_.view(B, *shape)
175
- return gen
176
-
177
- def _call_single(self, params_i: Dict[str, torch.Tensor], x_i: torch.Tensor, **kwargs: Any):
178
- """
179
- Apply the base autoencoder to a single sample with a single param set.
180
-
181
- Args:
182
- params_i:
183
- Parameter dict for this particular sample.
184
- x_i:
185
- Input tensor of shape (...,) for this sample.
186
-
187
- Returns:
188
- A tuple of output tensors corresponding to the `ModelOutput` fields,
189
- with the leading batch dimension (size 1) removed.
190
- """
191
- # print("encoder.1.weight in _call_single:", params_i["encoder.1.weight"].shape)
192
- # print("encoder.1.bias in _call_single:", params_i["encoder.1.bias"].shape)
193
- # print("decoder.0.weight in _call_single:", params_i["decoder.0.weight"].shape)
194
- # print("decoder.0.bias in _call_single:", params_i["decoder.0.bias"].shape)
195
- out = torch.func.functional_call(
196
- self.base_ae,
197
- params_i,
198
- (x_i.unsqueeze(0),),
199
- kwargs=kwargs,
200
- )
201
-
202
- # Construct a new output of the same type, but with squeezed tensors.
203
- squeezed_kwargs: Dict[str, Any] = {}
204
- for f in self._output_fields: # type: ignore
205
- v = getattr(out, f.name)
206
- if torch.is_tensor(v):
207
- squeezed_kwargs[f.name] = v.squeeze(0)
208
- else:
209
- squeezed_kwargs[f.name] = v
210
-
211
- out_squeezed = self._output_type(**squeezed_kwargs) # type: ignore
212
-
213
- # Convert the structured output into a tuple of tensors (for vmap).
214
- tensors_tuple = tuple(
215
- getattr(out_squeezed, name) for name in self._output_field_names # type: ignore
216
- )
217
- return tensors_tuple
218
-
219
- def forward(self, x: torch.Tensor, **kwargs: Any):
220
- """
221
- Forward pass with per-sample generated parameters.
222
-
223
- For each sample in the batch:
224
- - generate a distinct set of parameters for `target_modules`
225
- - combine them with the shared parameters
226
- - call the underlying `base_ae` via `torch.func.functional_call`
227
- The underlying `base_ae` is evaluated in a batched, vectorized way using
228
- `torch.func.vmap`.
229
-
230
- Args:
231
- x:
232
- Input batch of shape (B, ...).
233
- **kwargs:
234
- Additional keyword arguments forwarded to `base_ae.forward`.
235
-
236
- Returns:
237
- A `ModelOutput` instance of the same type/structure as produced by
238
- `base_ae`, but with each field batched over the leading dimension.
239
- """
240
- # Shared parameters are the same for all samples.
241
- shared = self._shared_param_dict
242
- # Generated parameters are per-sample (B, *shape)
243
- generated = self._generated_params(x)
244
-
245
- # Combined view of all parameters (names mapped to tensors).
246
- params: Dict[str, torch.Tensor] = {}
247
- params.update(shared)
248
- params.update(generated)
249
-
250
- # Vectorized application over batch of (params, x) using prebuilt vmap.
251
- batched_tensors_tuple = self._vmapped_call(params, x, **kwargs) # type: ignore
252
-
253
- # Reconstruct a batched `ModelOutput` of the same type as the base AE.
254
- batched_kwargs = {
255
- name: tensor
256
- for name, tensor in zip(self._output_field_names, batched_tensors_tuple) # type: ignore
257
- }
258
- batched_out = self._output_type(**batched_kwargs) # type: ignore
259
-
260
- return batched_out
261
-
File without changes
File without changes
File without changes
File without changes