pyautoencoder 1.1.2__tar.gz → 1.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/PKG-INFO +1 -1
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/_base/base.py +1 -1
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/_version.py +3 -3
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/loss/base.py +3 -31
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/variational/stochastic_layers.py +23 -2
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/variational/vae.py +15 -12
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder.egg-info/PKG-INFO +1 -1
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder.egg-info/SOURCES.txt +0 -4
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/test/_base/test_base.py +0 -1
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/test/loss/test_base_loss.py +7 -24
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/test/vanilla/test_autoencoder.py +20 -14
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/test/variational/test_stochastic_layers.py +19 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/test/variational/test_vae.py +115 -50
- pyautoencoder-1.1.2/pyautoencoder/experimental/__init__.py +0 -0
- pyautoencoder-1.1.2/pyautoencoder/experimental/benchmark_datasets/__init__.py +0 -0
- pyautoencoder-1.1.2/pyautoencoder/experimental/benchmark_datasets/disentanglement.py +0 -84
- pyautoencoder-1.1.2/pyautoencoder/experimental/hypernetworks.py +0 -261
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/.github/workflows/ci.yml +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/.github/workflows/publish.yml +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/.gitignore +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/LICENSE +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/README.md +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/assets/logo_nobackground.png +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/Makefile +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/make.bat +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/requirements.txt +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/api/base.rst +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/api/index.rst +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/api/losses.rst +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/api/models.rst +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/api/stochastic.rst +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/architecture.rst +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/conf.py +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/examples/ae_mnist.rst +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/examples/vae_mnist.rst +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/examples.rst +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/getting_started.rst +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/docs/source/index.rst +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/examples/mnist_ae.py +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/examples/mnist_vae_kingma2013.py +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/__init__.py +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/_base/__init__.py +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/loss/__init__.py +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/vanilla/__init__.py +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/vanilla/autoencoder.py +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder/variational/__init__.py +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder.egg-info/dependency_links.txt +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder.egg-info/requires.txt +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyautoencoder.egg-info/top_level.txt +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/pyproject.toml +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/readthedocs.yaml +0 -0
- {pyautoencoder-1.1.2 → pyautoencoder-1.1.4}/setup.cfg +0 -0
|
@@ -18,7 +18,7 @@ version_tuple: tuple[int | str, ...]
|
|
|
18
18
|
commit_id: str | None
|
|
19
19
|
__commit_id__: str | None
|
|
20
20
|
|
|
21
|
-
__version__ = version = '1.1.
|
|
22
|
-
__version_tuple__ = version_tuple = (1, 1,
|
|
21
|
+
__version__ = version = '1.1.4'
|
|
22
|
+
__version_tuple__ = version_tuple = (1, 1, 4)
|
|
23
23
|
|
|
24
|
-
__commit_id__ = commit_id = '
|
|
24
|
+
__commit_id__ = commit_id = 'g21e483730'
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import math
|
|
2
1
|
import torch
|
|
3
2
|
import torch.nn.functional as F
|
|
4
3
|
from dataclasses import dataclass
|
|
@@ -20,9 +19,6 @@ class LikelihoodType(Enum):
|
|
|
20
19
|
GAUSSIAN = 'gaussian'
|
|
21
20
|
BERNOULLI = 'bernoulli'
|
|
22
21
|
|
|
23
|
-
# Cache for log(2pi) constants per (device, dtype)
|
|
24
|
-
_LOG2PI_CACHE = {}
|
|
25
|
-
|
|
26
22
|
@dataclass(slots=True, repr=True)
|
|
27
23
|
class LossResult:
|
|
28
24
|
r"""Container for loss computation results with objective and diagnostics.
|
|
@@ -45,30 +41,6 @@ class LossResult:
|
|
|
45
41
|
objective: torch.Tensor
|
|
46
42
|
diagnostics: Dict[str, float]
|
|
47
43
|
|
|
48
|
-
def _get_log2pi(x: torch.Tensor) -> torch.Tensor:
|
|
49
|
-
r"""Return a cached value of :math:`\log(2\pi)` for the given device and dtype.
|
|
50
|
-
|
|
51
|
-
This avoids repeatedly allocating the constant for different devices or
|
|
52
|
-
precisions. A separate tensor is cached for each ``(device, dtype)`` pair.
|
|
53
|
-
|
|
54
|
-
Parameters
|
|
55
|
-
----------
|
|
56
|
-
x : torch.Tensor
|
|
57
|
-
A tensor whose ``device`` and ``dtype`` determine which cached value is
|
|
58
|
-
returned or created.
|
|
59
|
-
|
|
60
|
-
Returns
|
|
61
|
-
-------
|
|
62
|
-
torch.Tensor
|
|
63
|
-
A scalar tensor equal to :math:`\log(2\pi)` with the same device and
|
|
64
|
-
dtype as ``x``.
|
|
65
|
-
"""
|
|
66
|
-
|
|
67
|
-
key = (x.device, x.dtype)
|
|
68
|
-
if key not in _LOG2PI_CACHE:
|
|
69
|
-
_LOG2PI_CACHE[key] = torch.tensor(2.0 * math.pi, device=x.device, dtype=x.dtype).log()
|
|
70
|
-
return _LOG2PI_CACHE[key]
|
|
71
|
-
|
|
72
44
|
def log_likelihood(x: torch.Tensor,
|
|
73
45
|
x_hat: torch.Tensor,
|
|
74
46
|
likelihood: Union[str, LikelihoodType] = LikelihoodType.GAUSSIAN) -> torch.Tensor:
|
|
@@ -118,9 +90,9 @@ def log_likelihood(x: torch.Tensor,
|
|
|
118
90
|
|
|
119
91
|
Notes
|
|
120
92
|
-----
|
|
121
|
-
- The Gaussian case
|
|
122
|
-
:math
|
|
123
|
-
|
|
93
|
+
- The Gaussian case omits the normalization constant
|
|
94
|
+
:math:`-\tfrac{1}{2}\log(2\pi)`, which is constant with respect to
|
|
95
|
+
the model parameters and has no effect on optimization.
|
|
124
96
|
- The Bernoulli case is fully numerically stable because it operates
|
|
125
97
|
directly in log-space.
|
|
126
98
|
"""
|
|
@@ -62,7 +62,7 @@ class FullyFactorizedGaussian(nn.Module):
|
|
|
62
62
|
raise TypeError("build(x) expects a torch.Tensor.")
|
|
63
63
|
if input_sample.ndim != 2:
|
|
64
64
|
raise ValueError(f"build(x): expected shape [B, F], got {tuple(input_sample.shape)}. Flatten upstream.")
|
|
65
|
-
if input_sample.shape[1]
|
|
65
|
+
if input_sample.shape[1] == 0:
|
|
66
66
|
raise ValueError("build(x): F (feature dimension) must be > 0.")
|
|
67
67
|
|
|
68
68
|
in_features = int(input_sample.shape[1])
|
|
@@ -128,7 +128,28 @@ class FullyFactorizedGaussian(nn.Module):
|
|
|
128
128
|
|
|
129
129
|
return z, mu, log_var
|
|
130
130
|
|
|
131
|
-
def reparametrize(self, mu: torch.Tensor, log_var: torch.Tensor, S: int = 1):
|
|
131
|
+
def reparametrize(self, mu: torch.Tensor, log_var: torch.Tensor, S: int = 1) -> torch.Tensor:
|
|
132
|
+
r"""Draw ``S`` latent samples via the reparameterization trick.
|
|
133
|
+
|
|
134
|
+
.. math::
|
|
135
|
+
|
|
136
|
+
z^{(s)} = \mu + \sigma \odot \epsilon^{(s)},
|
|
137
|
+
\qquad \epsilon^{(s)} \sim \mathcal{N}(0, I).
|
|
138
|
+
|
|
139
|
+
Parameters
|
|
140
|
+
----------
|
|
141
|
+
mu : torch.Tensor
|
|
142
|
+
Mean of the posterior, shape ``[B, D_z]``.
|
|
143
|
+
log_var : torch.Tensor
|
|
144
|
+
Log-variance of the posterior, shape ``[B, D_z]``.
|
|
145
|
+
S : int, optional
|
|
146
|
+
Number of samples to draw. Defaults to ``1``.
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
torch.Tensor
|
|
151
|
+
Sampled latent codes of shape ``[B, S, D_z]``.
|
|
152
|
+
"""
|
|
132
153
|
std = torch.exp(0.5 * log_var) # [B, Dz]
|
|
133
154
|
mu_e = mu.unsqueeze(1).expand(-1, S, -1) # [B, S, Dz]
|
|
134
155
|
std_e = std.unsqueeze(1).expand(-1, S, -1) # [B, S, Dz]
|
|
@@ -322,11 +322,14 @@ class VAE(BaseAutoencoder):
|
|
|
322
322
|
class AdaGVAE(VAE):
|
|
323
323
|
r"""Adaptive Group Variational Autoencoder (Ada-GVAE), from Locatello et al. (2020).
|
|
324
324
|
|
|
325
|
-
This class extends the VAE class and enables feature disentanglement in the latent space.
|
|
326
|
-
|
|
327
|
-
|
|
325
|
+
This class extends the VAE class and enables feature disentanglement in the latent space.
|
|
326
|
+
Use :meth:`forward_pair` for the paired training pass and :meth:`compute_pair_loss` for
|
|
327
|
+
its loss. The inherited :meth:`encode` / :meth:`decode` methods work normally for inference
|
|
328
|
+
on single inputs.
|
|
328
329
|
"""
|
|
329
330
|
|
|
331
|
+
_GUARDED = VAE._GUARDED | {"_encode_pair"}
|
|
332
|
+
|
|
330
333
|
def __init__(
|
|
331
334
|
self,
|
|
332
335
|
encoder: nn.Module,
|
|
@@ -418,7 +421,7 @@ class AdaGVAE(VAE):
|
|
|
418
421
|
VAEEncodeOutput(z=z2, mu=mu_tilde2, log_var=log_var_tilde2)
|
|
419
422
|
|
|
420
423
|
|
|
421
|
-
def
|
|
424
|
+
def forward_pair(self, x1: torch.Tensor, x2: torch.Tensor, S: int = 1) -> Tuple[VAEOutput, VAEOutput]:
|
|
422
425
|
"""Full AdaGVAE forward pass: encode pairs with adaptive grouping, sample, and decode.
|
|
423
426
|
|
|
424
427
|
Parameters
|
|
@@ -445,14 +448,14 @@ class AdaGVAE(VAE):
|
|
|
445
448
|
x2_dec = self._decode(x2_enc.z)
|
|
446
449
|
return VAEOutput(x_hat=x1_dec.x_hat, z=x1_enc.z, mu=x1_enc.mu, log_var=x1_enc.log_var), \
|
|
447
450
|
VAEOutput(x_hat=x2_dec.x_hat, z=x2_enc.z, mu=x2_enc.mu, log_var=x2_enc.log_var)
|
|
448
|
-
|
|
449
|
-
def
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
451
|
+
|
|
452
|
+
def compute_pair_loss(self,
|
|
453
|
+
x1: torch.Tensor,
|
|
454
|
+
x1_vae_output: VAEOutput,
|
|
455
|
+
x2: torch.Tensor,
|
|
456
|
+
x2_vae_output: VAEOutput,
|
|
457
|
+
beta: float = 1,
|
|
458
|
+
likelihood: Union[str, LikelihoodType] = LikelihoodType.GAUSSIAN) -> LossResult:
|
|
456
459
|
r"""Compute the combined ELBO for a pair of inputs with adaptive posteriors.
|
|
457
460
|
|
|
458
461
|
This method computes the sum of the standard VAE ELBOs for both inputs:
|
|
@@ -32,10 +32,6 @@ pyautoencoder.egg-info/requires.txt
|
|
|
32
32
|
pyautoencoder.egg-info/top_level.txt
|
|
33
33
|
pyautoencoder/_base/__init__.py
|
|
34
34
|
pyautoencoder/_base/base.py
|
|
35
|
-
pyautoencoder/experimental/__init__.py
|
|
36
|
-
pyautoencoder/experimental/hypernetworks.py
|
|
37
|
-
pyautoencoder/experimental/benchmark_datasets/__init__.py
|
|
38
|
-
pyautoencoder/experimental/benchmark_datasets/disentanglement.py
|
|
39
35
|
pyautoencoder/loss/__init__.py
|
|
40
36
|
pyautoencoder/loss/base.py
|
|
41
37
|
pyautoencoder/vanilla/__init__.py
|
|
@@ -14,7 +14,6 @@ from pyautoencoder._base.base import (
|
|
|
14
14
|
# ================= ModelOutput =================
|
|
15
15
|
|
|
16
16
|
def test_model_output_repr_tensors_and_non_tensors():
|
|
17
|
-
@torch.no_grad()
|
|
18
17
|
@dataclass(slots=True, repr=False)
|
|
19
18
|
class MyOutput(ModelOutput):
|
|
20
19
|
logits: torch.Tensor
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import pytest
|
|
2
|
-
import math
|
|
3
2
|
import torch
|
|
4
3
|
import torch.nn.functional as F
|
|
5
4
|
|
|
@@ -252,22 +251,22 @@ def test_kl_divergence_preserves_device():
|
|
|
252
251
|
assert kl.device == mu_q.device
|
|
253
252
|
|
|
254
253
|
|
|
255
|
-
def
|
|
256
|
-
"""
|
|
254
|
+
def test_kl_divergence_is_asymmetric():
|
|
255
|
+
"""KL(q||p) != KL(p||q) in general — KL is not a symmetric distance."""
|
|
257
256
|
B, Dz = 2, 3
|
|
257
|
+
torch.manual_seed(0)
|
|
258
258
|
mu_q = torch.randn(B, Dz)
|
|
259
259
|
log_var_q = torch.randn(B, Dz)
|
|
260
260
|
mu_p = torch.randn(B, Dz)
|
|
261
261
|
log_var_p = torch.randn(B, Dz)
|
|
262
|
-
|
|
262
|
+
|
|
263
263
|
kl_q_p = kl_divergence_diag_gaussian(mu_q, log_var_q, mu_p, log_var_p)
|
|
264
264
|
kl_p_q = kl_divergence_diag_gaussian(mu_p, log_var_p, mu_q, log_var_q)
|
|
265
|
-
|
|
266
|
-
# KL divergence is asymmetric in general
|
|
267
|
-
# (unless the distributions happen to be very similar)
|
|
268
|
-
# Just check both are valid
|
|
265
|
+
|
|
269
266
|
assert torch.isfinite(kl_q_p).all()
|
|
270
267
|
assert torch.isfinite(kl_p_q).all()
|
|
268
|
+
# For random, distinct Gaussians this will virtually never be equal
|
|
269
|
+
assert not torch.allclose(kl_q_p, kl_p_q)
|
|
271
270
|
|
|
272
271
|
|
|
273
272
|
def test_kl_divergence_backward_flows_gradients():
|
|
@@ -306,22 +305,6 @@ def test_kl_divergence_with_custom_prior_backward():
|
|
|
306
305
|
assert log_var_p.grad is not None
|
|
307
306
|
|
|
308
307
|
|
|
309
|
-
def test_kl_divergence_matches_pytorch_implementation():
|
|
310
|
-
"""Compare with a reference PyTorch implementation."""
|
|
311
|
-
B, Dz = 3, 4
|
|
312
|
-
mu_q = torch.randn(B, Dz)
|
|
313
|
-
log_var_q = torch.randn(B, Dz)
|
|
314
|
-
|
|
315
|
-
# Our implementation
|
|
316
|
-
kl_ours = kl_divergence_diag_gaussian(mu_q, log_var_q)
|
|
317
|
-
|
|
318
|
-
# Reference implementation (standard VAE KL)
|
|
319
|
-
var_q = log_var_q.exp()
|
|
320
|
-
kl_ref = 0.5 * torch.sum(-log_var_q + var_q + mu_q.pow(2) - 1, dim=-1)
|
|
321
|
-
|
|
322
|
-
assert torch.allclose(kl_ours, kl_ref, atol=1e-5)
|
|
323
|
-
|
|
324
|
-
|
|
325
308
|
def test_kl_divergence_large_batch():
|
|
326
309
|
"""Test with larger batch size."""
|
|
327
310
|
B, Dz = 128, 32
|
|
@@ -247,8 +247,8 @@ def test_ae_compute_loss_gaussian_likelihood_returns_correct_type():
|
|
|
247
247
|
assert isinstance(loss_result.diagnostics['log_likelihood'], float)
|
|
248
248
|
|
|
249
249
|
|
|
250
|
-
def
|
|
251
|
-
"""Test that NLL
|
|
250
|
+
def test_ae_compute_loss_gaussian_nll_equals_half_mse():
|
|
251
|
+
"""Test that the Gaussian NLL objective equals 0.5 * mean MSE."""
|
|
252
252
|
batch_size = 5
|
|
253
253
|
in_features = 6
|
|
254
254
|
latent_features = 2
|
|
@@ -265,12 +265,15 @@ def test_ae_compute_loss_gaussian_likelihood_is_nonnegative():
|
|
|
265
265
|
|
|
266
266
|
loss_result = ae.compute_loss(x, ae_output, likelihood='gaussian')
|
|
267
267
|
|
|
268
|
-
# NLL
|
|
269
|
-
|
|
268
|
+
# Gaussian NLL (without normalization constant) == 0.5 * per-sample MSE, batch-averaged
|
|
269
|
+
expected = 0.5 * ((ae_output.x_hat - x) ** 2).reshape(batch_size, -1).sum(-1).mean()
|
|
270
|
+
assert torch.allclose(loss_result.objective, expected, atol=1e-6)
|
|
270
271
|
|
|
271
272
|
|
|
272
|
-
def
|
|
273
|
-
"""
|
|
273
|
+
def test_ae_compute_loss_bernoulli_nll_equals_bce():
|
|
274
|
+
"""Bernoulli NLL equals sum-over-features BCE, batch-averaged."""
|
|
275
|
+
import torch.nn.functional as F
|
|
276
|
+
|
|
274
277
|
batch_size = 4
|
|
275
278
|
in_features = 8
|
|
276
279
|
latent_features = 3
|
|
@@ -279,7 +282,7 @@ def test_ae_compute_loss_bernoulli_likelihood():
|
|
|
279
282
|
decoder = SimpleDecoder(latent_features=latent_features, out_features=in_features)
|
|
280
283
|
ae = AE(encoder=encoder, decoder=decoder)
|
|
281
284
|
|
|
282
|
-
x = torch.sigmoid(torch.randn(batch_size, in_features)) #
|
|
285
|
+
x = torch.sigmoid(torch.randn(batch_size, in_features)) # targets in [0, 1]
|
|
283
286
|
ae.build(x)
|
|
284
287
|
|
|
285
288
|
torch.set_grad_enabled(True)
|
|
@@ -287,16 +290,18 @@ def test_ae_compute_loss_bernoulli_likelihood():
|
|
|
287
290
|
|
|
288
291
|
loss_result = ae.compute_loss(x, ae_output, likelihood='bernoulli')
|
|
289
292
|
|
|
290
|
-
#
|
|
291
|
-
|
|
292
|
-
|
|
293
|
+
# NLL = mean over batch of (sum over features of BCE)
|
|
294
|
+
expected = F.binary_cross_entropy_with_logits(
|
|
295
|
+
ae_output.x_hat, x, reduction='none'
|
|
296
|
+
).reshape(batch_size, -1).sum(-1).mean()
|
|
297
|
+
|
|
293
298
|
assert loss_result.objective.dim() == 0
|
|
299
|
+
assert torch.allclose(loss_result.objective, expected, atol=1e-6)
|
|
294
300
|
assert 'log_likelihood' in loss_result.diagnostics
|
|
295
|
-
assert loss_result.objective.item() >= 0
|
|
296
301
|
|
|
297
302
|
|
|
298
|
-
def
|
|
299
|
-
"""Test that gradients flow
|
|
303
|
+
def test_ae_compute_loss_backward_flows_through_all_params():
|
|
304
|
+
"""Test that gradients flow through both encoder and decoder."""
|
|
300
305
|
batch_size = 2
|
|
301
306
|
in_features = 4
|
|
302
307
|
latent_features = 2
|
|
@@ -314,8 +319,9 @@ def test_ae_compute_loss_backward_flows_through_x_hat():
|
|
|
314
319
|
loss_result = ae.compute_loss(x, ae_output)
|
|
315
320
|
loss_result.objective.backward()
|
|
316
321
|
|
|
317
|
-
|
|
322
|
+
enc_grads = [p.grad for p in encoder.parameters() if p.requires_grad]
|
|
318
323
|
dec_grads = [p.grad for p in decoder.parameters() if p.requires_grad]
|
|
324
|
+
assert any(g is not None and torch.any(g != 0) for g in enc_grads)
|
|
319
325
|
assert any(g is not None and torch.any(g != 0) for g in dec_grads)
|
|
320
326
|
|
|
321
327
|
|
|
@@ -87,6 +87,23 @@ def test_ffg_build_can_be_called_twice_with_same_feature_dim():
|
|
|
87
87
|
assert isinstance(head.mu, nn.Linear)
|
|
88
88
|
assert head.mu.in_features == F
|
|
89
89
|
|
|
90
|
+
|
|
91
|
+
def test_ffg_build_replaces_layers_on_different_feature_dim():
|
|
92
|
+
"""Rebuilding with a different F must replace mu and log_var layers."""
|
|
93
|
+
latent_dim = 3
|
|
94
|
+
head = FullyFactorizedGaussian(latent_dim=latent_dim)
|
|
95
|
+
|
|
96
|
+
head.build(torch.randn(2, 5))
|
|
97
|
+
assert head.in_features == 5
|
|
98
|
+
assert isinstance(head.mu, nn.Linear) and head.mu.in_features == 5
|
|
99
|
+
assert isinstance(head.log_var, nn.Linear) and head.log_var.in_features == 5
|
|
100
|
+
|
|
101
|
+
head.build(torch.randn(2, 8))
|
|
102
|
+
assert head.in_features == 8
|
|
103
|
+
assert isinstance(head.mu, nn.Linear) and head.mu.in_features == 8
|
|
104
|
+
assert isinstance(head.log_var, nn.Linear) and head.log_var.in_features == 8
|
|
105
|
+
assert head.built is True
|
|
106
|
+
|
|
90
107
|
def test_ffg_forward_raises_if_not_built():
|
|
91
108
|
head = FullyFactorizedGaussian(latent_dim=3)
|
|
92
109
|
x = torch.randn(2, 5)
|
|
@@ -178,6 +195,8 @@ def test_ffg_eval_forward_respects_default_S_equals_1():
|
|
|
178
195
|
z, mu, log_var = head(x) # S default = 1
|
|
179
196
|
|
|
180
197
|
assert z.shape == (B, 1, Dz)
|
|
198
|
+
assert mu.shape == (B, Dz)
|
|
199
|
+
assert log_var.shape == (B, Dz)
|
|
181
200
|
expected_z = mu.unsqueeze(1) # [B, 1, Dz]
|
|
182
201
|
assert torch.allclose(z, expected_z)
|
|
183
202
|
|
|
@@ -323,6 +323,22 @@ def test_vae_output_repr_uses_modeloutput_smart_repr():
|
|
|
323
323
|
assert f"shape={tuple(log_var.shape)}" in s
|
|
324
324
|
|
|
325
325
|
|
|
326
|
+
def test_vae_build_runs_encoder_under_no_grad():
|
|
327
|
+
"""The build wrapper executes under torch.no_grad(); encoder must see grad disabled."""
|
|
328
|
+
B, in_features, feat_dim, latent_dim = 3, 5, 7, 4
|
|
329
|
+
x = torch.randn(B, in_features)
|
|
330
|
+
|
|
331
|
+
encoder = DummyEncoder(in_features=in_features, feat_dim=feat_dim)
|
|
332
|
+
decoder = DummyDecoder(latent_dim=latent_dim, out_features=in_features)
|
|
333
|
+
vae = VAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
|
|
334
|
+
|
|
335
|
+
torch.set_grad_enabled(True)
|
|
336
|
+
vae.build(x)
|
|
337
|
+
|
|
338
|
+
assert encoder.last_grad_enabled is False # encoder saw no_grad during build
|
|
339
|
+
assert torch.is_grad_enabled() is True # global state restored afterwards
|
|
340
|
+
|
|
341
|
+
|
|
326
342
|
# ================= compute_loss =================
|
|
327
343
|
|
|
328
344
|
def test_vae_compute_loss_gaussian_likelihood_returns_correct_type():
|
|
@@ -460,8 +476,8 @@ def test_vae_compute_loss_bernoulli_likelihood():
|
|
|
460
476
|
assert 'kl_divergence' in loss_result.diagnostics
|
|
461
477
|
|
|
462
478
|
|
|
463
|
-
def
|
|
464
|
-
"""
|
|
479
|
+
def test_vae_compute_loss_eval_mode_elbo_independent_of_S():
|
|
480
|
+
"""In eval mode z = tiled mu, so ELBO must be identical for any S >= 1."""
|
|
465
481
|
B, in_features, feat_dim, latent_dim = 2, 4, 6, 2
|
|
466
482
|
x = torch.randn(B, in_features)
|
|
467
483
|
|
|
@@ -470,24 +486,19 @@ def test_vae_compute_loss_multiple_samples():
|
|
|
470
486
|
vae = VAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
|
|
471
487
|
vae.build(x)
|
|
472
488
|
|
|
473
|
-
vae.
|
|
474
|
-
torch.set_grad_enabled(
|
|
489
|
+
vae.eval()
|
|
490
|
+
torch.set_grad_enabled(False)
|
|
475
491
|
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
loss_s1 = vae.compute_loss(x, vae_output_s1)
|
|
492
|
+
out_s1 = vae.forward(x, S=1)
|
|
493
|
+
loss_s1 = vae.compute_loss(x, out_s1)
|
|
479
494
|
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
loss_s5 = vae.compute_loss(x, vae_output_s5)
|
|
495
|
+
out_s5 = vae.forward(x, S=5)
|
|
496
|
+
loss_s5 = vae.compute_loss(x, out_s5)
|
|
483
497
|
|
|
484
|
-
#
|
|
485
|
-
assert
|
|
486
|
-
assert
|
|
487
|
-
|
|
488
|
-
# Shapes should match input batch size
|
|
489
|
-
assert vae_output_s1.x_hat.shape[0] == B
|
|
490
|
-
assert vae_output_s5.x_hat.shape[0] == B
|
|
498
|
+
# All S copies of z are identical (tiled mu), so the MC average is exact
|
|
499
|
+
assert torch.allclose(loss_s1.objective, loss_s5.objective, atol=1e-5)
|
|
500
|
+
assert abs(loss_s1.diagnostics['elbo'] - loss_s5.diagnostics['elbo']) < 1e-5
|
|
501
|
+
assert abs(loss_s1.diagnostics['log_likelihood'] - loss_s5.diagnostics['log_likelihood']) < 1e-5
|
|
491
502
|
|
|
492
503
|
|
|
493
504
|
def test_vae_compute_loss_backward_flows_through_all_params():
|
|
@@ -538,6 +549,29 @@ def test_vae_compute_loss_batch_size_one():
|
|
|
538
549
|
assert not torch.isinf(loss_result.objective)
|
|
539
550
|
|
|
540
551
|
|
|
552
|
+
def test_vae_compute_loss_diagnostics_elbo_consistency():
|
|
553
|
+
"""elbo diagnostic must equal log_likelihood - kl_divergence."""
|
|
554
|
+
B, in_features, feat_dim, latent_dim, S = 3, 5, 7, 2, 2
|
|
555
|
+
x = torch.randn(B, in_features)
|
|
556
|
+
|
|
557
|
+
encoder = DummyEncoder(in_features=in_features, feat_dim=feat_dim)
|
|
558
|
+
decoder = DummyDecoder(latent_dim=latent_dim, out_features=in_features)
|
|
559
|
+
vae = VAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
|
|
560
|
+
vae.build(x)
|
|
561
|
+
|
|
562
|
+
vae.train()
|
|
563
|
+
torch.set_grad_enabled(True)
|
|
564
|
+
|
|
565
|
+
vae_output = vae.forward(x, S=S)
|
|
566
|
+
loss_result = vae.compute_loss(x, vae_output)
|
|
567
|
+
|
|
568
|
+
ll = loss_result.diagnostics['log_likelihood']
|
|
569
|
+
kl = loss_result.diagnostics['kl_divergence']
|
|
570
|
+
elbo = loss_result.diagnostics['elbo']
|
|
571
|
+
|
|
572
|
+
assert abs(elbo - (ll - kl)) < 1e-5
|
|
573
|
+
|
|
574
|
+
|
|
541
575
|
def test_vae_compute_loss_with_different_likelihood_formats():
|
|
542
576
|
"""Test that compute_loss handles both string and LikelihoodType inputs."""
|
|
543
577
|
B, in_features, feat_dim, latent_dim, S = 3, 5, 7, 2, 2
|
|
@@ -572,10 +606,10 @@ def test_adagvae_inherits_from_vae():
|
|
|
572
606
|
"""Test that AdaGVAE is a subclass of VAE."""
|
|
573
607
|
from pyautoencoder.variational.vae import AdaGVAE
|
|
574
608
|
|
|
575
|
-
|
|
609
|
+
in_features, latent_dim = 6, 3
|
|
576
610
|
encoder = DummyEncoder(in_features=in_features, feat_dim=10)
|
|
577
611
|
decoder = DummyDecoder(latent_dim=latent_dim, out_features=in_features)
|
|
578
|
-
|
|
612
|
+
|
|
579
613
|
adagvae = AdaGVAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
|
|
580
614
|
|
|
581
615
|
assert isinstance(adagvae, VAE)
|
|
@@ -594,10 +628,10 @@ def test_adagvae_raises_before_build():
|
|
|
594
628
|
x1 = torch.randn(B, in_features)
|
|
595
629
|
x2 = torch.randn(B, in_features)
|
|
596
630
|
|
|
597
|
-
with pytest.raises(NotBuiltError):
|
|
598
|
-
adagvae.
|
|
631
|
+
with pytest.raises(NotBuiltError, match="Model is not built"):
|
|
632
|
+
adagvae.forward_pair(x1, x2)
|
|
599
633
|
|
|
600
|
-
with pytest.raises(NotBuiltError):
|
|
634
|
+
with pytest.raises(NotBuiltError, match="Model is not built"):
|
|
601
635
|
adagvae._encode_pair(x1, x2)
|
|
602
636
|
|
|
603
637
|
|
|
@@ -618,7 +652,7 @@ def test_adagvae_forward_pair_shapes_and_types():
|
|
|
618
652
|
torch.set_grad_enabled(True)
|
|
619
653
|
|
|
620
654
|
# Forward returns tuple of two VAEOutput
|
|
621
|
-
out1, out2 = adagvae.
|
|
655
|
+
out1, out2 = adagvae.forward_pair(x1, x2, S=S)
|
|
622
656
|
|
|
623
657
|
assert isinstance(out1, VAEOutput)
|
|
624
658
|
assert isinstance(out2, VAEOutput)
|
|
@@ -682,8 +716,8 @@ def test_adagvae_compute_loss_returns_correct_structure():
|
|
|
682
716
|
adagvae.train()
|
|
683
717
|
torch.set_grad_enabled(True)
|
|
684
718
|
|
|
685
|
-
out1, out2 = adagvae.
|
|
686
|
-
loss_result = adagvae.
|
|
719
|
+
out1, out2 = adagvae.forward_pair(x1, x2, S=S)
|
|
720
|
+
loss_result = adagvae.compute_pair_loss(x1, out1, x2, out2)
|
|
687
721
|
|
|
688
722
|
# Check return type
|
|
689
723
|
assert isinstance(loss_result, LossResult)
|
|
@@ -726,8 +760,8 @@ def test_adagvae_compute_loss_backward_flows():
|
|
|
726
760
|
adagvae.train()
|
|
727
761
|
torch.set_grad_enabled(True)
|
|
728
762
|
|
|
729
|
-
out1, out2 = adagvae.
|
|
730
|
-
loss_result = adagvae.
|
|
763
|
+
out1, out2 = adagvae.forward_pair(x1, x2, S=S)
|
|
764
|
+
loss_result = adagvae.compute_pair_loss(x1, out1, x2, out2)
|
|
731
765
|
loss_result.objective.backward()
|
|
732
766
|
|
|
733
767
|
# Check gradients in all components
|
|
@@ -756,13 +790,13 @@ def test_adagvae_compute_loss_with_beta():
|
|
|
756
790
|
adagvae.train()
|
|
757
791
|
torch.set_grad_enabled(True)
|
|
758
792
|
|
|
759
|
-
out1, out2 = adagvae.
|
|
793
|
+
out1, out2 = adagvae.forward_pair(x1, x2, S=S)
|
|
760
794
|
|
|
761
795
|
# Compute with beta=1
|
|
762
|
-
loss_beta1 = adagvae.
|
|
796
|
+
loss_beta1 = adagvae.compute_pair_loss(x1, out1, x2, out2, beta=1.0)
|
|
763
797
|
|
|
764
798
|
# Compute with beta=0.5
|
|
765
|
-
loss_beta05 = adagvae.
|
|
799
|
+
loss_beta05 = adagvae.compute_pair_loss(x1, out1, x2, out2, beta=0.5)
|
|
766
800
|
|
|
767
801
|
# ELBO should be different
|
|
768
802
|
elbo_beta1 = loss_beta1.diagnostics['elbo']
|
|
@@ -772,33 +806,34 @@ def test_adagvae_compute_loss_with_beta():
|
|
|
772
806
|
assert elbo_beta05 > elbo_beta1
|
|
773
807
|
|
|
774
808
|
|
|
775
|
-
def
|
|
776
|
-
"""
|
|
809
|
+
def test_adagvae_identical_inputs_produce_no_grouping():
|
|
810
|
+
"""When x1 == x2, KL(q1||q2) = 0 everywhere, tau = 0, mask is all-False.
|
|
811
|
+
The adapted posteriors must equal the individual (unadapted) posteriors."""
|
|
777
812
|
from pyautoencoder.variational.vae import AdaGVAE
|
|
778
|
-
|
|
779
|
-
B, in_features, feat_dim, latent_dim
|
|
780
|
-
|
|
781
|
-
# Create similar inputs (nearly identical)
|
|
782
|
-
x_base = torch.randn(B, in_features)
|
|
783
|
-
x1 = x_base.clone()
|
|
784
|
-
x2 = x_base + 0.01 * torch.randn_like(x_base) # Add small noise
|
|
813
|
+
|
|
814
|
+
B, in_features, feat_dim, latent_dim = 3, 5, 7, 4
|
|
815
|
+
x = torch.randn(B, in_features)
|
|
785
816
|
|
|
786
817
|
encoder = DummyEncoder(in_features=in_features, feat_dim=feat_dim)
|
|
787
818
|
decoder = DummyDecoder(latent_dim=latent_dim, out_features=in_features)
|
|
788
819
|
adagvae = AdaGVAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
|
|
789
|
-
adagvae.build(
|
|
820
|
+
adagvae.build(x)
|
|
790
821
|
|
|
791
822
|
adagvae.eval()
|
|
792
823
|
torch.set_grad_enabled(False)
|
|
793
824
|
|
|
794
|
-
|
|
825
|
+
# Identical inputs → mu1 == mu2, log_var1 == log_var2 → KL(q1||q2) = 0 everywhere
|
|
826
|
+
# → max_delta = min_delta = 0 → tau = 0 → mask = (0 < 0) = False
|
|
827
|
+
# → no grouping: adapted posteriors equal the original individual posteriors
|
|
828
|
+
enc1_pair, enc2_pair = adagvae._encode_pair(x, x.clone(), S=1)
|
|
795
829
|
|
|
796
|
-
#
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
assert
|
|
830
|
+
# Reference: single-input encode
|
|
831
|
+
single_enc = adagvae._encode(x, S=1)
|
|
832
|
+
|
|
833
|
+
assert torch.allclose(enc1_pair.mu, single_enc.mu, atol=1e-6)
|
|
834
|
+
assert torch.allclose(enc1_pair.log_var, single_enc.log_var, atol=1e-6)
|
|
835
|
+
assert torch.allclose(enc2_pair.mu, single_enc.mu, atol=1e-6)
|
|
836
|
+
assert torch.allclose(enc2_pair.log_var, single_enc.log_var, atol=1e-6)
|
|
802
837
|
|
|
803
838
|
|
|
804
839
|
def test_adagvae_encode_pair_with_different_s():
|
|
@@ -831,7 +866,7 @@ def test_adagvae_encode_pair_with_different_s():
|
|
|
831
866
|
def test_adagvae_compute_loss_bernoulli_likelihood():
|
|
832
867
|
"""Test AdaGVAE compute_loss with Bernoulli likelihood."""
|
|
833
868
|
from pyautoencoder.variational.vae import AdaGVAE
|
|
834
|
-
|
|
869
|
+
|
|
835
870
|
B, in_features, feat_dim, latent_dim, S = 3, 5, 7, 2, 2
|
|
836
871
|
x1 = torch.sigmoid(torch.randn(B, in_features))
|
|
837
872
|
x2 = torch.sigmoid(torch.randn(B, in_features))
|
|
@@ -844,10 +879,40 @@ def test_adagvae_compute_loss_bernoulli_likelihood():
|
|
|
844
879
|
adagvae.train()
|
|
845
880
|
torch.set_grad_enabled(True)
|
|
846
881
|
|
|
847
|
-
out1, out2 = adagvae.
|
|
848
|
-
loss_result = adagvae.
|
|
882
|
+
out1, out2 = adagvae.forward_pair(x1, x2, S=S)
|
|
883
|
+
loss_result = adagvae.compute_pair_loss(x1, out1, x2, out2, likelihood='bernoulli')
|
|
849
884
|
|
|
850
885
|
assert isinstance(loss_result.objective, torch.Tensor)
|
|
851
886
|
assert loss_result.objective.dim() == 0
|
|
852
887
|
assert 'elbo' in loss_result.diagnostics
|
|
853
888
|
|
|
889
|
+
|
|
890
|
+
def test_adagvae_compute_pair_loss_equals_sum_of_individual_losses():
|
|
891
|
+
"""compute_pair_loss objective == sum of the two individual VAE ELBO losses."""
|
|
892
|
+
from pyautoencoder.variational.vae import AdaGVAE
|
|
893
|
+
|
|
894
|
+
B, in_features, feat_dim, latent_dim, S = 3, 5, 7, 2, 2
|
|
895
|
+
x1 = torch.randn(B, in_features)
|
|
896
|
+
x2 = torch.randn(B, in_features)
|
|
897
|
+
|
|
898
|
+
encoder = DummyEncoder(in_features=in_features, feat_dim=feat_dim)
|
|
899
|
+
decoder = DummyDecoder(latent_dim=latent_dim, out_features=in_features)
|
|
900
|
+
adagvae = AdaGVAE(encoder=encoder, decoder=decoder, latent_dim=latent_dim)
|
|
901
|
+
adagvae.build(x1)
|
|
902
|
+
|
|
903
|
+
adagvae.train()
|
|
904
|
+
torch.set_grad_enabled(True)
|
|
905
|
+
|
|
906
|
+
out1, out2 = adagvae.forward_pair(x1, x2, S=S)
|
|
907
|
+
pair_loss = adagvae.compute_pair_loss(x1, out1, x2, out2)
|
|
908
|
+
|
|
909
|
+
# compute_pair_loss calls VAE.compute_loss twice and adds objectives
|
|
910
|
+
loss1 = adagvae.compute_loss(x1, out1)
|
|
911
|
+
loss2 = adagvae.compute_loss(x2, out2)
|
|
912
|
+
|
|
913
|
+
assert torch.allclose(pair_loss.objective, loss1.objective + loss2.objective, atol=1e-5)
|
|
914
|
+
assert abs(
|
|
915
|
+
pair_loss.diagnostics['elbo']
|
|
916
|
+
- (loss1.diagnostics['elbo'] + loss2.diagnostics['elbo'])
|
|
917
|
+
) < 1e-5
|
|
918
|
+
|
|
File without changes
|
|
File without changes
|
|
@@ -1,84 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torch.utils.data import Dataset
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from typing import Optional, Union, Tuple, Callable
|
|
7
|
-
import numpy as np
|
|
8
|
-
import wget
|
|
9
|
-
|
|
10
|
-
class DSprite(Dataset):
|
|
11
|
-
"""PyTorch dataset wrapper for the dSprites factorized shapes dataset.
|
|
12
|
-
|
|
13
|
-
Source:
|
|
14
|
-
Matthey et al., "dSprites: Disentanglement test sprites."
|
|
15
|
-
Original files hosted by DeepMind on GitHub: https://github.com/google-deepmind/dsprites-dataset
|
|
16
|
-
|
|
17
|
-
Files:
|
|
18
|
-
- dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz
|
|
19
|
-
|
|
20
|
-
Contents (loaded into memory on init):
|
|
21
|
-
- X (torch.Tensor[int8]): Binary images, shape [N, 64, 64].
|
|
22
|
-
- latents_values (torch.Tensor[float64]): Continuous latent values per sample,
|
|
23
|
-
shape [N, 6].
|
|
24
|
-
- latents_classes (torch.Tensor[int64]): Discrete latent indices per sample,
|
|
25
|
-
shape [N, 6].
|
|
26
|
-
|
|
27
|
-
Latent factor order (size):
|
|
28
|
-
[color (1), shape (3), scale (6), orientation (40), posX (32), posY (32)]
|
|
29
|
-
|
|
30
|
-
Notes:
|
|
31
|
-
- Images are binary (0/1) stored as int8; most models will want them converted
|
|
32
|
-
to float and possibly normalized. Provide a `transform` to handle this.
|
|
33
|
-
- All arrays are fully loaded into CPU memory at construction for fast access.
|
|
34
|
-
- Set `download=True` (default) to fetch the NPZ if missing at `root`.
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
_NPZ_URL = "https://github.com/deepmind/dsprites-dataset/raw/master/"
|
|
38
|
-
_NPZ_FILENAME = 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'
|
|
39
|
-
|
|
40
|
-
def __init__(self,
|
|
41
|
-
root: Optional[Union[str, Path]] = None,
|
|
42
|
-
transform: Optional[Callable] = None,
|
|
43
|
-
download: bool = True):
|
|
44
|
-
"""Initialize the dSprites dataset.
|
|
45
|
-
|
|
46
|
-
Args:
|
|
47
|
-
root (str | pathlib.Path | None): Directory to store/find the NPZ file.
|
|
48
|
-
Defaults to "./data/dSprites" when None.
|
|
49
|
-
transform (Callable | None): Optional transform applied to each image.
|
|
50
|
-
download (bool): If True and the dataset file is not present at `root`,
|
|
51
|
-
it will be downloaded from the official GitHub URL.
|
|
52
|
-
"""
|
|
53
|
-
# Assign default if no root is provided
|
|
54
|
-
if root is None:
|
|
55
|
-
root = Path('data') / 'dSprites'
|
|
56
|
-
elif isinstance(root, str):
|
|
57
|
-
root = Path(root)
|
|
58
|
-
|
|
59
|
-
self.root = root
|
|
60
|
-
self.root.mkdir(parents=True, exist_ok=True)
|
|
61
|
-
self.filepath = self.root / DSprite._NPZ_FILENAME
|
|
62
|
-
|
|
63
|
-
if download and not self.filepath.exists():
|
|
64
|
-
url = DSprite._NPZ_URL + DSprite._NPZ_FILENAME
|
|
65
|
-
print(f'Downloading dSprites from {url}')
|
|
66
|
-
wget.download(url, out=str(self.filepath))
|
|
67
|
-
print('\nDownload completed')
|
|
68
|
-
|
|
69
|
-
data = np.load(self.filepath, allow_pickle=True)
|
|
70
|
-
self.X = torch.as_tensor(data['imgs'], dtype=torch.int8).unsqueeze(1)
|
|
71
|
-
self.latents_values = torch.as_tensor(data['latents_values'], dtype=torch.float64)
|
|
72
|
-
self.latents_classes = torch.as_tensor(data['latents_classes'], dtype=torch.int64)
|
|
73
|
-
self.transform = transform
|
|
74
|
-
|
|
75
|
-
def __len__(self) -> int:
|
|
76
|
-
return self.X.shape[0]
|
|
77
|
-
|
|
78
|
-
def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
79
|
-
img = self.X[idx]
|
|
80
|
-
lv = self.latents_values[idx]
|
|
81
|
-
lc = self.latents_classes[idx]
|
|
82
|
-
if self.transform:
|
|
83
|
-
img = self.transform(img)
|
|
84
|
-
return img, lv, lc
|
|
@@ -1,261 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import torch.nn as nn
|
|
3
|
-
from typing import Iterable, Tuple, List, Dict, Optional, Any, Callable
|
|
4
|
-
from dataclasses import fields, Field
|
|
5
|
-
from .._base.base import BaseAutoencoder, BuildGuardMixin
|
|
6
|
-
|
|
7
|
-
class HyperAE(BuildGuardMixin, nn.Module):
|
|
8
|
-
"""
|
|
9
|
-
Hypernetwork wrapper around a base autoencoder.
|
|
10
|
-
|
|
11
|
-
This module:
|
|
12
|
-
- wraps a `BaseAutoencoder` and uses a hypernetwork (2 layered MLP with ReLU)
|
|
13
|
-
to generate a subset of its parameters on a per-input basis
|
|
14
|
-
- keeps the remaining parameters shared across the batch
|
|
15
|
-
"""
|
|
16
|
-
|
|
17
|
-
_GUARDED = {"forward"}
|
|
18
|
-
|
|
19
|
-
def __init__(
|
|
20
|
-
self,
|
|
21
|
-
base_ae: BaseAutoencoder,
|
|
22
|
-
target_modules: Iterable[str] = ("encoder", "sampling_layer", "decoder"),
|
|
23
|
-
hidden_dim: int = 256,
|
|
24
|
-
):
|
|
25
|
-
"""
|
|
26
|
-
Initialize the HyperAE wrapper.
|
|
27
|
-
|
|
28
|
-
Args:
|
|
29
|
-
base_ae:
|
|
30
|
-
The underlying autoencoder instance to be controlled by the hypernetwork.
|
|
31
|
-
target_modules:
|
|
32
|
-
Iterable of module name prefixes inside `base_ae` whose parameters
|
|
33
|
-
will be generated by the hypernetwork (e.g. "encoder", "decoder").
|
|
34
|
-
hidden_dim:
|
|
35
|
-
Hidden dimension for the MLP hypernetwork.
|
|
36
|
-
"""
|
|
37
|
-
super().__init__()
|
|
38
|
-
|
|
39
|
-
self.base_ae = base_ae
|
|
40
|
-
self.target_modules = tuple(target_modules)
|
|
41
|
-
self.hidden_dim = hidden_dim
|
|
42
|
-
|
|
43
|
-
# (param_name, original_shape, flat_start, flat_end)
|
|
44
|
-
self._param_info: List[Tuple[str, torch.Size, int, int]] = []
|
|
45
|
-
|
|
46
|
-
# Parameters that remain shared (usual trainable params).
|
|
47
|
-
self._shared_param_dict: Dict[str, torch.Tensor] = {}
|
|
48
|
-
|
|
49
|
-
# Total number of scalar parameters generated by the hypernetwork.
|
|
50
|
-
self.total_generated_params: int = 0
|
|
51
|
-
|
|
52
|
-
# Hypernetwork (built lazily in build()).
|
|
53
|
-
self.hypernet: Optional[nn.Module] = None
|
|
54
|
-
|
|
55
|
-
# Output type information inferred from a sample call in build().
|
|
56
|
-
self._output_type: Optional[type] = None
|
|
57
|
-
self._output_field_names: Optional[List[str]] = None
|
|
58
|
-
|
|
59
|
-
# Cached ModelOutput field metadata (set in build()).
|
|
60
|
-
self._output_fields: Optional[Tuple[Field, ...]] = None
|
|
61
|
-
|
|
62
|
-
# Cached vmap metadata / function (set in build()).
|
|
63
|
-
self._in_dims_params: Optional[Dict[str, Optional[int]]] = None
|
|
64
|
-
self._vmapped_call: Optional[Callable[..., Any]] = None
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
@torch.no_grad()
|
|
68
|
-
def build(self, input_sample: torch.Tensor) -> None:
|
|
69
|
-
"""
|
|
70
|
-
Build the hypernetwork and prepare parameter bookkeeping.
|
|
71
|
-
|
|
72
|
-
This method:
|
|
73
|
-
- builds the underlying `base_ae`
|
|
74
|
-
- infers the output `ModelOutput` type and field names
|
|
75
|
-
- splits `base_ae` parameters into generated vs shared sets
|
|
76
|
-
- constructs the MLP hypernetwork that outputs all generated parameters
|
|
77
|
-
Must be called once before using `forward` (enforced by `BuildGuardMixin`).
|
|
78
|
-
"""
|
|
79
|
-
# Ensure the base autoencoder is built (and warmed up via its own build).
|
|
80
|
-
self.base_ae.build(input_sample)
|
|
81
|
-
|
|
82
|
-
# Inspect a sample output to record the output structure.
|
|
83
|
-
sample_out = self.base_ae(input_sample)
|
|
84
|
-
self._output_type = type(sample_out)
|
|
85
|
-
self._output_fields = fields(sample_out)
|
|
86
|
-
self._output_field_names = [f.name for f in self._output_fields]
|
|
87
|
-
|
|
88
|
-
# Reset metadata containers.
|
|
89
|
-
self._param_info.clear()
|
|
90
|
-
|
|
91
|
-
# Walk over all base_ae parameters and mark them as generated or shared.
|
|
92
|
-
flat_offset = 0
|
|
93
|
-
shared_param_names = []
|
|
94
|
-
for name, p in self.base_ae.named_parameters():
|
|
95
|
-
if any(name.startswith(m + ".") for m in self.target_modules):
|
|
96
|
-
# Parameters inside target modules are generated by the hypernet.
|
|
97
|
-
numel = p.numel()
|
|
98
|
-
self._param_info.append(
|
|
99
|
-
(name, p.shape, flat_offset, flat_offset + numel)
|
|
100
|
-
)
|
|
101
|
-
flat_offset += numel
|
|
102
|
-
# Generated parameters are "owned" by the hypernet, so we freeze them.
|
|
103
|
-
p.requires_grad_(False)
|
|
104
|
-
else:
|
|
105
|
-
# Remaining parameters are shared and stay trainable as usual.
|
|
106
|
-
shared_param_names.append(name)
|
|
107
|
-
p.requires_grad_(True)
|
|
108
|
-
|
|
109
|
-
self.total_generated_params = flat_offset
|
|
110
|
-
if self.total_generated_params == 0:
|
|
111
|
-
raise ValueError(
|
|
112
|
-
f"Zero parameters defined by the hypernetwork. "
|
|
113
|
-
f"Check target_modules: {self.target_modules}."
|
|
114
|
-
)
|
|
115
|
-
|
|
116
|
-
# Cache a dict of the shared parameters for quick access in forward().
|
|
117
|
-
self._shared_param_dict = {
|
|
118
|
-
name: p
|
|
119
|
-
for name, p in self.base_ae.named_parameters()
|
|
120
|
-
if name in shared_param_names
|
|
121
|
-
}
|
|
122
|
-
|
|
123
|
-
# Hypernetwork input dimension is the flattened per-sample input.
|
|
124
|
-
in_dims = input_sample[0].numel()
|
|
125
|
-
self.hypernet = nn.Sequential(
|
|
126
|
-
nn.Flatten(1),
|
|
127
|
-
nn.Linear(in_dims, self.hidden_dim),
|
|
128
|
-
nn.ReLU(),
|
|
129
|
-
nn.Linear(self.hidden_dim, self.hidden_dim),
|
|
130
|
-
nn.ReLU(),
|
|
131
|
-
nn.Linear(self.hidden_dim, self.total_generated_params),
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
# Precompute in_dims mapping for vmap: generated params vary on dim 0, shared are None.
|
|
135
|
-
# Use the same key ordering as `params` in forward: shared first, then generated
|
|
136
|
-
generated_params_names = [name for (name, _, _, _) in self._param_info]
|
|
137
|
-
|
|
138
|
-
all_param_names = list(self._shared_param_dict.keys()) + generated_params_names
|
|
139
|
-
|
|
140
|
-
self._in_dims_params = {
|
|
141
|
-
name: (0 if name in generated_params_names else None)
|
|
142
|
-
for name in all_param_names
|
|
143
|
-
}
|
|
144
|
-
|
|
145
|
-
# Prebuild the vmapped single-sample call.
|
|
146
|
-
self._vmapped_call = torch.func.vmap(
|
|
147
|
-
self._call_single,
|
|
148
|
-
in_dims=(self._in_dims_params, 0),
|
|
149
|
-
out_dims=0,
|
|
150
|
-
randomness="different",
|
|
151
|
-
)
|
|
152
|
-
|
|
153
|
-
self._built = True
|
|
154
|
-
|
|
155
|
-
def _generated_params(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
156
|
-
"""
|
|
157
|
-
Run the hypernetwork and reshape its output into per-parameter tensors.
|
|
158
|
-
|
|
159
|
-
Args:
|
|
160
|
-
x:
|
|
161
|
-
Input batch of shape (B, ...) used to condition the hypernetwork.
|
|
162
|
-
|
|
163
|
-
Returns:
|
|
164
|
-
A dictionary mapping parameter names (inside `target_modules`) to
|
|
165
|
-
tensors of shape (B, *original_shape), i.e. one generated parameter
|
|
166
|
-
tensor per sample in the batch.
|
|
167
|
-
"""
|
|
168
|
-
B = x.size(0)
|
|
169
|
-
flat = self.hypernet(x) # type: ignore -- (B, total_generated_params)
|
|
170
|
-
|
|
171
|
-
gen: Dict[str, torch.Tensor] = {}
|
|
172
|
-
for name, shape, start, end in self._param_info:
|
|
173
|
-
slice_ = flat[:, start:end]
|
|
174
|
-
gen[name] = slice_.view(B, *shape)
|
|
175
|
-
return gen
|
|
176
|
-
|
|
177
|
-
def _call_single(self, params_i: Dict[str, torch.Tensor], x_i: torch.Tensor, **kwargs: Any):
|
|
178
|
-
"""
|
|
179
|
-
Apply the base autoencoder to a single sample with a single param set.
|
|
180
|
-
|
|
181
|
-
Args:
|
|
182
|
-
params_i:
|
|
183
|
-
Parameter dict for this particular sample.
|
|
184
|
-
x_i:
|
|
185
|
-
Input tensor of shape (...,) for this sample.
|
|
186
|
-
|
|
187
|
-
Returns:
|
|
188
|
-
A tuple of output tensors corresponding to the `ModelOutput` fields,
|
|
189
|
-
with the leading batch dimension (size 1) removed.
|
|
190
|
-
"""
|
|
191
|
-
# print("encoder.1.weight in _call_single:", params_i["encoder.1.weight"].shape)
|
|
192
|
-
# print("encoder.1.bias in _call_single:", params_i["encoder.1.bias"].shape)
|
|
193
|
-
# print("decoder.0.weight in _call_single:", params_i["decoder.0.weight"].shape)
|
|
194
|
-
# print("decoder.0.bias in _call_single:", params_i["decoder.0.bias"].shape)
|
|
195
|
-
out = torch.func.functional_call(
|
|
196
|
-
self.base_ae,
|
|
197
|
-
params_i,
|
|
198
|
-
(x_i.unsqueeze(0),),
|
|
199
|
-
kwargs=kwargs,
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
# Construct a new output of the same type, but with squeezed tensors.
|
|
203
|
-
squeezed_kwargs: Dict[str, Any] = {}
|
|
204
|
-
for f in self._output_fields: # type: ignore
|
|
205
|
-
v = getattr(out, f.name)
|
|
206
|
-
if torch.is_tensor(v):
|
|
207
|
-
squeezed_kwargs[f.name] = v.squeeze(0)
|
|
208
|
-
else:
|
|
209
|
-
squeezed_kwargs[f.name] = v
|
|
210
|
-
|
|
211
|
-
out_squeezed = self._output_type(**squeezed_kwargs) # type: ignore
|
|
212
|
-
|
|
213
|
-
# Convert the structured output into a tuple of tensors (for vmap).
|
|
214
|
-
tensors_tuple = tuple(
|
|
215
|
-
getattr(out_squeezed, name) for name in self._output_field_names # type: ignore
|
|
216
|
-
)
|
|
217
|
-
return tensors_tuple
|
|
218
|
-
|
|
219
|
-
def forward(self, x: torch.Tensor, **kwargs: Any):
|
|
220
|
-
"""
|
|
221
|
-
Forward pass with per-sample generated parameters.
|
|
222
|
-
|
|
223
|
-
For each sample in the batch:
|
|
224
|
-
- generate a distinct set of parameters for `target_modules`
|
|
225
|
-
- combine them with the shared parameters
|
|
226
|
-
- call the underlying `base_ae` via `torch.func.functional_call`
|
|
227
|
-
The underlying `base_ae` is evaluated in a batched, vectorized way using
|
|
228
|
-
`torch.func.vmap`.
|
|
229
|
-
|
|
230
|
-
Args:
|
|
231
|
-
x:
|
|
232
|
-
Input batch of shape (B, ...).
|
|
233
|
-
**kwargs:
|
|
234
|
-
Additional keyword arguments forwarded to `base_ae.forward`.
|
|
235
|
-
|
|
236
|
-
Returns:
|
|
237
|
-
A `ModelOutput` instance of the same type/structure as produced by
|
|
238
|
-
`base_ae`, but with each field batched over the leading dimension.
|
|
239
|
-
"""
|
|
240
|
-
# Shared parameters are the same for all samples.
|
|
241
|
-
shared = self._shared_param_dict
|
|
242
|
-
# Generated parameters are per-sample (B, *shape)
|
|
243
|
-
generated = self._generated_params(x)
|
|
244
|
-
|
|
245
|
-
# Combined view of all parameters (names mapped to tensors).
|
|
246
|
-
params: Dict[str, torch.Tensor] = {}
|
|
247
|
-
params.update(shared)
|
|
248
|
-
params.update(generated)
|
|
249
|
-
|
|
250
|
-
# Vectorized application over batch of (params, x) using prebuilt vmap.
|
|
251
|
-
batched_tensors_tuple = self._vmapped_call(params, x, **kwargs) # type: ignore
|
|
252
|
-
|
|
253
|
-
# Reconstruct a batched `ModelOutput` of the same type as the base AE.
|
|
254
|
-
batched_kwargs = {
|
|
255
|
-
name: tensor
|
|
256
|
-
for name, tensor in zip(self._output_field_names, batched_tensors_tuple) # type: ignore
|
|
257
|
-
}
|
|
258
|
-
batched_out = self._output_type(**batched_kwargs) # type: ignore
|
|
259
|
-
|
|
260
|
-
return batched_out
|
|
261
|
-
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|