qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from typing import Union
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from ..distributions import MultivariateNormal, MultivariateQExponential
|
|
7
|
+
from ..likelihoods import GaussianLikelihood, MultitaskGaussianLikelihood, QExponentialLikelihood, MultitaskQExponentialLikelihood
|
|
8
|
+
from gpytorch.mlls.added_loss_term import AddedLossTerm
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class InducingPointKernelAddedLossTerm(AddedLossTerm):
|
|
12
|
+
r"""
|
|
13
|
+
An added loss term that computes the additional "regularization trace term" of the SGPR (SQEPR) objective function.
|
|
14
|
+
|
|
15
|
+
.. math::
|
|
16
|
+
Gaussian: -\frac{1}{2 \sigma^2} \text{Tr} \left( \mathbf K_{\mathbf X \mathbf X} - \mathbf Q \right)
|
|
17
|
+
.. math::
|
|
18
|
+
Q-Exponential: \frac{d}{2}\left(-\log\sigma^2 +\left(\frac{q}{2}-1\right)\log r\right) -\frac{1}{2}r^{\frac{q}{2}},
|
|
19
|
+
r = \frac{1}{\sigma^2}\text{Tr} \left( \mathbf K_{\mathbf X \mathbf X} - \mathbf Q \right)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
where :math:`\mathbf Q = \mathbf K_{\mathbf X \mathbf Z} \mathbf K_{\mathbf Z \mathbf Z}^{-1}
|
|
23
|
+
\mathbf K_{\mathbf Z \mathbf X}` is the Nystrom approximation of :math:`\mathbf K_{\mathbf X \mathbf X}`
|
|
24
|
+
given by inducing points :math:`\mathbf Z`, :math:`\sigma^2` is the observational noise
|
|
25
|
+
of the Gaussian (Q-Exponential) likelihood, and :math:`d` is the dimensions being summed over,
|
|
26
|
+
i.e. :math:`N` for likelihood or :math:`ND` for multi-task likelihood.
|
|
27
|
+
|
|
28
|
+
See `Titsias, 2009`_, Eq. 9 for more more information.
|
|
29
|
+
|
|
30
|
+
:param prior_dist: A multivariate normal :math:`\mathcal N ( \mathbf 0, \mathbf K_{\mathbf X \mathbf X} )`
|
|
31
|
+
or q-exponential :math:`\mathcal Q ( \mathbf 0, \mathbf K_{\mathbf X \mathbf X} )`
|
|
32
|
+
with covariance matrix :math:`\mathbf K_{\mathbf X \mathbf X}`.
|
|
33
|
+
:param variational_dist: A multivariate normal :math:`\mathcal N ( \mathbf 0, \mathbf Q)`
|
|
34
|
+
or or q-exponential :math:`\mathcal Q ( \mathbf 0, \mathbf Q)`
|
|
35
|
+
with covariance matrix :math:`\mathbf Q = \mathbf K_{\mathbf X \mathbf Z}
|
|
36
|
+
\mathbf K_{\mathbf Z \mathbf Z}^{-1} \mathbf K_{\mathbf Z \mathbf X}`.
|
|
37
|
+
:param likelihood: The Gaussian (QExponential) likelihood with observational noise :math:`\sigma^2`.
|
|
38
|
+
|
|
39
|
+
.. _Titsias, 2009:
|
|
40
|
+
https://proceedings.mlr.press/v9/titsias10a/titsias10a.pdf
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self, prior_dist: Union[MultivariateNormal, MultivariateQExponential],
|
|
45
|
+
variational_dist: Union[MultivariateNormal, MultivariateQExponential],
|
|
46
|
+
likelihood: Union[GaussianLikelihood, QExponentialLikelihood],
|
|
47
|
+
):
|
|
48
|
+
self.prior_dist = prior_dist
|
|
49
|
+
self.variational_dist = variational_dist
|
|
50
|
+
self.likelihood = likelihood
|
|
51
|
+
|
|
52
|
+
def loss(self, *params) -> torch.Tensor:
|
|
53
|
+
prior_covar = self.prior_dist.lazy_covariance_matrix
|
|
54
|
+
variational_covar = self.variational_dist.lazy_covariance_matrix
|
|
55
|
+
diag = prior_covar.diagonal(dim1=-1, dim2=-2) - variational_covar.diagonal(dim1=-1, dim2=-2)
|
|
56
|
+
shape = prior_covar.shape[:-1]
|
|
57
|
+
if isinstance(self.likelihood, (MultitaskGaussianLikelihood, MultitaskQExponentialLikelihood)):
|
|
58
|
+
shape = torch.Size([*shape, 1])
|
|
59
|
+
diag = diag.unsqueeze(-1)
|
|
60
|
+
noise_diag = self.likelihood._shaped_noise_covar(shape, *params).diagonal(dim1=-1, dim2=-2)
|
|
61
|
+
if isinstance(self.likelihood, (MultitaskGaussianLikelihood, MultitaskQExponentialLikelihood)):
|
|
62
|
+
noise_diag = noise_diag.reshape(*shape[:-1], -1)
|
|
63
|
+
r = (diag / noise_diag).sum(dim=[-1, -2])
|
|
64
|
+
else:
|
|
65
|
+
r = (diag / noise_diag).sum(dim=-1)
|
|
66
|
+
res = -0.5 * r**(self.likelihood.power/2. if hasattr(self.likelihood,'power') else 1)
|
|
67
|
+
if 'QExponential' in self.likelihood.__class__.__name__:
|
|
68
|
+
if self.likelihood.power!=2: res += -0.5 * noise_diag.log().sum() + torch.tensor(noise_diag.shape[-2:]).prod()/2. * (self.likelihood.power/2.-1) * r.log()
|
|
69
|
+
return res
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from torch.distributions import kl_divergence
|
|
4
|
+
|
|
5
|
+
from ..distributions import MultivariateQExponential
|
|
6
|
+
from gpytorch.mlls.added_loss_term import AddedLossTerm
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class KLQExponentialAddedLossTerm(AddedLossTerm):
|
|
10
|
+
r"""
|
|
11
|
+
This class is used by variational QEPLVM models.
|
|
12
|
+
It adds the KL divergence between two multivariate Q-Exponential distributions:
|
|
13
|
+
scaled by the size of the data and the number of output dimensions.
|
|
14
|
+
|
|
15
|
+
.. math::
|
|
16
|
+
|
|
17
|
+
D_\text{KL} \left( q(\mathbf x) \Vert p(\mathbf x) \right)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
:param q_x: The QEP distribution :math:`q(\mathbf x)`.
|
|
21
|
+
:param p_x: The QEP distribution :math:`p(\mathbf x)`.
|
|
22
|
+
:param n: Size of the latent space.
|
|
23
|
+
:param data_dim: Dimensionality of the :math:`\mathbf Y` values.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, q_x: MultivariateQExponential, p_x: MultivariateQExponential, n: int, data_dim: int):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.q_x = q_x
|
|
29
|
+
self.p_x = p_x
|
|
30
|
+
self.n = n
|
|
31
|
+
self.data_dim = data_dim
|
|
32
|
+
|
|
33
|
+
def loss(self):
|
|
34
|
+
kl_per_latent_dim = kl_divergence(self.q_x, self.p_x).sum(axis=0) # vector of size latent_dim
|
|
35
|
+
kl_per_point = kl_per_latent_dim.sum() / self.n # scalar
|
|
36
|
+
# inside the forward method of variational ELBO,
|
|
37
|
+
# the added loss terms are expanded (using add_) to take the same
|
|
38
|
+
# shape as the log_lik term (has shape data_dim)
|
|
39
|
+
# so they can be added together. Hence, we divide by data_dim to avoid
|
|
40
|
+
# overcounting the kl term
|
|
41
|
+
return kl_per_point / self.data_dim
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
import math
|
|
3
|
+
from typing import Union
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from ..distributions import MultivariateNormal, MultivariateQExponential
|
|
8
|
+
from .exact_marginal_log_likelihood import ExactMarginalLogLikelihood
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LeaveOneOutPseudoLikelihood(ExactMarginalLogLikelihood):
|
|
12
|
+
r"""
|
|
13
|
+
The leave one out cross-validation (LOO-CV) likelihood from RW 5.4.2 for an exact Gaussian (Q-Exponential) process with a
|
|
14
|
+
Gaussian (Q-Exponential) likelihood. This offers an alternative to the exact marginal log likelihood where we
|
|
15
|
+
instead maximize the sum of the leave one out log probabilities :math:`\log p(y_i | X, y_{-i}, \theta)`.
|
|
16
|
+
|
|
17
|
+
Naively, this will be O(n^4) with Cholesky as we need to compute `n` Cholesky factorizations. Fortunately,
|
|
18
|
+
given the Cholesky factorization of the full kernel matrix (without any points removed), we can compute
|
|
19
|
+
both the mean and variance of each removed point via a bordered system formulation making the total
|
|
20
|
+
complexity O(n^3).
|
|
21
|
+
|
|
22
|
+
The LOO-CV approach can be more robust against model mis-specification as it gives an estimate for the
|
|
23
|
+
(log) predictive probability, whether or not the assumptions of the model is fulfilled.
|
|
24
|
+
|
|
25
|
+
.. note::
|
|
26
|
+
This module will not work with anything other than a :obj:`~qpytorch.likelihoods.GaussianLikelihood`
|
|
27
|
+
(:obj:`~qpytorch.likelihoods.QExponentialLikelihood`) and a :obj:`~gpytorch.models.ExactGP` (:obj:`~qpytorch.models.ExactQEP`).
|
|
28
|
+
It also cannot be used in conjunction with stochastic optimization.
|
|
29
|
+
|
|
30
|
+
:param ~qpytorch.likelihoods.GaussianLikelihood (~qpytorch.likelihoods.QExponentialLikelihood) likelihood: The Gaussian (Q-Exponential) likelihood for the model
|
|
31
|
+
:param ~gpytorch.models.ExactGP (~qpytorch.models.ExactQEP) model: The exact GP (QEP) model
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
>>> # model is a qpytorch.models.ExactGP or qpytorch.models.ExactQEP
|
|
35
|
+
>>> # likelihood is a qpytorch.likelihoods.Likelihood
|
|
36
|
+
>>> loocv = qpytorch.mlls.LeaveOneOutPseudoLikelihood(likelihood, model)
|
|
37
|
+
>>>
|
|
38
|
+
>>> output = model(train_x)
|
|
39
|
+
>>> loss = -loocv(output, train_y)
|
|
40
|
+
>>> loss.backward()
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, likelihood, model):
|
|
44
|
+
super().__init__(likelihood=likelihood, model=model)
|
|
45
|
+
self.likelihood = likelihood
|
|
46
|
+
self.model = model
|
|
47
|
+
|
|
48
|
+
def forward(self, function_dist: Union[MultivariateNormal, MultivariateQExponential], target: Tensor, *params) -> Tensor:
|
|
49
|
+
r"""
|
|
50
|
+
Computes the leave one out likelihood given :math:`p(\mathbf f)` and :math:`\mathbf y`
|
|
51
|
+
|
|
52
|
+
:param ~gpytorch.distributions.MultivariateNormal (~qpytorch.distributions.MultivariateQExponential)
|
|
53
|
+
output: the outputs of the latent function (the :obj:`~gpytorch.models.GP` or :obj:`~qpytorch.models.QEP`)
|
|
54
|
+
:param torch.Tensor target: :math:`\mathbf y` The target values
|
|
55
|
+
:param dict kwargs: Additional arguments to pass to the likelihood's forward function.
|
|
56
|
+
"""
|
|
57
|
+
output = self.likelihood(function_dist, *params)
|
|
58
|
+
m, L = output.mean, output.lazy_covariance_matrix.cholesky(upper=False)
|
|
59
|
+
m = m.reshape(*target.shape)
|
|
60
|
+
identity = torch.eye(*L.shape[-2:], dtype=m.dtype, device=m.device)
|
|
61
|
+
sigma2 = 1.0 / L._cholesky_solve(identity, upper=False).diagonal(dim1=-1, dim2=-2) # 1 / diag(inv(K))
|
|
62
|
+
mu = target - L._cholesky_solve((target - m).unsqueeze(-1), upper=False).squeeze(-1) * sigma2
|
|
63
|
+
term1 = -0.5 * sigma2.log()
|
|
64
|
+
power = getattr(self.likelihood, 'power', torch.tensor(2.0))
|
|
65
|
+
term2 = -0.5 * (target - mu).abs()**power / sigma2**(power/2.)
|
|
66
|
+
res = (term1 + term2).sum(dim=-1)
|
|
67
|
+
if power!=2: res += (power/2.-1) * ((target - mu).abs().log() + term1).sum(dim=-1)
|
|
68
|
+
|
|
69
|
+
res = self._add_other_terms(res, params)
|
|
70
|
+
|
|
71
|
+
# Scale by the amount of data we have and then add on the scaled constant
|
|
72
|
+
num_data = target.size(-1)
|
|
73
|
+
return res.div_(num_data) - 0.5 * math.log(2 * math.pi)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from ..models import GP, QEP
|
|
4
|
+
from ..module import Module
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MarginalLogLikelihood(Module):
|
|
8
|
+
r"""
|
|
9
|
+
These are modules to compute (or approximate/bound) the marginal log likelihood
|
|
10
|
+
(MLL) of the GP (QEP) model when applied to data. I.e., given a GP :math:`f \sim
|
|
11
|
+
\mathcal{GP}(\mu, K)` or QEP :math:`f \sim \mathcal{QEP}(\mu, K)`, and
|
|
12
|
+
data :math:`\mathbf X, \mathbf y`, these modules compute/approximate
|
|
13
|
+
|
|
14
|
+
.. math::
|
|
15
|
+
|
|
16
|
+
\begin{equation*}
|
|
17
|
+
\mathcal{L} = p_f(\mathbf y \! \mid \! \mathbf X)
|
|
18
|
+
= \int p \left( \mathbf y \! \mid \! f(\mathbf X) \right) \: p(f(\mathbf X) \! \mid \! \mathbf X) \: d f
|
|
19
|
+
\end{equation*}
|
|
20
|
+
|
|
21
|
+
This is computed exactly when the GP (QEP) inference is computed exactly (e.g. regression w/ a Gaussian (Q-Exponential) likelihood).
|
|
22
|
+
It is approximated/bounded for GP (QEP) models that use approximate inference.
|
|
23
|
+
|
|
24
|
+
These models are typically used as the "loss" functions for GP (QEP) models (though note that the output of
|
|
25
|
+
these functions must be negated for optimization).
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, likelihood, model):
|
|
29
|
+
super(MarginalLogLikelihood, self).__init__()
|
|
30
|
+
if not isinstance(model, (GP, QEP)):
|
|
31
|
+
raise RuntimeError(
|
|
32
|
+
"All MarginalLogLikelihood objects must be given a GP (QEP) object as a model. If you are "
|
|
33
|
+
"using a more complicated model involving a GP (QEP), pass the underlying GP (QEP) object as the "
|
|
34
|
+
"model, not a full PyTorch module."
|
|
35
|
+
)
|
|
36
|
+
self.likelihood = likelihood
|
|
37
|
+
self.model = model
|
|
38
|
+
|
|
39
|
+
def forward(self, output, target, **kwargs):
|
|
40
|
+
r"""
|
|
41
|
+
Computes the MLL given :math:`p(\mathbf f)` and `\mathbf y`
|
|
42
|
+
|
|
43
|
+
:param ~gpytorch.distributions.MultivariateNormal or ~qpytorch.distributions.MultivariateQExponential
|
|
44
|
+
output: the outputs of the latent function (the :obj:`~gpytorch.models.GP` or :obj:`~qpytorch.models.QEP`)
|
|
45
|
+
:param torch.Tensor target: :math:`\mathbf y` The target values
|
|
46
|
+
:param dict kwargs: Additional arguments to pass to the likelihood's forward function.
|
|
47
|
+
"""
|
|
48
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from ._approximate_mll import _ApproximateMarginalLogLikelihood
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class PredictiveLogLikelihood(_ApproximateMarginalLogLikelihood):
|
|
7
|
+
r"""
|
|
8
|
+
An alternative objective function for approximate GPs (QEPs), proposed in `Jankowiak et al., 2020`_.
|
|
9
|
+
It typically produces better predictive variances than the :obj:`qpytorch.mlls.VariationalELBO` objective.
|
|
10
|
+
|
|
11
|
+
.. math::
|
|
12
|
+
|
|
13
|
+
\begin{align*}
|
|
14
|
+
\mathcal{L}_\text{ELBO} &=
|
|
15
|
+
\mathbb{E}_{p_\text{data}( y, \mathbf x )} \left[
|
|
16
|
+
\log p( y \! \mid \! \mathbf x)
|
|
17
|
+
\right] - \beta \: \text{KL} \left[ q( \mathbf u) \Vert p( \mathbf u) \right]
|
|
18
|
+
\\
|
|
19
|
+
&\approx \sum_{i=1}^N \log \mathbb{E}_{q(\mathbf u)} \left[
|
|
20
|
+
\int p( y_i \! \mid \! f_i) p(f_i \! \mid \! \mathbf u, \mathbf x_i) \: d f_i
|
|
21
|
+
\right] - \beta \: \text{KL} \left[ q( \mathbf u) \Vert p( \mathbf u) \right]
|
|
22
|
+
\end{align*}
|
|
23
|
+
|
|
24
|
+
where :math:`N` is the total number of datapoints, :math:`q(\mathbf u)` is the variational distribution for
|
|
25
|
+
the inducing function values, and :math:`p(\mathbf u)` is the prior distribution for the inducing function
|
|
26
|
+
values.
|
|
27
|
+
|
|
28
|
+
:math:`\beta` is a scaling constant that reduces the regularization effect of the KL
|
|
29
|
+
divergence. Setting :math:`\beta=1` (default) results in an objective that can be motivated by a connection
|
|
30
|
+
to Stochastic Expectation Propagation (see `Jankowiak et al., 2020`_ for details).
|
|
31
|
+
|
|
32
|
+
.. note::
|
|
33
|
+
This objective is very similar to the variational ELBO.
|
|
34
|
+
The only difference is that the :math:`log` occurs *outside* the expectation :math:`\mathbb{E}_{q(\mathbf u)}`.
|
|
35
|
+
This difference results in very different predictive performance (see `Jankowiak et al., 2020`_).
|
|
36
|
+
|
|
37
|
+
:param ~qpytorch.likelihoods.Likelihood likelihood: The likelihood for the model
|
|
38
|
+
:param ~gpytorch.models.ApproximateGP (~qpytorch.models.ApproximateQEP) model: The approximate GP (QEP) model
|
|
39
|
+
:param int num_data: The total number of training data points (necessary for SGD)
|
|
40
|
+
:param float beta: (optional, default=1.) A multiplicative factor for the KL divergence term.
|
|
41
|
+
Setting it to anything less than 1 reduces the regularization effect of the model
|
|
42
|
+
(similarly to what was proposed in `the beta-VAE paper`_).
|
|
43
|
+
:param bool combine_terms: (default=True): Whether or not to sum the
|
|
44
|
+
expected NLL with the KL terms (default True)
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
>>> # model is a qpytorch.models.ApproximateGP or qpytorch.models.ApproximateQEP
|
|
48
|
+
>>> # likelihood is a qpytorch.likelihoods.Likelihood
|
|
49
|
+
>>> mll = qpytorch.mlls.PredictiveLogLikelihood(likelihood, model, num_data=100, beta=0.5)
|
|
50
|
+
>>>
|
|
51
|
+
>>> output = model(train_x)
|
|
52
|
+
>>> loss = -mll(output, train_y)
|
|
53
|
+
>>> loss.backward()
|
|
54
|
+
|
|
55
|
+
.. _Jankowiak et al., 2020:
|
|
56
|
+
https://arxiv.org/abs/1910.07123
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def _log_likelihood_term(self, approximate_dist_f, target, **kwargs):
|
|
60
|
+
return self.likelihood.log_marginal(target, approximate_dist_f, **kwargs).sum(-1)
|
|
61
|
+
|
|
62
|
+
def forward(self, approximate_dist_f, target, **kwargs):
|
|
63
|
+
r"""
|
|
64
|
+
Computes the predictive cross entropy given :math:`q(\mathbf f)` and :math:`\mathbf y`.
|
|
65
|
+
Calling this function will call the likelihood's
|
|
66
|
+
:meth:`~qpytorch.likelihoods.Likelihood.forward` function.
|
|
67
|
+
|
|
68
|
+
:param ~gpytorch.distributions.MultivariateNormal variational_dist_f: :math:`q(\mathbf f)`
|
|
69
|
+
the outputs of the latent function (the :obj:`gpytorch.models.ApproximateGP` or :obj:`qpytorch.models.ApproximateQEP`)
|
|
70
|
+
:param torch.Tensor target: :math:`\mathbf y` The target values
|
|
71
|
+
:param kwargs: Additional arguments passed to the
|
|
72
|
+
likelihood's :meth:`~qpytorch.likelihoods.Likelihood.forward` function.
|
|
73
|
+
:rtype: torch.Tensor
|
|
74
|
+
:return: Predictive log likelihood. Output shape corresponds to batch shape of the model/input data.
|
|
75
|
+
"""
|
|
76
|
+
return super().forward(approximate_dist_f, target, **kwargs)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
#! /usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from torch.nn import ModuleList
|
|
4
|
+
|
|
5
|
+
from . import ExactMarginalLogLikelihood, MarginalLogLikelihood
|
|
6
|
+
from gpytorch.utils.generic import length_safe_zip
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SumMarginalLogLikelihood(MarginalLogLikelihood):
|
|
10
|
+
"""Sum of marginal log likelihoods, to be used with Multi-Output models.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
likelihood: A MultiOutputLikelihood
|
|
14
|
+
model: A MultiOutputModel
|
|
15
|
+
mll_cls: The Marginal Log Likelihood class (default: ExactMarginalLogLikelihood)
|
|
16
|
+
|
|
17
|
+
In case the model outputs are independent/uncorrelated, this provides the MLL of the multi-output model.
|
|
18
|
+
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, likelihood, model, mll_cls=ExactMarginalLogLikelihood):
|
|
22
|
+
super().__init__(model.likelihood, model)
|
|
23
|
+
self.mlls = ModuleList([mll_cls(mdl.likelihood, mdl) for mdl in model.models])
|
|
24
|
+
|
|
25
|
+
def forward(self, outputs, targets, *params):
|
|
26
|
+
"""
|
|
27
|
+
Args:
|
|
28
|
+
outputs: (Iterable[MultivariateNormal/MultivariateQExponential]) - the outputs of the latent function
|
|
29
|
+
targets: (Iterable[Tensor]) - the target values
|
|
30
|
+
params: (Iterable[Iterable[Tensor]]) - the arguments to be passed through
|
|
31
|
+
(e.g. parameters in case of heteroskedastic likelihoods)
|
|
32
|
+
"""
|
|
33
|
+
if len(params) == 0:
|
|
34
|
+
sum_mll = sum(mll(output, target) for mll, output, target in length_safe_zip(self.mlls, outputs, targets))
|
|
35
|
+
else:
|
|
36
|
+
sum_mll = sum(
|
|
37
|
+
mll(output, target, *iparams)
|
|
38
|
+
for mll, output, target, iparams in length_safe_zip(self.mlls, outputs, targets, params)
|
|
39
|
+
)
|
|
40
|
+
return sum_mll.div_(len(self.mlls))
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from ._approximate_mll import _ApproximateMarginalLogLikelihood
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class VariationalELBO(_ApproximateMarginalLogLikelihood):
|
|
7
|
+
r"""
|
|
8
|
+
The variational evidence lower bound (ELBO). This is used to optimize
|
|
9
|
+
variational Gaussian (Q-Exponential) processes (with or without stochastic optimization).
|
|
10
|
+
|
|
11
|
+
.. math::
|
|
12
|
+
|
|
13
|
+
\begin{align*}
|
|
14
|
+
\mathcal{L}_\text{ELBO} &=
|
|
15
|
+
\mathbb{E}_{p_\text{data}( y, \mathbf x )} \left[
|
|
16
|
+
\mathbb{E}_{p(f \mid \mathbf u, \mathbf x) q(\mathbf u)} \left[ \log p( y \! \mid \! f) \right]
|
|
17
|
+
\right] - \beta \: \text{KL} \left[ q( \mathbf u) \Vert p( \mathbf u) \right]
|
|
18
|
+
\\
|
|
19
|
+
&\approx \sum_{i=1}^N \mathbb{E}_{q( f_i)} \left[
|
|
20
|
+
\log p( y_i \! \mid \! f_i) \right] - \beta \: \text{KL} \left[ q( \mathbf u) \Vert p( \mathbf u) \right]
|
|
21
|
+
\end{align*}
|
|
22
|
+
|
|
23
|
+
where :math:`N` is the number of datapoints, :math:`q(\mathbf u)` is the variational distribution for
|
|
24
|
+
the inducing function values, :math:`q(f_i)` is the marginal of
|
|
25
|
+
:math:`p(f_i \mid \mathbf u, \mathbf x_i) q(\mathbf u)`,
|
|
26
|
+
and :math:`p(\mathbf u)` is the prior distribution for the inducing function values.
|
|
27
|
+
|
|
28
|
+
:math:`\beta` is a scaling constant that reduces the regularization effect of the KL
|
|
29
|
+
divergence. Setting :math:`\beta=1` (default) results in the true variational ELBO.
|
|
30
|
+
|
|
31
|
+
For more information on this derivation, see `Scalable Variational Gaussian Process Classification`_
|
|
32
|
+
(Hensman et al., 2015).
|
|
33
|
+
|
|
34
|
+
:param ~qpytorch.likelihoods.Likelihood likelihood: The likelihood for the model
|
|
35
|
+
:param ~gpytorch.models.ApproximateGP (~qpytorch.models.ApproximateQEP) model: The approximate GP (QEP) model
|
|
36
|
+
:param int num_data: The total number of training data points (necessary for SGD)
|
|
37
|
+
:param float beta: (optional, default=1.) A multiplicative factor for the KL divergence term.
|
|
38
|
+
Setting it to 1 (default) recovers true variational inference
|
|
39
|
+
(as derived in `Scalable Variational Gaussian Process Classification`_).
|
|
40
|
+
Setting it to anything less than 1 reduces the regularization effect of the model
|
|
41
|
+
(similarly to what was proposed in `the beta-VAE paper`_).
|
|
42
|
+
:param bool combine_terms: (default=True): Whether or not to sum the
|
|
43
|
+
expected NLL with the KL terms (default True)
|
|
44
|
+
|
|
45
|
+
Example:
|
|
46
|
+
>>> # model is a qpytorch.models.ApproximateGP or qpytorch.models.ApproximateQEP
|
|
47
|
+
>>> # likelihood is a qpytorch.likelihoods.Likelihood
|
|
48
|
+
>>> mll = qpytorch.mlls.VariationalELBO(likelihood, model, num_data=100, beta=0.5)
|
|
49
|
+
>>>
|
|
50
|
+
>>> output = model(train_x)
|
|
51
|
+
>>> loss = -mll(output, train_y)
|
|
52
|
+
>>> loss.backward()
|
|
53
|
+
|
|
54
|
+
.. _Scalable Variational Gaussian Process Classification:
|
|
55
|
+
http://proceedings.mlr.press/v38/hensman15.pdf
|
|
56
|
+
.. _the beta-VAE paper:
|
|
57
|
+
https://openreview.net/pdf?id=Sy2fzU9gl
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def _log_likelihood_term(self, variational_dist_f, target, **kwargs):
|
|
61
|
+
return self.likelihood.expected_log_prob(target, variational_dist_f, **kwargs).sum(-1)
|
|
62
|
+
|
|
63
|
+
def forward(self, variational_dist_f, target, **kwargs):
|
|
64
|
+
r"""
|
|
65
|
+
Computes the Variational ELBO given :math:`q(\mathbf f)` and :math:`\mathbf y`.
|
|
66
|
+
Calling this function will call the likelihood's :meth:`~qpytorch.likelihoods.Likelihood.expected_log_prob`
|
|
67
|
+
function.
|
|
68
|
+
|
|
69
|
+
:param ~gpytorch.distributions.MultivariateNormal (~qpytorch.distributions.MultivariateQExponential) variational_dist_f: :math:`q(\mathbf f)`
|
|
70
|
+
the outputs of the latent function (the :obj:`gpytorch.models.ApproximateGP` or :obj:`qpytorch.models.ApproximateQEP`)
|
|
71
|
+
:param torch.Tensor target: :math:`\mathbf y` The target values
|
|
72
|
+
:param kwargs: Additional arguments passed to the
|
|
73
|
+
likelihood's :meth:`~qpytorch.likelihoods.Likelihood.expected_log_prob` function.
|
|
74
|
+
:rtype: torch.Tensor
|
|
75
|
+
:return: Variational ELBO. Output shape corresponds to batch shape of the model/input data.
|
|
76
|
+
"""
|
|
77
|
+
return super().forward(variational_dist_f, target, **kwargs)
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
from gpytorch.models import deep_gps, gplvm
|
|
6
|
+
from . import deep_qeps, exact_prediction_strategies, qeplvm, pyro
|
|
7
|
+
from gpytorch.models.approximate_gp import ApproximateGP
|
|
8
|
+
from gpytorch.models.exact_gp import ExactGP
|
|
9
|
+
from gpytorch.models.gp import GP
|
|
10
|
+
from .approximate_qep import ApproximateQEP
|
|
11
|
+
from .exact_qep import ExactQEP
|
|
12
|
+
from .qep import QEP
|
|
13
|
+
from .model_list import AbstractModelList, IndependentModelList, UncorrelatedModelList
|
|
14
|
+
from .pyro import PyroGP, PyroQEP
|
|
15
|
+
|
|
16
|
+
# Alternative name for ApproximateGP, ApproximateQEP
|
|
17
|
+
VariationalGP = ApproximateGP
|
|
18
|
+
VariationalQEP = ApproximateQEP
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Deprecated for 0.4 release
|
|
22
|
+
class AbstractVariationalGP(ApproximateGP):
|
|
23
|
+
# Remove after 1.0
|
|
24
|
+
def __init__(self, *args, **kwargs):
|
|
25
|
+
warnings.warn("AbstractVariationalGP has been renamed to ApproximateGP.", DeprecationWarning)
|
|
26
|
+
super().__init__(*args, **kwargs)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Deprecated for 0.4 release
|
|
30
|
+
class PyroVariationalGP(ApproximateGP):
|
|
31
|
+
# Remove after 1.0
|
|
32
|
+
def __init__(self, *args, **kwargs):
|
|
33
|
+
warnings.warn("PyroVariationalGP has been renamed to PyroGP.", DeprecationWarning)
|
|
34
|
+
super().__init__(*args, **kwargs)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# Deprecated for 0.4 release
|
|
38
|
+
class AbstractVariationalQEP(ApproximateQEP):
|
|
39
|
+
# Remove after 1.0
|
|
40
|
+
def __init__(self, *args, **kwargs):
|
|
41
|
+
warnings.warn("AbstractVariationalQEP has been renamed to ApproximateQEP.", DeprecationWarning)
|
|
42
|
+
super().__init__(*args, **kwargs)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# Deprecated for 0.4 release
|
|
46
|
+
class PyroVariationalQEP(ApproximateQEP):
|
|
47
|
+
# Remove after 1.0
|
|
48
|
+
def __init__(self, *args, **kwargs):
|
|
49
|
+
warnings.warn("PyroVariationalQEP has been renamed to PyroQEP.", DeprecationWarning)
|
|
50
|
+
super().__init__(*args, **kwargs)
|
|
51
|
+
|
|
52
|
+
__all__ = [
|
|
53
|
+
"AbstractModelList",
|
|
54
|
+
"ApproximateGP",
|
|
55
|
+
"ApproximateQEP",
|
|
56
|
+
"ExactGP",
|
|
57
|
+
"ExactQEP",
|
|
58
|
+
"GP",
|
|
59
|
+
"QEP",
|
|
60
|
+
"IndependentModelList",
|
|
61
|
+
"PyroGP",
|
|
62
|
+
"PyroQEP",
|
|
63
|
+
"UncorrelatedModelList",
|
|
64
|
+
"VariationalGP",
|
|
65
|
+
"VariationalQEP",
|
|
66
|
+
"deep_gps",
|
|
67
|
+
"deep_qeps",
|
|
68
|
+
"gplvm",
|
|
69
|
+
"qeplvm",
|
|
70
|
+
"exact_prediction_strategies",
|
|
71
|
+
"pyro",
|
|
72
|
+
]
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from ..distributions import MultivariateQExponential
|
|
8
|
+
from .exact_qep import ExactQEP
|
|
9
|
+
|
|
10
|
+
from .qep import QEP
|
|
11
|
+
from .pyro import _PyroMixin # This will only contain functions if Pyro is installed
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ApproximateQEP(QEP, _PyroMixin):
|
|
15
|
+
r"""
|
|
16
|
+
The base class for any Q-Exponential process latent function to be used in conjunction
|
|
17
|
+
with approximate inference (typically stochastic variational inference).
|
|
18
|
+
This base class can be used to implement most inducing point methods where the
|
|
19
|
+
variational parameters are learned directly.
|
|
20
|
+
|
|
21
|
+
:param ~qpytorch.variational._VariationalStrategy variational_strategy: The strategy that determines
|
|
22
|
+
how the model marginalizes over the variational distribution (over inducing points)
|
|
23
|
+
to produce the approximate posterior distribution (over data)
|
|
24
|
+
|
|
25
|
+
The :meth:`forward` function should describe how to compute the prior latent distribution
|
|
26
|
+
on a given input. Typically, this will involve a mean and kernel function.
|
|
27
|
+
The result must be a :obj:`~qpytorch.distributions.MultivariateQExponential`.
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
>>> class MyVariationalQEP(qpytorch.models.PyroQEP):
|
|
31
|
+
>>> def __init__(self, power=torch.tensor(1.0), variational_strategy):
|
|
32
|
+
>>> super().__init__(variational_strategy)
|
|
33
|
+
>>> self.mean_module = qpytorch.means.ZeroMean()
|
|
34
|
+
>>> self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())
|
|
35
|
+
>>> self.power = power
|
|
36
|
+
>>>
|
|
37
|
+
>>> def forward(self, x):
|
|
38
|
+
>>> mean = self.mean_module(x)
|
|
39
|
+
>>> covar = self.covar_module(x)
|
|
40
|
+
>>> return qpytorch.distributions.MultivariateQExponential(mean, covar, self.power)
|
|
41
|
+
>>>
|
|
42
|
+
>>> # variational_strategy = ...
|
|
43
|
+
>>> model = MyVariationalQEP(variational_strategy)
|
|
44
|
+
>>> likelihood = qpytorch.likelihoods.QExponentialLikelihood()
|
|
45
|
+
>>>
|
|
46
|
+
>>> # optimization loop for variational parameters...
|
|
47
|
+
>>>
|
|
48
|
+
>>> # test_x = ...;
|
|
49
|
+
>>> model(test_x) # Returns the approximate QEP latent function at test_x
|
|
50
|
+
>>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self, variational_strategy):
|
|
54
|
+
super().__init__()
|
|
55
|
+
|
|
56
|
+
self.variational_strategy = variational_strategy
|
|
57
|
+
|
|
58
|
+
def forward(self, x: Tensor):
|
|
59
|
+
raise NotImplementedError
|
|
60
|
+
|
|
61
|
+
def pyro_guide(self, input: Tensor, beta: float = 1.0, name_prefix: str = ""):
|
|
62
|
+
r"""
|
|
63
|
+
(For Pyro integration only). The component of a `pyro.guide` that
|
|
64
|
+
corresponds to drawing samples from the latent QEP function.
|
|
65
|
+
|
|
66
|
+
:param input: The inputs :math:`\mathbf X`.
|
|
67
|
+
:param beta: (default=1.) How much to scale the :math:`\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]`
|
|
68
|
+
term by.
|
|
69
|
+
:param name_prefix: (default="") A name prefix to prepend to pyro sample sites.
|
|
70
|
+
"""
|
|
71
|
+
return super().pyro_guide(input, beta=beta, name_prefix=name_prefix)
|
|
72
|
+
|
|
73
|
+
def pyro_model(self, input: Tensor, beta: float = 1.0, name_prefix: str = "") -> Tensor:
|
|
74
|
+
r"""
|
|
75
|
+
(For Pyro integration only). The component of a `pyro.model` that
|
|
76
|
+
corresponds to drawing samples from the latent QEP function.
|
|
77
|
+
|
|
78
|
+
:param input: The inputs :math:`\mathbf X`.
|
|
79
|
+
:param beta: (default=1.) How much to scale the :math:`\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]`
|
|
80
|
+
term by.
|
|
81
|
+
:param name_prefix: (default="") A name prefix to prepend to pyro sample sites.
|
|
82
|
+
:return: samples from :math:`q(\mathbf f)`
|
|
83
|
+
"""
|
|
84
|
+
return super().pyro_model(input, beta=beta, name_prefix=name_prefix)
|
|
85
|
+
|
|
86
|
+
def get_fantasy_model(self, inputs: Tensor, targets: Tensor, **kwargs: Any) -> ExactQEP:
|
|
87
|
+
r"""
|
|
88
|
+
Returns a new QEP model that incorporates the specified inputs and targets as new training data using
|
|
89
|
+
online variational conditioning (OVC).
|
|
90
|
+
|
|
91
|
+
This function first casts the inducing points and variational parameters into pseudo-points before
|
|
92
|
+
returning an equivalent ExactQEP model with a specialized likelihood.
|
|
93
|
+
|
|
94
|
+
.. note::
|
|
95
|
+
If `targets` is a batch (e.g. `b x m`), then the QEP returned from this method will be a batch mode QEP.
|
|
96
|
+
If `inputs` is of the same (or lesser) dimension as `targets`, then it is assumed that the fantasy points
|
|
97
|
+
are the same for each target batch.
|
|
98
|
+
|
|
99
|
+
:param inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
|
|
100
|
+
observations.
|
|
101
|
+
:param targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
|
|
102
|
+
:return: An `ExactQEP` model with `n + m` training examples, where the `m` fantasy examples have been added
|
|
103
|
+
and all test-time caches have been updated.
|
|
104
|
+
|
|
105
|
+
Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
|
|
106
|
+
Maddox, Stanton, Wilson, NeurIPS, '21
|
|
107
|
+
https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html
|
|
108
|
+
|
|
109
|
+
"""
|
|
110
|
+
return self.variational_strategy.get_fantasy_model(inputs=inputs, targets=targets, **kwargs)
|
|
111
|
+
|
|
112
|
+
def __call__(self, inputs: Optional[Tensor], prior: bool = False, **kwargs) -> MultivariateQExponential:
|
|
113
|
+
if inputs is not None and inputs.dim() == 1:
|
|
114
|
+
inputs = inputs.unsqueeze(-1)
|
|
115
|
+
return self.variational_strategy(inputs, prior=prior, **kwargs)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
from .deep_qep import DeepQEP, DeepQEPLayer, DeepLikelihood
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# Deprecated for 1.0 release
|
|
9
|
+
class AbstractDeepQEP(DeepQEP):
|
|
10
|
+
def __init__(self, *args, **kwargs):
|
|
11
|
+
warnings.warn("AbstractDeepQEP has been renamed to DeepQEP.", DeprecationWarning)
|
|
12
|
+
super().__init__(*args, **kwargs)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Deprecated for 1.0 release
|
|
16
|
+
class AbstractDeepQEPLayer(DeepQEPLayer):
|
|
17
|
+
def __init__(self, *args, **kwargs):
|
|
18
|
+
warnings.warn("AbstractDeepQEPLayer has been renamed to DeepQEPLayer.", DeprecationWarning)
|
|
19
|
+
super().__init__(*args, **kwargs)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
__all__ = ["DeepQEPLayer", "DeepQEP", "AbstractDeepQEPLayer", "AbstractDeepQEP", "DeepLikelihood"]
|