qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,400 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from abc import abstractproperty
4
+ from unittest.mock import MagicMock, patch
5
+
6
+ import linear_operator
7
+ import torch
8
+
9
+ import qpytorch
10
+
11
+ from gpytorch.test.base_test_case import BaseTestCase
12
+
13
+
14
+ class VariationalTestCase(BaseTestCase):
15
+ def _make_model_and_likelihood(
16
+ self,
17
+ num_inducing=16,
18
+ batch_shape=torch.Size([]),
19
+ inducing_batch_shape=torch.Size([]),
20
+ strategy_cls=qpytorch.variational.VariationalStrategy,
21
+ distribution_cls=qpytorch.variational.CholeskyVariationalDistribution,
22
+ constant_mean=True,
23
+ ):
24
+ _power = getattr(self, '_power', 2.0)
25
+ class _SV_PRegressionModel(qpytorch.models.ApproximateGP if _power==2 else qpytorch.models.ApproximateQEP):
26
+ def __init__(self, inducing_points):
27
+ if _power!=2: self.power = torch.tensor(_power)
28
+ variational_distribution = distribution_cls(num_inducing, batch_shape=batch_shape, power=self.power) if hasattr(self, 'power') \
29
+ else distribution_cls(num_inducing, batch_shape=batch_shape)
30
+ variational_strategy = strategy_cls(
31
+ self,
32
+ inducing_points,
33
+ variational_distribution,
34
+ learn_inducing_locations=True,
35
+ )
36
+ super().__init__(variational_strategy)
37
+ if constant_mean:
38
+ self.mean_module = qpytorch.means.ConstantMean()
39
+ self.mean_module.initialize(constant=1.0)
40
+ else:
41
+ self.mean_module = qpytorch.means.ZeroMean()
42
+ self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())
43
+
44
+ def forward(self, x):
45
+ mean_x = self.mean_module(x)
46
+ covar_x = self.covar_module(x)
47
+ latent_pred = qpytorch.distributions.MultivariateQExponential(mean_x, covar_x, power=self.power) if hasattr(self, 'power') \
48
+ else qpytorch.distributions.MultivariateNormal(mean_x, covar_x)
49
+ return latent_pred
50
+
51
+ inducing_points = torch.randn(num_inducing, 2).repeat(*inducing_batch_shape, 1, 1)
52
+ return _SV_PRegressionModel(inducing_points), self.likelihood_cls()
53
+
54
+ def _training_iter(
55
+ self,
56
+ model,
57
+ likelihood,
58
+ batch_shape=torch.Size([]),
59
+ mll_cls=qpytorch.mlls.VariationalELBO,
60
+ cuda=False,
61
+ ):
62
+ train_x = torch.randn(*batch_shape, 32, 2).clamp(-2.5, 2.5)
63
+ train_y = torch.linspace(-1, 1, self.event_shape[0])
64
+ train_y = train_y.view(self.event_shape[0], *([1] * (len(self.event_shape) - 1)))
65
+ train_y = train_y.expand(*self.event_shape)
66
+ mll = mll_cls(likelihood, model, num_data=train_x.size(-2))
67
+ if cuda:
68
+ train_x = train_x.cuda()
69
+ train_y = train_y.cuda()
70
+ model = model.cuda()
71
+ likelihood = likelihood.cuda()
72
+
73
+ # Single optimization iteration
74
+ model.train()
75
+ likelihood.train()
76
+ output = model(train_x)
77
+ loss = -mll(output, train_y)
78
+ loss.sum().backward()
79
+
80
+ # Make sure we have gradients for all parameters
81
+ for _, param in model.named_parameters():
82
+ self.assertTrue(param.grad is not None)
83
+ self.assertGreater(param.grad.norm().item(), 0)
84
+ for _, param in likelihood.named_parameters():
85
+ self.assertTrue(param.grad is not None)
86
+ self.assertGreater(param.grad.norm().item(), 0)
87
+
88
+ return output, loss
89
+
90
+ def _eval_iter(self, model, batch_shape=torch.Size([]), cuda=False):
91
+ test_x = torch.randn(*batch_shape, 32, 2).clamp(-2.5, 2.5)
92
+ if cuda:
93
+ test_x = test_x.cuda()
94
+ model = model.cuda()
95
+
96
+ # Single optimization iteration
97
+ model.eval()
98
+ with torch.no_grad():
99
+ output = model(test_x)
100
+
101
+ return output
102
+
103
+ def _fantasy_iter(
104
+ self,
105
+ model,
106
+ likelihood,
107
+ batch_shape=torch.Size([]),
108
+ cuda=False,
109
+ num_fant=10,
110
+ covar_module=None,
111
+ mean_module=None,
112
+ ):
113
+ model.likelihood = likelihood
114
+ val_x = torch.randn(*batch_shape, num_fant, 2).clamp(-2.5, 2.5)
115
+ val_y = torch.linspace(-1, 1, num_fant)
116
+ val_y = val_y.view(num_fant, *([1] * (len(self.event_shape) - 1)))
117
+ val_y = val_y.expand(*batch_shape, num_fant, *self.event_shape[1:])
118
+ if cuda:
119
+ model = model.cuda()
120
+ val_x = val_x.cuda()
121
+ val_y = val_y.cuda()
122
+ updated_model = model.get_fantasy_model(val_x, val_y, covar_module=covar_module, mean_module=mean_module)
123
+ return updated_model
124
+
125
+ @abstractproperty
126
+ def batch_shape(self):
127
+ raise NotImplementedError
128
+
129
+ @abstractproperty
130
+ def distribution_cls(self):
131
+ raise NotImplementedError
132
+
133
+ @property
134
+ def event_shape(self):
135
+ return torch.Size([32])
136
+
137
+ @property
138
+ def likelihood_cls(self):
139
+ return qpytorch.likelihoods.GaussianLikelihood if self._power==2 else qpytorch.likelihoods.QExponentialLikelihood
140
+
141
+ @abstractproperty
142
+ def mll_cls(self):
143
+ raise NotImplementedError
144
+
145
+ @abstractproperty
146
+ def strategy_cls(self):
147
+ raise NotImplementedError
148
+
149
+ @property
150
+ def cuda(self):
151
+ return False
152
+
153
+ def test_eval_iteration(
154
+ self,
155
+ data_batch_shape=None,
156
+ inducing_batch_shape=None,
157
+ model_batch_shape=None,
158
+ eval_data_batch_shape=None,
159
+ expected_batch_shape=None,
160
+ ):
161
+ # Batch shapes
162
+ model_batch_shape = model_batch_shape if model_batch_shape is not None else self.batch_shape
163
+ data_batch_shape = data_batch_shape if data_batch_shape is not None else self.batch_shape
164
+ inducing_batch_shape = inducing_batch_shape if inducing_batch_shape is not None else self.batch_shape
165
+ expected_batch_shape = expected_batch_shape if expected_batch_shape is not None else self.batch_shape
166
+ eval_data_batch_shape = eval_data_batch_shape if eval_data_batch_shape is not None else self.batch_shape
167
+
168
+ # Mocks
169
+ _wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex)
170
+ _wrapped_cg = MagicMock(wraps=linear_operator.utils.linear_cg)
171
+ _wrapped_ciq = MagicMock(wraps=linear_operator.utils.contour_integral_quad)
172
+ _cholesky_mock = patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky)
173
+ _cg_mock = patch("linear_operator.utils.linear_cg", new=_wrapped_cg)
174
+ _ciq_mock = patch("linear_operator.utils.contour_integral_quad", new=_wrapped_ciq)
175
+
176
+ # Make model and likelihood
177
+ model, likelihood = self._make_model_and_likelihood(
178
+ batch_shape=model_batch_shape,
179
+ inducing_batch_shape=inducing_batch_shape,
180
+ distribution_cls=self.distribution_cls,
181
+ strategy_cls=self.strategy_cls,
182
+ )
183
+
184
+ # Do one forward pass
185
+ self._training_iter(model, likelihood, data_batch_shape, mll_cls=self.mll_cls, cuda=self.cuda)
186
+
187
+ # Now do evaluation
188
+ with _cholesky_mock as cholesky_mock, _cg_mock as cg_mock, _ciq_mock as ciq_mock:
189
+ # Iter 1
190
+ _ = self._eval_iter(model, eval_data_batch_shape, cuda=self.cuda)
191
+ output = self._eval_iter(model, eval_data_batch_shape, cuda=self.cuda)
192
+ self.assertEqual(output.batch_shape, expected_batch_shape)
193
+ self.assertEqual(output.event_shape, self.event_shape)
194
+ return cg_mock, cholesky_mock, ciq_mock
195
+
196
+ def test_eval_smaller_pred_batch(self):
197
+ return self.test_eval_iteration(
198
+ model_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
199
+ inducing_batch_shape=(torch.Size([3, 1]) + self.batch_shape),
200
+ data_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
201
+ eval_data_batch_shape=(torch.Size([4]) + self.batch_shape),
202
+ expected_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
203
+ )
204
+
205
+ def test_eval_larger_pred_batch(self):
206
+ return self.test_eval_iteration(
207
+ model_batch_shape=(torch.Size([4]) + self.batch_shape),
208
+ inducing_batch_shape=(self.batch_shape),
209
+ data_batch_shape=(torch.Size([4]) + self.batch_shape),
210
+ eval_data_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
211
+ expected_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
212
+ )
213
+
214
+ def test_training_iteration(
215
+ self,
216
+ data_batch_shape=None,
217
+ inducing_batch_shape=None,
218
+ model_batch_shape=None,
219
+ expected_batch_shape=None,
220
+ constant_mean=True,
221
+ ):
222
+ # Batch shapes
223
+ model_batch_shape = model_batch_shape if model_batch_shape is not None else self.batch_shape
224
+ data_batch_shape = data_batch_shape if data_batch_shape is not None else self.batch_shape
225
+ inducing_batch_shape = inducing_batch_shape if inducing_batch_shape is not None else self.batch_shape
226
+ expected_batch_shape = expected_batch_shape if expected_batch_shape is not None else self.batch_shape
227
+
228
+ # Mocks
229
+ _wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex)
230
+ _wrapped_cg = MagicMock(wraps=linear_operator.utils.linear_cg)
231
+ _wrapped_ciq = MagicMock(wraps=linear_operator.utils.contour_integral_quad)
232
+ _cholesky_mock = patch("torch.linalg.cholesky_ex", new=_wrapped_cholesky)
233
+ _cg_mock = patch("linear_operator.utils.linear_cg", new=_wrapped_cg)
234
+ _ciq_mock = patch("linear_operator.utils.contour_integral_quad", new=_wrapped_ciq)
235
+
236
+ # Make model and likelihood
237
+ model, likelihood = self._make_model_and_likelihood(
238
+ batch_shape=model_batch_shape,
239
+ inducing_batch_shape=inducing_batch_shape,
240
+ distribution_cls=self.distribution_cls,
241
+ strategy_cls=self.strategy_cls,
242
+ constant_mean=constant_mean,
243
+ )
244
+
245
+ # Do forward pass
246
+ with _cholesky_mock as cholesky_mock, _cg_mock as cg_mock, _ciq_mock as ciq_mock:
247
+ # Iter 1
248
+ self.assertEqual(model.variational_strategy.variational_params_initialized.item(), 0)
249
+ self._training_iter(
250
+ model,
251
+ likelihood,
252
+ data_batch_shape,
253
+ mll_cls=self.mll_cls,
254
+ cuda=self.cuda,
255
+ )
256
+ self.assertEqual(model.variational_strategy.variational_params_initialized.item(), 1)
257
+ # Iter 2
258
+ output, loss = self._training_iter(
259
+ model,
260
+ likelihood,
261
+ data_batch_shape,
262
+ mll_cls=self.mll_cls,
263
+ cuda=self.cuda,
264
+ )
265
+ self.assertEqual(output.batch_shape, expected_batch_shape)
266
+ self.assertEqual(output.event_shape, self.event_shape)
267
+ self.assertEqual(loss.shape, expected_batch_shape)
268
+ return cg_mock, cholesky_mock, ciq_mock
269
+
270
+ def test_training_iteration_batch_inducing(self):
271
+ return self.test_training_iteration(
272
+ model_batch_shape=(torch.Size([3]) + self.batch_shape),
273
+ data_batch_shape=self.batch_shape,
274
+ inducing_batch_shape=(torch.Size([3]) + self.batch_shape),
275
+ expected_batch_shape=(torch.Size([3]) + self.batch_shape),
276
+ )
277
+
278
+ def test_training_iteration_batch_data(self):
279
+ return self.test_training_iteration(
280
+ model_batch_shape=self.batch_shape,
281
+ inducing_batch_shape=self.batch_shape,
282
+ data_batch_shape=(torch.Size([3]) + self.batch_shape),
283
+ expected_batch_shape=(torch.Size([3]) + self.batch_shape),
284
+ )
285
+
286
+ def test_training_iteration_batch_model(self):
287
+ return self.test_training_iteration(
288
+ model_batch_shape=(torch.Size([3]) + self.batch_shape),
289
+ inducing_batch_shape=self.batch_shape,
290
+ data_batch_shape=self.batch_shape,
291
+ expected_batch_shape=(torch.Size([3]) + self.batch_shape),
292
+ )
293
+
294
+ def test_training_all_batch_zero_mean(self):
295
+ return self.test_training_iteration(
296
+ model_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
297
+ inducing_batch_shape=(torch.Size([3, 1]) + self.batch_shape),
298
+ data_batch_shape=(torch.Size([4]) + self.batch_shape),
299
+ expected_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
300
+ constant_mean=False,
301
+ )
302
+
303
+ def test_fantasy_call(
304
+ self,
305
+ data_batch_shape=None,
306
+ inducing_batch_shape=None,
307
+ model_batch_shape=None,
308
+ expected_batch_shape=None,
309
+ constant_mean=True,
310
+ ):
311
+ # Batch shapes
312
+ model_batch_shape = model_batch_shape if model_batch_shape is not None else self.batch_shape
313
+ data_batch_shape = data_batch_shape if data_batch_shape is not None else self.batch_shape
314
+ inducing_batch_shape = inducing_batch_shape if inducing_batch_shape is not None else self.batch_shape
315
+ expected_batch_shape = expected_batch_shape if expected_batch_shape is not None else self.batch_shape
316
+
317
+ num_inducing = 16
318
+ num_fant = 10
319
+ # Make model and likelihood
320
+ model, likelihood = self._make_model_and_likelihood(
321
+ batch_shape=model_batch_shape,
322
+ inducing_batch_shape=inducing_batch_shape,
323
+ distribution_cls=self.distribution_cls,
324
+ strategy_cls=self.strategy_cls,
325
+ constant_mean=constant_mean,
326
+ num_inducing=num_inducing,
327
+ )
328
+
329
+ # we iterate through the covar and mean module possible settings
330
+ covar_mean_options = [
331
+ {"covar_module": None, "mean_module": None},
332
+ {"covar_module": qpytorch.kernels.MaternKernel(), "mean_module": qpytorch.means.ZeroMean()},
333
+ ]
334
+ for cm_dict in covar_mean_options:
335
+ fant_model = self._fantasy_iter(
336
+ model, likelihood, data_batch_shape, self.cuda, num_fant=num_fant, **cm_dict
337
+ )
338
+ self.assertTrue(isinstance(fant_model, {2.0: qpytorch.models.ExactGP, 1.0: qpytorch.models.ExactQEP}[self._power]))
339
+
340
+ # we check to ensure setting the covar_module and mean_modules are okay
341
+ if cm_dict["covar_module"] is None:
342
+ self.assertEqual(type(fant_model.covar_module), type(model.covar_module))
343
+ else:
344
+ self.assertNotEqual(type(fant_model.covar_module), type(model.covar_module))
345
+ if cm_dict["mean_module"] is None:
346
+ self.assertEqual(type(fant_model.mean_module), type(model.mean_module))
347
+ else:
348
+ self.assertNotEqual(type(fant_model.mean_module), type(model.mean_module))
349
+
350
+ # now we check to ensure the shapes of the fantasy strategy are correct
351
+ self.assertTrue(fant_model.prediction_strategy is not None)
352
+ for key in fant_model.prediction_strategy._memoize_cache.keys():
353
+ if key[0] == "mean_cache":
354
+ break
355
+ mean_cache = fant_model.prediction_strategy._memoize_cache[key]
356
+ self.assertEqual(mean_cache.shape, torch.Size([*expected_batch_shape, num_inducing + num_fant]))
357
+
358
+ # we remove the mean_module and covar_module and check for errors
359
+ del model.mean_module
360
+ with self.assertRaises(ModuleNotFoundError):
361
+ self._fantasy_iter(model, likelihood, data_batch_shape, self.cuda, num_fant=num_fant)
362
+
363
+ model.mean_module = qpytorch.means.ZeroMean()
364
+ del model.covar_module
365
+ with self.assertRaises(ModuleNotFoundError):
366
+ self._fantasy_iter(model, likelihood, data_batch_shape, self.cuda, num_fant=num_fant)
367
+
368
+ # finally we check to ensure failure for a non-gaussian likelihood
369
+ with self.assertRaises(NotImplementedError):
370
+ self._fantasy_iter(
371
+ model,
372
+ qpytorch.likelihoods.BernoulliLikelihood(),
373
+ data_batch_shape,
374
+ self.cuda,
375
+ num_fant=num_fant,
376
+ )
377
+
378
+ def test_fantasy_call_batch_inducing(self):
379
+ return self.test_fantasy_call(
380
+ model_batch_shape=(torch.Size([3]) + self.batch_shape),
381
+ data_batch_shape=self.batch_shape,
382
+ inducing_batch_shape=(torch.Size([3]) + self.batch_shape),
383
+ expected_batch_shape=(torch.Size([3]) + self.batch_shape),
384
+ )
385
+
386
+ def test_fantasy_call_batch_data(self):
387
+ return self.test_fantasy_call(
388
+ model_batch_shape=self.batch_shape,
389
+ inducing_batch_shape=self.batch_shape,
390
+ data_batch_shape=(torch.Size([3]) + self.batch_shape),
391
+ expected_batch_shape=(torch.Size([3]) + self.batch_shape),
392
+ )
393
+
394
+ def test_fantasy_call_batch_model(self):
395
+ return self.test_fantasy_call(
396
+ model_batch_shape=(torch.Size([3]) + self.batch_shape),
397
+ inducing_batch_shape=self.batch_shape,
398
+ data_batch_shape=self.batch_shape,
399
+ expected_batch_shape=(torch.Size([3]) + self.batch_shape),
400
+ )
@@ -0,0 +1,38 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from __future__ import annotations
4
+
5
+ import warnings as _warnings
6
+ from typing import Any
7
+
8
+ import linear_operator
9
+
10
+ from gpytorch.utils import deprecation, errors, generic, grid, interpolation, quadrature, transforms
11
+ from . import warnings
12
+ from gpytorch.utils.memoize import cached
13
+ from gpytorch.utils.nearest_neighbors import NNUtil
14
+ from gpytorch.utils.sum_interaction_terms import sum_interaction_terms
15
+
16
+ __all__ = [
17
+ "cached",
18
+ "deprecation",
19
+ "errors",
20
+ "generic",
21
+ "grid",
22
+ "interpolation",
23
+ "quadrature",
24
+ "sum_interaction_terms",
25
+ "transforms",
26
+ "warnings",
27
+ "NNUtil",
28
+ ]
29
+
30
+
31
+ def __getattr__(name: str) -> Any:
32
+ if hasattr(linear_operator.utils, name):
33
+ _warnings.warn(
34
+ f"gpytorch.utils.{name} is deprecated. Use linear_operator.utils.{name} instead.",
35
+ DeprecationWarning,
36
+ )
37
+ return getattr(linear_operator.utils, name)
38
+ raise AttributeError(f"module gpytorch.utils has no attribute {name}")
@@ -0,0 +1,37 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from linear_operator.utils.warnings import NumericalWarning
4
+
5
+
6
+ class GPInputWarning(UserWarning):
7
+ """
8
+ Warning thrown when a GP model receives an unexpected input.
9
+ For example, when an :obj:`~gpytorch.models.ExactGP` in eval mode receives the training data as input.
10
+ """
11
+
12
+ pass
13
+
14
+
15
+ class QEPInputWarning(UserWarning):
16
+ """
17
+ Warning thrown when a QEP model receives an unexpected input.
18
+ For example, when an :obj:`~qpytorch.models.ExactQEP` in eval mode receives the training data as input.
19
+ """
20
+
21
+ pass
22
+
23
+
24
+ class OldVersionWarning(UserWarning):
25
+ """
26
+ Warning thrown when loading a saved model from an outdated version of GPyTorch/QPyTorch.
27
+ """
28
+
29
+ pass
30
+
31
+
32
+ __all__ = [
33
+ "GPInputWarning",
34
+ "QEPInputWarning",
35
+ "OldVersionWarning",
36
+ "NumericalWarning",
37
+ ]
@@ -0,0 +1,47 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from ._variational_distribution import _VariationalDistribution
4
+ from ._variational_strategy import _VariationalStrategy
5
+ from .additive_grid_interpolation_variational_strategy import AdditiveGridInterpolationVariationalStrategy
6
+ from .batch_decoupled_variational_strategy import BatchDecoupledVariationalStrategy
7
+ from .cholesky_variational_distribution import CholeskyVariationalDistribution
8
+ from .ciq_variational_strategy import CiqVariationalStrategy
9
+ from .delta_variational_distribution import DeltaVariationalDistribution
10
+ from .grid_interpolation_variational_strategy import GridInterpolationVariationalStrategy
11
+ from .independent_multitask_variational_strategy import (
12
+ IndependentMultitaskVariationalStrategy,
13
+ # MultitaskVariationalStrategy,
14
+ )
15
+ from .lmc_variational_strategy import LMCVariationalStrategy
16
+ from .multitask_variational_strategy import MultitaskVariationalStrategy
17
+ from .mean_field_variational_distribution import MeanFieldVariationalDistribution
18
+ from .natural_variational_distribution import _NaturalVariationalDistribution, NaturalVariationalDistribution
19
+ from .nearest_neighbor_variational_strategy import NNVariationalStrategy
20
+ from .orthogonally_decoupled_variational_strategy import OrthogonallyDecoupledVariationalStrategy
21
+ from .tril_natural_variational_distribution import TrilNaturalVariationalDistribution
22
+ from .uncorrelated_multitask_variational_strategy import UncorrelatedMultitaskVariationalStrategy
23
+ from .unwhitened_variational_strategy import UnwhitenedVariationalStrategy
24
+ from .variational_strategy import VariationalStrategy
25
+
26
+ __all__ = [
27
+ "_VariationalStrategy",
28
+ "AdditiveGridInterpolationVariationalStrategy",
29
+ "BatchDecoupledVariationalStrategy",
30
+ "CiqVariationalStrategy",
31
+ "GridInterpolationVariationalStrategy",
32
+ "IndependentMultitaskVariationalStrategy",
33
+ "LMCVariationalStrategy",
34
+ "MultitaskVariationalStrategy",
35
+ "OrthogonallyDecoupledVariationalStrategy",
36
+ "VariationalStrategy",
37
+ "UnwhitenedVariationalStrategy",
38
+ "UncorrelatedMultitaskVariationalStrategy",
39
+ "_VariationalDistribution",
40
+ "CholeskyVariationalDistribution",
41
+ "MeanFieldVariationalDistribution",
42
+ "DeltaVariationalDistribution",
43
+ "_NaturalVariationalDistribution",
44
+ "NaturalVariationalDistribution",
45
+ "TrilNaturalVariationalDistribution",
46
+ "NNVariationalStrategy",
47
+ ]
@@ -0,0 +1,61 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Union
4
+
5
+ from abc import ABC, abstractmethod
6
+
7
+ import torch
8
+
9
+ from ..distributions import Distribution, MultivariateNormal, MultivariateQExponential
10
+ from ..module import Module
11
+
12
+
13
+ class _VariationalDistribution(Module, ABC):
14
+ r"""
15
+ Abstract base class for all Variational Distributions.
16
+
17
+ :ivar torch.dtype dtype: The dtype of the VariationalDistribution parameters
18
+ :ivar torch.dtype device: The device of the VariationalDistribution parameters
19
+ """
20
+
21
+ def __init__(self, num_inducing_points: int, batch_shape: torch.Size = torch.Size([]), mean_init_std: float = 1e-3):
22
+ super().__init__()
23
+ self.num_inducing_points = num_inducing_points
24
+ self.batch_shape = batch_shape
25
+ self.mean_init_std = mean_init_std
26
+
27
+ @property
28
+ def device(self) -> torch.device:
29
+ return next(self.parameters()).device
30
+
31
+ @property
32
+ def dtype(self) -> torch.dtype:
33
+ return next(self.parameters()).dtype
34
+
35
+ def forward(self) -> Distribution:
36
+ r"""
37
+ Constructs and returns the variational distribution
38
+
39
+ :rtype: ~gpytorch.distributions.MultivariateNormal or ~qpytorch.distributions.MultivariateQExponential
40
+ :return: The distribution :math:`q(\mathbf u)`
41
+ """
42
+ raise NotImplementedError
43
+
44
+ def shape(self) -> torch.Size:
45
+ r"""
46
+ Event + batch shape of VariationalDistribution object
47
+ :rtype: torch.Size
48
+ """
49
+ return torch.Size([*self.batch_shape, self.num_inducing_points])
50
+
51
+ @abstractmethod
52
+ def initialize_variational_distribution(self, prior_dist: Union[MultivariateNormal, MultivariateQExponential]) -> None:
53
+ r"""
54
+ Method for initializing the variational distribution, based on the prior distribution.
55
+
56
+ :param ~gpytorch.distributions.Distribution prior_dist: The prior distribution :math:`p(\mathbf u)`.
57
+ """
58
+ raise NotImplementedError
59
+
60
+ def __call__(self) -> Distribution:
61
+ return self.forward()