qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from linear_operator.operators import LinearOperator, MatmulLinearOperator, SumLinearOperator
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from torch.distributions.kl import kl_divergence
|
|
9
|
+
|
|
10
|
+
from ..distributions import Delta, MultivariateNormal, MultivariateQExponential
|
|
11
|
+
from ..models import ApproximateGP, ApproximateQEP
|
|
12
|
+
from gpytorch.utils.errors import CachingError
|
|
13
|
+
from gpytorch.utils.memoize import pop_from_cache_ignore_args
|
|
14
|
+
from ._variational_distribution import _VariationalDistribution
|
|
15
|
+
from .delta_variational_distribution import DeltaVariationalDistribution
|
|
16
|
+
from .variational_strategy import VariationalStrategy
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BatchDecoupledVariationalStrategy(VariationalStrategy):
|
|
20
|
+
r"""
|
|
21
|
+
A VariationalStrategy that uses a different set of inducing points for the
|
|
22
|
+
variational mean and variational covar. It follows the "decoupled" model
|
|
23
|
+
proposed by `Jankowiak et al. (2020)`_ (which is roughly based on the strategies
|
|
24
|
+
proposed by `Cheng et al. (2017)`_.
|
|
25
|
+
|
|
26
|
+
Let :math:`\mathbf Z_\mu` and :math:`\mathbf Z_\sigma` be the mean/variance
|
|
27
|
+
inducing points. The variational distribution for an input :math:`\mathbf
|
|
28
|
+
x` is given by:
|
|
29
|
+
|
|
30
|
+
.. math::
|
|
31
|
+
|
|
32
|
+
\begin{align*}
|
|
33
|
+
\mathbb E[ f(\mathbf x) ] &= \mathbf k_{\mathbf Z_\mu \mathbf x}^\top
|
|
34
|
+
\mathbf K_{\mathbf Z_\mu \mathbf Z_\mu}^{-1} \mathbf m
|
|
35
|
+
\\
|
|
36
|
+
\text{Var}[ f(\mathbf x) ] &= k_{\mathbf x \mathbf x} - \mathbf k_{\mathbf Z_\sigma \mathbf x}^\top
|
|
37
|
+
\mathbf K_{\mathbf Z_\sigma \mathbf Z_\sigma}^{-1}
|
|
38
|
+
\left( \mathbf K_{\mathbf Z_\sigma} - \mathbf S \right)
|
|
39
|
+
\mathbf K_{\mathbf Z_\sigma \mathbf Z_\sigma}^{-1}
|
|
40
|
+
\mathbf k_{\mathbf Z_\sigma \mathbf x}
|
|
41
|
+
\end{align*}
|
|
42
|
+
|
|
43
|
+
where :math:`\mathbf m` and :math:`\mathbf S` are the variational parameters.
|
|
44
|
+
Unlike the original proposed implementation, :math:`\mathbf Z_\mu` and :math:`\mathbf Z_\sigma`
|
|
45
|
+
have **the same number of inducing points**, which allows us to perform batched operations.
|
|
46
|
+
|
|
47
|
+
Additionally, you can use a different set of kernel hyperparameters for the mean and the variance function.
|
|
48
|
+
We recommend using this feature only with the :obj:`~qpytorch.mlls.PredictiveLogLikelihood` objective function
|
|
49
|
+
as proposed in "Parametric Gaussian Process Regressors" (`Jankowiak et al. (2020)`_).
|
|
50
|
+
Use the mean_var_batch_dim to indicate which batch dimension corresponds to the different mean/var
|
|
51
|
+
kernels.
|
|
52
|
+
|
|
53
|
+
.. note::
|
|
54
|
+
We recommend using the "right-most" batch dimension (i.e. ``mean_var_batch_dim=-1``) for the dimension
|
|
55
|
+
that corresponds to the different mean/variance kernel parameters.
|
|
56
|
+
|
|
57
|
+
Assuming you want `b1` many independent GPs (uncorrelated QEPs), the :obj:`~qpytorch.variational._VariationalDistribution`
|
|
58
|
+
objects should have a batch shape of `b1`, and the mean/covar modules
|
|
59
|
+
of the GP (QEP) should have a batch shape of `b1 x 2`.
|
|
60
|
+
(The 2 corresponds to the mean/variance hyperparameters.)
|
|
61
|
+
|
|
62
|
+
.. seealso::
|
|
63
|
+
:obj:`~qpytorch.variational.OrthogonallyDecoupledVariationalStrategy` (a variant proposed by
|
|
64
|
+
`Salimbeni et al. (2018)`_ that uses orthogonal projections.)
|
|
65
|
+
|
|
66
|
+
:param model: Model this strategy is applied to.
|
|
67
|
+
Typically passed in when the VariationalStrategy is created in the
|
|
68
|
+
__init__ method of the user defined model.
|
|
69
|
+
It should contain power if Q-Exponential distribution is involved in.
|
|
70
|
+
:param inducing_points: Tensor containing a set of inducing
|
|
71
|
+
points to use for variational inference.
|
|
72
|
+
:param variational_distribution: A
|
|
73
|
+
VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
|
|
74
|
+
:param learn_inducing_locations: (Default True): Whether or not
|
|
75
|
+
the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they
|
|
76
|
+
parameters of the model).
|
|
77
|
+
:param mean_var_batch_dim: (Default `None`):
|
|
78
|
+
Set this parameter (ideally to `-1`) to indicate which dimension corresponds to different
|
|
79
|
+
kernel hyperparameters for the mean/variance functions.
|
|
80
|
+
:param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability
|
|
81
|
+
|
|
82
|
+
.. _Cheng et al. (2017):
|
|
83
|
+
https://arxiv.org/abs/1711.10127
|
|
84
|
+
|
|
85
|
+
.. _Salimbeni et al. (2018):
|
|
86
|
+
https://arxiv.org/abs/1809.08820
|
|
87
|
+
|
|
88
|
+
.. _Jankowiak et al. (2020):
|
|
89
|
+
https://arxiv.org/abs/1910.07123
|
|
90
|
+
|
|
91
|
+
Example (**different** hypers for mean/variance):
|
|
92
|
+
>>> class MeanFieldDecoupledModel(qpytorch.models.ApproximateGP or qpytorch.models.ApproximateQEP):
|
|
93
|
+
>>> '''
|
|
94
|
+
>>> A batch of 3 independent MeanFieldDecoupled PPGPR (PPQEP) models.
|
|
95
|
+
>>> '''
|
|
96
|
+
>>> def __init__(self, inducing_points):
|
|
97
|
+
>>> # The variational parameters have a batch_shape of [3]
|
|
98
|
+
>>> variational_distribution = qpytorch.variational.MeanFieldVariationalDistribution(
|
|
99
|
+
>>> inducing_points.size(-1), batch_shape=torch.Size([3]),
|
|
100
|
+
>>> )
|
|
101
|
+
>>> variational_strategy = qpytorch.variational.BatchDecoupledVariationalStrategy(
|
|
102
|
+
>>> self, inducing_points, variational_distribution, learn_inducing_locations=True,
|
|
103
|
+
>>> mean_var_batch_dim=-1
|
|
104
|
+
>>> )
|
|
105
|
+
>>>
|
|
106
|
+
>>> # The mean/covar modules have a batch_shape of [3, 2]
|
|
107
|
+
>>> # where the last batch dim corresponds to the mean & variance hyperparameters
|
|
108
|
+
>>> super().__init__(variational_strategy)
|
|
109
|
+
>>> self.mean_module = qpytorch.means.ConstantMean(batch_shape=torch.Size([3, 2]))
|
|
110
|
+
>>> self.covar_module = qpytorch.kernels.ScaleKernel(
|
|
111
|
+
>>> qpytorch.kernels.RBFKernel(batch_shape=torch.Size([3, 2])),
|
|
112
|
+
>>> batch_shape=torch.Size([3, 2]),
|
|
113
|
+
>>> )
|
|
114
|
+
|
|
115
|
+
Example (**shared** hypers for mean/variance):
|
|
116
|
+
>>> class MeanFieldDecoupledModel(qpytorch.models.ApproximateGP or qpytorch.models.ApproximateQEP):
|
|
117
|
+
>>> '''
|
|
118
|
+
>>> A batch of 3 independent MeanFieldDecoupled PPGPR (PPQEP) models.
|
|
119
|
+
>>> '''
|
|
120
|
+
>>> def __init__(self, inducing_points):
|
|
121
|
+
>>> # The variational parameters have a batch_shape of [3]
|
|
122
|
+
>>> variational_distribution = qpytorch.variational.MeanFieldVariationalDistribution(
|
|
123
|
+
>>> inducing_points.size(-1), batch_shape=torch.Size([3]),
|
|
124
|
+
>>> )
|
|
125
|
+
>>> variational_strategy = qpytorch.variational.BatchDecoupledVariationalStrategy(
|
|
126
|
+
>>> self, inducing_points, variational_distribution, learn_inducing_locations=True,
|
|
127
|
+
>>> )
|
|
128
|
+
>>>
|
|
129
|
+
>>> # The mean/covar modules have a batch_shape of [3, 1]
|
|
130
|
+
>>> # where the singleton dimension corresponds to the shared mean/variance hyperparameters
|
|
131
|
+
>>> super().__init__(variational_strategy)
|
|
132
|
+
>>> self.mean_module = qpytorch.means.ConstantMean(batch_shape=torch.Size([3, 1]))
|
|
133
|
+
>>> self.covar_module = qpytorch.kernels.ScaleKernel(
|
|
134
|
+
>>> qpytorch.kernels.RBFKernel(batch_shape=torch.Size([3, 1])),
|
|
135
|
+
>>> batch_shape=torch.Size([3, 1]),
|
|
136
|
+
>>> )
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
model: Union[ApproximateGP, ApproximateQEP],
|
|
142
|
+
inducing_points: Tensor,
|
|
143
|
+
variational_distribution: _VariationalDistribution,
|
|
144
|
+
learn_inducing_locations: bool = True,
|
|
145
|
+
mean_var_batch_dim: Optional[int] = None,
|
|
146
|
+
jitter_val: Optional[float] = None,
|
|
147
|
+
):
|
|
148
|
+
if isinstance(variational_distribution, DeltaVariationalDistribution):
|
|
149
|
+
raise NotImplementedError(
|
|
150
|
+
"BatchDecoupledVariationalStrategy does not work with DeltaVariationalDistribution"
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if mean_var_batch_dim is not None and mean_var_batch_dim >= 0:
|
|
154
|
+
raise ValueError(f"mean_var_batch_dim should be negative indexed, got {mean_var_batch_dim}")
|
|
155
|
+
self.mean_var_batch_dim = mean_var_batch_dim
|
|
156
|
+
|
|
157
|
+
# Maybe unsqueeze inducing points
|
|
158
|
+
if inducing_points.dim() == 1:
|
|
159
|
+
inducing_points = inducing_points.unsqueeze(-1)
|
|
160
|
+
|
|
161
|
+
# We're going to create two set of inducing points
|
|
162
|
+
# One set for computing the mean, one set for computing the variance
|
|
163
|
+
if self.mean_var_batch_dim is not None:
|
|
164
|
+
inducing_points = torch.stack([inducing_points, inducing_points], dim=(self.mean_var_batch_dim - 2))
|
|
165
|
+
else:
|
|
166
|
+
inducing_points = torch.stack([inducing_points, inducing_points], dim=-3)
|
|
167
|
+
super().__init__(
|
|
168
|
+
model, inducing_points, variational_distribution, learn_inducing_locations, jitter_val=jitter_val
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def _expand_inputs(self, x: Tensor, inducing_points: Tensor) -> Tuple[Tensor, Tensor]:
|
|
172
|
+
# If we haven't explicitly marked a dimension as batch, add the corresponding batch dimension to the input
|
|
173
|
+
if self.mean_var_batch_dim is None:
|
|
174
|
+
x = x.unsqueeze(-3)
|
|
175
|
+
else:
|
|
176
|
+
x = x.unsqueeze(self.mean_var_batch_dim - 2)
|
|
177
|
+
return super()._expand_inputs(x, inducing_points)
|
|
178
|
+
|
|
179
|
+
def forward(
|
|
180
|
+
self,
|
|
181
|
+
x: Tensor,
|
|
182
|
+
inducing_points: Tensor,
|
|
183
|
+
inducing_values: Tensor,
|
|
184
|
+
variational_inducing_covar: Optional[LinearOperator] = None,
|
|
185
|
+
**kwargs,
|
|
186
|
+
) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
187
|
+
# We'll compute the covariance, and cross-covariance terms for both the
|
|
188
|
+
# pred-mean and pred-covar, using their different inducing points (and maybe kernel hypers)
|
|
189
|
+
|
|
190
|
+
mean_var_batch_dim = self.mean_var_batch_dim or -1
|
|
191
|
+
|
|
192
|
+
# Compute full prior distribution
|
|
193
|
+
full_inputs = torch.cat([inducing_points, x], dim=-2)
|
|
194
|
+
full_output = self.model.forward(full_inputs, **kwargs)
|
|
195
|
+
full_covar = full_output.lazy_covariance_matrix
|
|
196
|
+
|
|
197
|
+
# Covariance terms
|
|
198
|
+
num_induc = inducing_points.size(-2)
|
|
199
|
+
test_mean = full_output.mean[..., num_induc:]
|
|
200
|
+
induc_induc_covar = full_covar[..., :num_induc, :num_induc].add_jitter(self.jitter_val)
|
|
201
|
+
induc_data_covar = full_covar[..., :num_induc, num_induc:].to_dense()
|
|
202
|
+
data_data_covar = full_covar[..., num_induc:, num_induc:]
|
|
203
|
+
|
|
204
|
+
# Compute interpolation terms
|
|
205
|
+
# K_ZZ^{-1/2} K_ZX
|
|
206
|
+
# K_ZZ^{-1/2} \mu_Z
|
|
207
|
+
L = self._cholesky_factor(induc_induc_covar)
|
|
208
|
+
if L.shape != induc_induc_covar.shape:
|
|
209
|
+
# Aggressive caching can cause nasty shape incompatibilities when evaluating with different batch shapes
|
|
210
|
+
# TODO: Use a hook to make this cleaner
|
|
211
|
+
try:
|
|
212
|
+
pop_from_cache_ignore_args(self, "cholesky_factor")
|
|
213
|
+
except CachingError:
|
|
214
|
+
pass
|
|
215
|
+
L = self._cholesky_factor(induc_induc_covar)
|
|
216
|
+
interp_term = L.solve(induc_data_covar.double()).to(full_inputs.dtype)
|
|
217
|
+
mean_interp_term = interp_term.select(mean_var_batch_dim - 2, 0)
|
|
218
|
+
var_interp_term = interp_term.select(mean_var_batch_dim - 2, 1)
|
|
219
|
+
|
|
220
|
+
# Compute the mean of q(f)
|
|
221
|
+
# k_XZ K_ZZ^{-1/2} m + \mu_X
|
|
222
|
+
# Here we're using the terms that correspond to the mean's inducing points
|
|
223
|
+
predictive_mean = torch.add(
|
|
224
|
+
torch.matmul(mean_interp_term.transpose(-1, -2), inducing_values.unsqueeze(-1)).squeeze(-1),
|
|
225
|
+
test_mean.select(mean_var_batch_dim - 1, 0),
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Compute the covariance of q(f)
|
|
229
|
+
# K_XX + k_XZ K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} k_ZX
|
|
230
|
+
middle_term = self.prior_distribution.lazy_covariance_matrix.mul(-1)
|
|
231
|
+
if variational_inducing_covar is not None:
|
|
232
|
+
middle_term = SumLinearOperator(variational_inducing_covar, middle_term)
|
|
233
|
+
predictive_covar = SumLinearOperator(
|
|
234
|
+
data_data_covar.add_jitter(self.jitter_val).to_dense().select(mean_var_batch_dim - 2, 1),
|
|
235
|
+
MatmulLinearOperator(var_interp_term.transpose(-1, -2), middle_term @ var_interp_term),
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
if hasattr(self.model, 'power'):
|
|
239
|
+
return MultivariateQExponential(predictive_mean, predictive_covar, power=self.model.power)
|
|
240
|
+
else:
|
|
241
|
+
return MultivariateNormal(predictive_mean, predictive_covar)
|
|
242
|
+
|
|
243
|
+
def kl_divergence(self) -> Tensor:
|
|
244
|
+
variational_dist = self.variational_distribution
|
|
245
|
+
prior_dist = self.prior_distribution
|
|
246
|
+
|
|
247
|
+
mean_dist = Delta(variational_dist.mean)
|
|
248
|
+
if hasattr(self.model, 'power'):
|
|
249
|
+
covar_dist = MultivariateQExponential(
|
|
250
|
+
torch.zeros_like(variational_dist.mean), variational_dist.lazy_covariance_matrix, power=self.model.power
|
|
251
|
+
)
|
|
252
|
+
else:
|
|
253
|
+
covar_dist = MultivariateNormal(
|
|
254
|
+
torch.zeros_like(variational_dist.mean), variational_dist.lazy_covariance_matrix
|
|
255
|
+
)
|
|
256
|
+
return kl_divergence(mean_dist, prior_dist) + kl_divergence(covar_dist, prior_dist)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from linear_operator.operators import CholLinearOperator, TriangularLinearOperator
|
|
7
|
+
|
|
8
|
+
from ..distributions import MultivariateNormal, MultivariateQExponential
|
|
9
|
+
from ._variational_distribution import _VariationalDistribution
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class CholeskyVariationalDistribution(_VariationalDistribution):
|
|
13
|
+
"""
|
|
14
|
+
A :obj:`~qpytorch.variational._VariationalDistribution` that is defined to be a multivariate normal (q-exponential) distribution
|
|
15
|
+
with a full covariance matrix.
|
|
16
|
+
|
|
17
|
+
The most common way this distribution is defined is to parameterize it in terms of a mean vector and a covariance
|
|
18
|
+
matrix. In order to ensure that the covariance matrix remains positive definite, we only consider the lower
|
|
19
|
+
triangle.
|
|
20
|
+
|
|
21
|
+
:param num_inducing_points: Size of the variational distribution. This implies that the variational mean
|
|
22
|
+
should be this size, and the variational covariance matrix should have this many rows and columns.
|
|
23
|
+
:param batch_shape: Specifies an optional batch size
|
|
24
|
+
for the variational parameters. This is useful for example when doing additive variational inference.
|
|
25
|
+
:param mean_init_std: (Default: 1e-3) Standard deviation of gaussian (q-exponential) noise to add to the mean initialization.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
num_inducing_points: int,
|
|
31
|
+
batch_shape: torch.Size = torch.Size([]),
|
|
32
|
+
mean_init_std: float = 1e-3,
|
|
33
|
+
**kwargs,
|
|
34
|
+
):
|
|
35
|
+
super().__init__(num_inducing_points=num_inducing_points, batch_shape=batch_shape, mean_init_std=mean_init_std)
|
|
36
|
+
mean_init = torch.zeros(num_inducing_points)
|
|
37
|
+
covar_init = torch.eye(num_inducing_points, num_inducing_points)
|
|
38
|
+
mean_init = mean_init.repeat(*batch_shape, 1)
|
|
39
|
+
covar_init = covar_init.repeat(*batch_shape, 1, 1)
|
|
40
|
+
|
|
41
|
+
self.register_parameter(name="variational_mean", parameter=torch.nn.Parameter(mean_init))
|
|
42
|
+
self.register_parameter(name="chol_variational_covar", parameter=torch.nn.Parameter(covar_init))
|
|
43
|
+
|
|
44
|
+
if 'power' in kwargs: self.power = kwargs.pop('power')
|
|
45
|
+
|
|
46
|
+
def forward(self) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
47
|
+
chol_variational_covar = self.chol_variational_covar
|
|
48
|
+
dtype = chol_variational_covar.dtype
|
|
49
|
+
device = chol_variational_covar.device
|
|
50
|
+
|
|
51
|
+
# First make the cholesky factor is upper triangular
|
|
52
|
+
lower_mask = torch.ones(self.chol_variational_covar.shape[-2:], dtype=dtype, device=device).tril(0)
|
|
53
|
+
chol_variational_covar = TriangularLinearOperator(chol_variational_covar.mul(lower_mask))
|
|
54
|
+
|
|
55
|
+
# Now construct the actual matrix
|
|
56
|
+
variational_covar = CholLinearOperator(chol_variational_covar)
|
|
57
|
+
if not hasattr(self, 'power'):
|
|
58
|
+
return MultivariateNormal(self.variational_mean, variational_covar)
|
|
59
|
+
else:
|
|
60
|
+
return MultivariateQExponential(self.variational_mean, variational_covar, power=self.power)
|
|
61
|
+
|
|
62
|
+
def initialize_variational_distribution(self, prior_dist: Union[MultivariateNormal, MultivariateQExponential]) -> None:
|
|
63
|
+
self.variational_mean.data.copy_(prior_dist.mean)
|
|
64
|
+
self.variational_mean.data.add_(torch.randn_like(prior_dist.mean), alpha=self.mean_init_std)
|
|
65
|
+
self.chol_variational_covar.data.copy_(prior_dist.lazy_covariance_matrix.cholesky().to_dense())
|
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from linear_operator import to_linear_operator
|
|
7
|
+
from linear_operator.operators import DiagLinearOperator, LinearOperator, MatmulLinearOperator, SumLinearOperator
|
|
8
|
+
from linear_operator.utils import linear_cg
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
from torch.autograd.function import FunctionCtx
|
|
11
|
+
|
|
12
|
+
from .. import settings
|
|
13
|
+
from ..distributions import Delta, Distribution, MultivariateNormal, MultivariateQExponential
|
|
14
|
+
from ..module import Module
|
|
15
|
+
from gpytorch.utils.memoize import cached
|
|
16
|
+
from ._variational_strategy import _VariationalStrategy
|
|
17
|
+
from .natural_variational_distribution import NaturalVariationalDistribution
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class _NgdInterpTerms(torch.autograd.Function):
|
|
21
|
+
"""
|
|
22
|
+
This function takes in
|
|
23
|
+
|
|
24
|
+
- the kernel interpolation term K_ZZ^{-1/2} k_ZX
|
|
25
|
+
- the natural parameters of the variational distribution
|
|
26
|
+
|
|
27
|
+
and returns
|
|
28
|
+
|
|
29
|
+
- the predictive distribution mean/covariance
|
|
30
|
+
- the inducing KL divergence KL( q(u) || p(u))
|
|
31
|
+
|
|
32
|
+
However, the gradients will be with respect to the **cannonical parameters**
|
|
33
|
+
of the variational distribution, rather than the **natural parameters**.
|
|
34
|
+
This corresponds to performing natural gradient descent on the variational distribution.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def forward(
|
|
39
|
+
ctx: FunctionCtx,
|
|
40
|
+
interp_term: torch.Tensor,
|
|
41
|
+
natural_vec: torch.Tensor,
|
|
42
|
+
natural_mat: torch.Tensor,
|
|
43
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
44
|
+
# Compute precision
|
|
45
|
+
prec = natural_mat.mul(-2.0)
|
|
46
|
+
diag = prec.diagonal(dim1=-1, dim2=-2).unsqueeze(-1)
|
|
47
|
+
|
|
48
|
+
# Make sure that interp_term and natural_vec are the same batch shape
|
|
49
|
+
batch_shape = torch.broadcast_shapes(interp_term.shape[:-2], natural_vec.shape[:-1])
|
|
50
|
+
expanded_interp_term = interp_term.expand(*batch_shape, *interp_term.shape[-2:])
|
|
51
|
+
expanded_natural_vec = natural_vec.expand(*batch_shape, natural_vec.size(-1))
|
|
52
|
+
|
|
53
|
+
# Compute necessary solves with the precision. We need
|
|
54
|
+
# m = expec_vec = S * natural_vec
|
|
55
|
+
# S K^{-1/2} k
|
|
56
|
+
solves = linear_cg(
|
|
57
|
+
prec.matmul,
|
|
58
|
+
torch.cat([expanded_natural_vec.unsqueeze(-1), expanded_interp_term], dim=-1),
|
|
59
|
+
n_tridiag=0,
|
|
60
|
+
max_iter=settings.max_cg_iterations.value(),
|
|
61
|
+
tolerance=min(settings.eval_cg_tolerance.value(), settings.cg_tolerance.value()),
|
|
62
|
+
max_tridiag_iter=settings.max_lanczos_quadrature_iterations.value(),
|
|
63
|
+
preconditioner=lambda x: x / diag,
|
|
64
|
+
)
|
|
65
|
+
expec_vec = solves[..., 0]
|
|
66
|
+
s_times_interp_term = solves[..., 1:]
|
|
67
|
+
|
|
68
|
+
# Compute the interpolated mean
|
|
69
|
+
# k^T K^{-1/2} m
|
|
70
|
+
interp_mean = (s_times_interp_term.transpose(-1, -2) @ natural_vec.unsqueeze(-1)).squeeze(-1)
|
|
71
|
+
|
|
72
|
+
# Compute the interpolated variance
|
|
73
|
+
# k^T K^{-1/2} S K^{-1/2} k = k^T K^{-1/2} (expec_mat - expec_vec expec_vec^T) K^{-1/2} k
|
|
74
|
+
interp_var = (s_times_interp_term * interp_term).sum(dim=-2)
|
|
75
|
+
|
|
76
|
+
# Let's not bother actually computing the KL-div in the foward pass
|
|
77
|
+
# 1/2 ( -log | S | + tr(S) + m^T m - len(m) )
|
|
78
|
+
# = 1/2 ( -log | expec_mat - expec_vec expec_vec^T | + tr(expec_mat) - len(m) )
|
|
79
|
+
kl_div = torch.zeros_like(interp_mean[..., 0])
|
|
80
|
+
|
|
81
|
+
# We're done!
|
|
82
|
+
ctx.save_for_backward(interp_term, s_times_interp_term, interp_mean, natural_vec, expec_vec, prec)
|
|
83
|
+
return interp_mean, interp_var, kl_div
|
|
84
|
+
|
|
85
|
+
@staticmethod
|
|
86
|
+
def backward(
|
|
87
|
+
ctx: FunctionCtx, interp_mean_grad: torch.Tensor, interp_var_grad: torch.Tensor, kl_div_grad: torch.Tensor
|
|
88
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None]:
|
|
89
|
+
# Get the saved terms
|
|
90
|
+
interp_term, s_times_interp_term, interp_mean, natural_vec, expec_vec, prec = ctx.saved_tensors
|
|
91
|
+
|
|
92
|
+
# Expand data-depenedent gradients
|
|
93
|
+
interp_mean_grad = interp_mean_grad.unsqueeze(-2)
|
|
94
|
+
interp_var_grad = interp_var_grad.unsqueeze(-2)
|
|
95
|
+
|
|
96
|
+
# Compute gradient of interp term (K^{-1/2} k)
|
|
97
|
+
# interp_mean component: m
|
|
98
|
+
# interp_var component: S K^{-1/2} k
|
|
99
|
+
# kl component: 0
|
|
100
|
+
interp_term_grad = (interp_var_grad * s_times_interp_term).mul(2.0) + (
|
|
101
|
+
interp_mean_grad * expec_vec.unsqueeze(-1)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Compute gradient of expected vector (m)
|
|
105
|
+
# interp_mean component: K^{-1/2} k
|
|
106
|
+
# interp_var component: (k^T K^{-1/2} m) K^{-1/2} k
|
|
107
|
+
# kl component: S^{-1} m
|
|
108
|
+
expec_vec_grad = (
|
|
109
|
+
(interp_var_grad * interp_mean.unsqueeze(-2) * interp_term).sum(dim=-1).mul(-2)
|
|
110
|
+
+ (interp_mean_grad * interp_term).sum(dim=-1)
|
|
111
|
+
+ (kl_div_grad.unsqueeze(-1) * natural_vec)
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Compute gradient of expected matrix (mm^T + S)
|
|
115
|
+
# interp_mean component: 0
|
|
116
|
+
# interp_var component: K^{-1/2} k k^T K^{-1/2}
|
|
117
|
+
# kl component: 1/2 ( I - S^{-1} )
|
|
118
|
+
eye = torch.eye(expec_vec.size(-1), device=expec_vec.device, dtype=expec_vec.dtype)
|
|
119
|
+
expec_mat_grad = torch.add(
|
|
120
|
+
(interp_var_grad * interp_term) @ interp_term.transpose(-1, -2),
|
|
121
|
+
(kl_div_grad.unsqueeze(-1).unsqueeze(-1) * (eye - prec).mul(0.5)),
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# We're done!
|
|
125
|
+
return interp_term_grad, expec_vec_grad, expec_mat_grad, None # Extra "None" for the kwarg
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class CiqVariationalStrategy(_VariationalStrategy):
|
|
129
|
+
r"""
|
|
130
|
+
Similar to :class:`~qpytorch.variational.VariationalStrategy`,
|
|
131
|
+
except the whitening operation is performed using Contour Integral Quadrature
|
|
132
|
+
rather than Cholesky (see `Pleiss et al. (2020)`_ for more info).
|
|
133
|
+
See the `CIQ-SVGP tutorial`_ for an example.
|
|
134
|
+
|
|
135
|
+
Contour Integral Quadrature uses iterative matrix-vector multiplication to approximate
|
|
136
|
+
the :math:`\mathbf K_{\mathbf Z \mathbf Z}^{-1/2}` matrix used for the whitening operation.
|
|
137
|
+
This can be more efficient than the standard variational strategy for large numbers
|
|
138
|
+
of inducing points (e.g. :math:`M > 1000`) or when the inducing points have structure
|
|
139
|
+
(e.g. they lie on an evenly-spaced grid).
|
|
140
|
+
|
|
141
|
+
.. note::
|
|
142
|
+
|
|
143
|
+
It is recommended that this object is used in conjunction with
|
|
144
|
+
:obj:`~qpytorch.variational.NaturalVariationalDistribution` and
|
|
145
|
+
`natural gradient descent`_.
|
|
146
|
+
|
|
147
|
+
:param model: Model this strategy is applied to.
|
|
148
|
+
Typically passed in when the VariationalStrategy is created in the
|
|
149
|
+
__init__ method of the user defined model.
|
|
150
|
+
It should contain power if Q-Exponential distribution is involved in.
|
|
151
|
+
:param inducing_points: Tensor containing a set of inducing
|
|
152
|
+
points to use for variational inference.
|
|
153
|
+
:param variational_distribution: A
|
|
154
|
+
VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
|
|
155
|
+
:param learn_inducing_locations: (Default True): Whether or not
|
|
156
|
+
the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they
|
|
157
|
+
parameters of the model).
|
|
158
|
+
:param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability
|
|
159
|
+
|
|
160
|
+
.. _Pleiss et al. (2020):
|
|
161
|
+
https://arxiv.org/pdf/2006.11267.pdf
|
|
162
|
+
.. _CIQ-SVGP tutorial:
|
|
163
|
+
examples/04_Variational_and_Approximate_GPs/SVGP_CIQ.html
|
|
164
|
+
.. _natural gradient descent:
|
|
165
|
+
examples/04_Variational_and_Approximate_GPs/Natural_Gradient_Descent.html
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
def _ngd(self) -> bool:
|
|
169
|
+
return isinstance(self._variational_distribution, NaturalVariationalDistribution)
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
@cached(name="prior_distribution_memo")
|
|
173
|
+
def prior_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
174
|
+
zeros = torch.zeros(
|
|
175
|
+
self._variational_distribution.shape(),
|
|
176
|
+
dtype=self._variational_distribution.dtype,
|
|
177
|
+
device=self._variational_distribution.device,
|
|
178
|
+
)
|
|
179
|
+
ones = torch.ones_like(zeros)
|
|
180
|
+
if hasattr(self.model, 'power'):
|
|
181
|
+
res = MultivariateQExponential(zeros, DiagLinearOperator(ones), power=self.model.power)
|
|
182
|
+
else:
|
|
183
|
+
res = MultivariateNormal(zeros, DiagLinearOperator(ones))
|
|
184
|
+
return res
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
@cached(name="variational_distribution_memo")
|
|
188
|
+
def variational_distribution(self) -> Distribution:
|
|
189
|
+
if self._ngd():
|
|
190
|
+
raise RuntimeError(
|
|
191
|
+
"Variational distribution for NGD-CIQ should be computed during forward calls. "
|
|
192
|
+
"This is probably a bug in GPyTorch."
|
|
193
|
+
)
|
|
194
|
+
return super().variational_distribution
|
|
195
|
+
|
|
196
|
+
def forward(
|
|
197
|
+
self,
|
|
198
|
+
x: torch.Tensor,
|
|
199
|
+
inducing_points: torch.Tensor,
|
|
200
|
+
inducing_values: torch.Tensor,
|
|
201
|
+
variational_inducing_covar: Optional[LinearOperator] = None,
|
|
202
|
+
*params,
|
|
203
|
+
**kwargs,
|
|
204
|
+
) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
205
|
+
# Compute full prior distribution
|
|
206
|
+
full_inputs = torch.cat([inducing_points, x], dim=-2)
|
|
207
|
+
full_output = self.model.forward(full_inputs, *params, **kwargs)
|
|
208
|
+
full_covar = full_output.lazy_covariance_matrix
|
|
209
|
+
|
|
210
|
+
# Covariance terms
|
|
211
|
+
num_induc = inducing_points.size(-2)
|
|
212
|
+
test_mean = full_output.mean[..., num_induc:]
|
|
213
|
+
induc_induc_covar = full_covar[..., :num_induc, :num_induc].evaluate_kernel().add_jitter(self.jitter_val)
|
|
214
|
+
induc_data_covar = full_covar[..., :num_induc, num_induc:].to_dense()
|
|
215
|
+
data_data_covar = full_covar[..., num_induc:, num_induc:].add_jitter(self.jitter_val)
|
|
216
|
+
|
|
217
|
+
# Compute interpolation terms
|
|
218
|
+
# K_XZ K_ZZ^{-1} \mu_z
|
|
219
|
+
# K_XZ K_ZZ^{-1/2} \mu_Z
|
|
220
|
+
with settings.max_preconditioner_size(0): # Turn off preconditioning for CIQ
|
|
221
|
+
interp_term = to_linear_operator(induc_induc_covar).sqrt_inv_matmul(induc_data_covar)
|
|
222
|
+
|
|
223
|
+
# Compute interpolated mean and variance terms
|
|
224
|
+
# We have separate computation rules for NGD versus standard GD
|
|
225
|
+
if self._ngd():
|
|
226
|
+
interp_mean, interp_var, kl_div = _NgdInterpTerms.apply(
|
|
227
|
+
interp_term,
|
|
228
|
+
self._variational_distribution.natural_vec,
|
|
229
|
+
self._variational_distribution.natural_mat,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Compute the covariance of q(f)
|
|
233
|
+
predictive_var = data_data_covar.diagonal(dim1=-1, dim2=-2) - interp_term.pow(2).sum(dim=-2) + interp_var
|
|
234
|
+
predictive_var = torch.clamp_min(predictive_var, settings.min_variance.value(predictive_var.dtype))
|
|
235
|
+
predictive_covar = DiagLinearOperator(predictive_var)
|
|
236
|
+
|
|
237
|
+
# Also compute and cache the KL divergence
|
|
238
|
+
if not hasattr(self, "_memoize_cache"):
|
|
239
|
+
self._memoize_cache = dict()
|
|
240
|
+
self._memoize_cache["kl"] = kl_div
|
|
241
|
+
|
|
242
|
+
else:
|
|
243
|
+
# Compute interpolated mean term
|
|
244
|
+
interp_mean = torch.matmul(
|
|
245
|
+
interp_term.transpose(-1, -2), (inducing_values - self.prior_distribution.mean).unsqueeze(-1)
|
|
246
|
+
).squeeze(-1)
|
|
247
|
+
|
|
248
|
+
# Compute the covariance of q(f)
|
|
249
|
+
middle_term = self.prior_distribution.lazy_covariance_matrix.mul(-1)
|
|
250
|
+
if variational_inducing_covar is not None:
|
|
251
|
+
middle_term = SumLinearOperator(variational_inducing_covar, middle_term)
|
|
252
|
+
predictive_covar = SumLinearOperator(
|
|
253
|
+
data_data_covar.add_jitter(self.jitter_val),
|
|
254
|
+
MatmulLinearOperator(interp_term.transpose(-1, -2), middle_term @ interp_term),
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Compute the mean of q(f)
|
|
258
|
+
# k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z) + \mu_X
|
|
259
|
+
predictive_mean = interp_mean + test_mean
|
|
260
|
+
|
|
261
|
+
# Return the distribution
|
|
262
|
+
if hasattr(self.model, 'power'):
|
|
263
|
+
return MultivariateQExponential(predictive_mean, predictive_covar, power=self.model.power)
|
|
264
|
+
else:
|
|
265
|
+
return MultivariateNormal(predictive_mean, predictive_covar)
|
|
266
|
+
|
|
267
|
+
def kl_divergence(self) -> Tensor:
|
|
268
|
+
r"""
|
|
269
|
+
Compute the KL divergence between the variational inducing distribution :math:`q(\mathbf u)`
|
|
270
|
+
and the prior inducing distribution :math:`p(\mathbf u)`.
|
|
271
|
+
|
|
272
|
+
:rtype: torch.Tensor
|
|
273
|
+
"""
|
|
274
|
+
if self._ngd():
|
|
275
|
+
if hasattr(self, "_memoize_cache") and "kl" in self._memoize_cache:
|
|
276
|
+
return self._memoize_cache["kl"]
|
|
277
|
+
else:
|
|
278
|
+
raise RuntimeError(
|
|
279
|
+
"KL divergence for NGD-CIQ should be computed during forward calls."
|
|
280
|
+
"This is probably a bug in GPyTorch."
|
|
281
|
+
)
|
|
282
|
+
else:
|
|
283
|
+
return super().kl_divergence()
|
|
284
|
+
|
|
285
|
+
def __call__(self, x: torch.Tensor, prior: bool = False, *params, **kwargs) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
286
|
+
# This is mostly the same as _VariationalStrategy.__call__()
|
|
287
|
+
# but with special rules for natural gradient descent (to prevent O(M^3) computation)
|
|
288
|
+
|
|
289
|
+
# If we're in prior mode, then we're done!
|
|
290
|
+
if prior:
|
|
291
|
+
return self.model.forward(x)
|
|
292
|
+
|
|
293
|
+
# Delete previously cached items from the training distribution
|
|
294
|
+
if self.training:
|
|
295
|
+
self._clear_cache()
|
|
296
|
+
|
|
297
|
+
# (Maybe) initialize variational distribution
|
|
298
|
+
if not self.variational_params_initialized.item():
|
|
299
|
+
if self._ngd():
|
|
300
|
+
noise = torch.randn_like(self.prior_distribution.mean).mul_(1e-3)
|
|
301
|
+
eye = torch.eye(noise.size(-1), dtype=noise.dtype, device=noise.device).mul(-0.5)
|
|
302
|
+
self._variational_distribution.natural_vec.data.copy_(noise)
|
|
303
|
+
self._variational_distribution.natural_mat.data.copy_(eye)
|
|
304
|
+
self.variational_params_initialized.fill_(1)
|
|
305
|
+
else:
|
|
306
|
+
prior_dist = self.prior_distribution
|
|
307
|
+
self._variational_distribution.initialize_variational_distribution(prior_dist)
|
|
308
|
+
self.variational_params_initialized.fill_(1)
|
|
309
|
+
|
|
310
|
+
# Ensure inducing_points and x are the same size
|
|
311
|
+
inducing_points = self.inducing_points
|
|
312
|
+
if inducing_points.shape[:-2] != x.shape[:-2]:
|
|
313
|
+
x, inducing_points = self._expand_inputs(x, inducing_points)
|
|
314
|
+
|
|
315
|
+
# Get q(f)
|
|
316
|
+
if self._ngd():
|
|
317
|
+
return Module.__call__(
|
|
318
|
+
self,
|
|
319
|
+
x,
|
|
320
|
+
inducing_points,
|
|
321
|
+
inducing_values=None,
|
|
322
|
+
variational_inducing_covar=None,
|
|
323
|
+
*params,
|
|
324
|
+
**kwargs,
|
|
325
|
+
)
|
|
326
|
+
else:
|
|
327
|
+
# Get p(u)/q(u)
|
|
328
|
+
variational_dist_u = self.variational_distribution
|
|
329
|
+
|
|
330
|
+
if isinstance(variational_dist_u, (MultivariateNormal, MultivariateQExponential)):
|
|
331
|
+
return Module.__call__(
|
|
332
|
+
self,
|
|
333
|
+
x,
|
|
334
|
+
inducing_points,
|
|
335
|
+
inducing_values=variational_dist_u.mean,
|
|
336
|
+
variational_inducing_covar=variational_dist_u.lazy_covariance_matrix,
|
|
337
|
+
**kwargs,
|
|
338
|
+
)
|
|
339
|
+
elif isinstance(variational_dist_u, Delta):
|
|
340
|
+
return Module.__call__(
|
|
341
|
+
self,
|
|
342
|
+
x,
|
|
343
|
+
inducing_points,
|
|
344
|
+
inducing_values=variational_dist_u.mean,
|
|
345
|
+
variational_inducing_covar=None,
|
|
346
|
+
**kwargs,
|
|
347
|
+
)
|
|
348
|
+
else:
|
|
349
|
+
raise RuntimeError(
|
|
350
|
+
f"Invalid variational distribuition ({type(variational_dist_u)}). "
|
|
351
|
+
"Expected a multivariate normal (q-exponential) or a delta distribution."
|
|
352
|
+
)
|