qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,880 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import string
|
|
5
|
+
import warnings
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from linear_operator import to_dense, to_linear_operator
|
|
9
|
+
from linear_operator.operators import (
|
|
10
|
+
AddedDiagLinearOperator,
|
|
11
|
+
BatchRepeatLinearOperator,
|
|
12
|
+
ConstantMulLinearOperator,
|
|
13
|
+
InterpolatedLinearOperator,
|
|
14
|
+
LinearOperator,
|
|
15
|
+
LowRankRootAddedDiagLinearOperator,
|
|
16
|
+
MaskedLinearOperator,
|
|
17
|
+
MatmulLinearOperator,
|
|
18
|
+
RootLinearOperator,
|
|
19
|
+
ZeroLinearOperator,
|
|
20
|
+
)
|
|
21
|
+
from linear_operator.utils.cholesky import psd_safe_cholesky
|
|
22
|
+
from linear_operator.utils.interpolation import left_interp, left_t_interp
|
|
23
|
+
from torch import Tensor
|
|
24
|
+
|
|
25
|
+
from .. import settings
|
|
26
|
+
|
|
27
|
+
from ..distributions import MultitaskMultivariateNormal, MultitaskMultivariateQExponential
|
|
28
|
+
from ..lazy import LazyEvaluatedKernelTensor
|
|
29
|
+
from gpytorch.utils.memoize import add_to_cache, cached, clear_cache_hook, pop_from_cache
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def prediction_strategy(train_inputs, train_prior_dist, train_labels, likelihood):
|
|
33
|
+
train_train_covar = train_prior_dist.lazy_covariance_matrix
|
|
34
|
+
if isinstance(train_train_covar, LazyEvaluatedKernelTensor):
|
|
35
|
+
cls = train_train_covar.kernel.prediction_strategy
|
|
36
|
+
else:
|
|
37
|
+
cls = DefaultPredictionStrategy
|
|
38
|
+
return cls(train_inputs, train_prior_dist, train_labels, likelihood)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DefaultPredictionStrategy(object):
|
|
42
|
+
def __init__(self, train_inputs, train_prior_dist, train_labels, likelihood, root=None, inv_root=None):
|
|
43
|
+
# Get training shape
|
|
44
|
+
self._train_shape = train_prior_dist.event_shape
|
|
45
|
+
|
|
46
|
+
# Flatten the training labels
|
|
47
|
+
try:
|
|
48
|
+
train_labels = train_labels.reshape(
|
|
49
|
+
*train_labels.shape[: -len(self.train_shape)], self._train_shape.numel()
|
|
50
|
+
)
|
|
51
|
+
except RuntimeError:
|
|
52
|
+
raise RuntimeError(
|
|
53
|
+
"Flattening the training labels failed. The most common cause of this error is "
|
|
54
|
+
+ "that the shapes of the prior mean and the training labels are mismatched. "
|
|
55
|
+
+ "The shape of the train targets is {0}, ".format(train_labels.shape)
|
|
56
|
+
+ "while the reported shape of the mean is {0}.".format(train_prior_dist.mean.shape)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
self.train_inputs = train_inputs
|
|
60
|
+
self.train_prior_dist = train_prior_dist
|
|
61
|
+
self.train_labels = train_labels
|
|
62
|
+
self.likelihood = likelihood
|
|
63
|
+
self._last_test_train_covar = None
|
|
64
|
+
lik = self.likelihood(train_prior_dist, train_inputs)
|
|
65
|
+
self.lik_train_train_covar = lik.lazy_covariance_matrix
|
|
66
|
+
|
|
67
|
+
if root is not None:
|
|
68
|
+
add_to_cache(self.lik_train_train_covar, "root_decomposition", RootLinearOperator(root))
|
|
69
|
+
|
|
70
|
+
if inv_root is not None:
|
|
71
|
+
add_to_cache(self.lik_train_train_covar, "root_inv_decomposition", RootLinearOperator(inv_root))
|
|
72
|
+
|
|
73
|
+
def __deepcopy__(self, memo):
|
|
74
|
+
# deepcopying prediction strategies of a model evaluated on inputs that require gradients fails
|
|
75
|
+
# with RuntimeError (Only Tensors created explicitly by the user (graph leaves) support the deepcopy
|
|
76
|
+
# protocol at the moment). Overwriting this method make sure that the prediction strategies of a
|
|
77
|
+
# model are set to None upon deepcopying.
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
def _exact_predictive_covar_inv_quad_form_cache(self, train_train_covar_inv_root, test_train_covar):
|
|
81
|
+
"""
|
|
82
|
+
Computes a cache for K_X*X (K_XX + sigma^2 I)^-1 K_X*X if possible. By default, this does no work and returns
|
|
83
|
+
the first argument.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
train_train_covar_inv_root (:obj:`torch.tensor`): a root of (K_XX + sigma^2 I)^-1
|
|
87
|
+
test_train_covar (:obj:`torch.tensor`): the observed noise (from the likelihood)
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
A precomputed cache
|
|
91
|
+
"""
|
|
92
|
+
res = train_train_covar_inv_root
|
|
93
|
+
if settings.detach_test_caches.on():
|
|
94
|
+
res = res.detach()
|
|
95
|
+
|
|
96
|
+
if res.grad_fn is not None:
|
|
97
|
+
wrapper = functools.partial(clear_cache_hook, self)
|
|
98
|
+
functools.update_wrapper(wrapper, clear_cache_hook)
|
|
99
|
+
res.grad_fn.register_hook(wrapper)
|
|
100
|
+
|
|
101
|
+
return res
|
|
102
|
+
|
|
103
|
+
def _exact_predictive_covar_inv_quad_form_root(self, precomputed_cache, test_train_covar):
|
|
104
|
+
r"""
|
|
105
|
+
Computes :math:`K_{X^{*}X} S` given a precomputed cache
|
|
106
|
+
Where :math:`S` is a tensor such that :math:`SS^{\top} = (K_{XX} + \sigma^2 I)^{-1}`
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
precomputed_cache (:obj:`torch.tensor`): What was computed in _exact_predictive_covar_inv_quad_form_cache
|
|
110
|
+
test_train_covar (:obj:`torch.tensor`): The observed noise (from the likelihood)
|
|
111
|
+
|
|
112
|
+
Returns
|
|
113
|
+
:obj:`~linear_operator.operators.LinearOperator`: :math:`K_{X^{*}X} S`
|
|
114
|
+
"""
|
|
115
|
+
# Here the precomputed cache represents S,
|
|
116
|
+
# where S S^T = (K_XX + sigma^2 I)^-1
|
|
117
|
+
return test_train_covar.matmul(precomputed_cache)
|
|
118
|
+
|
|
119
|
+
def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_output, **kwargs):
|
|
120
|
+
"""
|
|
121
|
+
Returns a new PredictionStrategy that incorporates the specified inputs and targets as new training data.
|
|
122
|
+
|
|
123
|
+
This method is primary responsible for updating the mean and covariance caches. To add fantasy data to a
|
|
124
|
+
GP (QEP) model, use the :meth:`~gpytorch.models.ExactGP.get_fantasy_model` (:meth:`~qpytorch.models.ExactQEP.get_fantasy_model`) method.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
inputs (Tensor `b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`): Locations of fantasy
|
|
128
|
+
observations.
|
|
129
|
+
targets (Tensor `b1 x ... x bk x m` or `f x b1 x ... x bk x m`): Labels of fantasy observations.
|
|
130
|
+
full_inputs (Tensor `b1 x ... x bk x n+m x d` or `f x b1 x ... x bk x n+m x d`): Training data
|
|
131
|
+
concatenated with fantasy inputs
|
|
132
|
+
full_targets (Tensor `b1 x ... x bk x n+m` or `f x b1 x ... x bk x n+m`): Training labels
|
|
133
|
+
concatenated with fantasy labels.
|
|
134
|
+
full_output (:class:`gpytorch.distributions.MultivariateNormal` or :class:`gpytorch.distributions.MultivariateQExponential`): Prior called on full_inputs
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
A `DefaultPredictionStrategy` model with `n + m` training examples, where the `m` fantasy examples have
|
|
138
|
+
been added and all test-time caches have been updated.
|
|
139
|
+
"""
|
|
140
|
+
if not isinstance(full_output, (MultitaskMultivariateNormal, MultitaskMultivariateQExponential)):
|
|
141
|
+
target_batch_shape = targets.shape[:-1]
|
|
142
|
+
else:
|
|
143
|
+
target_batch_shape = targets.shape[:-2]
|
|
144
|
+
|
|
145
|
+
full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix
|
|
146
|
+
|
|
147
|
+
batch_shape = full_inputs[0].shape[:-2]
|
|
148
|
+
|
|
149
|
+
num_train = self.num_train
|
|
150
|
+
|
|
151
|
+
if isinstance(full_output, (MultitaskMultivariateNormal, MultitaskMultivariateQExponential)):
|
|
152
|
+
num_tasks = full_output.event_shape[-1]
|
|
153
|
+
full_mean = full_mean.view(*batch_shape, -1, num_tasks)
|
|
154
|
+
fant_mean = full_mean[..., (num_train // num_tasks) :, :]
|
|
155
|
+
full_targets = full_targets.view(*target_batch_shape, -1)
|
|
156
|
+
else:
|
|
157
|
+
full_mean = full_mean.view(*batch_shape, -1)
|
|
158
|
+
fant_mean = full_mean[..., num_train:]
|
|
159
|
+
|
|
160
|
+
# Evaluate fant x train and fant x fant covariance matrices, leave train x train unevaluated.
|
|
161
|
+
fant_fant_covar = full_covar[..., num_train:, num_train:]
|
|
162
|
+
dist = self.train_prior_dist.__class__(fant_mean, fant_fant_covar)
|
|
163
|
+
if hasattr(self.train_prior_dist, 'power'): dist.power = self.train_prior_dist.power
|
|
164
|
+
fant_likelihood = self.likelihood.get_fantasy_likelihood(**kwargs)
|
|
165
|
+
fant_obs = fant_likelihood(dist, inputs, **kwargs)
|
|
166
|
+
|
|
167
|
+
fant_fant_covar = fant_obs.covariance_matrix
|
|
168
|
+
fant_train_covar = to_dense(full_covar[..., num_train:, :num_train])
|
|
169
|
+
|
|
170
|
+
self.fantasy_inputs = inputs
|
|
171
|
+
self.fantasy_targets = targets
|
|
172
|
+
|
|
173
|
+
r"""
|
|
174
|
+
Compute a new mean cache given the old mean cache.
|
|
175
|
+
|
|
176
|
+
We have \alpha = K^{-1}y, and we want to solve [K U; U' S][a; b] = [y; y_f], where U' is fant_train_covar,
|
|
177
|
+
S is fant_fant_covar, and y_f is (targets - fant_mean)
|
|
178
|
+
|
|
179
|
+
To do this, we solve the bordered linear system of equations for [a; b]:
|
|
180
|
+
AQ = U # Q = fant_solve
|
|
181
|
+
[S - U'Q]b = y_f - U'\alpha ==> b = [S - U'Q]^{-1}(y_f - U'\alpha)
|
|
182
|
+
a = \alpha - Qb
|
|
183
|
+
"""
|
|
184
|
+
# Get cached K inverse decomp. (or compute if we somehow don't already have the covariance cache)
|
|
185
|
+
K_inverse = self.lik_train_train_covar.root_inv_decomposition()
|
|
186
|
+
fant_solve = K_inverse.matmul(fant_train_covar.transpose(-2, -1))
|
|
187
|
+
|
|
188
|
+
# Solve for "b", the lower portion of the *new* \\alpha corresponding to the fantasy points.
|
|
189
|
+
schur_complement = fant_fant_covar - fant_train_covar.matmul(fant_solve)
|
|
190
|
+
|
|
191
|
+
# we'd like to use a less hacky approach for the following, but einsum can be much faster than
|
|
192
|
+
# than unsqueezing/squeezing here (esp. in backward passes), unfortunately it currenlty has some
|
|
193
|
+
# issues with broadcasting: https://github.com/pytorch/pytorch/issues/15671
|
|
194
|
+
prefix = string.ascii_lowercase[: max(fant_train_covar.dim() - self.mean_cache.dim() - 1, 0)]
|
|
195
|
+
ftcm = torch.einsum(prefix + "...yz,...z->" + prefix + "...y", [fant_train_covar, self.mean_cache])
|
|
196
|
+
|
|
197
|
+
small_system_rhs = targets - fant_mean - ftcm
|
|
198
|
+
small_system_rhs = small_system_rhs.unsqueeze(-1)
|
|
199
|
+
# Schur complement of a spd matrix is guaranteed to be positive definite
|
|
200
|
+
schur_cholesky = psd_safe_cholesky(schur_complement)
|
|
201
|
+
fant_cache_lower = torch.cholesky_solve(small_system_rhs, schur_cholesky)
|
|
202
|
+
|
|
203
|
+
# Get "a", the new upper portion of the cache corresponding to the old training points.
|
|
204
|
+
fant_cache_upper = self.mean_cache.unsqueeze(-1) - fant_solve.matmul(fant_cache_lower)
|
|
205
|
+
|
|
206
|
+
fant_cache_upper = fant_cache_upper.squeeze(-1)
|
|
207
|
+
fant_cache_lower = fant_cache_lower.squeeze(-1)
|
|
208
|
+
|
|
209
|
+
# New mean cache.
|
|
210
|
+
fant_mean_cache = torch.cat((fant_cache_upper, fant_cache_lower), dim=-1)
|
|
211
|
+
|
|
212
|
+
# now update the root and root inverse
|
|
213
|
+
new_lt = self.lik_train_train_covar.cat_rows(fant_train_covar, fant_fant_covar)
|
|
214
|
+
new_root = new_lt.root_decomposition().root
|
|
215
|
+
if settings.detach_test_caches.on():
|
|
216
|
+
new_covar_cache = new_lt.root_inv_decomposition().root.detach()
|
|
217
|
+
else:
|
|
218
|
+
new_covar_cache = new_lt.root_inv_decomposition().root
|
|
219
|
+
|
|
220
|
+
# Expand inputs accordingly if necessary (for fantasies at the same points)
|
|
221
|
+
if full_inputs[0].dim() <= full_targets.dim():
|
|
222
|
+
fant_batch_shape = full_targets.shape[:1]
|
|
223
|
+
n_batch = len(full_mean.shape[:-1])
|
|
224
|
+
repeat_shape = fant_batch_shape + torch.Size([1] * n_batch)
|
|
225
|
+
full_inputs = [fi.expand(fant_batch_shape + fi.shape) for fi in full_inputs]
|
|
226
|
+
full_mean = full_mean.expand(fant_batch_shape + full_mean.shape)
|
|
227
|
+
full_covar = BatchRepeatLinearOperator(full_covar, repeat_shape)
|
|
228
|
+
new_root = BatchRepeatLinearOperator(new_root, repeat_shape)
|
|
229
|
+
# no need to repeat the covar cache, broadcasting will do the right thing
|
|
230
|
+
|
|
231
|
+
if isinstance(full_output, (MultitaskMultivariateNormal, MultitaskMultivariateQExponential)):
|
|
232
|
+
full_mean = full_mean.view(*target_batch_shape, -1, num_tasks).contiguous()
|
|
233
|
+
|
|
234
|
+
# Create new DefaultPredictionStrategy object
|
|
235
|
+
fant_strat = self.__class__(
|
|
236
|
+
train_inputs=full_inputs,
|
|
237
|
+
train_prior_dist=self.train_prior_dist.__class__(full_mean, full_covar) if not hasattr(self.train_prior_dist, 'power') else
|
|
238
|
+
self.train_prior_dist.__class__(full_mean, full_covar, self.train_prior_dist.power),
|
|
239
|
+
train_labels=full_targets,
|
|
240
|
+
likelihood=fant_likelihood,
|
|
241
|
+
root=new_root,
|
|
242
|
+
inv_root=new_covar_cache,
|
|
243
|
+
)
|
|
244
|
+
add_to_cache(fant_strat, "mean_cache", fant_mean_cache)
|
|
245
|
+
add_to_cache(fant_strat, "covar_cache", new_covar_cache.to_dense())
|
|
246
|
+
return fant_strat
|
|
247
|
+
|
|
248
|
+
@property
|
|
249
|
+
@cached(name="covar_cache")
|
|
250
|
+
def covar_cache(self):
|
|
251
|
+
train_train_covar = self.lik_train_train_covar
|
|
252
|
+
train_train_covar_inv_root = to_dense(train_train_covar.root_inv_decomposition().root)
|
|
253
|
+
return self._exact_predictive_covar_inv_quad_form_cache(train_train_covar_inv_root, self._last_test_train_covar)
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def mean_cache(self):
|
|
257
|
+
return self._mean_cache(settings.observation_nan_policy.value())
|
|
258
|
+
|
|
259
|
+
@cached(name="mean_cache")
|
|
260
|
+
def _mean_cache(self, nan_policy: str) -> Tensor:
|
|
261
|
+
lik = self.likelihood(self.train_prior_dist, self.train_inputs)
|
|
262
|
+
train_mean, train_train_covar = lik.loc, lik.lazy_covariance_matrix
|
|
263
|
+
|
|
264
|
+
train_labels_offset = (self.train_labels - train_mean).unsqueeze(-1)
|
|
265
|
+
|
|
266
|
+
if nan_policy == "ignore":
|
|
267
|
+
mean_cache = train_train_covar.evaluate_kernel().solve(train_labels_offset).squeeze(-1)
|
|
268
|
+
elif nan_policy == "mask":
|
|
269
|
+
# Mask all rows and columns in the kernel matrix corresponding to the missing observations.
|
|
270
|
+
observed = settings.observation_nan_policy._get_observed(
|
|
271
|
+
self.train_labels, torch.Size((self.train_labels.shape[-1],))
|
|
272
|
+
)
|
|
273
|
+
mean_cache = torch.full_like(self.train_labels, torch.nan)
|
|
274
|
+
kernel = MaskedLinearOperator(
|
|
275
|
+
train_train_covar.evaluate_kernel(), observed.reshape(-1), observed.reshape(-1)
|
|
276
|
+
)
|
|
277
|
+
mean_cache[..., observed] = kernel.solve(train_labels_offset[..., observed, :]).squeeze(-1)
|
|
278
|
+
else: # 'fill'
|
|
279
|
+
# Fill all rows and columns in the kernel matrix corresponding to the missing observations with 0.
|
|
280
|
+
# Don't touch the corresponding diagonal elements to ensure a unique solution.
|
|
281
|
+
# This ensures that missing data is ignored during solving.
|
|
282
|
+
warnings.warn(
|
|
283
|
+
"Observation NaN policy 'fill' makes the kernel matrix dense during exact prediction.",
|
|
284
|
+
RuntimeWarning,
|
|
285
|
+
)
|
|
286
|
+
kernel = train_train_covar.evaluate_kernel()
|
|
287
|
+
missing = torch.isnan(self.train_labels)
|
|
288
|
+
kernel_mask = (~missing).to(torch.float)
|
|
289
|
+
kernel_mask = kernel_mask[..., None] * kernel_mask[..., None, :]
|
|
290
|
+
torch.diagonal(kernel_mask, dim1=-2, dim2=-1)[...] = 1
|
|
291
|
+
kernel = kernel * kernel_mask # Unfortunately, this makes the kernel dense at the moment.
|
|
292
|
+
train_labels_offset = settings.observation_nan_policy._fill_tensor(train_labels_offset)
|
|
293
|
+
mean_cache = kernel.solve(train_labels_offset).squeeze(-1)
|
|
294
|
+
mean_cache[missing] = torch.nan # Ensure that nobody expects these values to be valid.
|
|
295
|
+
if settings.detach_test_caches.on():
|
|
296
|
+
mean_cache = mean_cache.detach()
|
|
297
|
+
|
|
298
|
+
if mean_cache.grad_fn is not None:
|
|
299
|
+
wrapper = functools.partial(clear_cache_hook, self)
|
|
300
|
+
functools.update_wrapper(wrapper, clear_cache_hook)
|
|
301
|
+
mean_cache.grad_fn.register_hook(wrapper)
|
|
302
|
+
|
|
303
|
+
return mean_cache
|
|
304
|
+
|
|
305
|
+
@property
|
|
306
|
+
def num_train(self):
|
|
307
|
+
return self._train_shape.numel()
|
|
308
|
+
|
|
309
|
+
@property
|
|
310
|
+
def train_shape(self):
|
|
311
|
+
return self._train_shape
|
|
312
|
+
|
|
313
|
+
def exact_prediction(self, joint_mean, joint_covar):
|
|
314
|
+
# Find the components of the distribution that contain test data
|
|
315
|
+
test_mean = joint_mean[..., self.num_train :]
|
|
316
|
+
# For efficiency - we can make things more efficient
|
|
317
|
+
if joint_covar.size(-1) <= settings.max_eager_kernel_size.value():
|
|
318
|
+
test_covar = joint_covar[..., self.num_train :, :].to_dense()
|
|
319
|
+
test_test_covar = test_covar[..., self.num_train :]
|
|
320
|
+
test_train_covar = test_covar[..., : self.num_train]
|
|
321
|
+
else:
|
|
322
|
+
test_test_covar = joint_covar[..., self.num_train :, self.num_train :]
|
|
323
|
+
test_train_covar = joint_covar[..., self.num_train :, : self.num_train]
|
|
324
|
+
|
|
325
|
+
return (
|
|
326
|
+
self.exact_predictive_mean(test_mean, test_train_covar),
|
|
327
|
+
self.exact_predictive_covar(test_test_covar, test_train_covar),
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOperator) -> Tensor:
|
|
331
|
+
"""
|
|
332
|
+
Computes the posterior predictive covariance of a GP (QEP)
|
|
333
|
+
|
|
334
|
+
:param Tensor test_mean: The test prior mean
|
|
335
|
+
:param ~linear_operator.operators.LinearOperator test_train_covar:
|
|
336
|
+
Covariance matrix between test and train inputs
|
|
337
|
+
:return: The predictive posterior mean of the test points
|
|
338
|
+
"""
|
|
339
|
+
# NOTE TO FUTURE SELF:
|
|
340
|
+
# You **cannot* use addmv here, because test_train_covar may not actually be a non lazy tensor even for an exact
|
|
341
|
+
# GP, and using addmv requires you to to_dense test_train_covar, which is obviously a huge no-no!
|
|
342
|
+
|
|
343
|
+
# see https://github.com/cornellius-gp/gpytorch/pull/2317#discussion_r1157994719
|
|
344
|
+
mean_cache = self.mean_cache
|
|
345
|
+
if len(mean_cache.shape) == 4:
|
|
346
|
+
mean_cache = mean_cache.squeeze(1)
|
|
347
|
+
|
|
348
|
+
# Handle NaNs
|
|
349
|
+
nan_policy = settings.observation_nan_policy.value()
|
|
350
|
+
if nan_policy == "ignore":
|
|
351
|
+
res = (test_train_covar @ mean_cache.unsqueeze(-1)).squeeze(-1)
|
|
352
|
+
elif nan_policy == "mask":
|
|
353
|
+
# Restrict train dimension to observed values
|
|
354
|
+
observed = settings.observation_nan_policy._get_observed(mean_cache, torch.Size((mean_cache.shape[-1],)))
|
|
355
|
+
full_mask = torch.ones(test_mean.shape[-1], dtype=torch.bool, device=test_mean.device)
|
|
356
|
+
test_train_covar = MaskedLinearOperator(
|
|
357
|
+
to_linear_operator(test_train_covar), full_mask, observed.reshape(-1)
|
|
358
|
+
)
|
|
359
|
+
res = (test_train_covar @ mean_cache[..., observed].unsqueeze(-1)).squeeze(-1)
|
|
360
|
+
else: # 'fill'
|
|
361
|
+
# Set the columns corresponding to missing observations to 0 to ignore them during matmul.
|
|
362
|
+
mask = (~torch.isnan(mean_cache)).to(torch.float)[..., None, :]
|
|
363
|
+
test_train_covar = test_train_covar * mask
|
|
364
|
+
mean = settings.observation_nan_policy._fill_tensor(mean_cache)
|
|
365
|
+
res = (test_train_covar @ mean.unsqueeze(-1)).squeeze(-1)
|
|
366
|
+
res = res + test_mean
|
|
367
|
+
|
|
368
|
+
return res
|
|
369
|
+
|
|
370
|
+
def exact_predictive_covar(
|
|
371
|
+
self, test_test_covar: LinearOperator, test_train_covar: LinearOperator
|
|
372
|
+
) -> LinearOperator:
|
|
373
|
+
"""
|
|
374
|
+
Computes the posterior predictive covariance of a GP (QEP)
|
|
375
|
+
|
|
376
|
+
:param ~linear_operator.operators.LinearOperator test_train_covar:
|
|
377
|
+
Covariance matrix between test and train inputs
|
|
378
|
+
:param ~linear_operator.operators.LinearOperator test_test_covar: Covariance matrix between test inputs
|
|
379
|
+
:return: A LinearOperator representing the predictive posterior covariance of the test points
|
|
380
|
+
"""
|
|
381
|
+
if settings.fast_pred_var.on():
|
|
382
|
+
self._last_test_train_covar = test_train_covar
|
|
383
|
+
|
|
384
|
+
if settings.skip_posterior_variances.on():
|
|
385
|
+
return ZeroLinearOperator(*test_test_covar.size())
|
|
386
|
+
|
|
387
|
+
if settings.fast_pred_var.off():
|
|
388
|
+
dist = self.train_prior_dist.__class__(
|
|
389
|
+
torch.zeros_like(self.train_prior_dist.mean), self.train_prior_dist.lazy_covariance_matrix
|
|
390
|
+
)
|
|
391
|
+
if hasattr(self.train_prior_dist, 'power'): dist.power = self.train_prior_dist.power
|
|
392
|
+
if settings.detach_test_caches.on():
|
|
393
|
+
train_train_covar = self.likelihood(dist, self.train_inputs).lazy_covariance_matrix.detach()
|
|
394
|
+
else:
|
|
395
|
+
train_train_covar = self.likelihood(dist, self.train_inputs).lazy_covariance_matrix
|
|
396
|
+
|
|
397
|
+
test_train_covar = to_dense(test_train_covar)
|
|
398
|
+
train_test_covar = test_train_covar.transpose(-1, -2)
|
|
399
|
+
covar_correction_rhs = train_train_covar.solve(train_test_covar)
|
|
400
|
+
# For efficiency
|
|
401
|
+
if torch.is_tensor(test_test_covar):
|
|
402
|
+
# We can use addmm in the 2d case
|
|
403
|
+
if test_test_covar.dim() == 2:
|
|
404
|
+
return to_linear_operator(
|
|
405
|
+
torch.addmm(test_test_covar, test_train_covar, covar_correction_rhs, beta=1, alpha=-1)
|
|
406
|
+
)
|
|
407
|
+
else:
|
|
408
|
+
return to_linear_operator(test_test_covar + test_train_covar @ covar_correction_rhs.mul(-1))
|
|
409
|
+
# In other cases - we'll use the standard infrastructure
|
|
410
|
+
else:
|
|
411
|
+
return test_test_covar + MatmulLinearOperator(test_train_covar, covar_correction_rhs.mul(-1))
|
|
412
|
+
|
|
413
|
+
precomputed_cache = self.covar_cache
|
|
414
|
+
covar_inv_quad_form_root = self._exact_predictive_covar_inv_quad_form_root(precomputed_cache, test_train_covar)
|
|
415
|
+
if torch.is_tensor(test_test_covar):
|
|
416
|
+
return to_linear_operator(
|
|
417
|
+
torch.add(
|
|
418
|
+
test_test_covar, covar_inv_quad_form_root @ covar_inv_quad_form_root.transpose(-1, -2), alpha=-1
|
|
419
|
+
)
|
|
420
|
+
)
|
|
421
|
+
else:
|
|
422
|
+
return test_test_covar + MatmulLinearOperator(
|
|
423
|
+
covar_inv_quad_form_root, covar_inv_quad_form_root.transpose(-1, -2).mul(-1)
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
class InterpolatedPredictionStrategy(DefaultPredictionStrategy):
|
|
428
|
+
def __init__(self, train_inputs, train_prior_dist, train_labels, likelihood, uses_wiski=False):
|
|
429
|
+
args = (train_prior_dist.mean, train_prior_dist.lazy_covariance_matrix.evaluate_kernel())
|
|
430
|
+
if hasattr(train_prior_dist, 'power'): args = args+(train_prior_dist.power,)
|
|
431
|
+
train_prior_dist = train_prior_dist.__class__(*args)
|
|
432
|
+
super().__init__(train_inputs, train_prior_dist, train_labels, likelihood)
|
|
433
|
+
# covar = self.train_prior_dist.lazy_covariance_matrix.evaluate_kernel()
|
|
434
|
+
# if isinstance(covar, LazyEvaluatedKernelTensor):
|
|
435
|
+
# covar = covar.evaluate_kernel()
|
|
436
|
+
# self.train_prior_dist = self.train_prior_dist.__class__(
|
|
437
|
+
# self.train_prior_dist.mean, covar
|
|
438
|
+
# )
|
|
439
|
+
self.uses_wiski = uses_wiski
|
|
440
|
+
|
|
441
|
+
def _exact_predictive_covar_inv_quad_form_cache(self, train_train_covar_inv_root, test_train_covar):
|
|
442
|
+
train_interp_indices = test_train_covar.right_interp_indices
|
|
443
|
+
train_interp_values = test_train_covar.right_interp_values
|
|
444
|
+
base_linear_op = test_train_covar.base_linear_op
|
|
445
|
+
base_size = base_linear_op.size(-1)
|
|
446
|
+
res = base_linear_op.matmul(
|
|
447
|
+
left_t_interp(train_interp_indices, train_interp_values, train_train_covar_inv_root, base_size)
|
|
448
|
+
)
|
|
449
|
+
return res
|
|
450
|
+
|
|
451
|
+
def _exact_predictive_covar_inv_quad_form_root(self, precomputed_cache, test_train_covar):
|
|
452
|
+
# Here the precomputed cache represents K_UU W S,
|
|
453
|
+
# where S S^T = (K_XX + sigma^2 I)^-1
|
|
454
|
+
test_interp_indices = test_train_covar.left_interp_indices
|
|
455
|
+
test_interp_values = test_train_covar.left_interp_values
|
|
456
|
+
res = left_interp(test_interp_indices, test_interp_values, precomputed_cache)
|
|
457
|
+
return res
|
|
458
|
+
|
|
459
|
+
def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_output, **kwargs):
|
|
460
|
+
r"""
|
|
461
|
+
Implements the fantasy strategy described in https://arxiv.org/abs/2103.01454.
|
|
462
|
+
"""
|
|
463
|
+
full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix
|
|
464
|
+
|
|
465
|
+
batch_shape = full_inputs[0].shape[:-2]
|
|
466
|
+
|
|
467
|
+
full_mean = full_mean.view(*batch_shape, -1)
|
|
468
|
+
num_train = self.num_train
|
|
469
|
+
|
|
470
|
+
# Evaluate fant x train and fant x fant covariance matrices, leave train x train unevaluated.
|
|
471
|
+
fant_fant_covar = full_covar[..., num_train:, num_train:].evaluate_kernel()
|
|
472
|
+
fant_mean = full_mean[..., num_train:]
|
|
473
|
+
|
|
474
|
+
fant_wmat = self.prepare_dense_wmat(fant_fant_covar)
|
|
475
|
+
|
|
476
|
+
fant_likelihood = self.likelihood.get_fantasy_likelihood(**kwargs)
|
|
477
|
+
fant_noise = fant_likelihood.noise_covar(fant_wmat.transpose(-1, -2) if len(fant_wmat.shape) > 2 else fant_wmat)
|
|
478
|
+
fant_root_vector = fant_noise.sqrt_inv_matmul(fant_wmat.transpose(-1, -2)).transpose(-1, -2)
|
|
479
|
+
|
|
480
|
+
new_wmat = self.interp_inner_prod.add_low_rank(fant_root_vector.to_dense())
|
|
481
|
+
mean_diff = (targets - fant_mean).unsqueeze(-1)
|
|
482
|
+
new_interp_response_cache = self.interp_response_cache + fant_wmat.matmul(fant_noise.solve(mean_diff))
|
|
483
|
+
|
|
484
|
+
# Create new DefaultPredictionStrategy object
|
|
485
|
+
fant_strat = self.__class__(
|
|
486
|
+
train_inputs=full_inputs,
|
|
487
|
+
train_prior_dist=self.train_prior_dist.__class__(full_mean, full_covar) if not hasattr(self.train_prior_dist, 'power') else
|
|
488
|
+
self.train_prior_dist.__class__(full_mean, full_covar, self.train_prior_dist.power),
|
|
489
|
+
train_labels=full_targets,
|
|
490
|
+
likelihood=fant_likelihood,
|
|
491
|
+
uses_wiski=True,
|
|
492
|
+
)
|
|
493
|
+
add_to_cache(fant_strat, "interp_inner_prod", new_wmat)
|
|
494
|
+
add_to_cache(fant_strat, "interp_response_cache", new_interp_response_cache)
|
|
495
|
+
return fant_strat
|
|
496
|
+
|
|
497
|
+
def prepare_dense_wmat(self, covar=None):
|
|
498
|
+
# prepare the w matrix which is batch shape x m x n, where n = covar.shape[-2]
|
|
499
|
+
if covar is None:
|
|
500
|
+
covar = self.train_prior_dist.lazy_covariance_matrix
|
|
501
|
+
wmat = covar._sparse_left_interp_t(covar.left_interp_indices, covar.left_interp_values).to_dense()
|
|
502
|
+
return to_linear_operator(wmat)
|
|
503
|
+
|
|
504
|
+
@property
|
|
505
|
+
@cached(name="interp_inner_prod")
|
|
506
|
+
def interp_inner_prod(self):
|
|
507
|
+
# the W'W cache
|
|
508
|
+
wmat = self.prepare_dense_wmat()
|
|
509
|
+
noise_term = self.likelihood.noise_covar(wmat.transpose(-1, -2) if len(wmat.shape) > 2 else wmat)
|
|
510
|
+
interp_inner_prod = wmat.matmul(noise_term.solve(wmat.transpose(-1, -2)))
|
|
511
|
+
return interp_inner_prod
|
|
512
|
+
|
|
513
|
+
@property
|
|
514
|
+
@cached(name="interp_response_cache")
|
|
515
|
+
def interp_response_cache(self):
|
|
516
|
+
wmat = self.prepare_dense_wmat()
|
|
517
|
+
noise_term = self.likelihood.noise_covar(wmat.transpose(-1, -2) if len(wmat.shape) > 2 else wmat)
|
|
518
|
+
demeaned_train_targets = self.train_labels - self.train_prior_dist.mean
|
|
519
|
+
dinv_y = noise_term.solve(demeaned_train_targets.unsqueeze(-1))
|
|
520
|
+
return wmat.matmul(dinv_y)
|
|
521
|
+
|
|
522
|
+
@property
|
|
523
|
+
@cached(name="mean_cache")
|
|
524
|
+
def mean_cache(self):
|
|
525
|
+
train_train_covar = self.train_prior_dist.lazy_covariance_matrix
|
|
526
|
+
train_interp_indices = train_train_covar.left_interp_indices
|
|
527
|
+
train_interp_values = train_train_covar.left_interp_values
|
|
528
|
+
|
|
529
|
+
lik = self.likelihood(self.train_prior_dist, self.train_inputs)
|
|
530
|
+
train_mean, train_train_covar_with_noise = lik.mean, lik.lazy_covariance_matrix
|
|
531
|
+
|
|
532
|
+
mean_diff = (self.train_labels - train_mean).unsqueeze(-1)
|
|
533
|
+
train_train_covar_inv_labels = train_train_covar_with_noise.solve(mean_diff)
|
|
534
|
+
|
|
535
|
+
# New root factor
|
|
536
|
+
base_size = train_train_covar.base_linear_op.size(-1)
|
|
537
|
+
mean_cache = train_train_covar.base_linear_op.matmul(
|
|
538
|
+
left_t_interp(train_interp_indices, train_interp_values, train_train_covar_inv_labels, base_size)
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
# Prevent backprop through this variable
|
|
542
|
+
if settings.detach_test_caches.on():
|
|
543
|
+
return mean_cache.detach()
|
|
544
|
+
else:
|
|
545
|
+
return mean_cache
|
|
546
|
+
|
|
547
|
+
@property
|
|
548
|
+
@cached(name="fantasy_mean_cache")
|
|
549
|
+
def fantasy_mean_cache(self):
|
|
550
|
+
# first construct K_UU
|
|
551
|
+
train_train_covar = self.train_prior_dist.lazy_covariance_matrix
|
|
552
|
+
inducing_covar = train_train_covar.base_linear_op
|
|
553
|
+
|
|
554
|
+
# now get L such that LL' \approx WD^{-1}W'
|
|
555
|
+
interp_inner_prod_root = self.interp_inner_prod.root_decomposition(method="cholesky").root
|
|
556
|
+
# M = KL
|
|
557
|
+
inducing_compression_matrix = inducing_covar.matmul(interp_inner_prod_root)
|
|
558
|
+
|
|
559
|
+
# Q = L'KL + 1
|
|
560
|
+
current_qmatrix = interp_inner_prod_root.transpose(-1, -2).matmul(inducing_compression_matrix).add_jitter(1.0)
|
|
561
|
+
|
|
562
|
+
# m = K_UU WD^{-1}(y - \mu)
|
|
563
|
+
inducing_covar_response = inducing_covar.matmul(self.interp_response_cache)
|
|
564
|
+
|
|
565
|
+
# L' m
|
|
566
|
+
root_space_projection = interp_inner_prod_root.transpose(-1, -2).matmul(inducing_covar_response)
|
|
567
|
+
# Q^{-1} (L' m)
|
|
568
|
+
qmat_solve = current_qmatrix.solve(root_space_projection)
|
|
569
|
+
|
|
570
|
+
mean_cache = inducing_covar_response - inducing_compression_matrix @ qmat_solve
|
|
571
|
+
|
|
572
|
+
# Prevent backprop through this variable
|
|
573
|
+
if settings.detach_test_caches.on():
|
|
574
|
+
return mean_cache.detach()
|
|
575
|
+
else:
|
|
576
|
+
return mean_cache
|
|
577
|
+
|
|
578
|
+
@property
|
|
579
|
+
@cached(name="fantasy_covar_cache")
|
|
580
|
+
def fantasy_covar_cache(self):
|
|
581
|
+
train_train_covar = self.train_prior_dist.lazy_covariance_matrix
|
|
582
|
+
inducing_covar = train_train_covar.base_linear_op
|
|
583
|
+
|
|
584
|
+
# we need to enforce a cholesky here for numerical stability
|
|
585
|
+
interp_inner_prod_root = self.interp_inner_prod.root_decomposition(method="cholesky").root
|
|
586
|
+
inducing_compression_matrix = inducing_covar.matmul(interp_inner_prod_root)
|
|
587
|
+
|
|
588
|
+
current_qmatrix = interp_inner_prod_root.transpose(-1, -2).matmul(inducing_compression_matrix).add_jitter(1.0)
|
|
589
|
+
|
|
590
|
+
if settings.fast_pred_var.on():
|
|
591
|
+
qmat_inv_root = current_qmatrix.root_inv_decomposition()
|
|
592
|
+
# to to_linear_operator you have to evaluate the inverse root which is slow
|
|
593
|
+
# otherwise, you can't backprop your way through it
|
|
594
|
+
inner_cache = RootLinearOperator(inducing_compression_matrix.matmul(qmat_inv_root.root.to_dense()))
|
|
595
|
+
else:
|
|
596
|
+
inner_cache = inducing_compression_matrix.matmul(
|
|
597
|
+
current_qmatrix.solve(inducing_compression_matrix.transpose(-1, -2))
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
# Precomputed factor
|
|
601
|
+
if settings.fast_pred_samples.on():
|
|
602
|
+
predictive_covar_cache = inducing_covar - inner_cache
|
|
603
|
+
inside_root = predictive_covar_cache.root_decomposition(method="cholesky").root
|
|
604
|
+
# Prevent backprop through this variable
|
|
605
|
+
if settings.detach_test_caches.on():
|
|
606
|
+
inside_root = inside_root.detach()
|
|
607
|
+
covar_cache = inside_root, None
|
|
608
|
+
else:
|
|
609
|
+
root = inner_cache.root_decomposition(method="cholesky").root
|
|
610
|
+
|
|
611
|
+
# Prevent backprop through this variable
|
|
612
|
+
if settings.detach_test_caches.on():
|
|
613
|
+
root = root.detach()
|
|
614
|
+
covar_cache = None, root
|
|
615
|
+
|
|
616
|
+
return covar_cache
|
|
617
|
+
|
|
618
|
+
@property
|
|
619
|
+
@cached(name="covar_cache")
|
|
620
|
+
def covar_cache(self):
|
|
621
|
+
# Get inverse root
|
|
622
|
+
train_train_covar = self.train_prior_dist.lazy_covariance_matrix
|
|
623
|
+
train_interp_indices = train_train_covar.left_interp_indices
|
|
624
|
+
train_interp_values = train_train_covar.left_interp_values
|
|
625
|
+
|
|
626
|
+
# Get probe vectors for inverse root
|
|
627
|
+
num_probe_vectors = settings.fast_pred_var.num_probe_vectors()
|
|
628
|
+
num_inducing = train_train_covar.base_linear_op.size(-1)
|
|
629
|
+
vector_indices = torch.randperm(num_inducing).type_as(train_interp_indices)
|
|
630
|
+
probe_vector_indices = vector_indices[:num_probe_vectors]
|
|
631
|
+
test_vector_indices = vector_indices[num_probe_vectors : 2 * num_probe_vectors]
|
|
632
|
+
|
|
633
|
+
probe_interp_indices = probe_vector_indices.unsqueeze(1)
|
|
634
|
+
probe_test_interp_indices = test_vector_indices.unsqueeze(1)
|
|
635
|
+
dtype = train_train_covar.dtype
|
|
636
|
+
device = train_train_covar.device
|
|
637
|
+
probe_interp_values = torch.ones(num_probe_vectors, 1, dtype=dtype, device=device)
|
|
638
|
+
|
|
639
|
+
batch_shape = train_train_covar.base_linear_op.batch_shape
|
|
640
|
+
probe_vectors = InterpolatedLinearOperator(
|
|
641
|
+
train_train_covar.base_linear_op,
|
|
642
|
+
train_interp_indices.expand(*batch_shape, *train_interp_indices.shape[-2:]),
|
|
643
|
+
train_interp_values.expand(*batch_shape, *train_interp_values.shape[-2:]),
|
|
644
|
+
probe_interp_indices.expand(*batch_shape, *probe_interp_indices.shape[-2:]),
|
|
645
|
+
probe_interp_values.expand(*batch_shape, *probe_interp_values.shape[-2:]),
|
|
646
|
+
).to_dense()
|
|
647
|
+
test_vectors = InterpolatedLinearOperator(
|
|
648
|
+
train_train_covar.base_linear_op,
|
|
649
|
+
train_interp_indices.expand(*batch_shape, *train_interp_indices.shape[-2:]),
|
|
650
|
+
train_interp_values.expand(*batch_shape, *train_interp_values.shape[-2:]),
|
|
651
|
+
probe_test_interp_indices.expand(*batch_shape, *probe_test_interp_indices.shape[-2:]),
|
|
652
|
+
probe_interp_values.expand(*batch_shape, *probe_interp_values.shape[-2:]),
|
|
653
|
+
).to_dense()
|
|
654
|
+
|
|
655
|
+
# Put data through the likelihood
|
|
656
|
+
dist = self.train_prior_dist.__class__(
|
|
657
|
+
torch.zeros_like(self.train_prior_dist.mean), self.train_prior_dist.lazy_covariance_matrix
|
|
658
|
+
)
|
|
659
|
+
if hasattr(self.train_prior_dist, 'power'): dist.power = self.train_prior_dist.power
|
|
660
|
+
train_train_covar_plus_noise = self.likelihood(dist, self.train_inputs).lazy_covariance_matrix
|
|
661
|
+
|
|
662
|
+
# Get inverse root
|
|
663
|
+
train_train_covar_inv_root = train_train_covar_plus_noise.root_inv_decomposition(
|
|
664
|
+
initial_vectors=probe_vectors, test_vectors=test_vectors
|
|
665
|
+
).root
|
|
666
|
+
train_train_covar_inv_root = train_train_covar_inv_root.to_dense()
|
|
667
|
+
|
|
668
|
+
# New root factor
|
|
669
|
+
root = self._exact_predictive_covar_inv_quad_form_cache(train_train_covar_inv_root, self._last_test_train_covar)
|
|
670
|
+
|
|
671
|
+
# Precomputed factor
|
|
672
|
+
if settings.fast_pred_samples.on():
|
|
673
|
+
inside = train_train_covar.base_linear_op + RootLinearOperator(root).mul(-1)
|
|
674
|
+
inside_root = inside.root_decomposition().root.to_dense()
|
|
675
|
+
# Prevent backprop through this variable
|
|
676
|
+
if settings.detach_test_caches.on():
|
|
677
|
+
inside_root = inside_root.detach()
|
|
678
|
+
covar_cache = inside_root, None
|
|
679
|
+
else:
|
|
680
|
+
# Prevent backprop through this variable
|
|
681
|
+
if settings.detach_test_caches.on():
|
|
682
|
+
root = root.detach()
|
|
683
|
+
covar_cache = None, root
|
|
684
|
+
|
|
685
|
+
return covar_cache
|
|
686
|
+
|
|
687
|
+
def exact_prediction(self, joint_mean, joint_covar):
|
|
688
|
+
# Find the components of the distribution that contain test data
|
|
689
|
+
test_mean = joint_mean[..., self.num_train :]
|
|
690
|
+
test_test_covar = joint_covar[..., self.num_train :, self.num_train :].evaluate_kernel()
|
|
691
|
+
test_train_covar = joint_covar[..., self.num_train :, : self.num_train].evaluate_kernel()
|
|
692
|
+
|
|
693
|
+
return (
|
|
694
|
+
self.exact_predictive_mean(test_mean, test_train_covar),
|
|
695
|
+
self.exact_predictive_covar(test_test_covar, test_train_covar),
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
def exact_predictive_mean(self, test_mean, test_train_covar):
|
|
699
|
+
precomputed_cache = self.fantasy_mean_cache if self.uses_wiski else self.mean_cache
|
|
700
|
+
test_interp_indices = test_train_covar.left_interp_indices
|
|
701
|
+
test_interp_values = test_train_covar.left_interp_values
|
|
702
|
+
res = left_interp(test_interp_indices, test_interp_values, precomputed_cache).squeeze(-1) + test_mean
|
|
703
|
+
return res
|
|
704
|
+
|
|
705
|
+
def exact_predictive_covar(self, test_test_covar, test_train_covar):
|
|
706
|
+
if settings.fast_pred_var.off() and settings.fast_pred_samples.off():
|
|
707
|
+
return super(InterpolatedPredictionStrategy, self).exact_predictive_covar(test_test_covar, test_train_covar)
|
|
708
|
+
|
|
709
|
+
self._last_test_train_covar = test_train_covar
|
|
710
|
+
test_interp_indices = test_train_covar.left_interp_indices
|
|
711
|
+
test_interp_values = test_train_covar.left_interp_values
|
|
712
|
+
|
|
713
|
+
if self.uses_wiski:
|
|
714
|
+
precomputed_cache = self.fantasy_covar_cache
|
|
715
|
+
fps = settings.fast_pred_samples.on()
|
|
716
|
+
if fps:
|
|
717
|
+
root = left_interp(test_interp_indices, test_interp_values, precomputed_cache[0].to_dense())
|
|
718
|
+
res = RootLinearOperator(root)
|
|
719
|
+
else:
|
|
720
|
+
root = left_interp(test_interp_indices, test_interp_values, precomputed_cache[1].to_dense())
|
|
721
|
+
res = test_test_covar + RootLinearOperator(root).mul(-1)
|
|
722
|
+
return res
|
|
723
|
+
else:
|
|
724
|
+
precomputed_cache = self.covar_cache
|
|
725
|
+
fps = settings.fast_pred_samples.on()
|
|
726
|
+
if (fps and precomputed_cache[0] is None) or (not fps and precomputed_cache[1] is None):
|
|
727
|
+
pop_from_cache(self, "covar_cache")
|
|
728
|
+
precomputed_cache = self.covar_cache
|
|
729
|
+
|
|
730
|
+
# Compute the exact predictive posterior
|
|
731
|
+
if settings.fast_pred_samples.on():
|
|
732
|
+
res = self._exact_predictive_covar_inv_quad_form_root(precomputed_cache[0], test_train_covar)
|
|
733
|
+
res = RootLinearOperator(res)
|
|
734
|
+
else:
|
|
735
|
+
root = left_interp(test_interp_indices, test_interp_values, precomputed_cache[1])
|
|
736
|
+
res = test_test_covar + RootLinearOperator(root).mul(-1)
|
|
737
|
+
return res
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
class RFFPredictionStrategy(DefaultPredictionStrategy):
|
|
741
|
+
def __init__(self, train_inputs, train_prior_dist, train_labels, likelihood):
|
|
742
|
+
super().__init__(train_inputs, train_prior_dist, train_labels, likelihood)
|
|
743
|
+
args = (self.train_prior_dist.mean, self.train_prior_dist.lazy_covariance_matrix.evaluate_kernel())
|
|
744
|
+
if hasattr(self.train_prior_dist, 'power'): args = args+(self.train_prior_dist.power,)
|
|
745
|
+
self.train_prior_dist = self.train_prior_dist.__class__(*args)
|
|
746
|
+
|
|
747
|
+
def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_output, **kwargs):
|
|
748
|
+
raise NotImplementedError("Fantasy observation updates not yet supported for models using RFFs")
|
|
749
|
+
|
|
750
|
+
@property
|
|
751
|
+
@cached(name="covar_cache")
|
|
752
|
+
def covar_cache(self):
|
|
753
|
+
lt = self.train_prior_dist.lazy_covariance_matrix
|
|
754
|
+
if isinstance(lt, ConstantMulLinearOperator):
|
|
755
|
+
constant = lt.expanded_constant
|
|
756
|
+
lt = lt.base_linear_op
|
|
757
|
+
else:
|
|
758
|
+
constant = torch.tensor(1.0, dtype=lt.dtype, device=lt.device)
|
|
759
|
+
|
|
760
|
+
train_factor = lt.root.to_dense()
|
|
761
|
+
train_train_covar = self.lik_train_train_covar
|
|
762
|
+
inner_term = (
|
|
763
|
+
torch.eye(train_factor.size(-1), dtype=train_factor.dtype, device=train_factor.device)
|
|
764
|
+
- (train_factor.transpose(-1, -2) @ train_train_covar.solve(train_factor)) * constant
|
|
765
|
+
)
|
|
766
|
+
return psd_safe_cholesky(inner_term)
|
|
767
|
+
|
|
768
|
+
def exact_prediction(self, joint_mean, joint_covar):
|
|
769
|
+
# Find the components of the distribution that contain test data
|
|
770
|
+
test_mean = joint_mean[..., self.num_train :]
|
|
771
|
+
test_test_covar = joint_covar[..., self.num_train :, self.num_train :].evaluate_kernel()
|
|
772
|
+
test_train_covar = joint_covar[..., self.num_train :, : self.num_train].evaluate_kernel()
|
|
773
|
+
|
|
774
|
+
return (
|
|
775
|
+
self.exact_predictive_mean(test_mean, test_train_covar),
|
|
776
|
+
self.exact_predictive_covar(test_test_covar, test_train_covar),
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
def exact_predictive_covar(self, test_test_covar, test_train_covar):
|
|
780
|
+
if settings.skip_posterior_variances.on():
|
|
781
|
+
return ZeroLinearOperator(*test_test_covar.size())
|
|
782
|
+
|
|
783
|
+
if isinstance(test_test_covar, ConstantMulLinearOperator):
|
|
784
|
+
constant = test_test_covar.expanded_constant
|
|
785
|
+
test_test_covar = test_test_covar.base_linear_op
|
|
786
|
+
else:
|
|
787
|
+
constant = torch.tensor(1.0, dtype=test_test_covar.dtype, device=test_test_covar.device)
|
|
788
|
+
|
|
789
|
+
covar_cache = self.covar_cache
|
|
790
|
+
factor = test_test_covar.root.to_dense() * constant.sqrt()
|
|
791
|
+
res = RootLinearOperator(factor @ covar_cache)
|
|
792
|
+
return res
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
class SQEPRPredictionStrategy(DefaultPredictionStrategy):
|
|
796
|
+
@property
|
|
797
|
+
@cached(name="covar_cache")
|
|
798
|
+
def covar_cache(self):
|
|
799
|
+
# Here, the covar_cache is going to be K_{UU}^{-1/2} K_{UX}( K_{XX} + \sigma^2 I )^{-1} K_{XU} K_{UU}^{-1/2}
|
|
800
|
+
# This is easily computed using Woodbury
|
|
801
|
+
# K_{XX} + \sigma^2 I = R R^T + \sigma^2 I
|
|
802
|
+
# = \sigma^{-2} ( I - \sigma^{-2} R (I + \sigma^{-2} R^T R)^{-1} R^T )
|
|
803
|
+
train_train_covar = self.lik_train_train_covar.evaluate_kernel()
|
|
804
|
+
|
|
805
|
+
# Get terms needed for woodbury
|
|
806
|
+
root = train_train_covar._linear_op.root_decomposition().root.to_dense() # R
|
|
807
|
+
inv_diag = train_train_covar._diag_tensor.inverse() # \sigma^{-2}
|
|
808
|
+
|
|
809
|
+
# Form LT using woodbury
|
|
810
|
+
ones = torch.tensor(1.0, dtype=root.dtype, device=root.device)
|
|
811
|
+
chol_factor = to_linear_operator(root.transpose(-1, -2) @ (inv_diag @ root)).add_diagonal(
|
|
812
|
+
ones
|
|
813
|
+
) # (I + \sigma^{-2} R^T R)^{-1}
|
|
814
|
+
woodbury_term = inv_diag @ torch.linalg.solve_triangular(
|
|
815
|
+
chol_factor.cholesky().to_dense(), root.transpose(-1, -2), upper=False
|
|
816
|
+
).transpose(-1, -2)
|
|
817
|
+
# woodbury_term @ woodbury_term^T = \sigma^{-2} R (I + \sigma^{-2} R^T R)^{-1} R^T \sigma^{-2}
|
|
818
|
+
|
|
819
|
+
inverse = AddedDiagLinearOperator(
|
|
820
|
+
inv_diag, MatmulLinearOperator(-woodbury_term, woodbury_term.transpose(-1, -2))
|
|
821
|
+
)
|
|
822
|
+
# \sigma^{-2} ( I - \sigma^{-2} R (I + \sigma^{-2} R^T R)^{-1} R^T )
|
|
823
|
+
|
|
824
|
+
return root.transpose(-1, -2) @ (inverse @ root)
|
|
825
|
+
|
|
826
|
+
def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_output, **kwargs):
|
|
827
|
+
raise NotImplementedError(
|
|
828
|
+
"Fantasy observation updates not yet supported for models using SQEPRPredictionStrategy"
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
def exact_prediction(self, joint_mean, joint_covar):
|
|
832
|
+
from ..kernels import InducingPointKernel
|
|
833
|
+
|
|
834
|
+
# Find the components of the distribution that contain test data
|
|
835
|
+
test_mean = joint_mean[..., self.num_train :]
|
|
836
|
+
|
|
837
|
+
# If we're in lazy evaluation mode, let's use the base kernel of the SQEPR output to compute the prior covar
|
|
838
|
+
test_test_covar = joint_covar[..., self.num_train :, self.num_train :]
|
|
839
|
+
if isinstance(test_test_covar, LazyEvaluatedKernelTensor) and isinstance(
|
|
840
|
+
test_test_covar.kernel, InducingPointKernel
|
|
841
|
+
):
|
|
842
|
+
test_test_covar = LazyEvaluatedKernelTensor(
|
|
843
|
+
test_test_covar.x1,
|
|
844
|
+
test_test_covar.x2,
|
|
845
|
+
test_test_covar.kernel.base_kernel,
|
|
846
|
+
test_test_covar.last_dim_is_batch,
|
|
847
|
+
**test_test_covar.params,
|
|
848
|
+
)
|
|
849
|
+
|
|
850
|
+
test_train_covar = joint_covar[..., self.num_train :, : self.num_train].evaluate_kernel()
|
|
851
|
+
|
|
852
|
+
return (
|
|
853
|
+
self.exact_predictive_mean(test_mean, test_train_covar),
|
|
854
|
+
self.exact_predictive_covar(test_test_covar, test_train_covar),
|
|
855
|
+
)
|
|
856
|
+
|
|
857
|
+
def exact_predictive_covar(self, test_test_covar, test_train_covar):
|
|
858
|
+
covar_cache = self.covar_cache
|
|
859
|
+
# covar_cache = K_{UU}^{-1/2} K_{UX}( K_{XX} + \sigma^2 I )^{-1} K_{XU} K_{UU}^{-1/2}
|
|
860
|
+
|
|
861
|
+
# Decompose test_train_covar = l, r
|
|
862
|
+
# Main case: test_x and train_x are different - test_train_covar is a MatmulLinearOperator
|
|
863
|
+
if isinstance(test_train_covar, MatmulLinearOperator):
|
|
864
|
+
L = test_train_covar.left_linear_op.to_dense()
|
|
865
|
+
# Edge case: test_x and train_x are the same - test_train_covar is a LowRankRootAddedDiagLinearOperator
|
|
866
|
+
elif isinstance(test_train_covar, LowRankRootAddedDiagLinearOperator):
|
|
867
|
+
L = test_train_covar._linear_op.root.to_dense()
|
|
868
|
+
else:
|
|
869
|
+
# We should not hit this point of the code - this is to catch potential bugs in GPyTorch
|
|
870
|
+
raise ValueError(
|
|
871
|
+
"Expected SQEPR output to be a MatmulLinearOperator or AddedDiagLinearOperator. "
|
|
872
|
+
f"Got {test_train_covar.__class__.__name__} instead. "
|
|
873
|
+
"This is likely a bug in GPyTorch."
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
res = test_test_covar - MatmulLinearOperator(L, covar_cache @ L.mT)
|
|
877
|
+
return res
|
|
878
|
+
|
|
879
|
+
class SGPRPredictionStrategy(SQEPRPredictionStrategy):
|
|
880
|
+
pass
|