qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,400 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from abc import abstractproperty
|
|
4
|
+
from unittest.mock import MagicMock, patch
|
|
5
|
+
|
|
6
|
+
import linear_operator
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
import qpytorch
|
|
10
|
+
|
|
11
|
+
from gpytorch.test.base_test_case import BaseTestCase
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class VariationalTestCase(BaseTestCase):
|
|
15
|
+
def _make_model_and_likelihood(
|
|
16
|
+
self,
|
|
17
|
+
num_inducing=16,
|
|
18
|
+
batch_shape=torch.Size([]),
|
|
19
|
+
inducing_batch_shape=torch.Size([]),
|
|
20
|
+
strategy_cls=qpytorch.variational.VariationalStrategy,
|
|
21
|
+
distribution_cls=qpytorch.variational.CholeskyVariationalDistribution,
|
|
22
|
+
constant_mean=True,
|
|
23
|
+
):
|
|
24
|
+
_power = getattr(self, '_power', 2.0)
|
|
25
|
+
class _SV_PRegressionModel(qpytorch.models.ApproximateGP if _power==2 else qpytorch.models.ApproximateQEP):
|
|
26
|
+
def __init__(self, inducing_points):
|
|
27
|
+
if _power!=2: self.power = torch.tensor(_power)
|
|
28
|
+
variational_distribution = distribution_cls(num_inducing, batch_shape=batch_shape, power=self.power) if hasattr(self, 'power') \
|
|
29
|
+
else distribution_cls(num_inducing, batch_shape=batch_shape)
|
|
30
|
+
variational_strategy = strategy_cls(
|
|
31
|
+
self,
|
|
32
|
+
inducing_points,
|
|
33
|
+
variational_distribution,
|
|
34
|
+
learn_inducing_locations=True,
|
|
35
|
+
)
|
|
36
|
+
super().__init__(variational_strategy)
|
|
37
|
+
if constant_mean:
|
|
38
|
+
self.mean_module = qpytorch.means.ConstantMean()
|
|
39
|
+
self.mean_module.initialize(constant=1.0)
|
|
40
|
+
else:
|
|
41
|
+
self.mean_module = qpytorch.means.ZeroMean()
|
|
42
|
+
self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())
|
|
43
|
+
|
|
44
|
+
def forward(self, x):
|
|
45
|
+
mean_x = self.mean_module(x)
|
|
46
|
+
covar_x = self.covar_module(x)
|
|
47
|
+
latent_pred = qpytorch.distributions.MultivariateQExponential(mean_x, covar_x, power=self.power) if hasattr(self, 'power') \
|
|
48
|
+
else qpytorch.distributions.MultivariateNormal(mean_x, covar_x)
|
|
49
|
+
return latent_pred
|
|
50
|
+
|
|
51
|
+
inducing_points = torch.randn(num_inducing, 2).repeat(*inducing_batch_shape, 1, 1)
|
|
52
|
+
return _SV_PRegressionModel(inducing_points), self.likelihood_cls()
|
|
53
|
+
|
|
54
|
+
def _training_iter(
|
|
55
|
+
self,
|
|
56
|
+
model,
|
|
57
|
+
likelihood,
|
|
58
|
+
batch_shape=torch.Size([]),
|
|
59
|
+
mll_cls=qpytorch.mlls.VariationalELBO,
|
|
60
|
+
cuda=False,
|
|
61
|
+
):
|
|
62
|
+
train_x = torch.randn(*batch_shape, 32, 2).clamp(-2.5, 2.5)
|
|
63
|
+
train_y = torch.linspace(-1, 1, self.event_shape[0])
|
|
64
|
+
train_y = train_y.view(self.event_shape[0], *([1] * (len(self.event_shape) - 1)))
|
|
65
|
+
train_y = train_y.expand(*self.event_shape)
|
|
66
|
+
mll = mll_cls(likelihood, model, num_data=train_x.size(-2))
|
|
67
|
+
if cuda:
|
|
68
|
+
train_x = train_x.cuda()
|
|
69
|
+
train_y = train_y.cuda()
|
|
70
|
+
model = model.cuda()
|
|
71
|
+
likelihood = likelihood.cuda()
|
|
72
|
+
|
|
73
|
+
# Single optimization iteration
|
|
74
|
+
model.train()
|
|
75
|
+
likelihood.train()
|
|
76
|
+
output = model(train_x)
|
|
77
|
+
loss = -mll(output, train_y)
|
|
78
|
+
loss.sum().backward()
|
|
79
|
+
|
|
80
|
+
# Make sure we have gradients for all parameters
|
|
81
|
+
for _, param in model.named_parameters():
|
|
82
|
+
self.assertTrue(param.grad is not None)
|
|
83
|
+
self.assertGreater(param.grad.norm().item(), 0)
|
|
84
|
+
for _, param in likelihood.named_parameters():
|
|
85
|
+
self.assertTrue(param.grad is not None)
|
|
86
|
+
self.assertGreater(param.grad.norm().item(), 0)
|
|
87
|
+
|
|
88
|
+
return output, loss
|
|
89
|
+
|
|
90
|
+
def _eval_iter(self, model, batch_shape=torch.Size([]), cuda=False):
|
|
91
|
+
test_x = torch.randn(*batch_shape, 32, 2).clamp(-2.5, 2.5)
|
|
92
|
+
if cuda:
|
|
93
|
+
test_x = test_x.cuda()
|
|
94
|
+
model = model.cuda()
|
|
95
|
+
|
|
96
|
+
# Single optimization iteration
|
|
97
|
+
model.eval()
|
|
98
|
+
with torch.no_grad():
|
|
99
|
+
output = model(test_x)
|
|
100
|
+
|
|
101
|
+
return output
|
|
102
|
+
|
|
103
|
+
def _fantasy_iter(
|
|
104
|
+
self,
|
|
105
|
+
model,
|
|
106
|
+
likelihood,
|
|
107
|
+
batch_shape=torch.Size([]),
|
|
108
|
+
cuda=False,
|
|
109
|
+
num_fant=10,
|
|
110
|
+
covar_module=None,
|
|
111
|
+
mean_module=None,
|
|
112
|
+
):
|
|
113
|
+
model.likelihood = likelihood
|
|
114
|
+
val_x = torch.randn(*batch_shape, num_fant, 2).clamp(-2.5, 2.5)
|
|
115
|
+
val_y = torch.linspace(-1, 1, num_fant)
|
|
116
|
+
val_y = val_y.view(num_fant, *([1] * (len(self.event_shape) - 1)))
|
|
117
|
+
val_y = val_y.expand(*batch_shape, num_fant, *self.event_shape[1:])
|
|
118
|
+
if cuda:
|
|
119
|
+
model = model.cuda()
|
|
120
|
+
val_x = val_x.cuda()
|
|
121
|
+
val_y = val_y.cuda()
|
|
122
|
+
updated_model = model.get_fantasy_model(val_x, val_y, covar_module=covar_module, mean_module=mean_module)
|
|
123
|
+
return updated_model
|
|
124
|
+
|
|
125
|
+
@abstractproperty
|
|
126
|
+
def batch_shape(self):
|
|
127
|
+
raise NotImplementedError
|
|
128
|
+
|
|
129
|
+
@abstractproperty
|
|
130
|
+
def distribution_cls(self):
|
|
131
|
+
raise NotImplementedError
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def event_shape(self):
|
|
135
|
+
return torch.Size([32])
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def likelihood_cls(self):
|
|
139
|
+
return qpytorch.likelihoods.GaussianLikelihood if self._power==2 else qpytorch.likelihoods.QExponentialLikelihood
|
|
140
|
+
|
|
141
|
+
@abstractproperty
|
|
142
|
+
def mll_cls(self):
|
|
143
|
+
raise NotImplementedError
|
|
144
|
+
|
|
145
|
+
@abstractproperty
|
|
146
|
+
def strategy_cls(self):
|
|
147
|
+
raise NotImplementedError
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def cuda(self):
|
|
151
|
+
return False
|
|
152
|
+
|
|
153
|
+
def test_eval_iteration(
|
|
154
|
+
self,
|
|
155
|
+
data_batch_shape=None,
|
|
156
|
+
inducing_batch_shape=None,
|
|
157
|
+
model_batch_shape=None,
|
|
158
|
+
eval_data_batch_shape=None,
|
|
159
|
+
expected_batch_shape=None,
|
|
160
|
+
):
|
|
161
|
+
# Batch shapes
|
|
162
|
+
model_batch_shape = model_batch_shape if model_batch_shape is not None else self.batch_shape
|
|
163
|
+
data_batch_shape = data_batch_shape if data_batch_shape is not None else self.batch_shape
|
|
164
|
+
inducing_batch_shape = inducing_batch_shape if inducing_batch_shape is not None else self.batch_shape
|
|
165
|
+
expected_batch_shape = expected_batch_shape if expected_batch_shape is not None else self.batch_shape
|
|
166
|
+
eval_data_batch_shape = eval_data_batch_shape if eval_data_batch_shape is not None else self.batch_shape
|
|
167
|
+
|
|
168
|
+
# Mocks
|
|
169
|
+
_wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex)
|
|
170
|
+
_wrapped_cg = MagicMock(wraps=linear_operator.utils.linear_cg)
|
|
171
|
+
_wrapped_ciq = MagicMock(wraps=linear_operator.utils.contour_integral_quad)
|
|
172
|
+
_cholesky_mock = patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky)
|
|
173
|
+
_cg_mock = patch("linear_operator.utils.linear_cg", new=_wrapped_cg)
|
|
174
|
+
_ciq_mock = patch("linear_operator.utils.contour_integral_quad", new=_wrapped_ciq)
|
|
175
|
+
|
|
176
|
+
# Make model and likelihood
|
|
177
|
+
model, likelihood = self._make_model_and_likelihood(
|
|
178
|
+
batch_shape=model_batch_shape,
|
|
179
|
+
inducing_batch_shape=inducing_batch_shape,
|
|
180
|
+
distribution_cls=self.distribution_cls,
|
|
181
|
+
strategy_cls=self.strategy_cls,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Do one forward pass
|
|
185
|
+
self._training_iter(model, likelihood, data_batch_shape, mll_cls=self.mll_cls, cuda=self.cuda)
|
|
186
|
+
|
|
187
|
+
# Now do evaluation
|
|
188
|
+
with _cholesky_mock as cholesky_mock, _cg_mock as cg_mock, _ciq_mock as ciq_mock:
|
|
189
|
+
# Iter 1
|
|
190
|
+
_ = self._eval_iter(model, eval_data_batch_shape, cuda=self.cuda)
|
|
191
|
+
output = self._eval_iter(model, eval_data_batch_shape, cuda=self.cuda)
|
|
192
|
+
self.assertEqual(output.batch_shape, expected_batch_shape)
|
|
193
|
+
self.assertEqual(output.event_shape, self.event_shape)
|
|
194
|
+
return cg_mock, cholesky_mock, ciq_mock
|
|
195
|
+
|
|
196
|
+
def test_eval_smaller_pred_batch(self):
|
|
197
|
+
return self.test_eval_iteration(
|
|
198
|
+
model_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
|
|
199
|
+
inducing_batch_shape=(torch.Size([3, 1]) + self.batch_shape),
|
|
200
|
+
data_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
|
|
201
|
+
eval_data_batch_shape=(torch.Size([4]) + self.batch_shape),
|
|
202
|
+
expected_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
def test_eval_larger_pred_batch(self):
|
|
206
|
+
return self.test_eval_iteration(
|
|
207
|
+
model_batch_shape=(torch.Size([4]) + self.batch_shape),
|
|
208
|
+
inducing_batch_shape=(self.batch_shape),
|
|
209
|
+
data_batch_shape=(torch.Size([4]) + self.batch_shape),
|
|
210
|
+
eval_data_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
|
|
211
|
+
expected_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
def test_training_iteration(
|
|
215
|
+
self,
|
|
216
|
+
data_batch_shape=None,
|
|
217
|
+
inducing_batch_shape=None,
|
|
218
|
+
model_batch_shape=None,
|
|
219
|
+
expected_batch_shape=None,
|
|
220
|
+
constant_mean=True,
|
|
221
|
+
):
|
|
222
|
+
# Batch shapes
|
|
223
|
+
model_batch_shape = model_batch_shape if model_batch_shape is not None else self.batch_shape
|
|
224
|
+
data_batch_shape = data_batch_shape if data_batch_shape is not None else self.batch_shape
|
|
225
|
+
inducing_batch_shape = inducing_batch_shape if inducing_batch_shape is not None else self.batch_shape
|
|
226
|
+
expected_batch_shape = expected_batch_shape if expected_batch_shape is not None else self.batch_shape
|
|
227
|
+
|
|
228
|
+
# Mocks
|
|
229
|
+
_wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex)
|
|
230
|
+
_wrapped_cg = MagicMock(wraps=linear_operator.utils.linear_cg)
|
|
231
|
+
_wrapped_ciq = MagicMock(wraps=linear_operator.utils.contour_integral_quad)
|
|
232
|
+
_cholesky_mock = patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky)
|
|
233
|
+
_cg_mock = patch("linear_operator.utils.linear_cg", new=_wrapped_cg)
|
|
234
|
+
_ciq_mock = patch("linear_operator.utils.contour_integral_quad", new=_wrapped_ciq)
|
|
235
|
+
|
|
236
|
+
# Make model and likelihood
|
|
237
|
+
model, likelihood = self._make_model_and_likelihood(
|
|
238
|
+
batch_shape=model_batch_shape,
|
|
239
|
+
inducing_batch_shape=inducing_batch_shape,
|
|
240
|
+
distribution_cls=self.distribution_cls,
|
|
241
|
+
strategy_cls=self.strategy_cls,
|
|
242
|
+
constant_mean=constant_mean,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# Do forward pass
|
|
246
|
+
with _cholesky_mock as cholesky_mock, _cg_mock as cg_mock, _ciq_mock as ciq_mock:
|
|
247
|
+
# Iter 1
|
|
248
|
+
self.assertEqual(model.variational_strategy.variational_params_initialized.item(), 0)
|
|
249
|
+
self._training_iter(
|
|
250
|
+
model,
|
|
251
|
+
likelihood,
|
|
252
|
+
data_batch_shape,
|
|
253
|
+
mll_cls=self.mll_cls,
|
|
254
|
+
cuda=self.cuda,
|
|
255
|
+
)
|
|
256
|
+
self.assertEqual(model.variational_strategy.variational_params_initialized.item(), 1)
|
|
257
|
+
# Iter 2
|
|
258
|
+
output, loss = self._training_iter(
|
|
259
|
+
model,
|
|
260
|
+
likelihood,
|
|
261
|
+
data_batch_shape,
|
|
262
|
+
mll_cls=self.mll_cls,
|
|
263
|
+
cuda=self.cuda,
|
|
264
|
+
)
|
|
265
|
+
self.assertEqual(output.batch_shape, expected_batch_shape)
|
|
266
|
+
self.assertEqual(output.event_shape, self.event_shape)
|
|
267
|
+
self.assertEqual(loss.shape, expected_batch_shape)
|
|
268
|
+
return cg_mock, cholesky_mock, ciq_mock
|
|
269
|
+
|
|
270
|
+
def test_training_iteration_batch_inducing(self):
|
|
271
|
+
return self.test_training_iteration(
|
|
272
|
+
model_batch_shape=(torch.Size([3]) + self.batch_shape),
|
|
273
|
+
data_batch_shape=self.batch_shape,
|
|
274
|
+
inducing_batch_shape=(torch.Size([3]) + self.batch_shape),
|
|
275
|
+
expected_batch_shape=(torch.Size([3]) + self.batch_shape),
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
def test_training_iteration_batch_data(self):
|
|
279
|
+
return self.test_training_iteration(
|
|
280
|
+
model_batch_shape=self.batch_shape,
|
|
281
|
+
inducing_batch_shape=self.batch_shape,
|
|
282
|
+
data_batch_shape=(torch.Size([3]) + self.batch_shape),
|
|
283
|
+
expected_batch_shape=(torch.Size([3]) + self.batch_shape),
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
def test_training_iteration_batch_model(self):
|
|
287
|
+
return self.test_training_iteration(
|
|
288
|
+
model_batch_shape=(torch.Size([3]) + self.batch_shape),
|
|
289
|
+
inducing_batch_shape=self.batch_shape,
|
|
290
|
+
data_batch_shape=self.batch_shape,
|
|
291
|
+
expected_batch_shape=(torch.Size([3]) + self.batch_shape),
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
def test_training_all_batch_zero_mean(self):
|
|
295
|
+
return self.test_training_iteration(
|
|
296
|
+
model_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
|
|
297
|
+
inducing_batch_shape=(torch.Size([3, 1]) + self.batch_shape),
|
|
298
|
+
data_batch_shape=(torch.Size([4]) + self.batch_shape),
|
|
299
|
+
expected_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
|
|
300
|
+
constant_mean=False,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
def test_fantasy_call(
|
|
304
|
+
self,
|
|
305
|
+
data_batch_shape=None,
|
|
306
|
+
inducing_batch_shape=None,
|
|
307
|
+
model_batch_shape=None,
|
|
308
|
+
expected_batch_shape=None,
|
|
309
|
+
constant_mean=True,
|
|
310
|
+
):
|
|
311
|
+
# Batch shapes
|
|
312
|
+
model_batch_shape = model_batch_shape if model_batch_shape is not None else self.batch_shape
|
|
313
|
+
data_batch_shape = data_batch_shape if data_batch_shape is not None else self.batch_shape
|
|
314
|
+
inducing_batch_shape = inducing_batch_shape if inducing_batch_shape is not None else self.batch_shape
|
|
315
|
+
expected_batch_shape = expected_batch_shape if expected_batch_shape is not None else self.batch_shape
|
|
316
|
+
|
|
317
|
+
num_inducing = 16
|
|
318
|
+
num_fant = 10
|
|
319
|
+
# Make model and likelihood
|
|
320
|
+
model, likelihood = self._make_model_and_likelihood(
|
|
321
|
+
batch_shape=model_batch_shape,
|
|
322
|
+
inducing_batch_shape=inducing_batch_shape,
|
|
323
|
+
distribution_cls=self.distribution_cls,
|
|
324
|
+
strategy_cls=self.strategy_cls,
|
|
325
|
+
constant_mean=constant_mean,
|
|
326
|
+
num_inducing=num_inducing,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
# we iterate through the covar and mean module possible settings
|
|
330
|
+
covar_mean_options = [
|
|
331
|
+
{"covar_module": None, "mean_module": None},
|
|
332
|
+
{"covar_module": qpytorch.kernels.MaternKernel(), "mean_module": qpytorch.means.ZeroMean()},
|
|
333
|
+
]
|
|
334
|
+
for cm_dict in covar_mean_options:
|
|
335
|
+
fant_model = self._fantasy_iter(
|
|
336
|
+
model, likelihood, data_batch_shape, self.cuda, num_fant=num_fant, **cm_dict
|
|
337
|
+
)
|
|
338
|
+
self.assertTrue(isinstance(fant_model, {2.0: qpytorch.models.ExactGP, 1.0: qpytorch.models.ExactQEP}[self._power]))
|
|
339
|
+
|
|
340
|
+
# we check to ensure setting the covar_module and mean_modules are okay
|
|
341
|
+
if cm_dict["covar_module"] is None:
|
|
342
|
+
self.assertEqual(type(fant_model.covar_module), type(model.covar_module))
|
|
343
|
+
else:
|
|
344
|
+
self.assertNotEqual(type(fant_model.covar_module), type(model.covar_module))
|
|
345
|
+
if cm_dict["mean_module"] is None:
|
|
346
|
+
self.assertEqual(type(fant_model.mean_module), type(model.mean_module))
|
|
347
|
+
else:
|
|
348
|
+
self.assertNotEqual(type(fant_model.mean_module), type(model.mean_module))
|
|
349
|
+
|
|
350
|
+
# now we check to ensure the shapes of the fantasy strategy are correct
|
|
351
|
+
self.assertTrue(fant_model.prediction_strategy is not None)
|
|
352
|
+
for key in fant_model.prediction_strategy._memoize_cache.keys():
|
|
353
|
+
if key[0] == "mean_cache":
|
|
354
|
+
break
|
|
355
|
+
mean_cache = fant_model.prediction_strategy._memoize_cache[key]
|
|
356
|
+
self.assertEqual(mean_cache.shape, torch.Size([*expected_batch_shape, num_inducing + num_fant]))
|
|
357
|
+
|
|
358
|
+
# we remove the mean_module and covar_module and check for errors
|
|
359
|
+
del model.mean_module
|
|
360
|
+
with self.assertRaises(ModuleNotFoundError):
|
|
361
|
+
self._fantasy_iter(model, likelihood, data_batch_shape, self.cuda, num_fant=num_fant)
|
|
362
|
+
|
|
363
|
+
model.mean_module = qpytorch.means.ZeroMean()
|
|
364
|
+
del model.covar_module
|
|
365
|
+
with self.assertRaises(ModuleNotFoundError):
|
|
366
|
+
self._fantasy_iter(model, likelihood, data_batch_shape, self.cuda, num_fant=num_fant)
|
|
367
|
+
|
|
368
|
+
# finally we check to ensure failure for a non-gaussian likelihood
|
|
369
|
+
with self.assertRaises(NotImplementedError):
|
|
370
|
+
self._fantasy_iter(
|
|
371
|
+
model,
|
|
372
|
+
qpytorch.likelihoods.BernoulliLikelihood(),
|
|
373
|
+
data_batch_shape,
|
|
374
|
+
self.cuda,
|
|
375
|
+
num_fant=num_fant,
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
def test_fantasy_call_batch_inducing(self):
|
|
379
|
+
return self.test_fantasy_call(
|
|
380
|
+
model_batch_shape=(torch.Size([3]) + self.batch_shape),
|
|
381
|
+
data_batch_shape=self.batch_shape,
|
|
382
|
+
inducing_batch_shape=(torch.Size([3]) + self.batch_shape),
|
|
383
|
+
expected_batch_shape=(torch.Size([3]) + self.batch_shape),
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
def test_fantasy_call_batch_data(self):
|
|
387
|
+
return self.test_fantasy_call(
|
|
388
|
+
model_batch_shape=self.batch_shape,
|
|
389
|
+
inducing_batch_shape=self.batch_shape,
|
|
390
|
+
data_batch_shape=(torch.Size([3]) + self.batch_shape),
|
|
391
|
+
expected_batch_shape=(torch.Size([3]) + self.batch_shape),
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
def test_fantasy_call_batch_model(self):
|
|
395
|
+
return self.test_fantasy_call(
|
|
396
|
+
model_batch_shape=(torch.Size([3]) + self.batch_shape),
|
|
397
|
+
inducing_batch_shape=self.batch_shape,
|
|
398
|
+
data_batch_shape=self.batch_shape,
|
|
399
|
+
expected_batch_shape=(torch.Size([3]) + self.batch_shape),
|
|
400
|
+
)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import warnings as _warnings
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import linear_operator
|
|
9
|
+
|
|
10
|
+
from gpytorch.utils import deprecation, errors, generic, grid, interpolation, quadrature, transforms
|
|
11
|
+
from . import warnings
|
|
12
|
+
from gpytorch.utils.memoize import cached
|
|
13
|
+
from gpytorch.utils.nearest_neighbors import NNUtil
|
|
14
|
+
from gpytorch.utils.sum_interaction_terms import sum_interaction_terms
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"cached",
|
|
18
|
+
"deprecation",
|
|
19
|
+
"errors",
|
|
20
|
+
"generic",
|
|
21
|
+
"grid",
|
|
22
|
+
"interpolation",
|
|
23
|
+
"quadrature",
|
|
24
|
+
"sum_interaction_terms",
|
|
25
|
+
"transforms",
|
|
26
|
+
"warnings",
|
|
27
|
+
"NNUtil",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def __getattr__(name: str) -> Any:
|
|
32
|
+
if hasattr(linear_operator.utils, name):
|
|
33
|
+
_warnings.warn(
|
|
34
|
+
f"gpytorch.utils.{name} is deprecated. Use linear_operator.utils.{name} instead.",
|
|
35
|
+
DeprecationWarning,
|
|
36
|
+
)
|
|
37
|
+
return getattr(linear_operator.utils, name)
|
|
38
|
+
raise AttributeError(f"module gpytorch.utils has no attribute {name}")
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from linear_operator.utils.warnings import NumericalWarning
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GPInputWarning(UserWarning):
|
|
7
|
+
"""
|
|
8
|
+
Warning thrown when a GP model receives an unexpected input.
|
|
9
|
+
For example, when an :obj:`~gpytorch.models.ExactGP` in eval mode receives the training data as input.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class QEPInputWarning(UserWarning):
|
|
16
|
+
"""
|
|
17
|
+
Warning thrown when a QEP model receives an unexpected input.
|
|
18
|
+
For example, when an :obj:`~qpytorch.models.ExactQEP` in eval mode receives the training data as input.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OldVersionWarning(UserWarning):
|
|
25
|
+
"""
|
|
26
|
+
Warning thrown when loading a saved model from an outdated version of GPyTorch/QPyTorch.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
__all__ = [
|
|
33
|
+
"GPInputWarning",
|
|
34
|
+
"QEPInputWarning",
|
|
35
|
+
"OldVersionWarning",
|
|
36
|
+
"NumericalWarning",
|
|
37
|
+
]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from ._variational_distribution import _VariationalDistribution
|
|
4
|
+
from ._variational_strategy import _VariationalStrategy
|
|
5
|
+
from .additive_grid_interpolation_variational_strategy import AdditiveGridInterpolationVariationalStrategy
|
|
6
|
+
from .batch_decoupled_variational_strategy import BatchDecoupledVariationalStrategy
|
|
7
|
+
from .cholesky_variational_distribution import CholeskyVariationalDistribution
|
|
8
|
+
from .ciq_variational_strategy import CiqVariationalStrategy
|
|
9
|
+
from .delta_variational_distribution import DeltaVariationalDistribution
|
|
10
|
+
from .grid_interpolation_variational_strategy import GridInterpolationVariationalStrategy
|
|
11
|
+
from .independent_multitask_variational_strategy import (
|
|
12
|
+
IndependentMultitaskVariationalStrategy,
|
|
13
|
+
# MultitaskVariationalStrategy,
|
|
14
|
+
)
|
|
15
|
+
from .lmc_variational_strategy import LMCVariationalStrategy
|
|
16
|
+
from .multitask_variational_strategy import MultitaskVariationalStrategy
|
|
17
|
+
from .mean_field_variational_distribution import MeanFieldVariationalDistribution
|
|
18
|
+
from .natural_variational_distribution import _NaturalVariationalDistribution, NaturalVariationalDistribution
|
|
19
|
+
from .nearest_neighbor_variational_strategy import NNVariationalStrategy
|
|
20
|
+
from .orthogonally_decoupled_variational_strategy import OrthogonallyDecoupledVariationalStrategy
|
|
21
|
+
from .tril_natural_variational_distribution import TrilNaturalVariationalDistribution
|
|
22
|
+
from .uncorrelated_multitask_variational_strategy import UncorrelatedMultitaskVariationalStrategy
|
|
23
|
+
from .unwhitened_variational_strategy import UnwhitenedVariationalStrategy
|
|
24
|
+
from .variational_strategy import VariationalStrategy
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"_VariationalStrategy",
|
|
28
|
+
"AdditiveGridInterpolationVariationalStrategy",
|
|
29
|
+
"BatchDecoupledVariationalStrategy",
|
|
30
|
+
"CiqVariationalStrategy",
|
|
31
|
+
"GridInterpolationVariationalStrategy",
|
|
32
|
+
"IndependentMultitaskVariationalStrategy",
|
|
33
|
+
"LMCVariationalStrategy",
|
|
34
|
+
"MultitaskVariationalStrategy",
|
|
35
|
+
"OrthogonallyDecoupledVariationalStrategy",
|
|
36
|
+
"VariationalStrategy",
|
|
37
|
+
"UnwhitenedVariationalStrategy",
|
|
38
|
+
"UncorrelatedMultitaskVariationalStrategy",
|
|
39
|
+
"_VariationalDistribution",
|
|
40
|
+
"CholeskyVariationalDistribution",
|
|
41
|
+
"MeanFieldVariationalDistribution",
|
|
42
|
+
"DeltaVariationalDistribution",
|
|
43
|
+
"_NaturalVariationalDistribution",
|
|
44
|
+
"NaturalVariationalDistribution",
|
|
45
|
+
"TrilNaturalVariationalDistribution",
|
|
46
|
+
"NNVariationalStrategy",
|
|
47
|
+
]
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ..distributions import Distribution, MultivariateNormal, MultivariateQExponential
|
|
10
|
+
from ..module import Module
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _VariationalDistribution(Module, ABC):
|
|
14
|
+
r"""
|
|
15
|
+
Abstract base class for all Variational Distributions.
|
|
16
|
+
|
|
17
|
+
:ivar torch.dtype dtype: The dtype of the VariationalDistribution parameters
|
|
18
|
+
:ivar torch.dtype device: The device of the VariationalDistribution parameters
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, num_inducing_points: int, batch_shape: torch.Size = torch.Size([]), mean_init_std: float = 1e-3):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.num_inducing_points = num_inducing_points
|
|
24
|
+
self.batch_shape = batch_shape
|
|
25
|
+
self.mean_init_std = mean_init_std
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def device(self) -> torch.device:
|
|
29
|
+
return next(self.parameters()).device
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def dtype(self) -> torch.dtype:
|
|
33
|
+
return next(self.parameters()).dtype
|
|
34
|
+
|
|
35
|
+
def forward(self) -> Distribution:
|
|
36
|
+
r"""
|
|
37
|
+
Constructs and returns the variational distribution
|
|
38
|
+
|
|
39
|
+
:rtype: ~gpytorch.distributions.MultivariateNormal or ~qpytorch.distributions.MultivariateQExponential
|
|
40
|
+
:return: The distribution :math:`q(\mathbf u)`
|
|
41
|
+
"""
|
|
42
|
+
raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
def shape(self) -> torch.Size:
|
|
45
|
+
r"""
|
|
46
|
+
Event + batch shape of VariationalDistribution object
|
|
47
|
+
:rtype: torch.Size
|
|
48
|
+
"""
|
|
49
|
+
return torch.Size([*self.batch_shape, self.num_inducing_points])
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def initialize_variational_distribution(self, prior_dist: Union[MultivariateNormal, MultivariateQExponential]) -> None:
|
|
53
|
+
r"""
|
|
54
|
+
Method for initializing the variational distribution, based on the prior distribution.
|
|
55
|
+
|
|
56
|
+
:param ~gpytorch.distributions.Distribution prior_dist: The prior distribution :math:`p(\mathbf u)`.
|
|
57
|
+
"""
|
|
58
|
+
raise NotImplementedError
|
|
59
|
+
|
|
60
|
+
def __call__(self) -> Distribution:
|
|
61
|
+
return self.forward()
|