qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,40 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from ..approximate_qep import ApproximateQEP
4
+
5
+
6
+ class BayesianQEPLVM(ApproximateQEP):
7
+ """
8
+ The Q-Exponential Process Latent Variable Model (QEPLVM) class for unsupervised learning.
9
+ The class supports
10
+
11
+ 1. Point estimates for latent X when prior_x = None
12
+ 2. MAP Inference for X when prior_x is not None and inference == 'map'
13
+ 3. Q-Exponential variational distribution q(X) when prior_x is not None and inference == 'variational'
14
+
15
+ .. seealso::
16
+ The `GPLVM tutorial
17
+ <examples/04_Variational_and_Approximate_GPs/Gaussian_Process_Latent_Variable_Models_with_Stochastic_Variational_Inference.ipynb>`_
18
+ for use instructions.
19
+
20
+ :param X: An instance of a sub-class of the LatentVariable class. One of,
21
+ :class:`~gpytorch.models.qeplvm.PointLatentVariable`, :class:`~gpytorch.models.qeplvm.MAPLatentVariable`, or
22
+ :class:`~gpytorch.models.qeplvm.VariationalLatentVariable`, to facilitate inference with 1, 2, or 3 respectively.
23
+ :type X: ~gpytorch.models.LatentVariable
24
+ :param ~gpytorch.variational._VariationalStrategy variational_strategy: The strategy that determines
25
+ how the model marginalizes over the variational distribution (over inducing points)
26
+ to produce the approximate posterior distribution (over data)
27
+ """
28
+
29
+ def __init__(self, X, variational_strategy):
30
+ super().__init__(variational_strategy)
31
+
32
+ # Assigning Latent Variable
33
+ self.X = X
34
+
35
+ def forward(self):
36
+ raise NotImplementedError
37
+
38
+ def sample_latent_variable(self):
39
+ sample = self.X()
40
+ return sample
@@ -0,0 +1,102 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import torch
4
+
5
+ from ...module import Module
6
+
7
+
8
+ class LatentVariable(Module):
9
+ """
10
+ This super class is used to describe the type of inference
11
+ used for the latent variable :math:`\\mathbf X` in QEPLVM models.
12
+
13
+ :param int n: Size of the latent space.
14
+ :param int latent_dim: Dimensionality of latent space.
15
+ """
16
+
17
+ def __init__(self, n, dim):
18
+ super().__init__()
19
+ self.n = n
20
+ self.latent_dim = dim
21
+
22
+ def forward(self, x):
23
+ raise NotImplementedError
24
+
25
+
26
+ class PointLatentVariable(LatentVariable):
27
+ """
28
+ This class is used for QEPLVM models to recover a MLE estimate of
29
+ the latent variable :math:`\\mathbf X`.
30
+
31
+ :param int n: Size of the latent space.
32
+ :param int latent_dim: Dimensionality of latent space.
33
+ :param torch.Tensor X_init: initialization for the point estimate of :math:`\\mathbf X`
34
+ """
35
+
36
+ def __init__(self, n, latent_dim, X_init):
37
+ super().__init__(n, latent_dim)
38
+ self.register_parameter("X", X_init)
39
+
40
+ def forward(self):
41
+ return self.X
42
+
43
+
44
+ class MAPLatentVariable(LatentVariable):
45
+ """
46
+ This class is used for QEPLVM models to recover a MAP estimate of
47
+ the latent variable :math:`\\mathbf X`, based on some supplied prior.
48
+
49
+ :param int n: Size of the latent space.
50
+ :param int latent_dim: Dimensionality of latent space.
51
+ :param torch.Tensor X_init: initialization for the point estimate of :math:`\\mathbf X`
52
+ :param ~gpytorch.priors.Prior prior_x: prior for :math:`\\mathbf X`
53
+ """
54
+
55
+ def __init__(self, n, latent_dim, X_init, prior_x):
56
+ super().__init__(n, latent_dim)
57
+ self.prior_x = prior_x
58
+ self.register_parameter("X", X_init)
59
+ self.register_prior("prior_x", prior_x, "X")
60
+
61
+ def forward(self):
62
+ return self.X
63
+
64
+
65
+ class VariationalLatentVariable(LatentVariable):
66
+ """
67
+ This class is used for QEPLVM models to recover a variational approximation of
68
+ the latent variable :math:`\\mathbf X`. The variational approximation will be
69
+ an isotropic Q-Exponential distribution.
70
+
71
+ :param int n: Size of the latent space.
72
+ :param int data_dim: Dimensionality of the :math:`\\mathbf Y` values.
73
+ :param int latent_dim: Dimensionality of latent space.
74
+ :param torch.Tensor X_init: initialization for the point estimate of :math:`\\mathbf X`
75
+ :param ~gpytorch.priors.Prior prior_x: prior for :math:`\\mathbf X`
76
+ """
77
+
78
+ def __init__(self, n, data_dim, latent_dim, X_init, prior_x, **kwargs):
79
+ super().__init__(n, latent_dim)
80
+
81
+ self.data_dim = data_dim
82
+ self.prior_x = prior_x
83
+ # G: there might be some issues here if someone calls .cuda() on their BayesianQEPLVM
84
+ # after initializing on the CPU
85
+
86
+ # Local variational params per latent point with dimensionality latent_dim
87
+ self.q_mu = torch.nn.Parameter(X_init)
88
+ self.q_log_sigma = torch.nn.Parameter(torch.randn(n, latent_dim))
89
+ # This will add the KL divergence KL(q(X) || p(X)) to the loss
90
+ self.register_added_loss_term("x_kl")
91
+
92
+ self.power = kwargs.pop('power', getattr(self.prior_x, 'power', torch.tensor(2.0)))
93
+
94
+ def forward(self):
95
+ from ...distributions import QExponential
96
+ from ...mlls import KLQExponentialAddedLossTerm
97
+
98
+ # Variational distribution over the latent variable q(x)
99
+ q_x = QExponential(self.q_mu, torch.nn.functional.softplus(self.q_log_sigma), power=self.power)
100
+ x_kl = KLQExponentialAddedLossTerm(q_x, self.prior_x, self.n, self.data_dim)
101
+ self.update_added_loss_term("x_kl", x_kl) # Update the KL term
102
+ return q_x.rsample()
qpytorch/module.py ADDED
@@ -0,0 +1,30 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from gpytorch.module import Module as GModule
4
+
5
+
6
+ class Module(GModule):
7
+ def named_hyperparameters(self):
8
+ from .variational._variational_distribution import _VariationalDistribution
9
+
10
+ for module_prefix, module in self.named_modules():
11
+ if not isinstance(module, _VariationalDistribution):
12
+ for elem in module.named_parameters(prefix=module_prefix, recurse=False):
13
+ yield elem
14
+
15
+ def named_variational_parameters(self):
16
+ from .variational._variational_distribution import _VariationalDistribution
17
+
18
+ for module_prefix, module in self.named_modules():
19
+ if isinstance(module, _VariationalDistribution):
20
+ for elem in module.named_parameters(prefix=module_prefix, recurse=False):
21
+ yield elem
22
+
23
+ def update_added_loss_term(self, name, added_loss_term):
24
+ from .mlls import AddedLossTerm
25
+
26
+ if not isinstance(added_loss_term, AddedLossTerm):
27
+ raise RuntimeError("added_loss_term must be a AddedLossTerm")
28
+ if name not in self._added_loss_terms.keys():
29
+ raise RuntimeError("added_loss_term {} not registered".format(name))
30
+ self._added_loss_terms[name] = added_loss_term
@@ -0,0 +1,5 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from gpytorch.optim.ngd import NGD
4
+
5
+ __all__ = ["NGD"]
@@ -0,0 +1,42 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from gpytorch.priors.horseshoe_prior import HorseshoePrior
4
+ from gpytorch.priors.lkj_prior import LKJCholeskyFactorPrior, LKJCovariancePrior, LKJPrior
5
+ from gpytorch.priors.prior import Prior
6
+ from gpytorch.priors.smoothed_box_prior import SmoothedBoxPrior
7
+ from gpytorch.priors.torch_priors import (
8
+ GammaPrior,
9
+ HalfCauchyPrior,
10
+ HalfNormalPrior,
11
+ LogNormalPrior,
12
+ MultivariateNormalPrior,
13
+ NormalPrior,
14
+ UniformPrior,
15
+ )
16
+ from .qep_priors import (
17
+ MultivariateQExponentialPrior,
18
+ QExponentialPrior,
19
+ )
20
+
21
+ # from .wishart_prior import InverseWishartPrior, WishartPrior
22
+
23
+
24
+ __all__ = [
25
+ "Prior",
26
+ "GammaPrior",
27
+ "HalfCauchyPrior",
28
+ "HalfNormalPrior",
29
+ "HorseshoePrior",
30
+ "LKJPrior",
31
+ "LKJCholeskyFactorPrior",
32
+ "LKJCovariancePrior",
33
+ "LogNormalPrior",
34
+ "MultivariateNormalPrior",
35
+ "MultivariateQExponentialPrior",
36
+ "NormalPrior",
37
+ "QExponentialPrior",
38
+ "SmoothedBoxPrior",
39
+ "UniformPrior",
40
+ # "InverseWishartPrior",
41
+ # "WishartPrior",
42
+ ]
@@ -0,0 +1,81 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import torch
4
+ from torch.nn import Module as TModule
5
+ from linear_operator import to_linear_operator
6
+
7
+ from ..distributions import QExponential, MultivariateQExponential
8
+ from gpytorch.priors.prior import Prior
9
+ from gpytorch.priors.utils import _bufferize_attributes, _del_attributes
10
+
11
+ QEP_LAZY_PROPERTIES = ("covariance_matrix",)
12
+
13
+
14
+ class QExponentialPrior(Prior, QExponential):
15
+ """
16
+ QExponential Prior
17
+
18
+ pdf(x) = q/2 * (2 * pi * sigma^2)^-0.5 * |(x - mu)/sigma|^(q/2-1) * exp(-0.5*|(x - mu)/sigma|^q)
19
+
20
+ where mu is the mean and sigma^2 is the variance.
21
+ """
22
+
23
+ def __init__(self, loc, scale, power=torch.tensor(1.0), validate_args=False, transform=None):
24
+ TModule.__init__(self)
25
+ QExponential.__init__(self, loc=loc, scale=scale, power=power, validate_args=validate_args)
26
+ _bufferize_attributes(self, ("loc", "scale"))
27
+ self._transform = transform
28
+
29
+ def expand(self, batch_shape):
30
+ batch_shape = torch.Size(batch_shape)
31
+ return QExponentialPrior(self.loc.expand(batch_shape), self.scale.expand(batch_shape), self.power)
32
+
33
+
34
+ class MultivariateQExponentialPrior(Prior, MultivariateQExponential):
35
+ """Multivariate Q-Exponential prior
36
+
37
+ pdf(x) = q/2 * det(2 * pi * Sigma)^-0.5 * r^((q/2-1)*d/2) * exp(-0.5 * r^(q/2)), r = (x - mu)' Sigma^-1 (x - mu)
38
+
39
+ where mu is the mean and Sigma > 0 is the covariance matrix.
40
+ """
41
+
42
+ def __init__(
43
+ self, mean, covariance_matrix, power=torch.tensor(1.0), validate_args=False, transform=None
44
+ ):
45
+ TModule.__init__(self)
46
+ MultivariateQExponential.__init__(
47
+ self,
48
+ mean=mean,
49
+ covariance_matrix=covariance_matrix,
50
+ power=power,
51
+ validate_args=validate_args,
52
+ )
53
+ _bufferize_attributes(self, ("loc",))
54
+ self._transform = transform
55
+
56
+ def cuda(self, device=None):
57
+ """Applies module-level cuda() call and resets all lazy properties"""
58
+ module = self._apply(lambda t: t.cuda(device))
59
+ _del_attributes(module, QEP_LAZY_PROPERTIES)
60
+ return module
61
+
62
+ def cpu(self):
63
+ """Applies module-level cpu() call and resets all lazy properties"""
64
+ module = self._apply(lambda t: t.cpu())
65
+ _del_attributes(module, QEP_LAZY_PROPERTIES)
66
+ return module
67
+
68
+ @property
69
+ def lazy_covariance_matrix(self):
70
+ if self.islazy:
71
+ return self._covar
72
+ else:
73
+ return to_linear_operator(super().covariance_matrix)
74
+
75
+ def expand(self, batch_shape):
76
+ batch_shape = torch.Size(batch_shape)
77
+ cov_shape = batch_shape + self.event_shape
78
+ new_loc = self.loc.expand(batch_shape)
79
+ new_covar = self._covar.expand(cov_shape)
80
+
81
+ return MultivariateQExponentialPrior(mean=new_loc, covariance_matrix=new_covar, power=self.power)
@@ -0,0 +1,22 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from gpytorch.test.base_test_case import BaseTestCase
4
+ from gpytorch.test.base_keops_test_case import BaseKeOpsTestCase
5
+ from gpytorch.test.base_kernel_test_case import BaseKernelTestCase
6
+ from .base_likelihood_test_case import BaseLikelihoodTestCase
7
+ from gpytorch.test.base_mean_test_case import BaseMeanTestCase
8
+ from .model_test_case import BaseModelTestCase, VariationalModelTestCase
9
+ from gpytorch.test import utils
10
+ from .variational_test_case import VariationalTestCase
11
+
12
+ __all__ = [
13
+ "BaseKeOpsTestCase",
14
+ "BaseKernelTestCase",
15
+ "BaseLikelihoodTestCase",
16
+ "BaseMeanTestCase",
17
+ "BaseModelTestCase",
18
+ "BaseTestCase"
19
+ "utils",
20
+ "VariationalModelTestCase",
21
+ "VariationalTestCase",
22
+ ]
@@ -0,0 +1,106 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from abc import abstractmethod
4
+
5
+ import torch
6
+ from torch.distributions import Distribution
7
+
8
+ import qpytorch
9
+ from ..distributions import MultivariateNormal, MultivariateQExponential
10
+ from ..likelihoods import Likelihood
11
+
12
+ from gpytorch.test.base_test_case import BaseTestCase
13
+
14
+
15
+ class BaseLikelihoodTestCase(BaseTestCase):
16
+ @abstractmethod
17
+ def create_likelihood(self, **kwargs):
18
+ raise NotImplementedError()
19
+
20
+ def _create_conditional_input(self, batch_shape=torch.Size()):
21
+ return torch.randn(*batch_shape, 5)
22
+
23
+ def _create_marginal_input(self, batch_shape=torch.Size()):
24
+ mat = torch.randn(*batch_shape, 5, 5)
25
+ eye = torch.diag_embed(torch.ones(*batch_shape, 5))
26
+ if 'Gaussian' in self.__class__.__name__:
27
+ return MultivariateNormal(torch.randn(*batch_shape, 5), mat @ mat.transpose(-1, -2) + eye)
28
+ elif 'QExponential' in self.__class__.__name__:
29
+ return MultivariateQExponential(torch.randn(*batch_shape, 5), mat @ mat.transpose(-1, -2) + eye, torch.tensor(getattr(self, '_power', 2.0)))
30
+
31
+ def _create_targets(self, batch_shape=torch.Size()):
32
+ return torch.randn(*batch_shape, 5)
33
+
34
+ def _test_conditional(self, batch_shape):
35
+ likelihood = self.create_likelihood()
36
+ likelihood.max_plate_nesting += len(batch_shape)
37
+ input = self._create_conditional_input(batch_shape)
38
+ output = likelihood(input)
39
+
40
+ self.assertTrue(isinstance(output, Distribution))
41
+ self.assertEqual(output.sample().shape, input.shape)
42
+
43
+ def _test_log_marginal(self, batch_shape):
44
+ likelihood = self.create_likelihood()
45
+ likelihood.max_plate_nesting += len(batch_shape)
46
+ input = self._create_marginal_input(batch_shape)
47
+ target = self._create_targets(batch_shape)
48
+ with qpytorch.settings.num_likelihood_samples(512):
49
+ output = likelihood.log_marginal(target, input)
50
+
51
+ self.assertTrue(torch.is_tensor(output))
52
+ self.assertEqual(output.shape, batch_shape + torch.Size([5]))
53
+ with qpytorch.settings.num_likelihood_samples(512):
54
+ default_log_prob = Likelihood.log_marginal(likelihood, target, input)
55
+ self.assertAllClose(output, default_log_prob, rtol=0.25)
56
+
57
+ def _test_log_prob(self, batch_shape):
58
+ likelihood = self.create_likelihood()
59
+ likelihood.max_plate_nesting += len(batch_shape)
60
+ input = self._create_marginal_input(batch_shape)
61
+ target = self._create_targets(batch_shape)
62
+ with qpytorch.settings.num_likelihood_samples(512):
63
+ output = likelihood.expected_log_prob(target, input)
64
+
65
+ self.assertTrue(torch.is_tensor(output))
66
+ self.assertEqual(output.shape, batch_shape + torch.Size([5]))
67
+ with qpytorch.settings.num_likelihood_samples(512):
68
+ default_log_prob = Likelihood.expected_log_prob(likelihood, target, input)
69
+ self.assertAllClose(output, default_log_prob, rtol=0.25)
70
+
71
+ def _test_marginal(self, batch_shape):
72
+ likelihood = self.create_likelihood()
73
+ likelihood.max_plate_nesting += len(batch_shape)
74
+ input = self._create_marginal_input(batch_shape)
75
+ output = likelihood(input)
76
+
77
+ self.assertTrue(isinstance(output, Distribution))
78
+ self.assertEqual(output.sample().shape[-len(input.sample().shape) :], input.sample().shape)
79
+
80
+ # Compare against default implementation
81
+ with qpytorch.settings.num_likelihood_samples(30000):
82
+ default = Likelihood.marginal(likelihood, input)
83
+ # print(output.mean, default.mean)
84
+ default_mean = default.mean
85
+ actual_mean = output.mean
86
+ if default_mean.dim() > actual_mean.dim():
87
+ default_mean = default_mean.mean(0)
88
+ self.assertAllClose(default_mean, actual_mean, rtol=0.25, atol=0.25)
89
+
90
+ def test_nonbatch(self):
91
+ self._test_conditional(batch_shape=torch.Size([]))
92
+ self._test_log_marginal(batch_shape=torch.Size([]))
93
+ self._test_log_prob(batch_shape=torch.Size([]))
94
+ self._test_marginal(batch_shape=torch.Size([]))
95
+
96
+ def test_batch(self):
97
+ self._test_conditional(batch_shape=torch.Size([3]))
98
+ self._test_log_marginal(batch_shape=torch.Size([3]))
99
+ self._test_log_prob(batch_shape=torch.Size([3]))
100
+ self._test_marginal(batch_shape=torch.Size([3]))
101
+
102
+ def test_multi_batch(self):
103
+ self._test_conditional(batch_shape=torch.Size([2, 3]))
104
+ self._test_log_marginal(batch_shape=torch.Size([2, 3]))
105
+ self._test_log_prob(batch_shape=torch.Size([2, 3]))
106
+ self._test_marginal(batch_shape=torch.Size([2, 3]))
@@ -0,0 +1,150 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from abc import abstractmethod
4
+
5
+ import torch
6
+
7
+ import qpytorch
8
+
9
+
10
+ class BaseModelTestCase(object):
11
+ @abstractmethod
12
+ def create_model(self, train_x, train_y, likelihood):
13
+ raise NotImplementedError()
14
+
15
+ @abstractmethod
16
+ def create_test_data(self):
17
+ raise NotImplementedError()
18
+
19
+ @abstractmethod
20
+ def create_likelihood_and_labels(self):
21
+ raise NotImplementedError()
22
+
23
+ @abstractmethod
24
+ def create_batch_test_data(self, batch_shape=torch.Size([3])):
25
+ raise NotImplementedError()
26
+
27
+ @abstractmethod
28
+ def create_batch_likelihood_and_labels(self, batch_shape=torch.Size([3])):
29
+ raise NotImplementedError()
30
+
31
+ def test_forward_train(self):
32
+ data = self.create_test_data()
33
+ likelihood, labels = self.create_likelihood_and_labels()
34
+ model = self.create_model(data, labels, likelihood)
35
+ model.train()
36
+ output = model(data)
37
+ self.assertTrue(output.lazy_covariance_matrix.dim() == 2)
38
+ self.assertTrue(output.lazy_covariance_matrix.size(-1) == data.size(-2))
39
+ self.assertTrue(output.lazy_covariance_matrix.size(-2) == data.size(-2))
40
+
41
+ def test_batch_forward_train(self):
42
+ batch_data = self.create_batch_test_data()
43
+ likelihood, labels = self.create_batch_likelihood_and_labels()
44
+ model = self.create_model(batch_data, labels, likelihood)
45
+ model.train()
46
+ output = model(batch_data)
47
+ self.assertTrue(output.lazy_covariance_matrix.dim() == 3)
48
+ self.assertTrue(output.lazy_covariance_matrix.size(-1) == batch_data.size(-2))
49
+ self.assertTrue(output.lazy_covariance_matrix.size(-2) == batch_data.size(-2))
50
+
51
+ def test_multi_batch_forward_train(self):
52
+ batch_data = self.create_batch_test_data(batch_shape=torch.Size([2, 3]))
53
+ likelihood, labels = self.create_batch_likelihood_and_labels(batch_shape=torch.Size([2, 3]))
54
+ model = self.create_model(batch_data, labels, likelihood)
55
+ model.train()
56
+ output = model(batch_data)
57
+ self.assertTrue(output.lazy_covariance_matrix.dim() == 4)
58
+ self.assertTrue(output.lazy_covariance_matrix.size(-1) == batch_data.size(-2))
59
+ self.assertTrue(output.lazy_covariance_matrix.size(-2) == batch_data.size(-2))
60
+
61
+ def test_forward_eval(self):
62
+ data = self.create_test_data()
63
+ likelihood, labels = self.create_likelihood_and_labels()
64
+ model = self.create_model(data, labels, likelihood)
65
+ model.eval()
66
+ output = model(data)
67
+ self.assertTrue(output.lazy_covariance_matrix.dim() == 2)
68
+ self.assertTrue(output.lazy_covariance_matrix.size(-1) == data.size(-2))
69
+ self.assertTrue(output.lazy_covariance_matrix.size(-2) == data.size(-2))
70
+
71
+ def test_batch_forward_eval(self):
72
+ batch_data = self.create_batch_test_data()
73
+ likelihood, labels = self.create_batch_likelihood_and_labels()
74
+ model = self.create_model(batch_data, labels, likelihood)
75
+ model.eval()
76
+ output = model(batch_data)
77
+ self.assertTrue(output.lazy_covariance_matrix.dim() == 3)
78
+ self.assertTrue(output.lazy_covariance_matrix.size(-1) == batch_data.size(-2))
79
+ self.assertTrue(output.lazy_covariance_matrix.size(-2) == batch_data.size(-2))
80
+
81
+ def test_multi_batch_forward_eval(self):
82
+ batch_data = self.create_batch_test_data(batch_shape=torch.Size([2, 3]))
83
+ likelihood, labels = self.create_batch_likelihood_and_labels(batch_shape=torch.Size([2, 3]))
84
+ model = self.create_model(batch_data, labels, likelihood)
85
+ model.eval()
86
+ output = model(batch_data)
87
+ self.assertTrue(output.lazy_covariance_matrix.dim() == 4)
88
+ self.assertTrue(output.lazy_covariance_matrix.size(-1) == batch_data.size(-2))
89
+ self.assertTrue(output.lazy_covariance_matrix.size(-2) == batch_data.size(-2))
90
+
91
+
92
+ class VariationalModelTestCase(BaseModelTestCase):
93
+ def test_backward_train(self):
94
+ data = self.create_test_data()
95
+ likelihood, labels = self.create_likelihood_and_labels()
96
+ model = self.create_model(data, labels, likelihood)
97
+ mll = qpytorch.mlls.VariationalELBO(likelihood, model, num_data=labels.size(-1))
98
+ model.train()
99
+ likelihood.train()
100
+
101
+ # We'll just do one step of gradient descent to mix up the params a bit
102
+ optimizer = torch.optim.Adam([{"params": model.parameters()}, {"params": likelihood.parameters()}], lr=0.01)
103
+
104
+ output = model(data)
105
+ loss = -mll(output, labels)
106
+ loss.backward()
107
+ optimizer.step()
108
+ optimizer.zero_grad()
109
+ output = model(data)
110
+ loss = -mll(output, labels)
111
+ loss.backward()
112
+
113
+ for _, param in model.named_parameters():
114
+ self.assertTrue(param.grad is not None)
115
+ self.assertGreater(param.grad.norm().item(), 0)
116
+ for _, param in likelihood.named_parameters():
117
+ self.assertTrue(param.grad is not None)
118
+ self.assertGreater(param.grad.norm().item(), 0)
119
+ optimizer.step()
120
+
121
+ def test_batch_backward_train(self, batch_shape=torch.Size([3])):
122
+ data = self.create_batch_test_data(batch_shape)
123
+ likelihood, labels = self.create_batch_likelihood_and_labels(batch_shape)
124
+ model = self.create_model(data, labels, likelihood)
125
+ mll = qpytorch.mlls.VariationalELBO(likelihood, model, num_data=labels.size(-1))
126
+ model.train()
127
+ likelihood.train()
128
+
129
+ # We'll just do one step of gradient descent to mix up the params a bit
130
+ optimizer = torch.optim.Adam([{"params": model.parameters()}, {"params": likelihood.parameters()}], lr=0.01)
131
+
132
+ output = model(data)
133
+ loss = -mll(output, labels).sum()
134
+ loss.backward()
135
+ optimizer.step()
136
+ optimizer.zero_grad()
137
+ output = model(data)
138
+ loss = -mll(output, labels).sum()
139
+ loss.backward()
140
+
141
+ for _, param in model.named_parameters():
142
+ self.assertTrue(param.grad is not None)
143
+ self.assertGreater(param.grad.norm().item(), 0)
144
+ for _, param in likelihood.named_parameters():
145
+ self.assertTrue(param.grad is not None)
146
+ self.assertGreater(param.grad.norm().item(), 0)
147
+ optimizer.step()
148
+
149
+ def test_multi_batch_backward_train(self, batch_shape=torch.Size([2, 3])):
150
+ return self.test_batch_backward_train(batch_shape=batch_shape)