qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,487 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from jaxtyping import Float
|
|
7
|
+
from linear_operator import to_dense
|
|
8
|
+
from linear_operator.operators import DiagLinearOperator, LinearOperator, TriangularLinearOperator
|
|
9
|
+
from linear_operator.utils.cholesky import psd_safe_cholesky
|
|
10
|
+
from torch import LongTensor, Tensor
|
|
11
|
+
|
|
12
|
+
from ..distributions import MultivariateNormal, MultivariateQExponential
|
|
13
|
+
from ..models import ApproximateGP, ApproximateQEP, ExactGP, ExactQEP
|
|
14
|
+
from ..module import Module
|
|
15
|
+
from gpytorch.utils.errors import CachingError
|
|
16
|
+
from gpytorch.utils.memoize import add_to_cache, cached, pop_from_cache
|
|
17
|
+
from gpytorch.utils.nearest_neighbors import NNUtil
|
|
18
|
+
from ._variational_distribution import _VariationalDistribution
|
|
19
|
+
from .mean_field_variational_distribution import MeanFieldVariationalDistribution
|
|
20
|
+
from .unwhitened_variational_strategy import UnwhitenedVariationalStrategy
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class NNVariationalStrategy(UnwhitenedVariationalStrategy):
|
|
24
|
+
r"""
|
|
25
|
+
This strategy sets all inducing point locations to observed inputs,
|
|
26
|
+
and employs a :math:`k`-nearest-neighbor approximation. It was introduced as the
|
|
27
|
+
`Variational Nearest Neighbor Gaussian Processes (VNNGP)` in `Wu et al (2022)`_.
|
|
28
|
+
See the `VNNGP tutorial`_ for an example.
|
|
29
|
+
|
|
30
|
+
VNNGP assumes a k-nearest-neighbor generative process for inducing points :math:`\mathbf u`,
|
|
31
|
+
:math:`\mathbf q(\mathbf u) = \prod_{j=1}^M q(u_j | \mathbf u_{n(j)})`
|
|
32
|
+
where :math:`n(j)` denotes the indices of :math:`k` nearest neighbors for :math:`u_j` among
|
|
33
|
+
:math:`u_1, \cdots, u_{j-1}`. For any test observation :math:`\mathbf f`,
|
|
34
|
+
VNNGP makes predictive inference conditioned on its :math:`k` nearest inducing points
|
|
35
|
+
:math:`\mathbf u_{n(f)}`, i.e. :math:`p(f|\mathbf u_{n(f)})`.
|
|
36
|
+
|
|
37
|
+
VNNGP's objective factorizes over inducing points and observations, making stochastic optimization over both
|
|
38
|
+
immediately available. After a one-time cost of computing the :math:`k`-nearest neighbor structure,
|
|
39
|
+
the training and inference complexity is :math:`O(k^3)`.
|
|
40
|
+
Since VNNGP uses observations as inducing points, it is a user choice to either (1)
|
|
41
|
+
use the same mini-batch of inducing points and observations (recommended),
|
|
42
|
+
or (2) use different mini-batches of inducing points and observations. See the `VNNGP tutorial`_ for
|
|
43
|
+
implementation and comparison.
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
.. note::
|
|
47
|
+
|
|
48
|
+
The current implementation only supports :obj:`~qpytorch.variational.MeanFieldVariationalDistribution`.
|
|
49
|
+
|
|
50
|
+
We recommend installing the `faiss`_ library (requiring separate package installment)
|
|
51
|
+
for nearest neighbor search, which is significantly faster than the `scikit-learn` nearest neighbor search.
|
|
52
|
+
GPyTorch will automatically use `faiss` if it is installed, but will revert to `scikit-learn` otherwise.
|
|
53
|
+
|
|
54
|
+
Different inducing point orderings will produce in different nearest neighbor approximations.
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
:param ~gpytorch.models.ApproximateGP (~qpytorch.models.ApproximateQEP) model: Model this strategy is applied to.
|
|
58
|
+
Typically passed in when the VariationalStrategy is created in the
|
|
59
|
+
__init__ method of the user defined model.
|
|
60
|
+
It should contain power if Q-Exponential distribution is involved in.
|
|
61
|
+
:param inducing_points: Tensor containing a set of inducing
|
|
62
|
+
points to use for variational inference.
|
|
63
|
+
:param variational_distribution: A
|
|
64
|
+
VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
|
|
65
|
+
:param k: Number of nearest neighbors.
|
|
66
|
+
:param training_batch_size: The number of data points that will be in the training batch size.
|
|
67
|
+
:param jitter_val: Amount of diagonal jitter to add for covariance matrix numerical stability.
|
|
68
|
+
:param compute_full_kl: Whether to compute full kl divergence or stochastic estimate.
|
|
69
|
+
|
|
70
|
+
.. _Wu et al (2022):
|
|
71
|
+
https://arxiv.org/pdf/2202.01694.pdf
|
|
72
|
+
.. _VNNGP tutorial:
|
|
73
|
+
examples/04_Variational_and_Approximate_GPs/VNNGP.html
|
|
74
|
+
.. _faiss:
|
|
75
|
+
https://github.com/facebookresearch/faiss
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
model: Union[ApproximateGP, ApproximateQEP],
|
|
81
|
+
inducing_points: Float[Tensor, "... M D"],
|
|
82
|
+
variational_distribution: Float[_VariationalDistribution, "... M"],
|
|
83
|
+
k: int,
|
|
84
|
+
training_batch_size: Optional[int] = None,
|
|
85
|
+
jitter_val: Optional[float] = 1e-3,
|
|
86
|
+
compute_full_kl: Optional[bool] = False,
|
|
87
|
+
):
|
|
88
|
+
assert isinstance(
|
|
89
|
+
variational_distribution, MeanFieldVariationalDistribution
|
|
90
|
+
), "Currently, NNVariationalStrategy only supports MeanFieldVariationalDistribution."
|
|
91
|
+
|
|
92
|
+
super().__init__(
|
|
93
|
+
model, inducing_points, variational_distribution, learn_inducing_locations=False, jitter_val=jitter_val
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Model
|
|
97
|
+
object.__setattr__(self, "model", model)
|
|
98
|
+
|
|
99
|
+
self.inducing_points = inducing_points
|
|
100
|
+
self.M, self.D = inducing_points.shape[-2:]
|
|
101
|
+
self.k = k
|
|
102
|
+
assert self.k < self.M, (
|
|
103
|
+
f"Number of nearest neighbors k must be smaller than the number of inducing points, "
|
|
104
|
+
f"but got k = {k}, M = {self.M}."
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
self._inducing_batch_shape: torch.Size = inducing_points.shape[:-2]
|
|
108
|
+
self._model_batch_shape: torch.Size = self._variational_distribution.variational_mean.shape[:-1]
|
|
109
|
+
self._batch_shape: torch.Size = torch.broadcast_shapes(self._inducing_batch_shape, self._model_batch_shape)
|
|
110
|
+
|
|
111
|
+
self.nn_util: NNUtil = NNUtil(
|
|
112
|
+
k, dim=self.D, batch_shape=self._inducing_batch_shape, device=inducing_points.device
|
|
113
|
+
)
|
|
114
|
+
self._compute_nn()
|
|
115
|
+
# otherwise, no nearest neighbor approximation is used
|
|
116
|
+
|
|
117
|
+
self.training_batch_size = training_batch_size if training_batch_size is not None else self.M
|
|
118
|
+
self._set_training_iterator()
|
|
119
|
+
|
|
120
|
+
self.compute_full_kl = compute_full_kl
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
@cached(name="prior_distribution_memo")
|
|
124
|
+
def prior_distribution(self) -> Union[Float[MultivariateNormal, "... M"], Float[MultivariateQExponential, "... M"]]:
|
|
125
|
+
out = self.model.forward(self.inducing_points)
|
|
126
|
+
if hasattr(self.model, 'power'):
|
|
127
|
+
res = MultivariateQExponential(out.mean, out.lazy_covariance_matrix.add_jitter(self.jitter_val), power=self.model.power)
|
|
128
|
+
else:
|
|
129
|
+
res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter(self.jitter_val))
|
|
130
|
+
return res
|
|
131
|
+
|
|
132
|
+
def _cholesky_factor(
|
|
133
|
+
self, induc_induc_covar: Float[LinearOperator, "... M M"]
|
|
134
|
+
) -> Float[TriangularLinearOperator, "... M M"]:
|
|
135
|
+
# Uncached version
|
|
136
|
+
L = psd_safe_cholesky(to_dense(induc_induc_covar))
|
|
137
|
+
return TriangularLinearOperator(L)
|
|
138
|
+
|
|
139
|
+
def __call__(
|
|
140
|
+
self, x: Float[Tensor, "... N D"], prior: bool = False, **kwargs: Any
|
|
141
|
+
) -> Union[Float[MultivariateNormal, "... N"], Float[MultivariateQExponential, "... N"]]:
|
|
142
|
+
# If we're in prior mode, then we're done!
|
|
143
|
+
if prior:
|
|
144
|
+
return self.model.forward(x, **kwargs)
|
|
145
|
+
|
|
146
|
+
if x is not None:
|
|
147
|
+
# Make sure x and inducing points have the same batch shape
|
|
148
|
+
if not (self.inducing_points.shape[:-2] == x.shape[:-2]):
|
|
149
|
+
try:
|
|
150
|
+
x = x.expand(*self.inducing_points.shape[:-2], *x.shape[-2:]).contiguous()
|
|
151
|
+
except RuntimeError:
|
|
152
|
+
raise RuntimeError(
|
|
153
|
+
f"x batch shape must match or broadcast with the inducing points' batch shape, "
|
|
154
|
+
f"but got x batch shape = {x.shape[:-2]}, "
|
|
155
|
+
f"inducing points batch shape = {self.inducing_points.shape[:-2]}."
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Delete previously cached items from the training distribution
|
|
159
|
+
if self.training:
|
|
160
|
+
self._clear_cache()
|
|
161
|
+
|
|
162
|
+
# (Maybe) initialize variational distribution
|
|
163
|
+
if not self.variational_params_initialized.item():
|
|
164
|
+
prior_dist = self.prior_distribution
|
|
165
|
+
self._variational_distribution.variational_mean.data.copy_(prior_dist.mean)
|
|
166
|
+
self._variational_distribution.variational_mean.data.add_(
|
|
167
|
+
torch.randn_like(prior_dist.mean), alpha=self._variational_distribution.mean_init_std
|
|
168
|
+
)
|
|
169
|
+
# initialize with a small variational stddev for quicker conv. of kl divergence
|
|
170
|
+
self._variational_distribution._variational_stddev.data.copy_(torch.tensor(1e-2))
|
|
171
|
+
self.variational_params_initialized.fill_(1)
|
|
172
|
+
|
|
173
|
+
return self.forward(
|
|
174
|
+
x, self.inducing_points, inducing_values=None, variational_inducing_covar=None, **kwargs
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
# Ensure inducing_points and x are the same size
|
|
178
|
+
inducing_points = self.inducing_points
|
|
179
|
+
return self.forward(x, inducing_points, inducing_values=None, variational_inducing_covar=None, **kwargs)
|
|
180
|
+
|
|
181
|
+
def forward(
|
|
182
|
+
self,
|
|
183
|
+
x: Float[Tensor, "... N D"],
|
|
184
|
+
inducing_points: Float[Tensor, "... M D"],
|
|
185
|
+
inducing_values: Float[Tensor, "... M"],
|
|
186
|
+
variational_inducing_covar: Optional[Float[LinearOperator, "... M M"]] = None,
|
|
187
|
+
**kwargs: Any,
|
|
188
|
+
) -> Union[Float[MultivariateNormal, "... N"], Float[MultivariateQExponential, "... N"]]:
|
|
189
|
+
if self.training:
|
|
190
|
+
# In training mode, note that the full inducing points set = full training dataset
|
|
191
|
+
# Users have the option to choose input None or a tensor of training data for x
|
|
192
|
+
# If x is None, will sample training data from inducing points
|
|
193
|
+
# Otherwise, will find the indices of inducing points that are equal to x
|
|
194
|
+
if x is None:
|
|
195
|
+
x_indices = self._get_training_indices()
|
|
196
|
+
kl_indices = x_indices
|
|
197
|
+
|
|
198
|
+
predictive_mean = self._variational_distribution.variational_mean[..., x_indices]
|
|
199
|
+
predictive_var = self._variational_distribution._variational_stddev[..., x_indices] ** 2
|
|
200
|
+
|
|
201
|
+
else:
|
|
202
|
+
# find the indices of inducing points that correspond to x
|
|
203
|
+
x_indices = self.nn_util.find_nn_idx(x.float(), k=1).squeeze(-1) # (*inducing_batch_shape, batch_size)
|
|
204
|
+
|
|
205
|
+
expanded_x_indices = x_indices.expand(*self._batch_shape, x_indices.shape[-1])
|
|
206
|
+
expanded_variational_mean = self._variational_distribution.variational_mean.expand(
|
|
207
|
+
*self._batch_shape, self.M
|
|
208
|
+
)
|
|
209
|
+
expanded_variational_var = (
|
|
210
|
+
self._variational_distribution._variational_stddev.expand(*self._batch_shape, self.M) ** 2
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
predictive_mean = expanded_variational_mean.gather(-1, expanded_x_indices)
|
|
214
|
+
predictive_var = expanded_variational_var.gather(-1, expanded_x_indices)
|
|
215
|
+
|
|
216
|
+
# sample a different indices for stochastic estimation of kl
|
|
217
|
+
kl_indices = self._get_training_indices()
|
|
218
|
+
|
|
219
|
+
kl = self._kl_divergence(kl_indices)
|
|
220
|
+
add_to_cache(self, "kl_divergence_memo", kl)
|
|
221
|
+
|
|
222
|
+
# if hasattr(self.model, 'power'):
|
|
223
|
+
# return MultivariateQExponential(predictive_mean, DiagLinearOperator(predictive_var), power=self.model.power)
|
|
224
|
+
# else:
|
|
225
|
+
# return MultivariateNormal(predictive_mean, DiagLinearOperator(predictive_var))
|
|
226
|
+
else:
|
|
227
|
+
nn_indices = self.nn_util.find_nn_idx(x.float())
|
|
228
|
+
|
|
229
|
+
x_batch_shape = x.shape[:-2]
|
|
230
|
+
batch_shape = torch.broadcast_shapes(self._batch_shape, x_batch_shape)
|
|
231
|
+
x_bsz = x.shape[-2]
|
|
232
|
+
assert nn_indices.shape == (*x_batch_shape, x_bsz, self.k), nn_indices.shape
|
|
233
|
+
|
|
234
|
+
# select K nearest neighbors from inducing points for test point x
|
|
235
|
+
expanded_nn_indices = nn_indices.unsqueeze(-1).expand(*x_batch_shape, x_bsz, self.k, self.D)
|
|
236
|
+
expanded_inducing_points = inducing_points.unsqueeze(-2).expand(*x_batch_shape, self.M, self.k, self.D)
|
|
237
|
+
inducing_points = expanded_inducing_points.gather(-3, expanded_nn_indices)
|
|
238
|
+
assert inducing_points.shape == (*x_batch_shape, x_bsz, self.k, self.D)
|
|
239
|
+
|
|
240
|
+
# get variational mean and covar for nearest neighbors
|
|
241
|
+
inducing_values = self._variational_distribution.variational_mean
|
|
242
|
+
expanded_inducing_values = inducing_values.unsqueeze(-1).expand(*batch_shape, self.M, self.k)
|
|
243
|
+
expanded_nn_indices = nn_indices.expand(*batch_shape, x_bsz, self.k)
|
|
244
|
+
inducing_values = expanded_inducing_values.gather(-2, expanded_nn_indices)
|
|
245
|
+
assert inducing_values.shape == (*batch_shape, x_bsz, self.k)
|
|
246
|
+
|
|
247
|
+
variational_stddev = self._variational_distribution._variational_stddev
|
|
248
|
+
assert variational_stddev.shape == (*self._model_batch_shape, self.M)
|
|
249
|
+
expanded_variational_stddev = variational_stddev.unsqueeze(-1).expand(*batch_shape, self.M, self.k)
|
|
250
|
+
variational_inducing_covar = expanded_variational_stddev.gather(-2, expanded_nn_indices) ** 2
|
|
251
|
+
assert variational_inducing_covar.shape == (*batch_shape, x_bsz, self.k)
|
|
252
|
+
variational_inducing_covar = DiagLinearOperator(variational_inducing_covar)
|
|
253
|
+
assert variational_inducing_covar.shape == (*batch_shape, x_bsz, self.k, self.k)
|
|
254
|
+
|
|
255
|
+
# Make everything batch mode
|
|
256
|
+
x = x.unsqueeze(-2)
|
|
257
|
+
assert x.shape == (*x_batch_shape, x_bsz, 1, self.D)
|
|
258
|
+
x = x.expand(*batch_shape, x_bsz, 1, self.D)
|
|
259
|
+
|
|
260
|
+
# Compute forward mode in the standard way
|
|
261
|
+
_batch_dims = tuple(range(len(batch_shape)))
|
|
262
|
+
_x = x.permute((-3,) + _batch_dims + (-2, -1)) # (x_bsz, *batch_shape, 1, D)
|
|
263
|
+
|
|
264
|
+
# inducing_points.shape (*x_batch_shape, x_bsz, self.k, self.D)
|
|
265
|
+
inducing_points = inducing_points.expand(*batch_shape, x_bsz, self.k, self.D)
|
|
266
|
+
_inducing_points = inducing_points.permute((-3,) + _batch_dims + (-2, -1)) # (x_bsz, *batch_shape, k, D)
|
|
267
|
+
_inducing_values = inducing_values.permute((-2,) + _batch_dims + (-1,))
|
|
268
|
+
_variational_inducing_covar = variational_inducing_covar.permute((-3,) + _batch_dims + (-2, -1))
|
|
269
|
+
dist = super().forward(_x, _inducing_points, _inducing_values, _variational_inducing_covar, **kwargs)
|
|
270
|
+
|
|
271
|
+
_x_batch_dims = tuple(range(1, 1 + len(batch_shape)))
|
|
272
|
+
predictive_mean = dist.mean # (x_bsz, *x_batch_shape, 1)
|
|
273
|
+
predictive_covar = dist.covariance_matrix # (x_bsz, *x_batch_shape, 1, 1)
|
|
274
|
+
predictive_mean = predictive_mean.permute(_x_batch_dims + (0, -1))
|
|
275
|
+
predictive_covar = predictive_covar.permute(_x_batch_dims + (0, -2, -1))
|
|
276
|
+
|
|
277
|
+
# Undo batch mode
|
|
278
|
+
predictive_mean = predictive_mean.squeeze(-1)
|
|
279
|
+
predictive_var = predictive_covar.squeeze(-2).squeeze(-1)
|
|
280
|
+
assert predictive_var.shape == predictive_covar.shape[:-2]
|
|
281
|
+
assert predictive_mean.shape == predictive_covar.shape[:-2]
|
|
282
|
+
|
|
283
|
+
# Return the distribution
|
|
284
|
+
if hasattr(self.model, 'power'):
|
|
285
|
+
return MultivariateQExponential(predictive_mean, DiagLinearOperator(predictive_var), power=self.model.power)
|
|
286
|
+
else:
|
|
287
|
+
return MultivariateNormal(predictive_mean, DiagLinearOperator(predictive_var))
|
|
288
|
+
|
|
289
|
+
def get_fantasy_model(
|
|
290
|
+
self,
|
|
291
|
+
inputs: Float[Tensor, "... N D"],
|
|
292
|
+
targets: Float[Tensor, "... N"],
|
|
293
|
+
mean_module: Optional[Module] = None,
|
|
294
|
+
covar_module: Optional[Module] = None,
|
|
295
|
+
**kwargs,
|
|
296
|
+
) -> Union[ExactGP, ExactQEP]:
|
|
297
|
+
raise NotImplementedError(
|
|
298
|
+
f"No fantasy model support for {self.__class__.__name__}. "
|
|
299
|
+
"Only VariationalStrategy and UnwhitenedVariationalStrategy are currently supported."
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
def _set_training_iterator(self) -> None:
|
|
303
|
+
self._training_indices_iter = 0
|
|
304
|
+
if self.training_batch_size == self.M:
|
|
305
|
+
self._training_indices_iterator = (torch.arange(self.M, device=self.inducing_points.device),)
|
|
306
|
+
else:
|
|
307
|
+
# The first training batch always contains the first k inducing points
|
|
308
|
+
# This is because computing the KL divergence for the first k inducing points is special-cased
|
|
309
|
+
# (since the first k inducing points have < k neighbors)
|
|
310
|
+
# Note that there is a special function _firstk_kl_helper for this
|
|
311
|
+
training_indices = torch.randperm(self.M - self.k, device=self.inducing_points.device) + self.k
|
|
312
|
+
self._training_indices_iterator = (torch.arange(self.k),) + training_indices.split(self.training_batch_size)
|
|
313
|
+
self._total_training_batches = len(self._training_indices_iterator)
|
|
314
|
+
|
|
315
|
+
def _get_training_indices(self) -> LongTensor:
|
|
316
|
+
self.current_training_indices = self._training_indices_iterator[self._training_indices_iter]
|
|
317
|
+
self._training_indices_iter += 1
|
|
318
|
+
if self._training_indices_iter == self._total_training_batches:
|
|
319
|
+
self._set_training_iterator()
|
|
320
|
+
return self.current_training_indices
|
|
321
|
+
|
|
322
|
+
def _firstk_kl_helper(self) -> Float[Tensor, "..."]:
|
|
323
|
+
# Compute the KL divergence for first k inducing points
|
|
324
|
+
train_x_firstk = self.inducing_points[..., : self.k, :]
|
|
325
|
+
full_output = self.model.forward(train_x_firstk)
|
|
326
|
+
|
|
327
|
+
induc_mean, induc_induc_covar = full_output.mean, full_output.lazy_covariance_matrix
|
|
328
|
+
|
|
329
|
+
induc_induc_covar = induc_induc_covar.add_jitter(self.jitter_val)
|
|
330
|
+
if hasattr(self.model, 'power'):
|
|
331
|
+
prior_dist = MultivariateQExponential(induc_mean, induc_induc_covar, power=self.model.power)
|
|
332
|
+
else:
|
|
333
|
+
prior_dist = MultivariateNormal(induc_mean, induc_induc_covar)
|
|
334
|
+
|
|
335
|
+
inducing_values = self._variational_distribution.variational_mean[..., : self.k]
|
|
336
|
+
variational_covar_fisrtk = self._variational_distribution._variational_stddev[..., : self.k] ** 2
|
|
337
|
+
variational_inducing_covar = DiagLinearOperator(variational_covar_fisrtk)
|
|
338
|
+
|
|
339
|
+
if hasattr(self.model, 'power'):
|
|
340
|
+
variational_distribution = MultivariateQExponential(inducing_values, variational_inducing_covar, power=self.model.power)
|
|
341
|
+
else:
|
|
342
|
+
variational_distribution = MultivariateNormal(inducing_values, variational_inducing_covar)
|
|
343
|
+
kl = torch.distributions.kl.kl_divergence(variational_distribution, prior_dist) # model_batch_shape
|
|
344
|
+
return kl
|
|
345
|
+
|
|
346
|
+
def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[Tensor, "..."]: # noqa: F821
|
|
347
|
+
# Compute the KL divergence for a mini batch of the rest M-k inducing points
|
|
348
|
+
# See paper appendix for kl breakdown
|
|
349
|
+
kl_bs = len(kl_indices) # training_batch_size
|
|
350
|
+
variational_mean = self._variational_distribution.variational_mean # (*model_bs, M)
|
|
351
|
+
variational_stddev = self._variational_distribution._variational_stddev
|
|
352
|
+
|
|
353
|
+
# (1) compute logdet_q
|
|
354
|
+
inducing_point_log_variational_covar = (variational_stddev[..., kl_indices] ** 2).log()
|
|
355
|
+
logdet_q = torch.sum(inducing_point_log_variational_covar, dim=-1) # model_bs
|
|
356
|
+
|
|
357
|
+
# (2) compute lodet_p
|
|
358
|
+
# Select a mini-batch of inducing points according to kl_indices
|
|
359
|
+
inducing_points = self.inducing_points[..., kl_indices, :].expand(*self._batch_shape, kl_bs, self.D)
|
|
360
|
+
# (*bs, kl_bs, D)
|
|
361
|
+
# Select their K nearest neighbors
|
|
362
|
+
nearest_neighbor_indices = self.nn_xinduce_idx[..., kl_indices - self.k, :].to(inducing_points.device)
|
|
363
|
+
# (*bs, kl_bs, K)
|
|
364
|
+
expanded_inducing_points_all = self.inducing_points.unsqueeze(-2).expand(
|
|
365
|
+
*self._batch_shape, self.M, self.k, self.D
|
|
366
|
+
)
|
|
367
|
+
expanded_nearest_neighbor_indices = nearest_neighbor_indices.unsqueeze(-1).expand(
|
|
368
|
+
*self._batch_shape, kl_bs, self.k, self.D
|
|
369
|
+
)
|
|
370
|
+
nearest_neighbors = expanded_inducing_points_all.gather(-3, expanded_nearest_neighbor_indices)
|
|
371
|
+
# (*bs, kl_bs, K, D)
|
|
372
|
+
|
|
373
|
+
# Compute prior distribution
|
|
374
|
+
# Move the kl_bs dimension to the first dimension to enable batch covar_module computation
|
|
375
|
+
nearest_neighbors_ = nearest_neighbors.permute((-3,) + tuple(range(len(self._batch_shape))) + (-2, -1))
|
|
376
|
+
# (kl_bs, *bs, K, D)
|
|
377
|
+
inducing_points_ = inducing_points.permute((-2,) + tuple(range(len(self._batch_shape))) + (-1,))
|
|
378
|
+
# (kl_bs, *bs, D)
|
|
379
|
+
full_output = self.model.forward(torch.cat([nearest_neighbors_, inducing_points_.unsqueeze(-2)], dim=-2))
|
|
380
|
+
full_mean, full_covar = full_output.mean, full_output.covariance_matrix
|
|
381
|
+
|
|
382
|
+
# Mean terms
|
|
383
|
+
_undo_permute_dims = tuple(range(1, 1 + len(self._batch_shape))) + (0, -1)
|
|
384
|
+
nearest_neighbors_prior_mean = full_mean[..., : self.k].permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K)
|
|
385
|
+
inducing_prior_mean = full_mean[..., self.k :].permute(_undo_permute_dims).squeeze(-1) # (*inducing_bs, kl_bs)
|
|
386
|
+
# Covar terms
|
|
387
|
+
nearest_neighbors_prior_cov = full_covar[..., : self.k, : self.k]
|
|
388
|
+
nearest_neighbors_inducing_prior_cross_cov = full_covar[..., : self.k, self.k :]
|
|
389
|
+
inducing_prior_cov = full_covar[..., self.k :, self.k :]
|
|
390
|
+
inducing_prior_cov = (
|
|
391
|
+
inducing_prior_cov.squeeze(-1).squeeze(-1).permute((-1,) + tuple(range(len(self._batch_shape))))
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# Interpolation term K_nn^{-1} k_{nu}
|
|
395
|
+
interp_term = torch.linalg.solve(
|
|
396
|
+
nearest_neighbors_prior_cov + self.jitter_val * torch.eye(self.k, device=self.inducing_points.device),
|
|
397
|
+
nearest_neighbors_inducing_prior_cross_cov,
|
|
398
|
+
).squeeze(
|
|
399
|
+
-1
|
|
400
|
+
) # (kl_bs, *inducing_bs, K)
|
|
401
|
+
interp_term = interp_term.permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K)
|
|
402
|
+
nearest_neighbors_inducing_prior_cross_cov = nearest_neighbors_inducing_prior_cross_cov.squeeze(-1).permute(
|
|
403
|
+
_undo_permute_dims
|
|
404
|
+
) # k_{n(j),j}, (*inducing_bs, kl_bs, K)
|
|
405
|
+
|
|
406
|
+
invquad_term_for_F = torch.sum(
|
|
407
|
+
interp_term * nearest_neighbors_inducing_prior_cross_cov, dim=-1
|
|
408
|
+
) # (*inducing_bs, kl_bs)
|
|
409
|
+
|
|
410
|
+
inducing_prior_cov = self.model.covar_module.forward(
|
|
411
|
+
inducing_points, inducing_points, diag=True
|
|
412
|
+
) # (*inducing_bs, kl_bs)
|
|
413
|
+
|
|
414
|
+
F = inducing_prior_cov - invquad_term_for_F
|
|
415
|
+
F = F + self.jitter_val
|
|
416
|
+
# K_uu - k_un K_nn^{-1} k_nu
|
|
417
|
+
logdet_p = F.log().sum(dim=-1) # shape: inducing_bs
|
|
418
|
+
|
|
419
|
+
# (3) compute trace_term
|
|
420
|
+
expanded_variational_stddev = variational_stddev.unsqueeze(-1).expand(*self._batch_shape, self.M, self.k)
|
|
421
|
+
expanded_variational_mean = variational_mean.unsqueeze(-1).expand(*self._batch_shape, self.M, self.k)
|
|
422
|
+
expanded_nearest_neighbor_indices = nearest_neighbor_indices.expand(*self._batch_shape, kl_bs, self.k)
|
|
423
|
+
nearest_neighbor_variational_covar = (
|
|
424
|
+
expanded_variational_stddev.gather(-2, expanded_nearest_neighbor_indices) ** 2
|
|
425
|
+
) # (*batch_shape, kl_bs, k)
|
|
426
|
+
bjsquared_s_nearest_neighbors = torch.sum(
|
|
427
|
+
interp_term**2 * nearest_neighbor_variational_covar, dim=-1
|
|
428
|
+
) # (*batch_shape, kl_bs)
|
|
429
|
+
inducing_point_variational_covar = variational_stddev[..., kl_indices] ** 2 # (model_bs, kl_bs)
|
|
430
|
+
trace_term = (1.0 / F * (bjsquared_s_nearest_neighbors + inducing_point_variational_covar)).sum(
|
|
431
|
+
dim=-1
|
|
432
|
+
) # batch_shape
|
|
433
|
+
|
|
434
|
+
# (4) compute invquad_term
|
|
435
|
+
nearest_neighbors_variational_mean = expanded_variational_mean.gather(-2, expanded_nearest_neighbor_indices)
|
|
436
|
+
Bj_m_nearest_neighbors = torch.sum(
|
|
437
|
+
interp_term * (nearest_neighbors_variational_mean - nearest_neighbors_prior_mean), dim=-1
|
|
438
|
+
)
|
|
439
|
+
inducing_variational_mean = variational_mean[..., kl_indices]
|
|
440
|
+
invquad_term = torch.sum(
|
|
441
|
+
(inducing_variational_mean - inducing_prior_mean - Bj_m_nearest_neighbors) ** 2 / F, dim=-1
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
trace_plus_invquad_form = trace_term + invquad_term
|
|
445
|
+
if hasattr(self.model, 'power'): trace_plus_invquad_form = trace_plus_invquad_form**(self.model.power/2.)
|
|
446
|
+
kl = (logdet_p - logdet_q - kl_bs + trace_plus_invquad_form) * (1.0 / 2)
|
|
447
|
+
if hasattr(self.model, 'power') and self.model.power!=2:
|
|
448
|
+
kl -= kl_bs*(1.0/2-1./self.model.power)*(torch.log(trace_plus_invquad_form)+torch.distributions.Chi2(kl_bs).entropy())
|
|
449
|
+
assert kl.shape == self._batch_shape, kl.shape
|
|
450
|
+
|
|
451
|
+
return kl
|
|
452
|
+
|
|
453
|
+
def _kl_divergence(
|
|
454
|
+
self, kl_indices: Optional[LongTensor] = None, batch_size: Optional[int] = None
|
|
455
|
+
) -> Float[Tensor, "..."]:
|
|
456
|
+
if self.compute_full_kl or (self._total_training_batches == 1):
|
|
457
|
+
if batch_size is None:
|
|
458
|
+
batch_size = self.training_batch_size
|
|
459
|
+
kl = self._firstk_kl_helper()
|
|
460
|
+
for kl_indices in torch.split(torch.arange(self.k, self.M), batch_size):
|
|
461
|
+
kl += self._stochastic_kl_helper(kl_indices)
|
|
462
|
+
else:
|
|
463
|
+
# compute a stochastic estimate
|
|
464
|
+
assert kl_indices is not None
|
|
465
|
+
if self._training_indices_iter == 1:
|
|
466
|
+
assert len(kl_indices) == self.k, (
|
|
467
|
+
f"kl_indices sould be the first batch data of length k, "
|
|
468
|
+
f"but got len(kl_indices) = {len(kl_indices)} and k = {self.k}."
|
|
469
|
+
)
|
|
470
|
+
kl = self._firstk_kl_helper() * self.M / self.k
|
|
471
|
+
else:
|
|
472
|
+
kl = self._stochastic_kl_helper(kl_indices) * self.M / len(kl_indices)
|
|
473
|
+
return kl
|
|
474
|
+
|
|
475
|
+
def kl_divergence(self) -> Float[Tensor, "..."]:
|
|
476
|
+
try:
|
|
477
|
+
return pop_from_cache(self, "kl_divergence_memo")
|
|
478
|
+
except CachingError:
|
|
479
|
+
raise RuntimeError("KL Divergence of variational strategy was called before nearest neighbors were set.")
|
|
480
|
+
|
|
481
|
+
def _compute_nn(self) -> "NNVariationalStrategy":
|
|
482
|
+
with torch.no_grad():
|
|
483
|
+
inducing_points_fl = self.inducing_points.data.float()
|
|
484
|
+
self.nn_util.set_nn_idx(inducing_points_fl)
|
|
485
|
+
self.nn_xinduce_idx = self.nn_util.build_sequential_nn_idx(inducing_points_fl)
|
|
486
|
+
# shape (*_inducing_batch_shape, M-k, k)
|
|
487
|
+
return self
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from linear_operator.operators import LinearOperator
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
from ..distributions import MultivariateNormal, MultivariateQExponential
|
|
10
|
+
from gpytorch.utils.memoize import add_to_cache, cached
|
|
11
|
+
from ._variational_distribution import _VariationalDistribution
|
|
12
|
+
from ._variational_strategy import _VariationalStrategy
|
|
13
|
+
from .delta_variational_distribution import DeltaVariationalDistribution
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OrthogonallyDecoupledVariationalStrategy(_VariationalStrategy):
|
|
17
|
+
r"""
|
|
18
|
+
Implements orthogonally decoupled VGPs as defined in `Salimbeni et al. (2018)`_.
|
|
19
|
+
This variational strategy uses a different set of inducing points for the mean and covariance functions.
|
|
20
|
+
The idea is to use more inducing points for the (computationally efficient) mean and fewer inducing points for the
|
|
21
|
+
(computationally expensive) covaraince.
|
|
22
|
+
|
|
23
|
+
This variational strategy defines the inducing points/:obj:`~qpytorch.variational._VariationalDistribution`
|
|
24
|
+
for the mean function.
|
|
25
|
+
It then wraps a different :obj:`~qpytorch.variational._VariationalStrategy` which
|
|
26
|
+
defines the covariance inducing points.
|
|
27
|
+
|
|
28
|
+
:param covar_variational_strategy:
|
|
29
|
+
The variational strategy for the covariance term.
|
|
30
|
+
:param inducing_points: Tensor containing a set of inducing
|
|
31
|
+
points to use for variational inference.
|
|
32
|
+
:param variational_distribution: A
|
|
33
|
+
VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
|
|
34
|
+
:param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
>>> mean_inducing_points = torch.randn(1000, train_x.size(-1), dtype=train_x.dtype, device=train_x.device)
|
|
38
|
+
>>> covar_inducing_points = torch.randn(100, train_x.size(-1), dtype=train_x.dtype, device=train_x.device)
|
|
39
|
+
>>>
|
|
40
|
+
>>> covar_variational_strategy = qpytorch.variational.VariationalStrategy(
|
|
41
|
+
>>> model, covar_inducing_points,
|
|
42
|
+
>>> qpytorch.variational.CholeskyVariationalDistribution(covar_inducing_points.size(-2)),
|
|
43
|
+
>>> learn_inducing_locations=True
|
|
44
|
+
>>> )
|
|
45
|
+
>>>
|
|
46
|
+
>>> variational_strategy = qpytorch.variational.OrthogonallyDecoupledVariationalStrategy(
|
|
47
|
+
>>> covar_variational_strategy, mean_inducing_points,
|
|
48
|
+
>>> qpytorch.variational.DeltaVariationalDistribution(mean_inducing_points.size(-2)),
|
|
49
|
+
>>> )
|
|
50
|
+
|
|
51
|
+
.. _Salimbeni et al. (2018):
|
|
52
|
+
https://arxiv.org/abs/1809.08820
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
covar_variational_strategy: _VariationalStrategy,
|
|
58
|
+
inducing_points: Tensor,
|
|
59
|
+
variational_distribution: _VariationalDistribution,
|
|
60
|
+
jitter_val: Optional[float] = None,
|
|
61
|
+
):
|
|
62
|
+
if not isinstance(variational_distribution, DeltaVariationalDistribution):
|
|
63
|
+
raise NotImplementedError(
|
|
64
|
+
"OrthogonallyDecoupledVariationalStrategy currently works with DeltaVariationalDistribution"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
super().__init__(
|
|
68
|
+
covar_variational_strategy,
|
|
69
|
+
inducing_points,
|
|
70
|
+
variational_distribution,
|
|
71
|
+
learn_inducing_locations=True,
|
|
72
|
+
jitter_val=jitter_val,
|
|
73
|
+
)
|
|
74
|
+
self.base_variational_strategy = covar_variational_strategy
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
@cached(name="prior_distribution_memo")
|
|
78
|
+
def prior_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
79
|
+
out = self.model(self.inducing_points)
|
|
80
|
+
if isinstance(out, MultivariateNormal):
|
|
81
|
+
res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter(self.jitter_val))
|
|
82
|
+
elif isinstance(out, MultivariateQExponential):
|
|
83
|
+
res = MultivariateQExponential(out.mean, out.lazy_covariance_matrix.add_jitter(self.jitter_val), power=out.power)
|
|
84
|
+
return res
|
|
85
|
+
|
|
86
|
+
def forward(
|
|
87
|
+
self,
|
|
88
|
+
x: Tensor,
|
|
89
|
+
inducing_points: Tensor,
|
|
90
|
+
inducing_values: Tensor,
|
|
91
|
+
variational_inducing_covar: Optional[LinearOperator] = None,
|
|
92
|
+
**kwargs,
|
|
93
|
+
) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
94
|
+
if variational_inducing_covar is not None:
|
|
95
|
+
raise NotImplementedError(
|
|
96
|
+
"OrthogonallyDecoupledVariationalStrategy currently works with DeltaVariationalDistribution"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
num_data = x.size(-2)
|
|
100
|
+
full_output = self.model(torch.cat([x, inducing_points], dim=-2), **kwargs)
|
|
101
|
+
full_mean = full_output.mean
|
|
102
|
+
full_covar = full_output.lazy_covariance_matrix
|
|
103
|
+
|
|
104
|
+
if self.training:
|
|
105
|
+
induc_mean = full_mean[..., num_data:]
|
|
106
|
+
induc_induc_covar = full_covar[..., num_data:, num_data:]
|
|
107
|
+
if isinstance(full_output, MultivariateNormal):
|
|
108
|
+
prior_dist = MultivariateNormal(induc_mean, induc_induc_covar)
|
|
109
|
+
if isinstance(full_output, MultivariateQExponential):
|
|
110
|
+
prior_dist = MultivariateQExponential(induc_mean, induc_induc_covar, power=full_output.power)
|
|
111
|
+
add_to_cache(self, "prior_distribution_memo", prior_dist)
|
|
112
|
+
|
|
113
|
+
test_mean = full_mean[..., :num_data]
|
|
114
|
+
data_induc_covar = full_covar[..., :num_data, num_data:]
|
|
115
|
+
predictive_mean = (data_induc_covar @ inducing_values.unsqueeze(-1)).squeeze(-1).add(test_mean)
|
|
116
|
+
predictive_covar = full_covar[..., :num_data, :num_data]
|
|
117
|
+
|
|
118
|
+
# Return the distribution
|
|
119
|
+
if isinstance(full_output, MultivariateNormal):
|
|
120
|
+
return MultivariateNormal(predictive_mean, predictive_covar)
|
|
121
|
+
elif isinstance(full_output, MultivariateQExponential):
|
|
122
|
+
return MultivariateQExponential(predictive_mean, predictive_covar, power=full_output.power)
|
|
123
|
+
|
|
124
|
+
def kl_divergence(self) -> Tensor:
|
|
125
|
+
mean = self.variational_distribution.mean
|
|
126
|
+
induc_induc_covar = self.prior_distribution.lazy_covariance_matrix
|
|
127
|
+
kl = self.model.kl_divergence() + ((induc_induc_covar @ mean.unsqueeze(-1)).squeeze(-1) * mean).sum(-1).mul(0.5)
|
|
128
|
+
return kl
|