qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from abc import ABC, abstractproperty
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from typing import Optional, Tuple, Union
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from linear_operator.operators import LinearOperator
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
|
|
12
|
+
from .. import settings
|
|
13
|
+
from ..distributions import Delta, Distribution, MultivariateNormal, MultivariateQExponential
|
|
14
|
+
from ..kernels import Kernel
|
|
15
|
+
from ..likelihoods import GaussianLikelihood, QExponentialLikelihood
|
|
16
|
+
from ..means import Mean
|
|
17
|
+
from ..models import ApproximateGP, ApproximateQEP, ExactGP, ExactQEP
|
|
18
|
+
from ..models.exact_prediction_strategies import DefaultPredictionStrategy
|
|
19
|
+
from ..module import Module
|
|
20
|
+
from gpytorch.utils.memoize import add_to_cache, cached, clear_cache_hook
|
|
21
|
+
from . import _VariationalDistribution
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class _BaseExactGP(ExactGP):
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
train_inputs: Optional[Union[Tensor, Tuple[Tensor, ...]]],
|
|
28
|
+
train_targets: Optional[Tensor],
|
|
29
|
+
likelihood: GaussianLikelihood,
|
|
30
|
+
mean_module: Mean,
|
|
31
|
+
covar_module: Kernel,
|
|
32
|
+
):
|
|
33
|
+
super().__init__(train_inputs, train_targets, likelihood)
|
|
34
|
+
self.mean_module = mean_module
|
|
35
|
+
self.covar_module = covar_module
|
|
36
|
+
|
|
37
|
+
def forward(self, x: Tensor, **kwargs) -> MultivariateNormal:
|
|
38
|
+
mean = self.mean_module(x)
|
|
39
|
+
covar = self.covar_module(x)
|
|
40
|
+
return MultivariateNormal(mean, covar)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class _BaseExactQEP(ExactQEP):
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
train_inputs: Optional[Union[Tensor, Tuple[Tensor, ...]]],
|
|
47
|
+
train_targets: Optional[Tensor],
|
|
48
|
+
likelihood: QExponentialLikelihood,
|
|
49
|
+
mean_module: Mean,
|
|
50
|
+
covar_module: Kernel,
|
|
51
|
+
):
|
|
52
|
+
super().__init__(train_inputs, train_targets, likelihood)
|
|
53
|
+
self.mean_module = mean_module
|
|
54
|
+
self.covar_module = covar_module
|
|
55
|
+
|
|
56
|
+
def forward(self, x: Tensor, **kwargs) -> MultivariateQExponential:
|
|
57
|
+
mean = self.mean_module(x)
|
|
58
|
+
covar = self.covar_module(x)
|
|
59
|
+
return MultivariateQExponential(mean, covar, power=self.likelihood.power)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _add_cache_hook(tsr: Tensor, pred_strat: DefaultPredictionStrategy) -> Tensor:
|
|
63
|
+
if tsr.grad_fn is not None:
|
|
64
|
+
wrapper = functools.partial(clear_cache_hook, pred_strat)
|
|
65
|
+
functools.update_wrapper(wrapper, clear_cache_hook)
|
|
66
|
+
tsr.grad_fn.register_hook(wrapper)
|
|
67
|
+
return tsr
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class _VariationalStrategy(Module, ABC):
|
|
71
|
+
"""
|
|
72
|
+
Abstract base class for all Variational Strategies.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
has_fantasy_strategy = False
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
model: Union[ApproximateGP, ApproximateQEP, "_VariationalStrategy"],
|
|
80
|
+
inducing_points: Tensor,
|
|
81
|
+
variational_distribution: _VariationalDistribution,
|
|
82
|
+
learn_inducing_locations: bool = True,
|
|
83
|
+
jitter_val: Optional[float] = None,
|
|
84
|
+
):
|
|
85
|
+
super().__init__()
|
|
86
|
+
|
|
87
|
+
self._jitter_val = jitter_val
|
|
88
|
+
|
|
89
|
+
# Model
|
|
90
|
+
object.__setattr__(self, "model", model)
|
|
91
|
+
|
|
92
|
+
# Inducing points
|
|
93
|
+
inducing_points = inducing_points.clone()
|
|
94
|
+
if inducing_points.dim() == 1:
|
|
95
|
+
inducing_points = inducing_points.unsqueeze(-1)
|
|
96
|
+
if learn_inducing_locations:
|
|
97
|
+
self.register_parameter(name="inducing_points", parameter=torch.nn.Parameter(inducing_points))
|
|
98
|
+
else:
|
|
99
|
+
self.register_buffer("inducing_points", inducing_points)
|
|
100
|
+
|
|
101
|
+
# Variational distribution
|
|
102
|
+
self._variational_distribution = variational_distribution
|
|
103
|
+
self.register_buffer("variational_params_initialized", torch.tensor(0))
|
|
104
|
+
|
|
105
|
+
def _clear_cache(self) -> None:
|
|
106
|
+
clear_cache_hook(self)
|
|
107
|
+
|
|
108
|
+
def _expand_inputs(self, x: Tensor, inducing_points: Tensor) -> Tuple[Tensor, Tensor]:
|
|
109
|
+
"""
|
|
110
|
+
Pre-processing step in __call__ to make x the same batch_shape as the inducing points
|
|
111
|
+
"""
|
|
112
|
+
batch_shape = torch.broadcast_shapes(inducing_points.shape[:-2], x.shape[:-2])
|
|
113
|
+
inducing_points = inducing_points.expand(*batch_shape, *inducing_points.shape[-2:])
|
|
114
|
+
x = x.expand(*batch_shape, *x.shape[-2:])
|
|
115
|
+
return x, inducing_points
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def jitter_val(self) -> float:
|
|
119
|
+
if self._jitter_val is None:
|
|
120
|
+
return settings.variational_cholesky_jitter.value(dtype=self.inducing_points.dtype)
|
|
121
|
+
return self._jitter_val
|
|
122
|
+
|
|
123
|
+
@jitter_val.setter
|
|
124
|
+
def jitter_val(self, jitter_val: float):
|
|
125
|
+
self._jitter_val = jitter_val
|
|
126
|
+
|
|
127
|
+
@abstractproperty
|
|
128
|
+
@cached(name="prior_distribution_memo")
|
|
129
|
+
def prior_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
130
|
+
r"""
|
|
131
|
+
The :func:`~qpytorch.variational.VariationalStrategy.prior_distribution` method determines how to compute the
|
|
132
|
+
GP (QEP) prior distribution of the inducing points, e.g. :math:`p(u) \sim N(\mu(X_u), K(X_u, X_u))` or :math:`p(u) \sim QED(\mu(X_u), K(X_u, X_u))`.
|
|
133
|
+
Most commonly, this is done simply by calling the user defined GP (QEP) prior on the inducing point data directly.
|
|
134
|
+
|
|
135
|
+
:rtype: :obj:`~gpytorch.distributions.MultivariateNormal` or :obj:`~qpytorch.distributions.MultivariateQExponential`
|
|
136
|
+
:return: The distribution :math:`p( \mathbf u)`
|
|
137
|
+
"""
|
|
138
|
+
raise NotImplementedError
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
@cached(name="variational_distribution_memo")
|
|
142
|
+
def variational_distribution(self) -> Distribution:
|
|
143
|
+
return self._variational_distribution()
|
|
144
|
+
|
|
145
|
+
def forward(
|
|
146
|
+
self,
|
|
147
|
+
x: Tensor,
|
|
148
|
+
inducing_points: Tensor,
|
|
149
|
+
inducing_values: Tensor,
|
|
150
|
+
variational_inducing_covar: Optional[LinearOperator] = None,
|
|
151
|
+
**kwargs,
|
|
152
|
+
) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
153
|
+
r"""
|
|
154
|
+
The :func:`~qpytorch.variational.VariationalStrategy.forward` method determines how to marginalize out the
|
|
155
|
+
inducing point function values. Specifically, forward defines how to transform a variational distribution
|
|
156
|
+
over the inducing point values, :math:`q(u)`, in to a variational distribution over the function values at
|
|
157
|
+
specified locations x, :math:`q(f|x)`, by integrating :math:`\int p(f|x, u)q(u)du`
|
|
158
|
+
|
|
159
|
+
:param x: Locations :math:`\mathbf X` to get the
|
|
160
|
+
variational posterior of the function values at.
|
|
161
|
+
:param inducing_points: Locations :math:`\mathbf Z` of the inducing points
|
|
162
|
+
:param inducing_values: Samples of the inducing function values :math:`\mathbf u`
|
|
163
|
+
(or the mean of the distribution :math:`q(\mathbf u)` if q is a Gaussian or Q-Exponential.
|
|
164
|
+
:param variational_inducing_covar: If the distribuiton :math:`q(\mathbf u)` is
|
|
165
|
+
Gaussian (Q-Exponential), then this variable is the covariance matrix of that Gaussian (Q-Exponential).
|
|
166
|
+
Otherwise, it will be None.
|
|
167
|
+
|
|
168
|
+
:rtype: :obj:`~gpytorch.distributions.MultivariateNormal` (`~qpytorch.distributions.MultivariateQExponential`)
|
|
169
|
+
:return: The distribution :math:`q( \mathbf f(\mathbf X))`
|
|
170
|
+
"""
|
|
171
|
+
raise NotImplementedError
|
|
172
|
+
|
|
173
|
+
def kl_divergence(self) -> Tensor:
|
|
174
|
+
r"""
|
|
175
|
+
Compute the KL divergence between the variational inducing distribution :math:`q(\mathbf u)`
|
|
176
|
+
and the prior inducing distribution :math:`p(\mathbf u)`.
|
|
177
|
+
"""
|
|
178
|
+
with settings.max_preconditioner_size(0):
|
|
179
|
+
kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
|
|
180
|
+
return kl_divergence
|
|
181
|
+
|
|
182
|
+
@cached(name="amortized_exact_")
|
|
183
|
+
def amortized_exact_(
|
|
184
|
+
self, mean_module: Optional[Module] = None, covar_module: Optional[Module] = None
|
|
185
|
+
) -> Union[ExactGP, ExactQEP]:
|
|
186
|
+
mean_module = self.model.mean_module if mean_module is None else mean_module
|
|
187
|
+
covar_module = self.model.covar_module if covar_module is None else covar_module
|
|
188
|
+
|
|
189
|
+
with torch.no_grad():
|
|
190
|
+
# from here on down, we refer to the inducing points as pseudo_inputs
|
|
191
|
+
pseudo_target_covar, pseudo_target_mean = self.pseudo_points
|
|
192
|
+
pseudo_inputs = self.inducing_points.detach()
|
|
193
|
+
if pseudo_inputs.ndim < pseudo_target_mean.ndim:
|
|
194
|
+
pseudo_inputs = pseudo_inputs.expand(*pseudo_target_mean.shape[:-2], *pseudo_inputs.shape)
|
|
195
|
+
# TODO: add flag for conditioning into SGPR after building fantasy strategy for SGPR
|
|
196
|
+
new_covar_module = deepcopy(covar_module)
|
|
197
|
+
|
|
198
|
+
# update inducing mean if necessary
|
|
199
|
+
pseudo_target_mean = pseudo_target_mean.squeeze() + mean_module(pseudo_inputs)
|
|
200
|
+
|
|
201
|
+
if 'Gaussian' in self.model.likelihood.__class__.__name__:
|
|
202
|
+
inducing_exact_model = _BaseExactGP(
|
|
203
|
+
pseudo_inputs,
|
|
204
|
+
pseudo_target_mean,
|
|
205
|
+
mean_module=deepcopy(mean_module),
|
|
206
|
+
covar_module=new_covar_module,
|
|
207
|
+
likelihood=deepcopy(self.model.likelihood),
|
|
208
|
+
)
|
|
209
|
+
elif 'QExponential' in self.model.likelihood.__class__.__name__:
|
|
210
|
+
inducing_exact_model = _BaseExactQEP(
|
|
211
|
+
pseudo_inputs,
|
|
212
|
+
pseudo_target_mean,
|
|
213
|
+
mean_module=deepcopy(mean_module),
|
|
214
|
+
covar_module=new_covar_module,
|
|
215
|
+
likelihood=deepcopy(self.model.likelihood),
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
raise RuntimeError("Exact model can only handle Gaussian or Q-Exponential likelihoods")
|
|
219
|
+
|
|
220
|
+
# now fantasize around this model
|
|
221
|
+
# as this model is new, we need to compute a posterior to construct the prediction strategy
|
|
222
|
+
# which uses the likelihood pseudo caches
|
|
223
|
+
faked_points = torch.randn(
|
|
224
|
+
*pseudo_target_mean.shape[:-2],
|
|
225
|
+
1,
|
|
226
|
+
pseudo_inputs.shape[-1],
|
|
227
|
+
device=pseudo_inputs.device,
|
|
228
|
+
dtype=pseudo_inputs.dtype,
|
|
229
|
+
)
|
|
230
|
+
inducing_exact_model.eval()
|
|
231
|
+
_ = inducing_exact_model(faked_points)
|
|
232
|
+
|
|
233
|
+
# then we overwrite the likelihood to take into account the multivariate normal term
|
|
234
|
+
pred_strat = inducing_exact_model.prediction_strategy
|
|
235
|
+
pred_strat._memoize_cache = {}
|
|
236
|
+
with torch.no_grad():
|
|
237
|
+
updated_lik_train_train_covar = pred_strat.train_prior_dist.lazy_covariance_matrix + pseudo_target_covar
|
|
238
|
+
pred_strat.lik_train_train_covar = updated_lik_train_train_covar
|
|
239
|
+
|
|
240
|
+
# do the mean cache because the mean cache doesn't solve against lik_train_train_covar
|
|
241
|
+
train_mean = inducing_exact_model.mean_module(*inducing_exact_model.train_inputs)
|
|
242
|
+
train_labels_offset = (inducing_exact_model.prediction_strategy.train_labels - train_mean).unsqueeze(-1)
|
|
243
|
+
mean_cache = updated_lik_train_train_covar.solve(train_labels_offset).squeeze(-1)
|
|
244
|
+
mean_cache = _add_cache_hook(mean_cache, inducing_exact_model.prediction_strategy)
|
|
245
|
+
add_to_cache(pred_strat, "mean_cache", mean_cache)
|
|
246
|
+
# TODO: check to see if we need to do the covar_cache?
|
|
247
|
+
|
|
248
|
+
inducing_exact_model.prediction_strategy = pred_strat
|
|
249
|
+
return inducing_exact_model
|
|
250
|
+
|
|
251
|
+
def pseudo_points(self) -> Tuple[Tensor, Tensor]:
|
|
252
|
+
raise NotImplementedError("Each variational strategy must implement its own pseudo points method")
|
|
253
|
+
|
|
254
|
+
def get_fantasy_model(
|
|
255
|
+
self,
|
|
256
|
+
inputs: Tensor,
|
|
257
|
+
targets: Tensor,
|
|
258
|
+
mean_module: Optional[Module] = None,
|
|
259
|
+
covar_module: Optional[Module] = None,
|
|
260
|
+
**kwargs,
|
|
261
|
+
) -> Union[ExactGP, ExactQEP]:
|
|
262
|
+
r"""
|
|
263
|
+
Performs the online variational conditioning (OVC) strategy of Maddox et al, '21 to return
|
|
264
|
+
an exact GP (QEP) model that incorporates the inputs and targets alongside the variational model's inducing
|
|
265
|
+
points and targets.
|
|
266
|
+
|
|
267
|
+
Currently, instead of directly updating the variational parameters (and inducing points), we instead
|
|
268
|
+
return an ExactGP (ExactQEP) model rather than an updated variational GP (QEP) model. This is done primarily for
|
|
269
|
+
numerical stability.
|
|
270
|
+
|
|
271
|
+
Unlike the ExactGP's (ExactQEP's) call for get_fantasy_model, we enable options for mean_module and covar_module
|
|
272
|
+
that allow specification of the mean / covariance. We expect that either the mean and covariance
|
|
273
|
+
modules are attributes of the model itself called mean_module and covar_module respectively OR that you
|
|
274
|
+
pass them into this method explicitly.
|
|
275
|
+
|
|
276
|
+
:param inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
|
|
277
|
+
observations.
|
|
278
|
+
:param targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
|
|
279
|
+
:param mean_module: torch module describing the mean function of the GP (QEP) model. Optional if
|
|
280
|
+
`mean_module` is already an attribute of the variational GP (QEP).
|
|
281
|
+
:param covar_module: torch module describing the covariance function of the GP (QEP) model. Optional
|
|
282
|
+
if `covar_module` is already an attribute of the variational GP (QEP).
|
|
283
|
+
:return: An `ExactGP` (`ExactQEP`) model with `k + m` training examples, where the `m` fantasy examples have been added
|
|
284
|
+
and all test-time caches have been updated. We assume that there are `k` inducing points in this variational
|
|
285
|
+
GP (QEP). Note that we return an `ExactGP` rather than a variational GP (QEP).
|
|
286
|
+
|
|
287
|
+
Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
|
|
288
|
+
Maddox, Stanton, Wilson, NeurIPS, '21
|
|
289
|
+
https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
# currently, we only support fantasization for CholeskyVariationalDistribution and
|
|
293
|
+
# whitened / unwhitened variational strategies
|
|
294
|
+
if not self.has_fantasy_strategy:
|
|
295
|
+
raise NotImplementedError(
|
|
296
|
+
f"No fantasy model support for {self.__class__.__name__}. "
|
|
297
|
+
"Only VariationalStrategy and UnwhitenedVariationalStrategy are currently supported."
|
|
298
|
+
)
|
|
299
|
+
else:
|
|
300
|
+
from . import CholeskyVariationalDistribution # Circular import otherwise
|
|
301
|
+
|
|
302
|
+
if not isinstance(self._variational_distribution, CholeskyVariationalDistribution):
|
|
303
|
+
raise NotImplementedError(
|
|
304
|
+
"Fantasy models are only support for variational models with CholeskyVariationalDistribution."
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
if not isinstance(self.model.likelihood, (GaussianLikelihood, QExponentialLikelihood)):
|
|
308
|
+
raise NotImplementedError(
|
|
309
|
+
f"No fantasy model support for {self.model.likelihood.__class__.__name__}. "
|
|
310
|
+
"Only GaussianLikelihoods and QExponentialLikelihoods are currently supported."
|
|
311
|
+
)
|
|
312
|
+
# we assume that either the user has given the model a mean_module and a covar_module
|
|
313
|
+
# or that it will be passed into the get_fantasy_model function. we check for these.
|
|
314
|
+
if mean_module is None:
|
|
315
|
+
mean_module = getattr(self.model, "mean_module", None)
|
|
316
|
+
if mean_module is None:
|
|
317
|
+
raise ModuleNotFoundError(
|
|
318
|
+
"Either you must provide a mean_module as input to get_fantasy_model "
|
|
319
|
+
"or it must be an attribute of the model called mean_module."
|
|
320
|
+
)
|
|
321
|
+
if covar_module is None:
|
|
322
|
+
covar_module = getattr(self.model, "covar_module", None)
|
|
323
|
+
if covar_module is None:
|
|
324
|
+
# raise an error
|
|
325
|
+
raise ModuleNotFoundError(
|
|
326
|
+
"Either you must provide a covar_module as input to get_fantasy_model "
|
|
327
|
+
"or it must be an attribute of the model called covar_module."
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# first we construct an exact model over the inducing points with the inducing covariance
|
|
331
|
+
# matrix
|
|
332
|
+
inducing_exact_model = self.amortized_exact_(mean_module=mean_module, covar_module=covar_module)
|
|
333
|
+
|
|
334
|
+
# then we update this model by adding in the inputs and pseudo targets
|
|
335
|
+
# finally we fantasize wrt targets
|
|
336
|
+
fantasy_model = inducing_exact_model.get_fantasy_model(inputs, targets, **kwargs)
|
|
337
|
+
fant_pred_strat = fantasy_model.prediction_strategy
|
|
338
|
+
|
|
339
|
+
# first we update the lik_train_train_covar
|
|
340
|
+
# do the mean cache again because the mean cache resets the likelihood forward
|
|
341
|
+
train_mean = fantasy_model.mean_module(*fantasy_model.train_inputs)
|
|
342
|
+
train_labels_offset = (fant_pred_strat.train_labels - train_mean).unsqueeze(-1)
|
|
343
|
+
fantasy_lik_train_root_inv = fant_pred_strat.lik_train_train_covar.root_inv_decomposition()
|
|
344
|
+
mean_cache = fantasy_lik_train_root_inv.matmul(train_labels_offset).squeeze(-1)
|
|
345
|
+
mean_cache = _add_cache_hook(mean_cache, fant_pred_strat)
|
|
346
|
+
add_to_cache(fant_pred_strat, "mean_cache", mean_cache)
|
|
347
|
+
# TODO: should we update the covar_cache?
|
|
348
|
+
|
|
349
|
+
fantasy_model.prediction_strategy = fant_pred_strat
|
|
350
|
+
return fantasy_model
|
|
351
|
+
|
|
352
|
+
def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
353
|
+
# If we're in prior mode, then we're done!
|
|
354
|
+
if prior:
|
|
355
|
+
return self.model.forward(x, **kwargs)
|
|
356
|
+
|
|
357
|
+
# Delete previously cached items from the training distribution
|
|
358
|
+
if self.training:
|
|
359
|
+
self._clear_cache()
|
|
360
|
+
# (Maybe) initialize variational distribution
|
|
361
|
+
if not self.variational_params_initialized.item():
|
|
362
|
+
prior_dist = self.prior_distribution
|
|
363
|
+
self._variational_distribution.initialize_variational_distribution(prior_dist)
|
|
364
|
+
self.variational_params_initialized.fill_(1)
|
|
365
|
+
|
|
366
|
+
# Ensure inducing_points and x are the same size
|
|
367
|
+
inducing_points = self.inducing_points
|
|
368
|
+
if inducing_points.shape[:-2] != x.shape[:-2]:
|
|
369
|
+
x, inducing_points = self._expand_inputs(x, inducing_points)
|
|
370
|
+
|
|
371
|
+
# Get p(u)/q(u)
|
|
372
|
+
variational_dist_u = self.variational_distribution
|
|
373
|
+
|
|
374
|
+
# Get q(f)
|
|
375
|
+
if isinstance(variational_dist_u, (MultivariateNormal, MultivariateQExponential)):
|
|
376
|
+
return super().__call__(
|
|
377
|
+
x,
|
|
378
|
+
inducing_points,
|
|
379
|
+
inducing_values=variational_dist_u.mean,
|
|
380
|
+
variational_inducing_covar=variational_dist_u.lazy_covariance_matrix,
|
|
381
|
+
**kwargs,
|
|
382
|
+
)
|
|
383
|
+
elif isinstance(variational_dist_u, Delta):
|
|
384
|
+
return super().__call__(
|
|
385
|
+
x, inducing_points, inducing_values=variational_dist_u.mean, variational_inducing_covar=None, **kwargs
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
raise RuntimeError(
|
|
389
|
+
f"Invalid variational distribuition ({type(variational_dist_u)}). "
|
|
390
|
+
"Expected a multivariate normal (q-exponential) or a delta distribution."
|
|
391
|
+
)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from typing import Iterable, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from linear_operator.operators import LinearOperator
|
|
7
|
+
from torch import LongTensor, Tensor
|
|
8
|
+
|
|
9
|
+
from ..distributions import Delta, MultivariateNormal, MultivariateQExponential
|
|
10
|
+
from ..models import ApproximateGP, ApproximateQEP
|
|
11
|
+
from ._variational_distribution import _VariationalDistribution
|
|
12
|
+
from .grid_interpolation_variational_strategy import GridInterpolationVariationalStrategy
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AdditiveGridInterpolationVariationalStrategy(GridInterpolationVariationalStrategy):
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
model: Union[ApproximateGP, ApproximateQEP],
|
|
19
|
+
grid_size: int,
|
|
20
|
+
grid_bounds: Iterable[Tuple[float, float]],
|
|
21
|
+
num_dim: int,
|
|
22
|
+
variational_distribution: _VariationalDistribution,
|
|
23
|
+
mixing_params: bool = False,
|
|
24
|
+
sum_output: bool = True,
|
|
25
|
+
):
|
|
26
|
+
super(AdditiveGridInterpolationVariationalStrategy, self).__init__(
|
|
27
|
+
model, grid_size, grid_bounds, variational_distribution
|
|
28
|
+
)
|
|
29
|
+
self.num_dim = num_dim
|
|
30
|
+
self.sum_output = sum_output
|
|
31
|
+
# Mixing parameters
|
|
32
|
+
if mixing_params:
|
|
33
|
+
self.register_parameter(name="mixing_params", parameter=torch.nn.Parameter(torch.ones(num_dim) / num_dim))
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def prior_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
37
|
+
# If desired, models can compare the input to forward to inducing_points and use a GridKernel for space
|
|
38
|
+
# efficiency.
|
|
39
|
+
# However, when using a default VariationalDistribution which has an O(m^2) space complexity anyways,
|
|
40
|
+
# we find that GridKernel is typically not worth it due to the moderate slow down of using FFTs.
|
|
41
|
+
out = super(AdditiveGridInterpolationVariationalStrategy, self).prior_distribution
|
|
42
|
+
mean = out.mean.repeat(self.num_dim, 1)
|
|
43
|
+
covar = out.lazy_covariance_matrix.repeat(self.num_dim, 1, 1)
|
|
44
|
+
if hasattr(self.model, 'power'):
|
|
45
|
+
return MultivariateQExponential(mean, covar, power=self.model.power)
|
|
46
|
+
else:
|
|
47
|
+
return MultivariateNormal(mean, covar)
|
|
48
|
+
|
|
49
|
+
def _compute_grid(self, inputs: Tensor) -> Tuple[LongTensor, Tensor]:
|
|
50
|
+
num_data, num_dim = inputs.size()
|
|
51
|
+
inputs = inputs.transpose(0, 1).reshape(-1, 1)
|
|
52
|
+
interp_indices, interp_values = super(AdditiveGridInterpolationVariationalStrategy, self)._compute_grid(inputs)
|
|
53
|
+
interp_indices = interp_indices.view(num_dim, num_data, -1)
|
|
54
|
+
interp_values = interp_values.view(num_dim, num_data, -1)
|
|
55
|
+
|
|
56
|
+
if hasattr(self, "mixing_params"):
|
|
57
|
+
interp_values = interp_values.mul(self.mixing_params.unsqueeze(1).unsqueeze(2))
|
|
58
|
+
return interp_indices, interp_values
|
|
59
|
+
|
|
60
|
+
def forward(
|
|
61
|
+
self,
|
|
62
|
+
x: Tensor,
|
|
63
|
+
inducing_points: Tensor,
|
|
64
|
+
inducing_values: Tensor,
|
|
65
|
+
variational_inducing_covar: Optional[LinearOperator] = None,
|
|
66
|
+
*params,
|
|
67
|
+
**kwargs,
|
|
68
|
+
) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
69
|
+
if x.ndimension() == 1:
|
|
70
|
+
x = x.unsqueeze(-1)
|
|
71
|
+
elif x.ndimension() != 2:
|
|
72
|
+
raise RuntimeError("AdditiveGridInterpolationVariationalStrategy expects a 2d tensor.")
|
|
73
|
+
|
|
74
|
+
num_data, num_dim = x.size()
|
|
75
|
+
if num_dim != self.num_dim:
|
|
76
|
+
raise RuntimeError("The number of dims should match the number specified.")
|
|
77
|
+
|
|
78
|
+
output = super().forward(x, inducing_points, inducing_values, variational_inducing_covar)
|
|
79
|
+
if self.sum_output:
|
|
80
|
+
if variational_inducing_covar is not None:
|
|
81
|
+
mean = output.mean.sum(0)
|
|
82
|
+
covar = output.lazy_covariance_matrix.sum(-3)
|
|
83
|
+
if hasattr(self.model, 'power'):
|
|
84
|
+
return MultivariateQExponential(mean, covar, power=self.model.power)
|
|
85
|
+
else:
|
|
86
|
+
return MultivariateNormal(mean, covar)
|
|
87
|
+
else:
|
|
88
|
+
return Delta(output.mean.sum(0))
|
|
89
|
+
else:
|
|
90
|
+
return output
|