qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import warnings
|
|
6
|
+
|
|
7
|
+
from collections.abc import Iterable
|
|
8
|
+
from copy import deepcopy
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from torch import Tensor
|
|
12
|
+
|
|
13
|
+
from .. import settings
|
|
14
|
+
from ..distributions import MultitaskMultivariateQExponential, MultivariateQExponential
|
|
15
|
+
from ..likelihoods import _QExponentialLikelihoodBase
|
|
16
|
+
from gpytorch.utils.generic import length_safe_zip
|
|
17
|
+
from ..utils.warnings import QEPInputWarning
|
|
18
|
+
from .exact_prediction_strategies import prediction_strategy
|
|
19
|
+
from .qep import QEP
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ExactQEP(QEP):
|
|
23
|
+
r"""
|
|
24
|
+
The base class for any Q-Exponential process latent function to be used in conjunction
|
|
25
|
+
with exact inference.
|
|
26
|
+
|
|
27
|
+
:param torch.Tensor train_inputs: (size n x d) The training features :math:`\mathbf X`.
|
|
28
|
+
:param torch.Tensor train_targets: (size n) The training targets :math:`\mathbf y`.
|
|
29
|
+
:param ~qpytorch.likelihoods.QExponentialLikelihood likelihood: The Q-Exponential likelihood that defines
|
|
30
|
+
the observational distribution. Since we're using exact inference, the likelihood must be Q-Exponential.
|
|
31
|
+
|
|
32
|
+
The :meth:`forward` function should describe how to compute the prior latent distribution
|
|
33
|
+
on a given input. Typically, this will involve a mean and kernel function.
|
|
34
|
+
The result must be a :obj:`~qpytorch.distributions.MultivariateQExponential`.
|
|
35
|
+
|
|
36
|
+
Calling this model will return the posterior of the latent Q-Exponential process when conditioned
|
|
37
|
+
on the training data. The output will be a :obj:`~qpytorch.distributions.MultivariateQExponential`.
|
|
38
|
+
|
|
39
|
+
Example:
|
|
40
|
+
>>> class MyQEP(qpytorch.models.ExactQEP):
|
|
41
|
+
>>> def __init__(self, train_x, train_y, likelihood):
|
|
42
|
+
>>> super().__init__(train_x, train_y, likelihood)
|
|
43
|
+
>>> self.mean_module = qpytorch.means.ZeroMean()
|
|
44
|
+
>>> self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())
|
|
45
|
+
>>>
|
|
46
|
+
>>> def forward(self, x):
|
|
47
|
+
>>> mean = self.mean_module(x)
|
|
48
|
+
>>> covar = self.covar_module(x)
|
|
49
|
+
>>> return qpytorch.distributions.MultivariateQExponential(mean, covar, self.likelihood.power)
|
|
50
|
+
>>>
|
|
51
|
+
>>> # train_x = ...; train_y = ...
|
|
52
|
+
>>> likelihood = qpytorch.likelihoods.QExponentialLikelihood(power=torch.tensor(1.0))
|
|
53
|
+
>>> model = MyQEP(train_x, train_y, likelihood)
|
|
54
|
+
>>>
|
|
55
|
+
>>> # test_x = ...;
|
|
56
|
+
>>> model(test_x) # Returns the QEP latent function at test_x
|
|
57
|
+
>>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
train_inputs: Tensor | Iterable[Tensor] | None,
|
|
63
|
+
train_targets: Tensor | None,
|
|
64
|
+
likelihood: _QExponentialLikelihoodBase,
|
|
65
|
+
):
|
|
66
|
+
if train_inputs is not None and isinstance(train_inputs, Tensor):
|
|
67
|
+
train_inputs = (train_inputs,)
|
|
68
|
+
if train_inputs is not None and not all(isinstance(train_input, Tensor) for train_input in train_inputs):
|
|
69
|
+
raise RuntimeError("Train inputs must be a tensor, or a list/tuple of tensors")
|
|
70
|
+
if not isinstance(likelihood, _QExponentialLikelihoodBase):
|
|
71
|
+
raise RuntimeError("ExactQEP can only handle Q-Exponential likelihoods")
|
|
72
|
+
|
|
73
|
+
super().__init__()
|
|
74
|
+
if train_inputs is not None:
|
|
75
|
+
self.train_inputs = tuple(tri.unsqueeze(-1) if tri.ndimension() == 1 else tri for tri in train_inputs)
|
|
76
|
+
self.train_targets = train_targets
|
|
77
|
+
else:
|
|
78
|
+
self.train_inputs = None
|
|
79
|
+
self.train_targets = None
|
|
80
|
+
self.likelihood = likelihood
|
|
81
|
+
|
|
82
|
+
self.prediction_strategy = None
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def train_targets(self) -> tuple[Tensor] | None:
|
|
86
|
+
return self._train_targets
|
|
87
|
+
|
|
88
|
+
@train_targets.setter
|
|
89
|
+
def train_targets(self, value: Tensor | None) -> None:
|
|
90
|
+
object.__setattr__(self, "_train_targets", value)
|
|
91
|
+
|
|
92
|
+
def _apply(self, fn):
|
|
93
|
+
if self.train_inputs is not None:
|
|
94
|
+
self.train_inputs = tuple(fn(train_input) for train_input in self.train_inputs)
|
|
95
|
+
self.train_targets = fn(self.train_targets)
|
|
96
|
+
return super()._apply(fn)
|
|
97
|
+
|
|
98
|
+
def _clear_cache(self) -> None:
|
|
99
|
+
# The precomputed caches from test time live in prediction_strategy
|
|
100
|
+
self.prediction_strategy = None
|
|
101
|
+
|
|
102
|
+
def local_load_samples(self, samples_dict, memo, prefix):
|
|
103
|
+
"""
|
|
104
|
+
Replace the model's learned hyperparameters with samples from a posterior distribution.
|
|
105
|
+
"""
|
|
106
|
+
# Pyro always puts the samples in the first batch dimension
|
|
107
|
+
num_samples = next(iter(samples_dict.values())).size(0)
|
|
108
|
+
self.train_inputs = tuple(tri.unsqueeze(0).expand(num_samples, *tri.shape) for tri in self.train_inputs)
|
|
109
|
+
self.train_targets = self.train_targets.unsqueeze(0).expand(num_samples, *self.train_targets.shape)
|
|
110
|
+
super().local_load_samples(samples_dict, memo, prefix)
|
|
111
|
+
|
|
112
|
+
def set_train_data(
|
|
113
|
+
self, inputs: Tensor | Iterable[Tensor] | None = None, targets: Tensor | None = None, strict: bool = True
|
|
114
|
+
) -> None:
|
|
115
|
+
"""
|
|
116
|
+
Set training data (does not re-fit model hyper-parameters).
|
|
117
|
+
|
|
118
|
+
:param inputs: The new training inputs.
|
|
119
|
+
:param targets: The new training targets.
|
|
120
|
+
:param strict: If `True`, the new inputs and targets must have the same shape,
|
|
121
|
+
dtype, and device as the current inputs and targets. Otherwise, any
|
|
122
|
+
shape/dtype/device are allowed.
|
|
123
|
+
"""
|
|
124
|
+
if inputs is not None:
|
|
125
|
+
if isinstance(inputs, Tensor):
|
|
126
|
+
inputs = (inputs,)
|
|
127
|
+
inputs = tuple(input_.unsqueeze(-1) if input_.ndimension() == 1 else input_ for input_ in inputs)
|
|
128
|
+
if strict:
|
|
129
|
+
for input_, t_input in length_safe_zip(inputs, self.train_inputs or (None,)):
|
|
130
|
+
for attr in {"shape", "dtype", "device"}:
|
|
131
|
+
expected_attr = getattr(t_input, attr, None)
|
|
132
|
+
found_attr = getattr(input_, attr, None)
|
|
133
|
+
if expected_attr != found_attr:
|
|
134
|
+
msg = "Cannot modify {attr} of inputs (expected {e_attr}, found {f_attr})."
|
|
135
|
+
msg = msg.format(attr=attr, e_attr=expected_attr, f_attr=found_attr)
|
|
136
|
+
raise RuntimeError(msg)
|
|
137
|
+
self.train_inputs = inputs
|
|
138
|
+
if targets is not None:
|
|
139
|
+
if strict:
|
|
140
|
+
for attr in {"shape", "dtype", "device"}:
|
|
141
|
+
expected_attr = getattr(self.train_targets, attr, None)
|
|
142
|
+
found_attr = getattr(targets, attr, None)
|
|
143
|
+
if expected_attr != found_attr:
|
|
144
|
+
msg = "Cannot modify {attr} of targets (expected {e_attr}, found {f_attr})."
|
|
145
|
+
msg = msg.format(attr=attr, e_attr=expected_attr, f_attr=found_attr)
|
|
146
|
+
raise RuntimeError(msg)
|
|
147
|
+
self.train_targets = targets
|
|
148
|
+
self.prediction_strategy = None
|
|
149
|
+
|
|
150
|
+
def get_fantasy_model(self, inputs, targets, **kwargs):
|
|
151
|
+
"""
|
|
152
|
+
Returns a new QEP model that incorporates the specified inputs and targets as new training data.
|
|
153
|
+
|
|
154
|
+
Using this method is more efficient than updating with `set_train_data` when the number of inputs is relatively
|
|
155
|
+
small, because any computed test-time caches will be updated in linear time rather than computed from scratch.
|
|
156
|
+
|
|
157
|
+
.. note::
|
|
158
|
+
If `targets` is a batch (e.g. `b x m`), then the QEP returned from this method will be a batch mode QEP.
|
|
159
|
+
If `inputs` is of the same (or lesser) dimension as `targets`, then it is assumed that the fantasy points
|
|
160
|
+
are the same for each target batch.
|
|
161
|
+
|
|
162
|
+
:param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
|
|
163
|
+
observations.
|
|
164
|
+
:param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
|
|
165
|
+
:return: An `ExactQEP` model with `n + m` training examples, where the `m` fantasy examples have been added
|
|
166
|
+
and all test-time caches have been updated.
|
|
167
|
+
:rtype: ~qpytorch.models.ExactEQP
|
|
168
|
+
"""
|
|
169
|
+
if self.prediction_strategy is None:
|
|
170
|
+
raise RuntimeError(
|
|
171
|
+
"Fantasy observations can only be added after making predictions with a model so that "
|
|
172
|
+
"all test independent caches exist. Call the model on some data first!"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
model_batch_shape = self.train_inputs[0].shape[:-2]
|
|
176
|
+
|
|
177
|
+
if not isinstance(inputs, list):
|
|
178
|
+
inputs = [inputs]
|
|
179
|
+
|
|
180
|
+
inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in inputs]
|
|
181
|
+
|
|
182
|
+
if not isinstance(self.prediction_strategy.train_prior_dist, MultitaskMultivariateQExponential):
|
|
183
|
+
data_dim_start = -1
|
|
184
|
+
else:
|
|
185
|
+
data_dim_start = -2
|
|
186
|
+
|
|
187
|
+
target_batch_shape = targets.shape[:data_dim_start]
|
|
188
|
+
input_batch_shape = inputs[0].shape[:-2]
|
|
189
|
+
tbdim, ibdim = len(target_batch_shape), len(input_batch_shape)
|
|
190
|
+
|
|
191
|
+
if not (tbdim == ibdim + 1 or tbdim == ibdim):
|
|
192
|
+
raise RuntimeError(
|
|
193
|
+
f"Unsupported batch shapes: The target batch shape ({target_batch_shape}) must have either the "
|
|
194
|
+
f"same dimension as or one more dimension than the input batch shape ({input_batch_shape})"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Check whether we can properly broadcast batch dimensions
|
|
198
|
+
try:
|
|
199
|
+
torch.broadcast_shapes(model_batch_shape, target_batch_shape)
|
|
200
|
+
except RuntimeError:
|
|
201
|
+
raise RuntimeError(
|
|
202
|
+
f"Model batch shape ({model_batch_shape}) and target batch shape "
|
|
203
|
+
f"({target_batch_shape}) are not broadcastable."
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if len(model_batch_shape) > len(input_batch_shape):
|
|
207
|
+
input_batch_shape = model_batch_shape
|
|
208
|
+
if len(model_batch_shape) > len(target_batch_shape):
|
|
209
|
+
target_batch_shape = model_batch_shape
|
|
210
|
+
|
|
211
|
+
# If input has no fantasy batch dimension but target does, we can save memory and computation by not
|
|
212
|
+
# computing the covariance for each element of the batch. Therefore we don't expand the inputs to the
|
|
213
|
+
# size of the fantasy model here - this is done below, after the evaluation and fast fantasy update
|
|
214
|
+
train_inputs = [tin.expand(input_batch_shape + tin.shape[-2:]) for tin in self.train_inputs]
|
|
215
|
+
train_targets = self.train_targets.expand(target_batch_shape + self.train_targets.shape[data_dim_start:])
|
|
216
|
+
|
|
217
|
+
full_inputs = [
|
|
218
|
+
torch.cat(
|
|
219
|
+
[train_input, input.expand(input_batch_shape + input.shape[-2:])],
|
|
220
|
+
dim=-2,
|
|
221
|
+
)
|
|
222
|
+
for train_input, input in length_safe_zip(train_inputs, inputs)
|
|
223
|
+
]
|
|
224
|
+
full_targets = torch.cat(
|
|
225
|
+
[train_targets, targets.expand(target_batch_shape + targets.shape[data_dim_start:])], dim=data_dim_start
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
try:
|
|
229
|
+
fantasy_kwargs = {"noise": kwargs.pop("noise")}
|
|
230
|
+
except KeyError:
|
|
231
|
+
fantasy_kwargs = {}
|
|
232
|
+
|
|
233
|
+
full_output = super().__call__(*full_inputs, **kwargs)
|
|
234
|
+
|
|
235
|
+
# Copy model without copying training data or prediction strategy (since we'll overwrite those)
|
|
236
|
+
old_pred_strat = self.prediction_strategy
|
|
237
|
+
old_train_inputs = self.train_inputs
|
|
238
|
+
old_train_targets = self.train_targets
|
|
239
|
+
old_likelihood = self.likelihood
|
|
240
|
+
self.prediction_strategy = None
|
|
241
|
+
self.train_inputs = None
|
|
242
|
+
self.train_targets = None
|
|
243
|
+
self.likelihood = None
|
|
244
|
+
new_model = deepcopy(self)
|
|
245
|
+
self.prediction_strategy = old_pred_strat
|
|
246
|
+
self.train_inputs = old_train_inputs
|
|
247
|
+
self.train_targets = old_train_targets
|
|
248
|
+
self.likelihood = old_likelihood
|
|
249
|
+
|
|
250
|
+
new_model.likelihood = old_likelihood.get_fantasy_likelihood(**fantasy_kwargs)
|
|
251
|
+
new_model.prediction_strategy = old_pred_strat.get_fantasy_strategy(
|
|
252
|
+
inputs, targets, full_inputs, full_targets, full_output, **fantasy_kwargs
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# if the fantasies are at the same points, we need to expand the inputs for the new model
|
|
256
|
+
if tbdim == ibdim + 1:
|
|
257
|
+
new_model.train_inputs = [fi.expand(target_batch_shape + fi.shape[-2:]) for fi in full_inputs]
|
|
258
|
+
else:
|
|
259
|
+
new_model.train_inputs = full_inputs
|
|
260
|
+
new_model.train_targets = full_targets
|
|
261
|
+
|
|
262
|
+
return new_model
|
|
263
|
+
|
|
264
|
+
def __call__(self, *args, **kwargs):
|
|
265
|
+
train_inputs = list(self.train_inputs) if self.train_inputs is not None else []
|
|
266
|
+
inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in args]
|
|
267
|
+
|
|
268
|
+
# Training mode: optimizing
|
|
269
|
+
if self.training:
|
|
270
|
+
if self.train_inputs is None:
|
|
271
|
+
raise RuntimeError(
|
|
272
|
+
"train_inputs cannot be None in training mode. "
|
|
273
|
+
"Call .eval() for prior predictions, or call .set_train_data() to add training data."
|
|
274
|
+
)
|
|
275
|
+
if settings.debug.on():
|
|
276
|
+
if not all(
|
|
277
|
+
torch.equal(train_input, input) for train_input, input in length_safe_zip(train_inputs, inputs)
|
|
278
|
+
):
|
|
279
|
+
raise RuntimeError("You must train on the training inputs!")
|
|
280
|
+
res = super().__call__(*inputs, **kwargs)
|
|
281
|
+
return res
|
|
282
|
+
|
|
283
|
+
# Prior mode
|
|
284
|
+
elif settings.prior_mode.on() or self.train_inputs is None or self.train_targets is None:
|
|
285
|
+
full_inputs = args
|
|
286
|
+
full_output = super().__call__(*full_inputs, **kwargs)
|
|
287
|
+
if settings.debug().on():
|
|
288
|
+
if not isinstance(full_output, MultivariateQExponential):
|
|
289
|
+
raise RuntimeError("ExactQEP.forward must return a MultivariateQExponential")
|
|
290
|
+
return full_output
|
|
291
|
+
|
|
292
|
+
# Posterior mode
|
|
293
|
+
else:
|
|
294
|
+
if settings.debug.on():
|
|
295
|
+
if all(torch.equal(train_input, input) for train_input, input in length_safe_zip(train_inputs, inputs)):
|
|
296
|
+
warnings.warn(
|
|
297
|
+
"The input matches the stored training data. Did you forget to call model.train()?",
|
|
298
|
+
QEPInputWarning,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Get the terms that only depend on training data
|
|
302
|
+
if self.prediction_strategy is None:
|
|
303
|
+
train_output = super().__call__(*train_inputs, **kwargs)
|
|
304
|
+
|
|
305
|
+
# Create the prediction strategy for
|
|
306
|
+
self.prediction_strategy = prediction_strategy(
|
|
307
|
+
train_inputs=train_inputs,
|
|
308
|
+
train_prior_dist=train_output,
|
|
309
|
+
train_labels=self.train_targets,
|
|
310
|
+
likelihood=self.likelihood,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Concatenate the input to the training input
|
|
314
|
+
full_inputs = []
|
|
315
|
+
batch_shape = train_inputs[0].shape[:-2]
|
|
316
|
+
for train_input, input in length_safe_zip(train_inputs, inputs):
|
|
317
|
+
# Make sure the batch shapes agree for training/test data
|
|
318
|
+
if batch_shape != train_input.shape[:-2]:
|
|
319
|
+
batch_shape = torch.broadcast_shapes(batch_shape, train_input.shape[:-2])
|
|
320
|
+
train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
|
|
321
|
+
if batch_shape != input.shape[:-2]:
|
|
322
|
+
batch_shape = torch.broadcast_shapes(batch_shape, input.shape[:-2])
|
|
323
|
+
train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
|
|
324
|
+
input = input.expand(*batch_shape, *input.shape[-2:])
|
|
325
|
+
full_inputs.append(torch.cat([train_input, input], dim=-2))
|
|
326
|
+
|
|
327
|
+
# Get the joint distribution for training/test data
|
|
328
|
+
full_output = super().__call__(*full_inputs, **kwargs)
|
|
329
|
+
if settings.debug().on():
|
|
330
|
+
if not isinstance(full_output, MultivariateQExponential):
|
|
331
|
+
raise RuntimeError("ExactQEP.forward must return a MultivariateQExponential")
|
|
332
|
+
full_mean, full_covar = full_output.loc, full_output.lazy_covariance_matrix
|
|
333
|
+
|
|
334
|
+
# Determine the shape of the joint distribution
|
|
335
|
+
batch_shape = full_output.batch_shape
|
|
336
|
+
joint_shape = full_output.event_shape
|
|
337
|
+
tasks_shape = joint_shape[1:] # For multitask learning
|
|
338
|
+
test_shape = torch.Size([joint_shape[0] - self.prediction_strategy.train_shape[0], *tasks_shape])
|
|
339
|
+
|
|
340
|
+
# Make the prediction
|
|
341
|
+
with settings.cg_tolerance(settings.eval_cg_tolerance.value()):
|
|
342
|
+
(
|
|
343
|
+
predictive_mean,
|
|
344
|
+
predictive_covar,
|
|
345
|
+
) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
|
|
346
|
+
|
|
347
|
+
# Reshape predictive mean to match the appropriate event shape
|
|
348
|
+
predictive_mean = predictive_mean.view(*batch_shape, *test_shape).contiguous()
|
|
349
|
+
return full_output.__class__(predictive_mean, predictive_covar, power=full_output.power)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
#! /usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch.nn import ModuleList
|
|
7
|
+
|
|
8
|
+
from ..likelihoods import LikelihoodList
|
|
9
|
+
from . import GP, QEP
|
|
10
|
+
from gpytorch.utils.generic import length_safe_zip
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AbstractModelList(GP, QEP, ABC):
|
|
14
|
+
def forward_i(self, i, *args, **kwargs):
|
|
15
|
+
"""Forward restricted to the i-th model only."""
|
|
16
|
+
raise NotImplementedError
|
|
17
|
+
|
|
18
|
+
def likelihood_i(self, i, *args, **kwargs):
|
|
19
|
+
"""Evaluate likelihood of the i-th model only."""
|
|
20
|
+
raise NotImplementedError
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class IndependentModelList(AbstractModelList):
|
|
24
|
+
def __init__(self, *models):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.models = ModuleList(models)
|
|
27
|
+
for m in models:
|
|
28
|
+
if not hasattr(m, "likelihood"):
|
|
29
|
+
raise ValueError(
|
|
30
|
+
"IndependentModelList currently only supports models that have a likelihood (e.g. ExactGPs)"
|
|
31
|
+
)
|
|
32
|
+
self.likelihood = LikelihoodList(*[m.likelihood for m in models])
|
|
33
|
+
|
|
34
|
+
def forward_i(self, i, *args, **kwargs):
|
|
35
|
+
return self.models[i].forward(*args, **kwargs)
|
|
36
|
+
|
|
37
|
+
def likelihood_i(self, i, *args, **kwargs):
|
|
38
|
+
return self.likelihood.likelihoods[i](*args, **kwargs)
|
|
39
|
+
|
|
40
|
+
def forward(self, *args, **kwargs):
|
|
41
|
+
return [
|
|
42
|
+
model.forward(*args_, **kwargs) for model, args_ in length_safe_zip(self.models, _get_tensor_args(*args))
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
def get_fantasy_model(self, inputs, targets, **kwargs):
|
|
46
|
+
"""
|
|
47
|
+
Returns a new GP (QEP) model that incorporates the specified inputs and targets as new training data.
|
|
48
|
+
|
|
49
|
+
This is a simple wrapper that creates fantasy models for each of the models in the model list,
|
|
50
|
+
and returns the same class of fantasy models.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
inputs: List of locations of fantasy observations, one for each model.
|
|
54
|
+
targets List of labels of fantasy observations, one for each model.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
An `IndependentModelList` model, where each sub-model is the fantasy model of the respective
|
|
58
|
+
sub-model in the original model at the corresponding input locations / labels.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
if "noise" in kwargs:
|
|
62
|
+
noise = kwargs.pop("noise")
|
|
63
|
+
kwargs = [{**kwargs, "noise": noise_} if noise_ is not None else kwargs for noise_ in noise]
|
|
64
|
+
else:
|
|
65
|
+
kwargs = [kwargs] * len(inputs)
|
|
66
|
+
|
|
67
|
+
fantasy_models = [
|
|
68
|
+
model.get_fantasy_model(*inputs_, *targets_, **kwargs_)
|
|
69
|
+
for model, inputs_, targets_, kwargs_ in length_safe_zip(
|
|
70
|
+
self.models,
|
|
71
|
+
_get_tensor_args(*inputs),
|
|
72
|
+
_get_tensor_args(*targets),
|
|
73
|
+
kwargs,
|
|
74
|
+
)
|
|
75
|
+
]
|
|
76
|
+
return self.__class__(*fantasy_models)
|
|
77
|
+
|
|
78
|
+
def __call__(self, *args, **kwargs):
|
|
79
|
+
return [
|
|
80
|
+
model.__call__(*args_, **kwargs) for model, args_ in length_safe_zip(self.models, _get_tensor_args(*args))
|
|
81
|
+
]
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def train_inputs(self):
|
|
85
|
+
return [model.train_inputs for model in self.models]
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def train_targets(self):
|
|
89
|
+
return [model.train_targets for model in self.models]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _get_tensor_args(*args):
|
|
93
|
+
for arg in args:
|
|
94
|
+
if torch.is_tensor(arg):
|
|
95
|
+
yield (arg,)
|
|
96
|
+
else:
|
|
97
|
+
yield arg
|
|
98
|
+
|
|
99
|
+
class UncorrelatedModelList(IndependentModelList):
|
|
100
|
+
pass
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
from ._pyro_mixin import _PyroMixin
|
|
5
|
+
from gpytorch.models.pyro.pyro_gp import PyroGP
|
|
6
|
+
from .pyro_qep import PyroQEP
|
|
7
|
+
except ImportError:
|
|
8
|
+
|
|
9
|
+
class PyroGP(object):
|
|
10
|
+
def __init__(self, *args, **kwargs):
|
|
11
|
+
raise RuntimeError("Cannot use a PyroGP because you dont have Pyro installed.")
|
|
12
|
+
|
|
13
|
+
class PyroQEP(object):
|
|
14
|
+
def __init__(self, *args, **kwargs):
|
|
15
|
+
raise RuntimeError("Cannot use a PyroQEP because you dont have Pyro installed.")
|
|
16
|
+
|
|
17
|
+
class _PyroMixin(object):
|
|
18
|
+
def pyro_factors(self, *args, **kwargs):
|
|
19
|
+
raise RuntimeError("Cannot call `pyro_factors` because you dont have Pyro installed.")
|
|
20
|
+
|
|
21
|
+
def pyro_guide(self, *args, **kwargs):
|
|
22
|
+
raise RuntimeError("Cannot call `pyro_sample` because you dont have Pyro installed.")
|
|
23
|
+
|
|
24
|
+
def pyro_model(self, *args, **kwargs):
|
|
25
|
+
raise RuntimeError("Cannot call `pyro_sample` because you dont have Pyro installed.")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
__all__ = ["PyroGP", "_PyroMixin", "PyroQEP"]
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import pyro
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
# from .distributions import QExponential
|
|
7
|
+
|
|
8
|
+
class _PyroMixin(object):
|
|
9
|
+
def pyro_guide(self, input, beta=1.0, name_prefix=""):
|
|
10
|
+
# Inducing values q(u)
|
|
11
|
+
with pyro.poutine.scale(scale=beta):
|
|
12
|
+
variational_distribution = self.variational_strategy.variational_distribution
|
|
13
|
+
variational_distribution = variational_distribution.to_event(len(variational_distribution.batch_shape))
|
|
14
|
+
pyro.sample(name_prefix + ".u", variational_distribution)
|
|
15
|
+
|
|
16
|
+
# Draw samples from q(f)
|
|
17
|
+
function_dist = self(input, prior=False)
|
|
18
|
+
if 'Normal' in input.__class__.__name__:
|
|
19
|
+
function_dist = pyro.distributions.Normal(loc=function_dist.mean, scale=function_dist.stddev).to_event(
|
|
20
|
+
len(function_dist.event_shape) - 1
|
|
21
|
+
)
|
|
22
|
+
elif 'QExponential' in input.__class__.__name__:
|
|
23
|
+
function_dist = pyro.distributions.QExponential(loc=function_dist.mean, scale=function_dist.stddev, power=function_dist.power).to_event(
|
|
24
|
+
len(function_dist.event_shape) - 1
|
|
25
|
+
)
|
|
26
|
+
return function_dist.mask(False)
|
|
27
|
+
|
|
28
|
+
def pyro_model(self, input, beta=1.0, name_prefix=""):
|
|
29
|
+
# Inducing values p(u)
|
|
30
|
+
with pyro.poutine.scale(scale=beta):
|
|
31
|
+
prior_distribution = self.variational_strategy.prior_distribution
|
|
32
|
+
prior_distribution = prior_distribution.to_event(len(prior_distribution.batch_shape))
|
|
33
|
+
u_samples = pyro.sample(name_prefix + ".u", prior_distribution)
|
|
34
|
+
|
|
35
|
+
# Include term for GPyTorch priors
|
|
36
|
+
log_prior = torch.tensor(0.0, dtype=u_samples.dtype, device=u_samples.device)
|
|
37
|
+
for _, module, prior, closure, _ in self.named_priors():
|
|
38
|
+
log_prior.add_(prior.log_prob(closure(module)).sum())
|
|
39
|
+
pyro.factor(name_prefix + ".log_prior", log_prior)
|
|
40
|
+
|
|
41
|
+
# Include factor for added loss terms
|
|
42
|
+
added_loss = torch.tensor(0.0, dtype=u_samples.dtype, device=u_samples.device)
|
|
43
|
+
for added_loss_term in self.added_loss_terms():
|
|
44
|
+
added_loss.add_(added_loss_term.loss())
|
|
45
|
+
pyro.factor(name_prefix + ".added_loss", added_loss)
|
|
46
|
+
|
|
47
|
+
# Draw samples from p(f)
|
|
48
|
+
function_dist = self(input, prior=True)
|
|
49
|
+
if 'Normal' in input.__class__.__name__:
|
|
50
|
+
function_dist = pyro.distributions.Normal(loc=function_dist.mean, scale=function_dist.stddev).to_event(
|
|
51
|
+
len(function_dist.event_shape) - 1
|
|
52
|
+
)
|
|
53
|
+
elif 'QExponential' in input.__class__.__name__:
|
|
54
|
+
function_dist = pyro.distributions.QExponential(loc=function_dist.mean, scale=function_dist.stddev, power=function_dist.power).to_event(
|
|
55
|
+
len(function_dist.event_shape) - 1
|
|
56
|
+
)
|
|
57
|
+
return function_dist.mask(False)
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import pyro
|
|
4
|
+
|
|
5
|
+
from ..qep import QEP
|
|
6
|
+
from ._pyro_mixin import _PyroMixin
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PyroQEP(QEP, _PyroMixin):
|
|
10
|
+
"""
|
|
11
|
+
A :obj:`~gpytorch.models.ApproximateQEP` designed to work with Pyro.
|
|
12
|
+
|
|
13
|
+
This module makes it possible to include QEP models with more complex probablistic models,
|
|
14
|
+
or to use likelihood functions with additional variational/approximate distributions.
|
|
15
|
+
|
|
16
|
+
The parameters of these models are learned using Pyro's inference tools, unlike other models
|
|
17
|
+
that optimize models with respect to a :obj:`~gpytorch.mlls.MarginalLogLikelihood`.
|
|
18
|
+
See `the Pyro examples <examples/09_Pyro_Integration/index.html>`_ for detailed examples.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
variational_strategy (:obj:`~gpytorch.variational.VariationalStrategy`):
|
|
22
|
+
The variational strategy that defines the variational distribution and
|
|
23
|
+
the marginalization strategy.
|
|
24
|
+
likelihood (:obj:`~gpytorch.likelihoods.Likelihood`):
|
|
25
|
+
The likelihood for the model
|
|
26
|
+
num_data (int):
|
|
27
|
+
The total number of training data points (necessary for SGD)
|
|
28
|
+
name_prefix (str, optional):
|
|
29
|
+
A prefix to put in front of pyro sample/plate sites
|
|
30
|
+
beta (float - default 1.):
|
|
31
|
+
A multiplicative factor for the KL divergence term.
|
|
32
|
+
Setting it to 1 (default) recovers true variational inference
|
|
33
|
+
(as derived in `Scalable Variational Gaussian Process Classification`_).
|
|
34
|
+
Setting it to anything less than 1 reduces the regularization effect of the model
|
|
35
|
+
(similarly to what was proposed in `the beta-VAE paper`_).
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
>>> class MyVariationalQEP(gpytorch.models.PyroQEP):
|
|
39
|
+
>>> # implementation
|
|
40
|
+
>>>
|
|
41
|
+
>>> # variational_strategy = ...
|
|
42
|
+
>>> likelihood = gpytorch.likelihoods.QExponentialLikelihood()
|
|
43
|
+
>>> model = MyVariationalQEP(variational_strategy, likelihood, train_y.size())
|
|
44
|
+
>>>
|
|
45
|
+
>>> optimizer = pyro.optim.Adam({"lr": 0.01})
|
|
46
|
+
>>> elbo = pyro.infer.Trace_ELBO(num_particles=64, vectorize_particles=True)
|
|
47
|
+
>>> svi = pyro.infer.SVI(model.model, model.guide, optimizer, elbo)
|
|
48
|
+
>>>
|
|
49
|
+
>>> # Optimize variational parameters
|
|
50
|
+
>>> for _ in range(n_iter):
|
|
51
|
+
>>> loss = svi.step(train_x, train_y)
|
|
52
|
+
|
|
53
|
+
.. _Scalable Variational Gaussian Process Classification:
|
|
54
|
+
http://proceedings.mlr.press/v38/hensman15.pdf
|
|
55
|
+
.. _the beta-VAE paper:
|
|
56
|
+
https://openreview.net/pdf?id=Sy2fzU9gl
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, variational_strategy, likelihood, num_data, name_prefix="", beta=1.0):
|
|
60
|
+
super().__init__()
|
|
61
|
+
self.variational_strategy = variational_strategy
|
|
62
|
+
self.name_prefix = name_prefix
|
|
63
|
+
self.likelihood = likelihood
|
|
64
|
+
self.num_data = num_data
|
|
65
|
+
self.beta = beta
|
|
66
|
+
|
|
67
|
+
# Set values for the likelihood
|
|
68
|
+
self.likelihood.num_data = num_data
|
|
69
|
+
self.likelihood.name_prefix = name_prefix
|
|
70
|
+
|
|
71
|
+
def guide(self, input, target, *args, **kwargs):
|
|
72
|
+
r"""
|
|
73
|
+
Guide function for Pyro inference.
|
|
74
|
+
Includes the guide for the QEP's likelihood function as well.
|
|
75
|
+
|
|
76
|
+
:param torch.Tensor input: :math:`\mathbf X` The input values values
|
|
77
|
+
:param torch.Tensor target: :math:`\mathbf y` The target values
|
|
78
|
+
:param args: Additional arguments passed to the likelihood's forward function.
|
|
79
|
+
:param kwargs: Additional keyword arguments passed to the likelihood's forward function.
|
|
80
|
+
"""
|
|
81
|
+
# Get q(f)
|
|
82
|
+
function_dist = self.pyro_guide(input, beta=self.beta, name_prefix=self.name_prefix)
|
|
83
|
+
return self.likelihood.pyro_guide(function_dist, target, *args, **kwargs)
|
|
84
|
+
|
|
85
|
+
def model(self, input, target, *args, **kwargs):
|
|
86
|
+
r"""
|
|
87
|
+
Model function for Pyro inference.
|
|
88
|
+
Includes the model for the QEP's likelihood function as well.
|
|
89
|
+
|
|
90
|
+
:param torch.Tensor input: :math:`\mathbf X` The input values values
|
|
91
|
+
:param torch.Tensor target: :math:`\mathbf y` The target values
|
|
92
|
+
:param args: Additional arguments passed to the likelihood's forward function.
|
|
93
|
+
:param kwargs: Additional keyword arguments passed to the likelihood's forward function.
|
|
94
|
+
"""
|
|
95
|
+
# Include module
|
|
96
|
+
pyro.module(self.name_prefix + ".qep", self)
|
|
97
|
+
|
|
98
|
+
# Get p(f)
|
|
99
|
+
function_dist = self.pyro_model(input, beta=self.beta, name_prefix=self.name_prefix)
|
|
100
|
+
return self.likelihood.pyro_model(function_dist, target, *args, **kwargs)
|
|
101
|
+
|
|
102
|
+
def __call__(self, inputs, prior=False):
|
|
103
|
+
if inputs.dim() == 1:
|
|
104
|
+
inputs = inputs.unsqueeze(-1)
|
|
105
|
+
return self.variational_strategy(inputs, prior=prior)
|
qpytorch/models/qep.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from .bayesian_qeplvm import BayesianQEPLVM
|
|
4
|
+
from .latent_variable import MAPLatentVariable, PointLatentVariable, VariationalLatentVariable
|
|
5
|
+
|
|
6
|
+
__all__ = ["BayesianQEPLVM", "PointLatentVariable", "MAPLatentVariable", "VariationalLatentVariable"]
|