qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,391 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import functools
4
+ from abc import ABC, abstractproperty
5
+ from copy import deepcopy
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ from linear_operator.operators import LinearOperator
10
+ from torch import Tensor
11
+
12
+ from .. import settings
13
+ from ..distributions import Delta, Distribution, MultivariateNormal, MultivariateQExponential
14
+ from ..kernels import Kernel
15
+ from ..likelihoods import GaussianLikelihood, QExponentialLikelihood
16
+ from ..means import Mean
17
+ from ..models import ApproximateGP, ApproximateQEP, ExactGP, ExactQEP
18
+ from ..models.exact_prediction_strategies import DefaultPredictionStrategy
19
+ from ..module import Module
20
+ from gpytorch.utils.memoize import add_to_cache, cached, clear_cache_hook
21
+ from . import _VariationalDistribution
22
+
23
+
24
+ class _BaseExactGP(ExactGP):
25
+ def __init__(
26
+ self,
27
+ train_inputs: Optional[Union[Tensor, Tuple[Tensor, ...]]],
28
+ train_targets: Optional[Tensor],
29
+ likelihood: GaussianLikelihood,
30
+ mean_module: Mean,
31
+ covar_module: Kernel,
32
+ ):
33
+ super().__init__(train_inputs, train_targets, likelihood)
34
+ self.mean_module = mean_module
35
+ self.covar_module = covar_module
36
+
37
+ def forward(self, x: Tensor, **kwargs) -> MultivariateNormal:
38
+ mean = self.mean_module(x)
39
+ covar = self.covar_module(x)
40
+ return MultivariateNormal(mean, covar)
41
+
42
+
43
+ class _BaseExactQEP(ExactQEP):
44
+ def __init__(
45
+ self,
46
+ train_inputs: Optional[Union[Tensor, Tuple[Tensor, ...]]],
47
+ train_targets: Optional[Tensor],
48
+ likelihood: QExponentialLikelihood,
49
+ mean_module: Mean,
50
+ covar_module: Kernel,
51
+ ):
52
+ super().__init__(train_inputs, train_targets, likelihood)
53
+ self.mean_module = mean_module
54
+ self.covar_module = covar_module
55
+
56
+ def forward(self, x: Tensor, **kwargs) -> MultivariateQExponential:
57
+ mean = self.mean_module(x)
58
+ covar = self.covar_module(x)
59
+ return MultivariateQExponential(mean, covar, power=self.likelihood.power)
60
+
61
+
62
+ def _add_cache_hook(tsr: Tensor, pred_strat: DefaultPredictionStrategy) -> Tensor:
63
+ if tsr.grad_fn is not None:
64
+ wrapper = functools.partial(clear_cache_hook, pred_strat)
65
+ functools.update_wrapper(wrapper, clear_cache_hook)
66
+ tsr.grad_fn.register_hook(wrapper)
67
+ return tsr
68
+
69
+
70
+ class _VariationalStrategy(Module, ABC):
71
+ """
72
+ Abstract base class for all Variational Strategies.
73
+ """
74
+
75
+ has_fantasy_strategy = False
76
+
77
+ def __init__(
78
+ self,
79
+ model: Union[ApproximateGP, ApproximateQEP, "_VariationalStrategy"],
80
+ inducing_points: Tensor,
81
+ variational_distribution: _VariationalDistribution,
82
+ learn_inducing_locations: bool = True,
83
+ jitter_val: Optional[float] = None,
84
+ ):
85
+ super().__init__()
86
+
87
+ self._jitter_val = jitter_val
88
+
89
+ # Model
90
+ object.__setattr__(self, "model", model)
91
+
92
+ # Inducing points
93
+ inducing_points = inducing_points.clone()
94
+ if inducing_points.dim() == 1:
95
+ inducing_points = inducing_points.unsqueeze(-1)
96
+ if learn_inducing_locations:
97
+ self.register_parameter(name="inducing_points", parameter=torch.nn.Parameter(inducing_points))
98
+ else:
99
+ self.register_buffer("inducing_points", inducing_points)
100
+
101
+ # Variational distribution
102
+ self._variational_distribution = variational_distribution
103
+ self.register_buffer("variational_params_initialized", torch.tensor(0))
104
+
105
+ def _clear_cache(self) -> None:
106
+ clear_cache_hook(self)
107
+
108
+ def _expand_inputs(self, x: Tensor, inducing_points: Tensor) -> Tuple[Tensor, Tensor]:
109
+ """
110
+ Pre-processing step in __call__ to make x the same batch_shape as the inducing points
111
+ """
112
+ batch_shape = torch.broadcast_shapes(inducing_points.shape[:-2], x.shape[:-2])
113
+ inducing_points = inducing_points.expand(*batch_shape, *inducing_points.shape[-2:])
114
+ x = x.expand(*batch_shape, *x.shape[-2:])
115
+ return x, inducing_points
116
+
117
+ @property
118
+ def jitter_val(self) -> float:
119
+ if self._jitter_val is None:
120
+ return settings.variational_cholesky_jitter.value(dtype=self.inducing_points.dtype)
121
+ return self._jitter_val
122
+
123
+ @jitter_val.setter
124
+ def jitter_val(self, jitter_val: float):
125
+ self._jitter_val = jitter_val
126
+
127
+ @abstractproperty
128
+ @cached(name="prior_distribution_memo")
129
+ def prior_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
130
+ r"""
131
+ The :func:`~qpytorch.variational.VariationalStrategy.prior_distribution` method determines how to compute the
132
+ GP (QEP) prior distribution of the inducing points, e.g. :math:`p(u) \sim N(\mu(X_u), K(X_u, X_u))` or :math:`p(u) \sim QED(\mu(X_u), K(X_u, X_u))`.
133
+ Most commonly, this is done simply by calling the user defined GP (QEP) prior on the inducing point data directly.
134
+
135
+ :rtype: :obj:`~gpytorch.distributions.MultivariateNormal` or :obj:`~qpytorch.distributions.MultivariateQExponential`
136
+ :return: The distribution :math:`p( \mathbf u)`
137
+ """
138
+ raise NotImplementedError
139
+
140
+ @property
141
+ @cached(name="variational_distribution_memo")
142
+ def variational_distribution(self) -> Distribution:
143
+ return self._variational_distribution()
144
+
145
+ def forward(
146
+ self,
147
+ x: Tensor,
148
+ inducing_points: Tensor,
149
+ inducing_values: Tensor,
150
+ variational_inducing_covar: Optional[LinearOperator] = None,
151
+ **kwargs,
152
+ ) -> Union[MultivariateNormal, MultivariateQExponential]:
153
+ r"""
154
+ The :func:`~qpytorch.variational.VariationalStrategy.forward` method determines how to marginalize out the
155
+ inducing point function values. Specifically, forward defines how to transform a variational distribution
156
+ over the inducing point values, :math:`q(u)`, in to a variational distribution over the function values at
157
+ specified locations x, :math:`q(f|x)`, by integrating :math:`\int p(f|x, u)q(u)du`
158
+
159
+ :param x: Locations :math:`\mathbf X` to get the
160
+ variational posterior of the function values at.
161
+ :param inducing_points: Locations :math:`\mathbf Z` of the inducing points
162
+ :param inducing_values: Samples of the inducing function values :math:`\mathbf u`
163
+ (or the mean of the distribution :math:`q(\mathbf u)` if q is a Gaussian or Q-Exponential.
164
+ :param variational_inducing_covar: If the distribuiton :math:`q(\mathbf u)` is
165
+ Gaussian (Q-Exponential), then this variable is the covariance matrix of that Gaussian (Q-Exponential).
166
+ Otherwise, it will be None.
167
+
168
+ :rtype: :obj:`~gpytorch.distributions.MultivariateNormal` (`~qpytorch.distributions.MultivariateQExponential`)
169
+ :return: The distribution :math:`q( \mathbf f(\mathbf X))`
170
+ """
171
+ raise NotImplementedError
172
+
173
+ def kl_divergence(self) -> Tensor:
174
+ r"""
175
+ Compute the KL divergence between the variational inducing distribution :math:`q(\mathbf u)`
176
+ and the prior inducing distribution :math:`p(\mathbf u)`.
177
+ """
178
+ with settings.max_preconditioner_size(0):
179
+ kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
180
+ return kl_divergence
181
+
182
+ @cached(name="amortized_exact_")
183
+ def amortized_exact_(
184
+ self, mean_module: Optional[Module] = None, covar_module: Optional[Module] = None
185
+ ) -> Union[ExactGP, ExactQEP]:
186
+ mean_module = self.model.mean_module if mean_module is None else mean_module
187
+ covar_module = self.model.covar_module if covar_module is None else covar_module
188
+
189
+ with torch.no_grad():
190
+ # from here on down, we refer to the inducing points as pseudo_inputs
191
+ pseudo_target_covar, pseudo_target_mean = self.pseudo_points
192
+ pseudo_inputs = self.inducing_points.detach()
193
+ if pseudo_inputs.ndim < pseudo_target_mean.ndim:
194
+ pseudo_inputs = pseudo_inputs.expand(*pseudo_target_mean.shape[:-2], *pseudo_inputs.shape)
195
+ # TODO: add flag for conditioning into SGPR after building fantasy strategy for SGPR
196
+ new_covar_module = deepcopy(covar_module)
197
+
198
+ # update inducing mean if necessary
199
+ pseudo_target_mean = pseudo_target_mean.squeeze() + mean_module(pseudo_inputs)
200
+
201
+ if 'Gaussian' in self.model.likelihood.__class__.__name__:
202
+ inducing_exact_model = _BaseExactGP(
203
+ pseudo_inputs,
204
+ pseudo_target_mean,
205
+ mean_module=deepcopy(mean_module),
206
+ covar_module=new_covar_module,
207
+ likelihood=deepcopy(self.model.likelihood),
208
+ )
209
+ elif 'QExponential' in self.model.likelihood.__class__.__name__:
210
+ inducing_exact_model = _BaseExactQEP(
211
+ pseudo_inputs,
212
+ pseudo_target_mean,
213
+ mean_module=deepcopy(mean_module),
214
+ covar_module=new_covar_module,
215
+ likelihood=deepcopy(self.model.likelihood),
216
+ )
217
+ else:
218
+ raise RuntimeError("Exact model can only handle Gaussian or Q-Exponential likelihoods")
219
+
220
+ # now fantasize around this model
221
+ # as this model is new, we need to compute a posterior to construct the prediction strategy
222
+ # which uses the likelihood pseudo caches
223
+ faked_points = torch.randn(
224
+ *pseudo_target_mean.shape[:-2],
225
+ 1,
226
+ pseudo_inputs.shape[-1],
227
+ device=pseudo_inputs.device,
228
+ dtype=pseudo_inputs.dtype,
229
+ )
230
+ inducing_exact_model.eval()
231
+ _ = inducing_exact_model(faked_points)
232
+
233
+ # then we overwrite the likelihood to take into account the multivariate normal term
234
+ pred_strat = inducing_exact_model.prediction_strategy
235
+ pred_strat._memoize_cache = {}
236
+ with torch.no_grad():
237
+ updated_lik_train_train_covar = pred_strat.train_prior_dist.lazy_covariance_matrix + pseudo_target_covar
238
+ pred_strat.lik_train_train_covar = updated_lik_train_train_covar
239
+
240
+ # do the mean cache because the mean cache doesn't solve against lik_train_train_covar
241
+ train_mean = inducing_exact_model.mean_module(*inducing_exact_model.train_inputs)
242
+ train_labels_offset = (inducing_exact_model.prediction_strategy.train_labels - train_mean).unsqueeze(-1)
243
+ mean_cache = updated_lik_train_train_covar.solve(train_labels_offset).squeeze(-1)
244
+ mean_cache = _add_cache_hook(mean_cache, inducing_exact_model.prediction_strategy)
245
+ add_to_cache(pred_strat, "mean_cache", mean_cache)
246
+ # TODO: check to see if we need to do the covar_cache?
247
+
248
+ inducing_exact_model.prediction_strategy = pred_strat
249
+ return inducing_exact_model
250
+
251
+ def pseudo_points(self) -> Tuple[Tensor, Tensor]:
252
+ raise NotImplementedError("Each variational strategy must implement its own pseudo points method")
253
+
254
+ def get_fantasy_model(
255
+ self,
256
+ inputs: Tensor,
257
+ targets: Tensor,
258
+ mean_module: Optional[Module] = None,
259
+ covar_module: Optional[Module] = None,
260
+ **kwargs,
261
+ ) -> Union[ExactGP, ExactQEP]:
262
+ r"""
263
+ Performs the online variational conditioning (OVC) strategy of Maddox et al, '21 to return
264
+ an exact GP (QEP) model that incorporates the inputs and targets alongside the variational model's inducing
265
+ points and targets.
266
+
267
+ Currently, instead of directly updating the variational parameters (and inducing points), we instead
268
+ return an ExactGP (ExactQEP) model rather than an updated variational GP (QEP) model. This is done primarily for
269
+ numerical stability.
270
+
271
+ Unlike the ExactGP's (ExactQEP's) call for get_fantasy_model, we enable options for mean_module and covar_module
272
+ that allow specification of the mean / covariance. We expect that either the mean and covariance
273
+ modules are attributes of the model itself called mean_module and covar_module respectively OR that you
274
+ pass them into this method explicitly.
275
+
276
+ :param inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
277
+ observations.
278
+ :param targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
279
+ :param mean_module: torch module describing the mean function of the GP (QEP) model. Optional if
280
+ `mean_module` is already an attribute of the variational GP (QEP).
281
+ :param covar_module: torch module describing the covariance function of the GP (QEP) model. Optional
282
+ if `covar_module` is already an attribute of the variational GP (QEP).
283
+ :return: An `ExactGP` (`ExactQEP`) model with `k + m` training examples, where the `m` fantasy examples have been added
284
+ and all test-time caches have been updated. We assume that there are `k` inducing points in this variational
285
+ GP (QEP). Note that we return an `ExactGP` rather than a variational GP (QEP).
286
+
287
+ Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
288
+ Maddox, Stanton, Wilson, NeurIPS, '21
289
+ https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html
290
+ """
291
+
292
+ # currently, we only support fantasization for CholeskyVariationalDistribution and
293
+ # whitened / unwhitened variational strategies
294
+ if not self.has_fantasy_strategy:
295
+ raise NotImplementedError(
296
+ f"No fantasy model support for {self.__class__.__name__}. "
297
+ "Only VariationalStrategy and UnwhitenedVariationalStrategy are currently supported."
298
+ )
299
+ else:
300
+ from . import CholeskyVariationalDistribution # Circular import otherwise
301
+
302
+ if not isinstance(self._variational_distribution, CholeskyVariationalDistribution):
303
+ raise NotImplementedError(
304
+ "Fantasy models are only support for variational models with CholeskyVariationalDistribution."
305
+ )
306
+
307
+ if not isinstance(self.model.likelihood, (GaussianLikelihood, QExponentialLikelihood)):
308
+ raise NotImplementedError(
309
+ f"No fantasy model support for {self.model.likelihood.__class__.__name__}. "
310
+ "Only GaussianLikelihoods and QExponentialLikelihoods are currently supported."
311
+ )
312
+ # we assume that either the user has given the model a mean_module and a covar_module
313
+ # or that it will be passed into the get_fantasy_model function. we check for these.
314
+ if mean_module is None:
315
+ mean_module = getattr(self.model, "mean_module", None)
316
+ if mean_module is None:
317
+ raise ModuleNotFoundError(
318
+ "Either you must provide a mean_module as input to get_fantasy_model "
319
+ "or it must be an attribute of the model called mean_module."
320
+ )
321
+ if covar_module is None:
322
+ covar_module = getattr(self.model, "covar_module", None)
323
+ if covar_module is None:
324
+ # raise an error
325
+ raise ModuleNotFoundError(
326
+ "Either you must provide a covar_module as input to get_fantasy_model "
327
+ "or it must be an attribute of the model called covar_module."
328
+ )
329
+
330
+ # first we construct an exact model over the inducing points with the inducing covariance
331
+ # matrix
332
+ inducing_exact_model = self.amortized_exact_(mean_module=mean_module, covar_module=covar_module)
333
+
334
+ # then we update this model by adding in the inputs and pseudo targets
335
+ # finally we fantasize wrt targets
336
+ fantasy_model = inducing_exact_model.get_fantasy_model(inputs, targets, **kwargs)
337
+ fant_pred_strat = fantasy_model.prediction_strategy
338
+
339
+ # first we update the lik_train_train_covar
340
+ # do the mean cache again because the mean cache resets the likelihood forward
341
+ train_mean = fantasy_model.mean_module(*fantasy_model.train_inputs)
342
+ train_labels_offset = (fant_pred_strat.train_labels - train_mean).unsqueeze(-1)
343
+ fantasy_lik_train_root_inv = fant_pred_strat.lik_train_train_covar.root_inv_decomposition()
344
+ mean_cache = fantasy_lik_train_root_inv.matmul(train_labels_offset).squeeze(-1)
345
+ mean_cache = _add_cache_hook(mean_cache, fant_pred_strat)
346
+ add_to_cache(fant_pred_strat, "mean_cache", mean_cache)
347
+ # TODO: should we update the covar_cache?
348
+
349
+ fantasy_model.prediction_strategy = fant_pred_strat
350
+ return fantasy_model
351
+
352
+ def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> Union[MultivariateNormal, MultivariateQExponential]:
353
+ # If we're in prior mode, then we're done!
354
+ if prior:
355
+ return self.model.forward(x, **kwargs)
356
+
357
+ # Delete previously cached items from the training distribution
358
+ if self.training:
359
+ self._clear_cache()
360
+ # (Maybe) initialize variational distribution
361
+ if not self.variational_params_initialized.item():
362
+ prior_dist = self.prior_distribution
363
+ self._variational_distribution.initialize_variational_distribution(prior_dist)
364
+ self.variational_params_initialized.fill_(1)
365
+
366
+ # Ensure inducing_points and x are the same size
367
+ inducing_points = self.inducing_points
368
+ if inducing_points.shape[:-2] != x.shape[:-2]:
369
+ x, inducing_points = self._expand_inputs(x, inducing_points)
370
+
371
+ # Get p(u)/q(u)
372
+ variational_dist_u = self.variational_distribution
373
+
374
+ # Get q(f)
375
+ if isinstance(variational_dist_u, (MultivariateNormal, MultivariateQExponential)):
376
+ return super().__call__(
377
+ x,
378
+ inducing_points,
379
+ inducing_values=variational_dist_u.mean,
380
+ variational_inducing_covar=variational_dist_u.lazy_covariance_matrix,
381
+ **kwargs,
382
+ )
383
+ elif isinstance(variational_dist_u, Delta):
384
+ return super().__call__(
385
+ x, inducing_points, inducing_values=variational_dist_u.mean, variational_inducing_covar=None, **kwargs
386
+ )
387
+ else:
388
+ raise RuntimeError(
389
+ f"Invalid variational distribuition ({type(variational_dist_u)}). "
390
+ "Expected a multivariate normal (q-exponential) or a delta distribution."
391
+ )
@@ -0,0 +1,90 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Iterable, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from linear_operator.operators import LinearOperator
7
+ from torch import LongTensor, Tensor
8
+
9
+ from ..distributions import Delta, MultivariateNormal, MultivariateQExponential
10
+ from ..models import ApproximateGP, ApproximateQEP
11
+ from ._variational_distribution import _VariationalDistribution
12
+ from .grid_interpolation_variational_strategy import GridInterpolationVariationalStrategy
13
+
14
+
15
+ class AdditiveGridInterpolationVariationalStrategy(GridInterpolationVariationalStrategy):
16
+ def __init__(
17
+ self,
18
+ model: Union[ApproximateGP, ApproximateQEP],
19
+ grid_size: int,
20
+ grid_bounds: Iterable[Tuple[float, float]],
21
+ num_dim: int,
22
+ variational_distribution: _VariationalDistribution,
23
+ mixing_params: bool = False,
24
+ sum_output: bool = True,
25
+ ):
26
+ super(AdditiveGridInterpolationVariationalStrategy, self).__init__(
27
+ model, grid_size, grid_bounds, variational_distribution
28
+ )
29
+ self.num_dim = num_dim
30
+ self.sum_output = sum_output
31
+ # Mixing parameters
32
+ if mixing_params:
33
+ self.register_parameter(name="mixing_params", parameter=torch.nn.Parameter(torch.ones(num_dim) / num_dim))
34
+
35
+ @property
36
+ def prior_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
37
+ # If desired, models can compare the input to forward to inducing_points and use a GridKernel for space
38
+ # efficiency.
39
+ # However, when using a default VariationalDistribution which has an O(m^2) space complexity anyways,
40
+ # we find that GridKernel is typically not worth it due to the moderate slow down of using FFTs.
41
+ out = super(AdditiveGridInterpolationVariationalStrategy, self).prior_distribution
42
+ mean = out.mean.repeat(self.num_dim, 1)
43
+ covar = out.lazy_covariance_matrix.repeat(self.num_dim, 1, 1)
44
+ if hasattr(self.model, 'power'):
45
+ return MultivariateQExponential(mean, covar, power=self.model.power)
46
+ else:
47
+ return MultivariateNormal(mean, covar)
48
+
49
+ def _compute_grid(self, inputs: Tensor) -> Tuple[LongTensor, Tensor]:
50
+ num_data, num_dim = inputs.size()
51
+ inputs = inputs.transpose(0, 1).reshape(-1, 1)
52
+ interp_indices, interp_values = super(AdditiveGridInterpolationVariationalStrategy, self)._compute_grid(inputs)
53
+ interp_indices = interp_indices.view(num_dim, num_data, -1)
54
+ interp_values = interp_values.view(num_dim, num_data, -1)
55
+
56
+ if hasattr(self, "mixing_params"):
57
+ interp_values = interp_values.mul(self.mixing_params.unsqueeze(1).unsqueeze(2))
58
+ return interp_indices, interp_values
59
+
60
+ def forward(
61
+ self,
62
+ x: Tensor,
63
+ inducing_points: Tensor,
64
+ inducing_values: Tensor,
65
+ variational_inducing_covar: Optional[LinearOperator] = None,
66
+ *params,
67
+ **kwargs,
68
+ ) -> Union[MultivariateNormal, MultivariateQExponential]:
69
+ if x.ndimension() == 1:
70
+ x = x.unsqueeze(-1)
71
+ elif x.ndimension() != 2:
72
+ raise RuntimeError("AdditiveGridInterpolationVariationalStrategy expects a 2d tensor.")
73
+
74
+ num_data, num_dim = x.size()
75
+ if num_dim != self.num_dim:
76
+ raise RuntimeError("The number of dims should match the number specified.")
77
+
78
+ output = super().forward(x, inducing_points, inducing_values, variational_inducing_covar)
79
+ if self.sum_output:
80
+ if variational_inducing_covar is not None:
81
+ mean = output.mean.sum(0)
82
+ covar = output.lazy_covariance_matrix.sum(-3)
83
+ if hasattr(self.model, 'power'):
84
+ return MultivariateQExponential(mean, covar, power=self.model.power)
85
+ else:
86
+ return MultivariateNormal(mean, covar)
87
+ else:
88
+ return Delta(output.mean.sum(0))
89
+ else:
90
+ return output