qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,880 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import functools
4
+ import string
5
+ import warnings
6
+
7
+ import torch
8
+ from linear_operator import to_dense, to_linear_operator
9
+ from linear_operator.operators import (
10
+ AddedDiagLinearOperator,
11
+ BatchRepeatLinearOperator,
12
+ ConstantMulLinearOperator,
13
+ InterpolatedLinearOperator,
14
+ LinearOperator,
15
+ LowRankRootAddedDiagLinearOperator,
16
+ MaskedLinearOperator,
17
+ MatmulLinearOperator,
18
+ RootLinearOperator,
19
+ ZeroLinearOperator,
20
+ )
21
+ from linear_operator.utils.cholesky import psd_safe_cholesky
22
+ from linear_operator.utils.interpolation import left_interp, left_t_interp
23
+ from torch import Tensor
24
+
25
+ from .. import settings
26
+
27
+ from ..distributions import MultitaskMultivariateNormal, MultitaskMultivariateQExponential
28
+ from ..lazy import LazyEvaluatedKernelTensor
29
+ from gpytorch.utils.memoize import add_to_cache, cached, clear_cache_hook, pop_from_cache
30
+
31
+
32
+ def prediction_strategy(train_inputs, train_prior_dist, train_labels, likelihood):
33
+ train_train_covar = train_prior_dist.lazy_covariance_matrix
34
+ if isinstance(train_train_covar, LazyEvaluatedKernelTensor):
35
+ cls = train_train_covar.kernel.prediction_strategy
36
+ else:
37
+ cls = DefaultPredictionStrategy
38
+ return cls(train_inputs, train_prior_dist, train_labels, likelihood)
39
+
40
+
41
+ class DefaultPredictionStrategy(object):
42
+ def __init__(self, train_inputs, train_prior_dist, train_labels, likelihood, root=None, inv_root=None):
43
+ # Get training shape
44
+ self._train_shape = train_prior_dist.event_shape
45
+
46
+ # Flatten the training labels
47
+ try:
48
+ train_labels = train_labels.reshape(
49
+ *train_labels.shape[: -len(self.train_shape)], self._train_shape.numel()
50
+ )
51
+ except RuntimeError:
52
+ raise RuntimeError(
53
+ "Flattening the training labels failed. The most common cause of this error is "
54
+ + "that the shapes of the prior mean and the training labels are mismatched. "
55
+ + "The shape of the train targets is {0}, ".format(train_labels.shape)
56
+ + "while the reported shape of the mean is {0}.".format(train_prior_dist.mean.shape)
57
+ )
58
+
59
+ self.train_inputs = train_inputs
60
+ self.train_prior_dist = train_prior_dist
61
+ self.train_labels = train_labels
62
+ self.likelihood = likelihood
63
+ self._last_test_train_covar = None
64
+ lik = self.likelihood(train_prior_dist, train_inputs)
65
+ self.lik_train_train_covar = lik.lazy_covariance_matrix
66
+
67
+ if root is not None:
68
+ add_to_cache(self.lik_train_train_covar, "root_decomposition", RootLinearOperator(root))
69
+
70
+ if inv_root is not None:
71
+ add_to_cache(self.lik_train_train_covar, "root_inv_decomposition", RootLinearOperator(inv_root))
72
+
73
+ def __deepcopy__(self, memo):
74
+ # deepcopying prediction strategies of a model evaluated on inputs that require gradients fails
75
+ # with RuntimeError (Only Tensors created explicitly by the user (graph leaves) support the deepcopy
76
+ # protocol at the moment). Overwriting this method make sure that the prediction strategies of a
77
+ # model are set to None upon deepcopying.
78
+ pass
79
+
80
+ def _exact_predictive_covar_inv_quad_form_cache(self, train_train_covar_inv_root, test_train_covar):
81
+ """
82
+ Computes a cache for K_X*X (K_XX + sigma^2 I)^-1 K_X*X if possible. By default, this does no work and returns
83
+ the first argument.
84
+
85
+ Args:
86
+ train_train_covar_inv_root (:obj:`torch.tensor`): a root of (K_XX + sigma^2 I)^-1
87
+ test_train_covar (:obj:`torch.tensor`): the observed noise (from the likelihood)
88
+
89
+ Returns
90
+ A precomputed cache
91
+ """
92
+ res = train_train_covar_inv_root
93
+ if settings.detach_test_caches.on():
94
+ res = res.detach()
95
+
96
+ if res.grad_fn is not None:
97
+ wrapper = functools.partial(clear_cache_hook, self)
98
+ functools.update_wrapper(wrapper, clear_cache_hook)
99
+ res.grad_fn.register_hook(wrapper)
100
+
101
+ return res
102
+
103
+ def _exact_predictive_covar_inv_quad_form_root(self, precomputed_cache, test_train_covar):
104
+ r"""
105
+ Computes :math:`K_{X^{*}X} S` given a precomputed cache
106
+ Where :math:`S` is a tensor such that :math:`SS^{\top} = (K_{XX} + \sigma^2 I)^{-1}`
107
+
108
+ Args:
109
+ precomputed_cache (:obj:`torch.tensor`): What was computed in _exact_predictive_covar_inv_quad_form_cache
110
+ test_train_covar (:obj:`torch.tensor`): The observed noise (from the likelihood)
111
+
112
+ Returns
113
+ :obj:`~linear_operator.operators.LinearOperator`: :math:`K_{X^{*}X} S`
114
+ """
115
+ # Here the precomputed cache represents S,
116
+ # where S S^T = (K_XX + sigma^2 I)^-1
117
+ return test_train_covar.matmul(precomputed_cache)
118
+
119
+ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_output, **kwargs):
120
+ """
121
+ Returns a new PredictionStrategy that incorporates the specified inputs and targets as new training data.
122
+
123
+ This method is primary responsible for updating the mean and covariance caches. To add fantasy data to a
124
+ GP (QEP) model, use the :meth:`~gpytorch.models.ExactGP.get_fantasy_model` (:meth:`~qpytorch.models.ExactQEP.get_fantasy_model`) method.
125
+
126
+ Args:
127
+ inputs (Tensor `b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`): Locations of fantasy
128
+ observations.
129
+ targets (Tensor `b1 x ... x bk x m` or `f x b1 x ... x bk x m`): Labels of fantasy observations.
130
+ full_inputs (Tensor `b1 x ... x bk x n+m x d` or `f x b1 x ... x bk x n+m x d`): Training data
131
+ concatenated with fantasy inputs
132
+ full_targets (Tensor `b1 x ... x bk x n+m` or `f x b1 x ... x bk x n+m`): Training labels
133
+ concatenated with fantasy labels.
134
+ full_output (:class:`gpytorch.distributions.MultivariateNormal` or :class:`gpytorch.distributions.MultivariateQExponential`): Prior called on full_inputs
135
+
136
+ Returns:
137
+ A `DefaultPredictionStrategy` model with `n + m` training examples, where the `m` fantasy examples have
138
+ been added and all test-time caches have been updated.
139
+ """
140
+ if not isinstance(full_output, (MultitaskMultivariateNormal, MultitaskMultivariateQExponential)):
141
+ target_batch_shape = targets.shape[:-1]
142
+ else:
143
+ target_batch_shape = targets.shape[:-2]
144
+
145
+ full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix
146
+
147
+ batch_shape = full_inputs[0].shape[:-2]
148
+
149
+ num_train = self.num_train
150
+
151
+ if isinstance(full_output, (MultitaskMultivariateNormal, MultitaskMultivariateQExponential)):
152
+ num_tasks = full_output.event_shape[-1]
153
+ full_mean = full_mean.view(*batch_shape, -1, num_tasks)
154
+ fant_mean = full_mean[..., (num_train // num_tasks) :, :]
155
+ full_targets = full_targets.view(*target_batch_shape, -1)
156
+ else:
157
+ full_mean = full_mean.view(*batch_shape, -1)
158
+ fant_mean = full_mean[..., num_train:]
159
+
160
+ # Evaluate fant x train and fant x fant covariance matrices, leave train x train unevaluated.
161
+ fant_fant_covar = full_covar[..., num_train:, num_train:]
162
+ dist = self.train_prior_dist.__class__(fant_mean, fant_fant_covar)
163
+ if hasattr(self.train_prior_dist, 'power'): dist.power = self.train_prior_dist.power
164
+ fant_likelihood = self.likelihood.get_fantasy_likelihood(**kwargs)
165
+ fant_obs = fant_likelihood(dist, inputs, **kwargs)
166
+
167
+ fant_fant_covar = fant_obs.covariance_matrix
168
+ fant_train_covar = to_dense(full_covar[..., num_train:, :num_train])
169
+
170
+ self.fantasy_inputs = inputs
171
+ self.fantasy_targets = targets
172
+
173
+ r"""
174
+ Compute a new mean cache given the old mean cache.
175
+
176
+ We have \alpha = K^{-1}y, and we want to solve [K U; U' S][a; b] = [y; y_f], where U' is fant_train_covar,
177
+ S is fant_fant_covar, and y_f is (targets - fant_mean)
178
+
179
+ To do this, we solve the bordered linear system of equations for [a; b]:
180
+ AQ = U # Q = fant_solve
181
+ [S - U'Q]b = y_f - U'\alpha ==> b = [S - U'Q]^{-1}(y_f - U'\alpha)
182
+ a = \alpha - Qb
183
+ """
184
+ # Get cached K inverse decomp. (or compute if we somehow don't already have the covariance cache)
185
+ K_inverse = self.lik_train_train_covar.root_inv_decomposition()
186
+ fant_solve = K_inverse.matmul(fant_train_covar.transpose(-2, -1))
187
+
188
+ # Solve for "b", the lower portion of the *new* \\alpha corresponding to the fantasy points.
189
+ schur_complement = fant_fant_covar - fant_train_covar.matmul(fant_solve)
190
+
191
+ # we'd like to use a less hacky approach for the following, but einsum can be much faster than
192
+ # than unsqueezing/squeezing here (esp. in backward passes), unfortunately it currenlty has some
193
+ # issues with broadcasting: https://github.com/pytorch/pytorch/issues/15671
194
+ prefix = string.ascii_lowercase[: max(fant_train_covar.dim() - self.mean_cache.dim() - 1, 0)]
195
+ ftcm = torch.einsum(prefix + "...yz,...z->" + prefix + "...y", [fant_train_covar, self.mean_cache])
196
+
197
+ small_system_rhs = targets - fant_mean - ftcm
198
+ small_system_rhs = small_system_rhs.unsqueeze(-1)
199
+ # Schur complement of a spd matrix is guaranteed to be positive definite
200
+ schur_cholesky = psd_safe_cholesky(schur_complement)
201
+ fant_cache_lower = torch.cholesky_solve(small_system_rhs, schur_cholesky)
202
+
203
+ # Get "a", the new upper portion of the cache corresponding to the old training points.
204
+ fant_cache_upper = self.mean_cache.unsqueeze(-1) - fant_solve.matmul(fant_cache_lower)
205
+
206
+ fant_cache_upper = fant_cache_upper.squeeze(-1)
207
+ fant_cache_lower = fant_cache_lower.squeeze(-1)
208
+
209
+ # New mean cache.
210
+ fant_mean_cache = torch.cat((fant_cache_upper, fant_cache_lower), dim=-1)
211
+
212
+ # now update the root and root inverse
213
+ new_lt = self.lik_train_train_covar.cat_rows(fant_train_covar, fant_fant_covar)
214
+ new_root = new_lt.root_decomposition().root
215
+ if settings.detach_test_caches.on():
216
+ new_covar_cache = new_lt.root_inv_decomposition().root.detach()
217
+ else:
218
+ new_covar_cache = new_lt.root_inv_decomposition().root
219
+
220
+ # Expand inputs accordingly if necessary (for fantasies at the same points)
221
+ if full_inputs[0].dim() <= full_targets.dim():
222
+ fant_batch_shape = full_targets.shape[:1]
223
+ n_batch = len(full_mean.shape[:-1])
224
+ repeat_shape = fant_batch_shape + torch.Size([1] * n_batch)
225
+ full_inputs = [fi.expand(fant_batch_shape + fi.shape) for fi in full_inputs]
226
+ full_mean = full_mean.expand(fant_batch_shape + full_mean.shape)
227
+ full_covar = BatchRepeatLinearOperator(full_covar, repeat_shape)
228
+ new_root = BatchRepeatLinearOperator(new_root, repeat_shape)
229
+ # no need to repeat the covar cache, broadcasting will do the right thing
230
+
231
+ if isinstance(full_output, (MultitaskMultivariateNormal, MultitaskMultivariateQExponential)):
232
+ full_mean = full_mean.view(*target_batch_shape, -1, num_tasks).contiguous()
233
+
234
+ # Create new DefaultPredictionStrategy object
235
+ fant_strat = self.__class__(
236
+ train_inputs=full_inputs,
237
+ train_prior_dist=self.train_prior_dist.__class__(full_mean, full_covar) if not hasattr(self.train_prior_dist, 'power') else
238
+ self.train_prior_dist.__class__(full_mean, full_covar, self.train_prior_dist.power),
239
+ train_labels=full_targets,
240
+ likelihood=fant_likelihood,
241
+ root=new_root,
242
+ inv_root=new_covar_cache,
243
+ )
244
+ add_to_cache(fant_strat, "mean_cache", fant_mean_cache)
245
+ add_to_cache(fant_strat, "covar_cache", new_covar_cache.to_dense())
246
+ return fant_strat
247
+
248
+ @property
249
+ @cached(name="covar_cache")
250
+ def covar_cache(self):
251
+ train_train_covar = self.lik_train_train_covar
252
+ train_train_covar_inv_root = to_dense(train_train_covar.root_inv_decomposition().root)
253
+ return self._exact_predictive_covar_inv_quad_form_cache(train_train_covar_inv_root, self._last_test_train_covar)
254
+
255
+ @property
256
+ def mean_cache(self):
257
+ return self._mean_cache(settings.observation_nan_policy.value())
258
+
259
+ @cached(name="mean_cache")
260
+ def _mean_cache(self, nan_policy: str) -> Tensor:
261
+ lik = self.likelihood(self.train_prior_dist, self.train_inputs)
262
+ train_mean, train_train_covar = lik.loc, lik.lazy_covariance_matrix
263
+
264
+ train_labels_offset = (self.train_labels - train_mean).unsqueeze(-1)
265
+
266
+ if nan_policy == "ignore":
267
+ mean_cache = train_train_covar.evaluate_kernel().solve(train_labels_offset).squeeze(-1)
268
+ elif nan_policy == "mask":
269
+ # Mask all rows and columns in the kernel matrix corresponding to the missing observations.
270
+ observed = settings.observation_nan_policy._get_observed(
271
+ self.train_labels, torch.Size((self.train_labels.shape[-1],))
272
+ )
273
+ mean_cache = torch.full_like(self.train_labels, torch.nan)
274
+ kernel = MaskedLinearOperator(
275
+ train_train_covar.evaluate_kernel(), observed.reshape(-1), observed.reshape(-1)
276
+ )
277
+ mean_cache[..., observed] = kernel.solve(train_labels_offset[..., observed, :]).squeeze(-1)
278
+ else: # 'fill'
279
+ # Fill all rows and columns in the kernel matrix corresponding to the missing observations with 0.
280
+ # Don't touch the corresponding diagonal elements to ensure a unique solution.
281
+ # This ensures that missing data is ignored during solving.
282
+ warnings.warn(
283
+ "Observation NaN policy 'fill' makes the kernel matrix dense during exact prediction.",
284
+ RuntimeWarning,
285
+ )
286
+ kernel = train_train_covar.evaluate_kernel()
287
+ missing = torch.isnan(self.train_labels)
288
+ kernel_mask = (~missing).to(torch.float)
289
+ kernel_mask = kernel_mask[..., None] * kernel_mask[..., None, :]
290
+ torch.diagonal(kernel_mask, dim1=-2, dim2=-1)[...] = 1
291
+ kernel = kernel * kernel_mask # Unfortunately, this makes the kernel dense at the moment.
292
+ train_labels_offset = settings.observation_nan_policy._fill_tensor(train_labels_offset)
293
+ mean_cache = kernel.solve(train_labels_offset).squeeze(-1)
294
+ mean_cache[missing] = torch.nan # Ensure that nobody expects these values to be valid.
295
+ if settings.detach_test_caches.on():
296
+ mean_cache = mean_cache.detach()
297
+
298
+ if mean_cache.grad_fn is not None:
299
+ wrapper = functools.partial(clear_cache_hook, self)
300
+ functools.update_wrapper(wrapper, clear_cache_hook)
301
+ mean_cache.grad_fn.register_hook(wrapper)
302
+
303
+ return mean_cache
304
+
305
+ @property
306
+ def num_train(self):
307
+ return self._train_shape.numel()
308
+
309
+ @property
310
+ def train_shape(self):
311
+ return self._train_shape
312
+
313
+ def exact_prediction(self, joint_mean, joint_covar):
314
+ # Find the components of the distribution that contain test data
315
+ test_mean = joint_mean[..., self.num_train :]
316
+ # For efficiency - we can make things more efficient
317
+ if joint_covar.size(-1) <= settings.max_eager_kernel_size.value():
318
+ test_covar = joint_covar[..., self.num_train :, :].to_dense()
319
+ test_test_covar = test_covar[..., self.num_train :]
320
+ test_train_covar = test_covar[..., : self.num_train]
321
+ else:
322
+ test_test_covar = joint_covar[..., self.num_train :, self.num_train :]
323
+ test_train_covar = joint_covar[..., self.num_train :, : self.num_train]
324
+
325
+ return (
326
+ self.exact_predictive_mean(test_mean, test_train_covar),
327
+ self.exact_predictive_covar(test_test_covar, test_train_covar),
328
+ )
329
+
330
+ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOperator) -> Tensor:
331
+ """
332
+ Computes the posterior predictive covariance of a GP (QEP)
333
+
334
+ :param Tensor test_mean: The test prior mean
335
+ :param ~linear_operator.operators.LinearOperator test_train_covar:
336
+ Covariance matrix between test and train inputs
337
+ :return: The predictive posterior mean of the test points
338
+ """
339
+ # NOTE TO FUTURE SELF:
340
+ # You **cannot* use addmv here, because test_train_covar may not actually be a non lazy tensor even for an exact
341
+ # GP, and using addmv requires you to to_dense test_train_covar, which is obviously a huge no-no!
342
+
343
+ # see https://github.com/cornellius-gp/gpytorch/pull/2317#discussion_r1157994719
344
+ mean_cache = self.mean_cache
345
+ if len(mean_cache.shape) == 4:
346
+ mean_cache = mean_cache.squeeze(1)
347
+
348
+ # Handle NaNs
349
+ nan_policy = settings.observation_nan_policy.value()
350
+ if nan_policy == "ignore":
351
+ res = (test_train_covar @ mean_cache.unsqueeze(-1)).squeeze(-1)
352
+ elif nan_policy == "mask":
353
+ # Restrict train dimension to observed values
354
+ observed = settings.observation_nan_policy._get_observed(mean_cache, torch.Size((mean_cache.shape[-1],)))
355
+ full_mask = torch.ones(test_mean.shape[-1], dtype=torch.bool, device=test_mean.device)
356
+ test_train_covar = MaskedLinearOperator(
357
+ to_linear_operator(test_train_covar), full_mask, observed.reshape(-1)
358
+ )
359
+ res = (test_train_covar @ mean_cache[..., observed].unsqueeze(-1)).squeeze(-1)
360
+ else: # 'fill'
361
+ # Set the columns corresponding to missing observations to 0 to ignore them during matmul.
362
+ mask = (~torch.isnan(mean_cache)).to(torch.float)[..., None, :]
363
+ test_train_covar = test_train_covar * mask
364
+ mean = settings.observation_nan_policy._fill_tensor(mean_cache)
365
+ res = (test_train_covar @ mean.unsqueeze(-1)).squeeze(-1)
366
+ res = res + test_mean
367
+
368
+ return res
369
+
370
+ def exact_predictive_covar(
371
+ self, test_test_covar: LinearOperator, test_train_covar: LinearOperator
372
+ ) -> LinearOperator:
373
+ """
374
+ Computes the posterior predictive covariance of a GP (QEP)
375
+
376
+ :param ~linear_operator.operators.LinearOperator test_train_covar:
377
+ Covariance matrix between test and train inputs
378
+ :param ~linear_operator.operators.LinearOperator test_test_covar: Covariance matrix between test inputs
379
+ :return: A LinearOperator representing the predictive posterior covariance of the test points
380
+ """
381
+ if settings.fast_pred_var.on():
382
+ self._last_test_train_covar = test_train_covar
383
+
384
+ if settings.skip_posterior_variances.on():
385
+ return ZeroLinearOperator(*test_test_covar.size())
386
+
387
+ if settings.fast_pred_var.off():
388
+ dist = self.train_prior_dist.__class__(
389
+ torch.zeros_like(self.train_prior_dist.mean), self.train_prior_dist.lazy_covariance_matrix
390
+ )
391
+ if hasattr(self.train_prior_dist, 'power'): dist.power = self.train_prior_dist.power
392
+ if settings.detach_test_caches.on():
393
+ train_train_covar = self.likelihood(dist, self.train_inputs).lazy_covariance_matrix.detach()
394
+ else:
395
+ train_train_covar = self.likelihood(dist, self.train_inputs).lazy_covariance_matrix
396
+
397
+ test_train_covar = to_dense(test_train_covar)
398
+ train_test_covar = test_train_covar.transpose(-1, -2)
399
+ covar_correction_rhs = train_train_covar.solve(train_test_covar)
400
+ # For efficiency
401
+ if torch.is_tensor(test_test_covar):
402
+ # We can use addmm in the 2d case
403
+ if test_test_covar.dim() == 2:
404
+ return to_linear_operator(
405
+ torch.addmm(test_test_covar, test_train_covar, covar_correction_rhs, beta=1, alpha=-1)
406
+ )
407
+ else:
408
+ return to_linear_operator(test_test_covar + test_train_covar @ covar_correction_rhs.mul(-1))
409
+ # In other cases - we'll use the standard infrastructure
410
+ else:
411
+ return test_test_covar + MatmulLinearOperator(test_train_covar, covar_correction_rhs.mul(-1))
412
+
413
+ precomputed_cache = self.covar_cache
414
+ covar_inv_quad_form_root = self._exact_predictive_covar_inv_quad_form_root(precomputed_cache, test_train_covar)
415
+ if torch.is_tensor(test_test_covar):
416
+ return to_linear_operator(
417
+ torch.add(
418
+ test_test_covar, covar_inv_quad_form_root @ covar_inv_quad_form_root.transpose(-1, -2), alpha=-1
419
+ )
420
+ )
421
+ else:
422
+ return test_test_covar + MatmulLinearOperator(
423
+ covar_inv_quad_form_root, covar_inv_quad_form_root.transpose(-1, -2).mul(-1)
424
+ )
425
+
426
+
427
+ class InterpolatedPredictionStrategy(DefaultPredictionStrategy):
428
+ def __init__(self, train_inputs, train_prior_dist, train_labels, likelihood, uses_wiski=False):
429
+ args = (train_prior_dist.mean, train_prior_dist.lazy_covariance_matrix.evaluate_kernel())
430
+ if hasattr(train_prior_dist, 'power'): args = args+(train_prior_dist.power,)
431
+ train_prior_dist = train_prior_dist.__class__(*args)
432
+ super().__init__(train_inputs, train_prior_dist, train_labels, likelihood)
433
+ # covar = self.train_prior_dist.lazy_covariance_matrix.evaluate_kernel()
434
+ # if isinstance(covar, LazyEvaluatedKernelTensor):
435
+ # covar = covar.evaluate_kernel()
436
+ # self.train_prior_dist = self.train_prior_dist.__class__(
437
+ # self.train_prior_dist.mean, covar
438
+ # )
439
+ self.uses_wiski = uses_wiski
440
+
441
+ def _exact_predictive_covar_inv_quad_form_cache(self, train_train_covar_inv_root, test_train_covar):
442
+ train_interp_indices = test_train_covar.right_interp_indices
443
+ train_interp_values = test_train_covar.right_interp_values
444
+ base_linear_op = test_train_covar.base_linear_op
445
+ base_size = base_linear_op.size(-1)
446
+ res = base_linear_op.matmul(
447
+ left_t_interp(train_interp_indices, train_interp_values, train_train_covar_inv_root, base_size)
448
+ )
449
+ return res
450
+
451
+ def _exact_predictive_covar_inv_quad_form_root(self, precomputed_cache, test_train_covar):
452
+ # Here the precomputed cache represents K_UU W S,
453
+ # where S S^T = (K_XX + sigma^2 I)^-1
454
+ test_interp_indices = test_train_covar.left_interp_indices
455
+ test_interp_values = test_train_covar.left_interp_values
456
+ res = left_interp(test_interp_indices, test_interp_values, precomputed_cache)
457
+ return res
458
+
459
+ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_output, **kwargs):
460
+ r"""
461
+ Implements the fantasy strategy described in https://arxiv.org/abs/2103.01454.
462
+ """
463
+ full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix
464
+
465
+ batch_shape = full_inputs[0].shape[:-2]
466
+
467
+ full_mean = full_mean.view(*batch_shape, -1)
468
+ num_train = self.num_train
469
+
470
+ # Evaluate fant x train and fant x fant covariance matrices, leave train x train unevaluated.
471
+ fant_fant_covar = full_covar[..., num_train:, num_train:].evaluate_kernel()
472
+ fant_mean = full_mean[..., num_train:]
473
+
474
+ fant_wmat = self.prepare_dense_wmat(fant_fant_covar)
475
+
476
+ fant_likelihood = self.likelihood.get_fantasy_likelihood(**kwargs)
477
+ fant_noise = fant_likelihood.noise_covar(fant_wmat.transpose(-1, -2) if len(fant_wmat.shape) > 2 else fant_wmat)
478
+ fant_root_vector = fant_noise.sqrt_inv_matmul(fant_wmat.transpose(-1, -2)).transpose(-1, -2)
479
+
480
+ new_wmat = self.interp_inner_prod.add_low_rank(fant_root_vector.to_dense())
481
+ mean_diff = (targets - fant_mean).unsqueeze(-1)
482
+ new_interp_response_cache = self.interp_response_cache + fant_wmat.matmul(fant_noise.solve(mean_diff))
483
+
484
+ # Create new DefaultPredictionStrategy object
485
+ fant_strat = self.__class__(
486
+ train_inputs=full_inputs,
487
+ train_prior_dist=self.train_prior_dist.__class__(full_mean, full_covar) if not hasattr(self.train_prior_dist, 'power') else
488
+ self.train_prior_dist.__class__(full_mean, full_covar, self.train_prior_dist.power),
489
+ train_labels=full_targets,
490
+ likelihood=fant_likelihood,
491
+ uses_wiski=True,
492
+ )
493
+ add_to_cache(fant_strat, "interp_inner_prod", new_wmat)
494
+ add_to_cache(fant_strat, "interp_response_cache", new_interp_response_cache)
495
+ return fant_strat
496
+
497
+ def prepare_dense_wmat(self, covar=None):
498
+ # prepare the w matrix which is batch shape x m x n, where n = covar.shape[-2]
499
+ if covar is None:
500
+ covar = self.train_prior_dist.lazy_covariance_matrix
501
+ wmat = covar._sparse_left_interp_t(covar.left_interp_indices, covar.left_interp_values).to_dense()
502
+ return to_linear_operator(wmat)
503
+
504
+ @property
505
+ @cached(name="interp_inner_prod")
506
+ def interp_inner_prod(self):
507
+ # the W'W cache
508
+ wmat = self.prepare_dense_wmat()
509
+ noise_term = self.likelihood.noise_covar(wmat.transpose(-1, -2) if len(wmat.shape) > 2 else wmat)
510
+ interp_inner_prod = wmat.matmul(noise_term.solve(wmat.transpose(-1, -2)))
511
+ return interp_inner_prod
512
+
513
+ @property
514
+ @cached(name="interp_response_cache")
515
+ def interp_response_cache(self):
516
+ wmat = self.prepare_dense_wmat()
517
+ noise_term = self.likelihood.noise_covar(wmat.transpose(-1, -2) if len(wmat.shape) > 2 else wmat)
518
+ demeaned_train_targets = self.train_labels - self.train_prior_dist.mean
519
+ dinv_y = noise_term.solve(demeaned_train_targets.unsqueeze(-1))
520
+ return wmat.matmul(dinv_y)
521
+
522
+ @property
523
+ @cached(name="mean_cache")
524
+ def mean_cache(self):
525
+ train_train_covar = self.train_prior_dist.lazy_covariance_matrix
526
+ train_interp_indices = train_train_covar.left_interp_indices
527
+ train_interp_values = train_train_covar.left_interp_values
528
+
529
+ lik = self.likelihood(self.train_prior_dist, self.train_inputs)
530
+ train_mean, train_train_covar_with_noise = lik.mean, lik.lazy_covariance_matrix
531
+
532
+ mean_diff = (self.train_labels - train_mean).unsqueeze(-1)
533
+ train_train_covar_inv_labels = train_train_covar_with_noise.solve(mean_diff)
534
+
535
+ # New root factor
536
+ base_size = train_train_covar.base_linear_op.size(-1)
537
+ mean_cache = train_train_covar.base_linear_op.matmul(
538
+ left_t_interp(train_interp_indices, train_interp_values, train_train_covar_inv_labels, base_size)
539
+ )
540
+
541
+ # Prevent backprop through this variable
542
+ if settings.detach_test_caches.on():
543
+ return mean_cache.detach()
544
+ else:
545
+ return mean_cache
546
+
547
+ @property
548
+ @cached(name="fantasy_mean_cache")
549
+ def fantasy_mean_cache(self):
550
+ # first construct K_UU
551
+ train_train_covar = self.train_prior_dist.lazy_covariance_matrix
552
+ inducing_covar = train_train_covar.base_linear_op
553
+
554
+ # now get L such that LL' \approx WD^{-1}W'
555
+ interp_inner_prod_root = self.interp_inner_prod.root_decomposition(method="cholesky").root
556
+ # M = KL
557
+ inducing_compression_matrix = inducing_covar.matmul(interp_inner_prod_root)
558
+
559
+ # Q = L'KL + 1
560
+ current_qmatrix = interp_inner_prod_root.transpose(-1, -2).matmul(inducing_compression_matrix).add_jitter(1.0)
561
+
562
+ # m = K_UU WD^{-1}(y - \mu)
563
+ inducing_covar_response = inducing_covar.matmul(self.interp_response_cache)
564
+
565
+ # L' m
566
+ root_space_projection = interp_inner_prod_root.transpose(-1, -2).matmul(inducing_covar_response)
567
+ # Q^{-1} (L' m)
568
+ qmat_solve = current_qmatrix.solve(root_space_projection)
569
+
570
+ mean_cache = inducing_covar_response - inducing_compression_matrix @ qmat_solve
571
+
572
+ # Prevent backprop through this variable
573
+ if settings.detach_test_caches.on():
574
+ return mean_cache.detach()
575
+ else:
576
+ return mean_cache
577
+
578
+ @property
579
+ @cached(name="fantasy_covar_cache")
580
+ def fantasy_covar_cache(self):
581
+ train_train_covar = self.train_prior_dist.lazy_covariance_matrix
582
+ inducing_covar = train_train_covar.base_linear_op
583
+
584
+ # we need to enforce a cholesky here for numerical stability
585
+ interp_inner_prod_root = self.interp_inner_prod.root_decomposition(method="cholesky").root
586
+ inducing_compression_matrix = inducing_covar.matmul(interp_inner_prod_root)
587
+
588
+ current_qmatrix = interp_inner_prod_root.transpose(-1, -2).matmul(inducing_compression_matrix).add_jitter(1.0)
589
+
590
+ if settings.fast_pred_var.on():
591
+ qmat_inv_root = current_qmatrix.root_inv_decomposition()
592
+ # to to_linear_operator you have to evaluate the inverse root which is slow
593
+ # otherwise, you can't backprop your way through it
594
+ inner_cache = RootLinearOperator(inducing_compression_matrix.matmul(qmat_inv_root.root.to_dense()))
595
+ else:
596
+ inner_cache = inducing_compression_matrix.matmul(
597
+ current_qmatrix.solve(inducing_compression_matrix.transpose(-1, -2))
598
+ )
599
+
600
+ # Precomputed factor
601
+ if settings.fast_pred_samples.on():
602
+ predictive_covar_cache = inducing_covar - inner_cache
603
+ inside_root = predictive_covar_cache.root_decomposition(method="cholesky").root
604
+ # Prevent backprop through this variable
605
+ if settings.detach_test_caches.on():
606
+ inside_root = inside_root.detach()
607
+ covar_cache = inside_root, None
608
+ else:
609
+ root = inner_cache.root_decomposition(method="cholesky").root
610
+
611
+ # Prevent backprop through this variable
612
+ if settings.detach_test_caches.on():
613
+ root = root.detach()
614
+ covar_cache = None, root
615
+
616
+ return covar_cache
617
+
618
+ @property
619
+ @cached(name="covar_cache")
620
+ def covar_cache(self):
621
+ # Get inverse root
622
+ train_train_covar = self.train_prior_dist.lazy_covariance_matrix
623
+ train_interp_indices = train_train_covar.left_interp_indices
624
+ train_interp_values = train_train_covar.left_interp_values
625
+
626
+ # Get probe vectors for inverse root
627
+ num_probe_vectors = settings.fast_pred_var.num_probe_vectors()
628
+ num_inducing = train_train_covar.base_linear_op.size(-1)
629
+ vector_indices = torch.randperm(num_inducing).type_as(train_interp_indices)
630
+ probe_vector_indices = vector_indices[:num_probe_vectors]
631
+ test_vector_indices = vector_indices[num_probe_vectors : 2 * num_probe_vectors]
632
+
633
+ probe_interp_indices = probe_vector_indices.unsqueeze(1)
634
+ probe_test_interp_indices = test_vector_indices.unsqueeze(1)
635
+ dtype = train_train_covar.dtype
636
+ device = train_train_covar.device
637
+ probe_interp_values = torch.ones(num_probe_vectors, 1, dtype=dtype, device=device)
638
+
639
+ batch_shape = train_train_covar.base_linear_op.batch_shape
640
+ probe_vectors = InterpolatedLinearOperator(
641
+ train_train_covar.base_linear_op,
642
+ train_interp_indices.expand(*batch_shape, *train_interp_indices.shape[-2:]),
643
+ train_interp_values.expand(*batch_shape, *train_interp_values.shape[-2:]),
644
+ probe_interp_indices.expand(*batch_shape, *probe_interp_indices.shape[-2:]),
645
+ probe_interp_values.expand(*batch_shape, *probe_interp_values.shape[-2:]),
646
+ ).to_dense()
647
+ test_vectors = InterpolatedLinearOperator(
648
+ train_train_covar.base_linear_op,
649
+ train_interp_indices.expand(*batch_shape, *train_interp_indices.shape[-2:]),
650
+ train_interp_values.expand(*batch_shape, *train_interp_values.shape[-2:]),
651
+ probe_test_interp_indices.expand(*batch_shape, *probe_test_interp_indices.shape[-2:]),
652
+ probe_interp_values.expand(*batch_shape, *probe_interp_values.shape[-2:]),
653
+ ).to_dense()
654
+
655
+ # Put data through the likelihood
656
+ dist = self.train_prior_dist.__class__(
657
+ torch.zeros_like(self.train_prior_dist.mean), self.train_prior_dist.lazy_covariance_matrix
658
+ )
659
+ if hasattr(self.train_prior_dist, 'power'): dist.power = self.train_prior_dist.power
660
+ train_train_covar_plus_noise = self.likelihood(dist, self.train_inputs).lazy_covariance_matrix
661
+
662
+ # Get inverse root
663
+ train_train_covar_inv_root = train_train_covar_plus_noise.root_inv_decomposition(
664
+ initial_vectors=probe_vectors, test_vectors=test_vectors
665
+ ).root
666
+ train_train_covar_inv_root = train_train_covar_inv_root.to_dense()
667
+
668
+ # New root factor
669
+ root = self._exact_predictive_covar_inv_quad_form_cache(train_train_covar_inv_root, self._last_test_train_covar)
670
+
671
+ # Precomputed factor
672
+ if settings.fast_pred_samples.on():
673
+ inside = train_train_covar.base_linear_op + RootLinearOperator(root).mul(-1)
674
+ inside_root = inside.root_decomposition().root.to_dense()
675
+ # Prevent backprop through this variable
676
+ if settings.detach_test_caches.on():
677
+ inside_root = inside_root.detach()
678
+ covar_cache = inside_root, None
679
+ else:
680
+ # Prevent backprop through this variable
681
+ if settings.detach_test_caches.on():
682
+ root = root.detach()
683
+ covar_cache = None, root
684
+
685
+ return covar_cache
686
+
687
+ def exact_prediction(self, joint_mean, joint_covar):
688
+ # Find the components of the distribution that contain test data
689
+ test_mean = joint_mean[..., self.num_train :]
690
+ test_test_covar = joint_covar[..., self.num_train :, self.num_train :].evaluate_kernel()
691
+ test_train_covar = joint_covar[..., self.num_train :, : self.num_train].evaluate_kernel()
692
+
693
+ return (
694
+ self.exact_predictive_mean(test_mean, test_train_covar),
695
+ self.exact_predictive_covar(test_test_covar, test_train_covar),
696
+ )
697
+
698
+ def exact_predictive_mean(self, test_mean, test_train_covar):
699
+ precomputed_cache = self.fantasy_mean_cache if self.uses_wiski else self.mean_cache
700
+ test_interp_indices = test_train_covar.left_interp_indices
701
+ test_interp_values = test_train_covar.left_interp_values
702
+ res = left_interp(test_interp_indices, test_interp_values, precomputed_cache).squeeze(-1) + test_mean
703
+ return res
704
+
705
+ def exact_predictive_covar(self, test_test_covar, test_train_covar):
706
+ if settings.fast_pred_var.off() and settings.fast_pred_samples.off():
707
+ return super(InterpolatedPredictionStrategy, self).exact_predictive_covar(test_test_covar, test_train_covar)
708
+
709
+ self._last_test_train_covar = test_train_covar
710
+ test_interp_indices = test_train_covar.left_interp_indices
711
+ test_interp_values = test_train_covar.left_interp_values
712
+
713
+ if self.uses_wiski:
714
+ precomputed_cache = self.fantasy_covar_cache
715
+ fps = settings.fast_pred_samples.on()
716
+ if fps:
717
+ root = left_interp(test_interp_indices, test_interp_values, precomputed_cache[0].to_dense())
718
+ res = RootLinearOperator(root)
719
+ else:
720
+ root = left_interp(test_interp_indices, test_interp_values, precomputed_cache[1].to_dense())
721
+ res = test_test_covar + RootLinearOperator(root).mul(-1)
722
+ return res
723
+ else:
724
+ precomputed_cache = self.covar_cache
725
+ fps = settings.fast_pred_samples.on()
726
+ if (fps and precomputed_cache[0] is None) or (not fps and precomputed_cache[1] is None):
727
+ pop_from_cache(self, "covar_cache")
728
+ precomputed_cache = self.covar_cache
729
+
730
+ # Compute the exact predictive posterior
731
+ if settings.fast_pred_samples.on():
732
+ res = self._exact_predictive_covar_inv_quad_form_root(precomputed_cache[0], test_train_covar)
733
+ res = RootLinearOperator(res)
734
+ else:
735
+ root = left_interp(test_interp_indices, test_interp_values, precomputed_cache[1])
736
+ res = test_test_covar + RootLinearOperator(root).mul(-1)
737
+ return res
738
+
739
+
740
+ class RFFPredictionStrategy(DefaultPredictionStrategy):
741
+ def __init__(self, train_inputs, train_prior_dist, train_labels, likelihood):
742
+ super().__init__(train_inputs, train_prior_dist, train_labels, likelihood)
743
+ args = (self.train_prior_dist.mean, self.train_prior_dist.lazy_covariance_matrix.evaluate_kernel())
744
+ if hasattr(self.train_prior_dist, 'power'): args = args+(self.train_prior_dist.power,)
745
+ self.train_prior_dist = self.train_prior_dist.__class__(*args)
746
+
747
+ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_output, **kwargs):
748
+ raise NotImplementedError("Fantasy observation updates not yet supported for models using RFFs")
749
+
750
+ @property
751
+ @cached(name="covar_cache")
752
+ def covar_cache(self):
753
+ lt = self.train_prior_dist.lazy_covariance_matrix
754
+ if isinstance(lt, ConstantMulLinearOperator):
755
+ constant = lt.expanded_constant
756
+ lt = lt.base_linear_op
757
+ else:
758
+ constant = torch.tensor(1.0, dtype=lt.dtype, device=lt.device)
759
+
760
+ train_factor = lt.root.to_dense()
761
+ train_train_covar = self.lik_train_train_covar
762
+ inner_term = (
763
+ torch.eye(train_factor.size(-1), dtype=train_factor.dtype, device=train_factor.device)
764
+ - (train_factor.transpose(-1, -2) @ train_train_covar.solve(train_factor)) * constant
765
+ )
766
+ return psd_safe_cholesky(inner_term)
767
+
768
+ def exact_prediction(self, joint_mean, joint_covar):
769
+ # Find the components of the distribution that contain test data
770
+ test_mean = joint_mean[..., self.num_train :]
771
+ test_test_covar = joint_covar[..., self.num_train :, self.num_train :].evaluate_kernel()
772
+ test_train_covar = joint_covar[..., self.num_train :, : self.num_train].evaluate_kernel()
773
+
774
+ return (
775
+ self.exact_predictive_mean(test_mean, test_train_covar),
776
+ self.exact_predictive_covar(test_test_covar, test_train_covar),
777
+ )
778
+
779
+ def exact_predictive_covar(self, test_test_covar, test_train_covar):
780
+ if settings.skip_posterior_variances.on():
781
+ return ZeroLinearOperator(*test_test_covar.size())
782
+
783
+ if isinstance(test_test_covar, ConstantMulLinearOperator):
784
+ constant = test_test_covar.expanded_constant
785
+ test_test_covar = test_test_covar.base_linear_op
786
+ else:
787
+ constant = torch.tensor(1.0, dtype=test_test_covar.dtype, device=test_test_covar.device)
788
+
789
+ covar_cache = self.covar_cache
790
+ factor = test_test_covar.root.to_dense() * constant.sqrt()
791
+ res = RootLinearOperator(factor @ covar_cache)
792
+ return res
793
+
794
+
795
+ class SQEPRPredictionStrategy(DefaultPredictionStrategy):
796
+ @property
797
+ @cached(name="covar_cache")
798
+ def covar_cache(self):
799
+ # Here, the covar_cache is going to be K_{UU}^{-1/2} K_{UX}( K_{XX} + \sigma^2 I )^{-1} K_{XU} K_{UU}^{-1/2}
800
+ # This is easily computed using Woodbury
801
+ # K_{XX} + \sigma^2 I = R R^T + \sigma^2 I
802
+ # = \sigma^{-2} ( I - \sigma^{-2} R (I + \sigma^{-2} R^T R)^{-1} R^T )
803
+ train_train_covar = self.lik_train_train_covar.evaluate_kernel()
804
+
805
+ # Get terms needed for woodbury
806
+ root = train_train_covar._linear_op.root_decomposition().root.to_dense() # R
807
+ inv_diag = train_train_covar._diag_tensor.inverse() # \sigma^{-2}
808
+
809
+ # Form LT using woodbury
810
+ ones = torch.tensor(1.0, dtype=root.dtype, device=root.device)
811
+ chol_factor = to_linear_operator(root.transpose(-1, -2) @ (inv_diag @ root)).add_diagonal(
812
+ ones
813
+ ) # (I + \sigma^{-2} R^T R)^{-1}
814
+ woodbury_term = inv_diag @ torch.linalg.solve_triangular(
815
+ chol_factor.cholesky().to_dense(), root.transpose(-1, -2), upper=False
816
+ ).transpose(-1, -2)
817
+ # woodbury_term @ woodbury_term^T = \sigma^{-2} R (I + \sigma^{-2} R^T R)^{-1} R^T \sigma^{-2}
818
+
819
+ inverse = AddedDiagLinearOperator(
820
+ inv_diag, MatmulLinearOperator(-woodbury_term, woodbury_term.transpose(-1, -2))
821
+ )
822
+ # \sigma^{-2} ( I - \sigma^{-2} R (I + \sigma^{-2} R^T R)^{-1} R^T )
823
+
824
+ return root.transpose(-1, -2) @ (inverse @ root)
825
+
826
+ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_output, **kwargs):
827
+ raise NotImplementedError(
828
+ "Fantasy observation updates not yet supported for models using SQEPRPredictionStrategy"
829
+ )
830
+
831
+ def exact_prediction(self, joint_mean, joint_covar):
832
+ from ..kernels import InducingPointKernel
833
+
834
+ # Find the components of the distribution that contain test data
835
+ test_mean = joint_mean[..., self.num_train :]
836
+
837
+ # If we're in lazy evaluation mode, let's use the base kernel of the SQEPR output to compute the prior covar
838
+ test_test_covar = joint_covar[..., self.num_train :, self.num_train :]
839
+ if isinstance(test_test_covar, LazyEvaluatedKernelTensor) and isinstance(
840
+ test_test_covar.kernel, InducingPointKernel
841
+ ):
842
+ test_test_covar = LazyEvaluatedKernelTensor(
843
+ test_test_covar.x1,
844
+ test_test_covar.x2,
845
+ test_test_covar.kernel.base_kernel,
846
+ test_test_covar.last_dim_is_batch,
847
+ **test_test_covar.params,
848
+ )
849
+
850
+ test_train_covar = joint_covar[..., self.num_train :, : self.num_train].evaluate_kernel()
851
+
852
+ return (
853
+ self.exact_predictive_mean(test_mean, test_train_covar),
854
+ self.exact_predictive_covar(test_test_covar, test_train_covar),
855
+ )
856
+
857
+ def exact_predictive_covar(self, test_test_covar, test_train_covar):
858
+ covar_cache = self.covar_cache
859
+ # covar_cache = K_{UU}^{-1/2} K_{UX}( K_{XX} + \sigma^2 I )^{-1} K_{XU} K_{UU}^{-1/2}
860
+
861
+ # Decompose test_train_covar = l, r
862
+ # Main case: test_x and train_x are different - test_train_covar is a MatmulLinearOperator
863
+ if isinstance(test_train_covar, MatmulLinearOperator):
864
+ L = test_train_covar.left_linear_op.to_dense()
865
+ # Edge case: test_x and train_x are the same - test_train_covar is a LowRankRootAddedDiagLinearOperator
866
+ elif isinstance(test_train_covar, LowRankRootAddedDiagLinearOperator):
867
+ L = test_train_covar._linear_op.root.to_dense()
868
+ else:
869
+ # We should not hit this point of the code - this is to catch potential bugs in GPyTorch
870
+ raise ValueError(
871
+ "Expected SQEPR output to be a MatmulLinearOperator or AddedDiagLinearOperator. "
872
+ f"Got {test_train_covar.__class__.__name__} instead. "
873
+ "This is likely a bug in GPyTorch."
874
+ )
875
+
876
+ res = test_test_covar - MatmulLinearOperator(L, covar_cache @ L.mT)
877
+ return res
878
+
879
+ class SGPRPredictionStrategy(SQEPRPredictionStrategy):
880
+ pass