qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,349 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from __future__ import annotations
4
+
5
+ import warnings
6
+
7
+ from collections.abc import Iterable
8
+ from copy import deepcopy
9
+
10
+ import torch
11
+ from torch import Tensor
12
+
13
+ from .. import settings
14
+ from ..distributions import MultitaskMultivariateQExponential, MultivariateQExponential
15
+ from ..likelihoods import _QExponentialLikelihoodBase
16
+ from gpytorch.utils.generic import length_safe_zip
17
+ from ..utils.warnings import QEPInputWarning
18
+ from .exact_prediction_strategies import prediction_strategy
19
+ from .qep import QEP
20
+
21
+
22
+ class ExactQEP(QEP):
23
+ r"""
24
+ The base class for any Q-Exponential process latent function to be used in conjunction
25
+ with exact inference.
26
+
27
+ :param torch.Tensor train_inputs: (size n x d) The training features :math:`\mathbf X`.
28
+ :param torch.Tensor train_targets: (size n) The training targets :math:`\mathbf y`.
29
+ :param ~qpytorch.likelihoods.QExponentialLikelihood likelihood: The Q-Exponential likelihood that defines
30
+ the observational distribution. Since we're using exact inference, the likelihood must be Q-Exponential.
31
+
32
+ The :meth:`forward` function should describe how to compute the prior latent distribution
33
+ on a given input. Typically, this will involve a mean and kernel function.
34
+ The result must be a :obj:`~qpytorch.distributions.MultivariateQExponential`.
35
+
36
+ Calling this model will return the posterior of the latent Q-Exponential process when conditioned
37
+ on the training data. The output will be a :obj:`~qpytorch.distributions.MultivariateQExponential`.
38
+
39
+ Example:
40
+ >>> class MyQEP(qpytorch.models.ExactQEP):
41
+ >>> def __init__(self, train_x, train_y, likelihood):
42
+ >>> super().__init__(train_x, train_y, likelihood)
43
+ >>> self.mean_module = qpytorch.means.ZeroMean()
44
+ >>> self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())
45
+ >>>
46
+ >>> def forward(self, x):
47
+ >>> mean = self.mean_module(x)
48
+ >>> covar = self.covar_module(x)
49
+ >>> return qpytorch.distributions.MultivariateQExponential(mean, covar, self.likelihood.power)
50
+ >>>
51
+ >>> # train_x = ...; train_y = ...
52
+ >>> likelihood = qpytorch.likelihoods.QExponentialLikelihood(power=torch.tensor(1.0))
53
+ >>> model = MyQEP(train_x, train_y, likelihood)
54
+ >>>
55
+ >>> # test_x = ...;
56
+ >>> model(test_x) # Returns the QEP latent function at test_x
57
+ >>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ train_inputs: Tensor | Iterable[Tensor] | None,
63
+ train_targets: Tensor | None,
64
+ likelihood: _QExponentialLikelihoodBase,
65
+ ):
66
+ if train_inputs is not None and isinstance(train_inputs, Tensor):
67
+ train_inputs = (train_inputs,)
68
+ if train_inputs is not None and not all(isinstance(train_input, Tensor) for train_input in train_inputs):
69
+ raise RuntimeError("Train inputs must be a tensor, or a list/tuple of tensors")
70
+ if not isinstance(likelihood, _QExponentialLikelihoodBase):
71
+ raise RuntimeError("ExactQEP can only handle Q-Exponential likelihoods")
72
+
73
+ super().__init__()
74
+ if train_inputs is not None:
75
+ self.train_inputs = tuple(tri.unsqueeze(-1) if tri.ndimension() == 1 else tri for tri in train_inputs)
76
+ self.train_targets = train_targets
77
+ else:
78
+ self.train_inputs = None
79
+ self.train_targets = None
80
+ self.likelihood = likelihood
81
+
82
+ self.prediction_strategy = None
83
+
84
+ @property
85
+ def train_targets(self) -> tuple[Tensor] | None:
86
+ return self._train_targets
87
+
88
+ @train_targets.setter
89
+ def train_targets(self, value: Tensor | None) -> None:
90
+ object.__setattr__(self, "_train_targets", value)
91
+
92
+ def _apply(self, fn):
93
+ if self.train_inputs is not None:
94
+ self.train_inputs = tuple(fn(train_input) for train_input in self.train_inputs)
95
+ self.train_targets = fn(self.train_targets)
96
+ return super()._apply(fn)
97
+
98
+ def _clear_cache(self) -> None:
99
+ # The precomputed caches from test time live in prediction_strategy
100
+ self.prediction_strategy = None
101
+
102
+ def local_load_samples(self, samples_dict, memo, prefix):
103
+ """
104
+ Replace the model's learned hyperparameters with samples from a posterior distribution.
105
+ """
106
+ # Pyro always puts the samples in the first batch dimension
107
+ num_samples = next(iter(samples_dict.values())).size(0)
108
+ self.train_inputs = tuple(tri.unsqueeze(0).expand(num_samples, *tri.shape) for tri in self.train_inputs)
109
+ self.train_targets = self.train_targets.unsqueeze(0).expand(num_samples, *self.train_targets.shape)
110
+ super().local_load_samples(samples_dict, memo, prefix)
111
+
112
+ def set_train_data(
113
+ self, inputs: Tensor | Iterable[Tensor] | None = None, targets: Tensor | None = None, strict: bool = True
114
+ ) -> None:
115
+ """
116
+ Set training data (does not re-fit model hyper-parameters).
117
+
118
+ :param inputs: The new training inputs.
119
+ :param targets: The new training targets.
120
+ :param strict: If `True`, the new inputs and targets must have the same shape,
121
+ dtype, and device as the current inputs and targets. Otherwise, any
122
+ shape/dtype/device are allowed.
123
+ """
124
+ if inputs is not None:
125
+ if isinstance(inputs, Tensor):
126
+ inputs = (inputs,)
127
+ inputs = tuple(input_.unsqueeze(-1) if input_.ndimension() == 1 else input_ for input_ in inputs)
128
+ if strict:
129
+ for input_, t_input in length_safe_zip(inputs, self.train_inputs or (None,)):
130
+ for attr in {"shape", "dtype", "device"}:
131
+ expected_attr = getattr(t_input, attr, None)
132
+ found_attr = getattr(input_, attr, None)
133
+ if expected_attr != found_attr:
134
+ msg = "Cannot modify {attr} of inputs (expected {e_attr}, found {f_attr})."
135
+ msg = msg.format(attr=attr, e_attr=expected_attr, f_attr=found_attr)
136
+ raise RuntimeError(msg)
137
+ self.train_inputs = inputs
138
+ if targets is not None:
139
+ if strict:
140
+ for attr in {"shape", "dtype", "device"}:
141
+ expected_attr = getattr(self.train_targets, attr, None)
142
+ found_attr = getattr(targets, attr, None)
143
+ if expected_attr != found_attr:
144
+ msg = "Cannot modify {attr} of targets (expected {e_attr}, found {f_attr})."
145
+ msg = msg.format(attr=attr, e_attr=expected_attr, f_attr=found_attr)
146
+ raise RuntimeError(msg)
147
+ self.train_targets = targets
148
+ self.prediction_strategy = None
149
+
150
+ def get_fantasy_model(self, inputs, targets, **kwargs):
151
+ """
152
+ Returns a new QEP model that incorporates the specified inputs and targets as new training data.
153
+
154
+ Using this method is more efficient than updating with `set_train_data` when the number of inputs is relatively
155
+ small, because any computed test-time caches will be updated in linear time rather than computed from scratch.
156
+
157
+ .. note::
158
+ If `targets` is a batch (e.g. `b x m`), then the QEP returned from this method will be a batch mode QEP.
159
+ If `inputs` is of the same (or lesser) dimension as `targets`, then it is assumed that the fantasy points
160
+ are the same for each target batch.
161
+
162
+ :param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
163
+ observations.
164
+ :param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
165
+ :return: An `ExactQEP` model with `n + m` training examples, where the `m` fantasy examples have been added
166
+ and all test-time caches have been updated.
167
+ :rtype: ~qpytorch.models.ExactEQP
168
+ """
169
+ if self.prediction_strategy is None:
170
+ raise RuntimeError(
171
+ "Fantasy observations can only be added after making predictions with a model so that "
172
+ "all test independent caches exist. Call the model on some data first!"
173
+ )
174
+
175
+ model_batch_shape = self.train_inputs[0].shape[:-2]
176
+
177
+ if not isinstance(inputs, list):
178
+ inputs = [inputs]
179
+
180
+ inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in inputs]
181
+
182
+ if not isinstance(self.prediction_strategy.train_prior_dist, MultitaskMultivariateQExponential):
183
+ data_dim_start = -1
184
+ else:
185
+ data_dim_start = -2
186
+
187
+ target_batch_shape = targets.shape[:data_dim_start]
188
+ input_batch_shape = inputs[0].shape[:-2]
189
+ tbdim, ibdim = len(target_batch_shape), len(input_batch_shape)
190
+
191
+ if not (tbdim == ibdim + 1 or tbdim == ibdim):
192
+ raise RuntimeError(
193
+ f"Unsupported batch shapes: The target batch shape ({target_batch_shape}) must have either the "
194
+ f"same dimension as or one more dimension than the input batch shape ({input_batch_shape})"
195
+ )
196
+
197
+ # Check whether we can properly broadcast batch dimensions
198
+ try:
199
+ torch.broadcast_shapes(model_batch_shape, target_batch_shape)
200
+ except RuntimeError:
201
+ raise RuntimeError(
202
+ f"Model batch shape ({model_batch_shape}) and target batch shape "
203
+ f"({target_batch_shape}) are not broadcastable."
204
+ )
205
+
206
+ if len(model_batch_shape) > len(input_batch_shape):
207
+ input_batch_shape = model_batch_shape
208
+ if len(model_batch_shape) > len(target_batch_shape):
209
+ target_batch_shape = model_batch_shape
210
+
211
+ # If input has no fantasy batch dimension but target does, we can save memory and computation by not
212
+ # computing the covariance for each element of the batch. Therefore we don't expand the inputs to the
213
+ # size of the fantasy model here - this is done below, after the evaluation and fast fantasy update
214
+ train_inputs = [tin.expand(input_batch_shape + tin.shape[-2:]) for tin in self.train_inputs]
215
+ train_targets = self.train_targets.expand(target_batch_shape + self.train_targets.shape[data_dim_start:])
216
+
217
+ full_inputs = [
218
+ torch.cat(
219
+ [train_input, input.expand(input_batch_shape + input.shape[-2:])],
220
+ dim=-2,
221
+ )
222
+ for train_input, input in length_safe_zip(train_inputs, inputs)
223
+ ]
224
+ full_targets = torch.cat(
225
+ [train_targets, targets.expand(target_batch_shape + targets.shape[data_dim_start:])], dim=data_dim_start
226
+ )
227
+
228
+ try:
229
+ fantasy_kwargs = {"noise": kwargs.pop("noise")}
230
+ except KeyError:
231
+ fantasy_kwargs = {}
232
+
233
+ full_output = super().__call__(*full_inputs, **kwargs)
234
+
235
+ # Copy model without copying training data or prediction strategy (since we'll overwrite those)
236
+ old_pred_strat = self.prediction_strategy
237
+ old_train_inputs = self.train_inputs
238
+ old_train_targets = self.train_targets
239
+ old_likelihood = self.likelihood
240
+ self.prediction_strategy = None
241
+ self.train_inputs = None
242
+ self.train_targets = None
243
+ self.likelihood = None
244
+ new_model = deepcopy(self)
245
+ self.prediction_strategy = old_pred_strat
246
+ self.train_inputs = old_train_inputs
247
+ self.train_targets = old_train_targets
248
+ self.likelihood = old_likelihood
249
+
250
+ new_model.likelihood = old_likelihood.get_fantasy_likelihood(**fantasy_kwargs)
251
+ new_model.prediction_strategy = old_pred_strat.get_fantasy_strategy(
252
+ inputs, targets, full_inputs, full_targets, full_output, **fantasy_kwargs
253
+ )
254
+
255
+ # if the fantasies are at the same points, we need to expand the inputs for the new model
256
+ if tbdim == ibdim + 1:
257
+ new_model.train_inputs = [fi.expand(target_batch_shape + fi.shape[-2:]) for fi in full_inputs]
258
+ else:
259
+ new_model.train_inputs = full_inputs
260
+ new_model.train_targets = full_targets
261
+
262
+ return new_model
263
+
264
+ def __call__(self, *args, **kwargs):
265
+ train_inputs = list(self.train_inputs) if self.train_inputs is not None else []
266
+ inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in args]
267
+
268
+ # Training mode: optimizing
269
+ if self.training:
270
+ if self.train_inputs is None:
271
+ raise RuntimeError(
272
+ "train_inputs cannot be None in training mode. "
273
+ "Call .eval() for prior predictions, or call .set_train_data() to add training data."
274
+ )
275
+ if settings.debug.on():
276
+ if not all(
277
+ torch.equal(train_input, input) for train_input, input in length_safe_zip(train_inputs, inputs)
278
+ ):
279
+ raise RuntimeError("You must train on the training inputs!")
280
+ res = super().__call__(*inputs, **kwargs)
281
+ return res
282
+
283
+ # Prior mode
284
+ elif settings.prior_mode.on() or self.train_inputs is None or self.train_targets is None:
285
+ full_inputs = args
286
+ full_output = super().__call__(*full_inputs, **kwargs)
287
+ if settings.debug().on():
288
+ if not isinstance(full_output, MultivariateQExponential):
289
+ raise RuntimeError("ExactQEP.forward must return a MultivariateQExponential")
290
+ return full_output
291
+
292
+ # Posterior mode
293
+ else:
294
+ if settings.debug.on():
295
+ if all(torch.equal(train_input, input) for train_input, input in length_safe_zip(train_inputs, inputs)):
296
+ warnings.warn(
297
+ "The input matches the stored training data. Did you forget to call model.train()?",
298
+ QEPInputWarning,
299
+ )
300
+
301
+ # Get the terms that only depend on training data
302
+ if self.prediction_strategy is None:
303
+ train_output = super().__call__(*train_inputs, **kwargs)
304
+
305
+ # Create the prediction strategy for
306
+ self.prediction_strategy = prediction_strategy(
307
+ train_inputs=train_inputs,
308
+ train_prior_dist=train_output,
309
+ train_labels=self.train_targets,
310
+ likelihood=self.likelihood,
311
+ )
312
+
313
+ # Concatenate the input to the training input
314
+ full_inputs = []
315
+ batch_shape = train_inputs[0].shape[:-2]
316
+ for train_input, input in length_safe_zip(train_inputs, inputs):
317
+ # Make sure the batch shapes agree for training/test data
318
+ if batch_shape != train_input.shape[:-2]:
319
+ batch_shape = torch.broadcast_shapes(batch_shape, train_input.shape[:-2])
320
+ train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
321
+ if batch_shape != input.shape[:-2]:
322
+ batch_shape = torch.broadcast_shapes(batch_shape, input.shape[:-2])
323
+ train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
324
+ input = input.expand(*batch_shape, *input.shape[-2:])
325
+ full_inputs.append(torch.cat([train_input, input], dim=-2))
326
+
327
+ # Get the joint distribution for training/test data
328
+ full_output = super().__call__(*full_inputs, **kwargs)
329
+ if settings.debug().on():
330
+ if not isinstance(full_output, MultivariateQExponential):
331
+ raise RuntimeError("ExactQEP.forward must return a MultivariateQExponential")
332
+ full_mean, full_covar = full_output.loc, full_output.lazy_covariance_matrix
333
+
334
+ # Determine the shape of the joint distribution
335
+ batch_shape = full_output.batch_shape
336
+ joint_shape = full_output.event_shape
337
+ tasks_shape = joint_shape[1:] # For multitask learning
338
+ test_shape = torch.Size([joint_shape[0] - self.prediction_strategy.train_shape[0], *tasks_shape])
339
+
340
+ # Make the prediction
341
+ with settings.cg_tolerance(settings.eval_cg_tolerance.value()):
342
+ (
343
+ predictive_mean,
344
+ predictive_covar,
345
+ ) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
346
+
347
+ # Reshape predictive mean to match the appropriate event shape
348
+ predictive_mean = predictive_mean.view(*batch_shape, *test_shape).contiguous()
349
+ return full_output.__class__(predictive_mean, predictive_covar, power=full_output.power)
@@ -0,0 +1,100 @@
1
+ #! /usr/bin/env python3
2
+
3
+ from abc import ABC
4
+
5
+ import torch
6
+ from torch.nn import ModuleList
7
+
8
+ from ..likelihoods import LikelihoodList
9
+ from . import GP, QEP
10
+ from gpytorch.utils.generic import length_safe_zip
11
+
12
+
13
+ class AbstractModelList(GP, QEP, ABC):
14
+ def forward_i(self, i, *args, **kwargs):
15
+ """Forward restricted to the i-th model only."""
16
+ raise NotImplementedError
17
+
18
+ def likelihood_i(self, i, *args, **kwargs):
19
+ """Evaluate likelihood of the i-th model only."""
20
+ raise NotImplementedError
21
+
22
+
23
+ class IndependentModelList(AbstractModelList):
24
+ def __init__(self, *models):
25
+ super().__init__()
26
+ self.models = ModuleList(models)
27
+ for m in models:
28
+ if not hasattr(m, "likelihood"):
29
+ raise ValueError(
30
+ "IndependentModelList currently only supports models that have a likelihood (e.g. ExactGPs)"
31
+ )
32
+ self.likelihood = LikelihoodList(*[m.likelihood for m in models])
33
+
34
+ def forward_i(self, i, *args, **kwargs):
35
+ return self.models[i].forward(*args, **kwargs)
36
+
37
+ def likelihood_i(self, i, *args, **kwargs):
38
+ return self.likelihood.likelihoods[i](*args, **kwargs)
39
+
40
+ def forward(self, *args, **kwargs):
41
+ return [
42
+ model.forward(*args_, **kwargs) for model, args_ in length_safe_zip(self.models, _get_tensor_args(*args))
43
+ ]
44
+
45
+ def get_fantasy_model(self, inputs, targets, **kwargs):
46
+ """
47
+ Returns a new GP (QEP) model that incorporates the specified inputs and targets as new training data.
48
+
49
+ This is a simple wrapper that creates fantasy models for each of the models in the model list,
50
+ and returns the same class of fantasy models.
51
+
52
+ Args:
53
+ inputs: List of locations of fantasy observations, one for each model.
54
+ targets List of labels of fantasy observations, one for each model.
55
+
56
+ Returns:
57
+ An `IndependentModelList` model, where each sub-model is the fantasy model of the respective
58
+ sub-model in the original model at the corresponding input locations / labels.
59
+ """
60
+
61
+ if "noise" in kwargs:
62
+ noise = kwargs.pop("noise")
63
+ kwargs = [{**kwargs, "noise": noise_} if noise_ is not None else kwargs for noise_ in noise]
64
+ else:
65
+ kwargs = [kwargs] * len(inputs)
66
+
67
+ fantasy_models = [
68
+ model.get_fantasy_model(*inputs_, *targets_, **kwargs_)
69
+ for model, inputs_, targets_, kwargs_ in length_safe_zip(
70
+ self.models,
71
+ _get_tensor_args(*inputs),
72
+ _get_tensor_args(*targets),
73
+ kwargs,
74
+ )
75
+ ]
76
+ return self.__class__(*fantasy_models)
77
+
78
+ def __call__(self, *args, **kwargs):
79
+ return [
80
+ model.__call__(*args_, **kwargs) for model, args_ in length_safe_zip(self.models, _get_tensor_args(*args))
81
+ ]
82
+
83
+ @property
84
+ def train_inputs(self):
85
+ return [model.train_inputs for model in self.models]
86
+
87
+ @property
88
+ def train_targets(self):
89
+ return [model.train_targets for model in self.models]
90
+
91
+
92
+ def _get_tensor_args(*args):
93
+ for arg in args:
94
+ if torch.is_tensor(arg):
95
+ yield (arg,)
96
+ else:
97
+ yield arg
98
+
99
+ class UncorrelatedModelList(IndependentModelList):
100
+ pass
@@ -0,0 +1,28 @@
1
+ #!/usr/bin/env python3
2
+
3
+ try:
4
+ from ._pyro_mixin import _PyroMixin
5
+ from gpytorch.models.pyro.pyro_gp import PyroGP
6
+ from .pyro_qep import PyroQEP
7
+ except ImportError:
8
+
9
+ class PyroGP(object):
10
+ def __init__(self, *args, **kwargs):
11
+ raise RuntimeError("Cannot use a PyroGP because you dont have Pyro installed.")
12
+
13
+ class PyroQEP(object):
14
+ def __init__(self, *args, **kwargs):
15
+ raise RuntimeError("Cannot use a PyroQEP because you dont have Pyro installed.")
16
+
17
+ class _PyroMixin(object):
18
+ def pyro_factors(self, *args, **kwargs):
19
+ raise RuntimeError("Cannot call `pyro_factors` because you dont have Pyro installed.")
20
+
21
+ def pyro_guide(self, *args, **kwargs):
22
+ raise RuntimeError("Cannot call `pyro_sample` because you dont have Pyro installed.")
23
+
24
+ def pyro_model(self, *args, **kwargs):
25
+ raise RuntimeError("Cannot call `pyro_sample` because you dont have Pyro installed.")
26
+
27
+
28
+ __all__ = ["PyroGP", "_PyroMixin", "PyroQEP"]
@@ -0,0 +1,57 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import pyro
4
+ import torch
5
+
6
+ # from .distributions import QExponential
7
+
8
+ class _PyroMixin(object):
9
+ def pyro_guide(self, input, beta=1.0, name_prefix=""):
10
+ # Inducing values q(u)
11
+ with pyro.poutine.scale(scale=beta):
12
+ variational_distribution = self.variational_strategy.variational_distribution
13
+ variational_distribution = variational_distribution.to_event(len(variational_distribution.batch_shape))
14
+ pyro.sample(name_prefix + ".u", variational_distribution)
15
+
16
+ # Draw samples from q(f)
17
+ function_dist = self(input, prior=False)
18
+ if 'Normal' in input.__class__.__name__:
19
+ function_dist = pyro.distributions.Normal(loc=function_dist.mean, scale=function_dist.stddev).to_event(
20
+ len(function_dist.event_shape) - 1
21
+ )
22
+ elif 'QExponential' in input.__class__.__name__:
23
+ function_dist = pyro.distributions.QExponential(loc=function_dist.mean, scale=function_dist.stddev, power=function_dist.power).to_event(
24
+ len(function_dist.event_shape) - 1
25
+ )
26
+ return function_dist.mask(False)
27
+
28
+ def pyro_model(self, input, beta=1.0, name_prefix=""):
29
+ # Inducing values p(u)
30
+ with pyro.poutine.scale(scale=beta):
31
+ prior_distribution = self.variational_strategy.prior_distribution
32
+ prior_distribution = prior_distribution.to_event(len(prior_distribution.batch_shape))
33
+ u_samples = pyro.sample(name_prefix + ".u", prior_distribution)
34
+
35
+ # Include term for GPyTorch priors
36
+ log_prior = torch.tensor(0.0, dtype=u_samples.dtype, device=u_samples.device)
37
+ for _, module, prior, closure, _ in self.named_priors():
38
+ log_prior.add_(prior.log_prob(closure(module)).sum())
39
+ pyro.factor(name_prefix + ".log_prior", log_prior)
40
+
41
+ # Include factor for added loss terms
42
+ added_loss = torch.tensor(0.0, dtype=u_samples.dtype, device=u_samples.device)
43
+ for added_loss_term in self.added_loss_terms():
44
+ added_loss.add_(added_loss_term.loss())
45
+ pyro.factor(name_prefix + ".added_loss", added_loss)
46
+
47
+ # Draw samples from p(f)
48
+ function_dist = self(input, prior=True)
49
+ if 'Normal' in input.__class__.__name__:
50
+ function_dist = pyro.distributions.Normal(loc=function_dist.mean, scale=function_dist.stddev).to_event(
51
+ len(function_dist.event_shape) - 1
52
+ )
53
+ elif 'QExponential' in input.__class__.__name__:
54
+ function_dist = pyro.distributions.QExponential(loc=function_dist.mean, scale=function_dist.stddev, power=function_dist.power).to_event(
55
+ len(function_dist.event_shape) - 1
56
+ )
57
+ return function_dist.mask(False)
@@ -0,0 +1,5 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from ....distributions.qexponential import QExponential
4
+
5
+ __all__ = ["QExponential"]
@@ -0,0 +1,105 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import pyro
4
+
5
+ from ..qep import QEP
6
+ from ._pyro_mixin import _PyroMixin
7
+
8
+
9
+ class PyroQEP(QEP, _PyroMixin):
10
+ """
11
+ A :obj:`~gpytorch.models.ApproximateQEP` designed to work with Pyro.
12
+
13
+ This module makes it possible to include QEP models with more complex probablistic models,
14
+ or to use likelihood functions with additional variational/approximate distributions.
15
+
16
+ The parameters of these models are learned using Pyro's inference tools, unlike other models
17
+ that optimize models with respect to a :obj:`~gpytorch.mlls.MarginalLogLikelihood`.
18
+ See `the Pyro examples <examples/09_Pyro_Integration/index.html>`_ for detailed examples.
19
+
20
+ Args:
21
+ variational_strategy (:obj:`~gpytorch.variational.VariationalStrategy`):
22
+ The variational strategy that defines the variational distribution and
23
+ the marginalization strategy.
24
+ likelihood (:obj:`~gpytorch.likelihoods.Likelihood`):
25
+ The likelihood for the model
26
+ num_data (int):
27
+ The total number of training data points (necessary for SGD)
28
+ name_prefix (str, optional):
29
+ A prefix to put in front of pyro sample/plate sites
30
+ beta (float - default 1.):
31
+ A multiplicative factor for the KL divergence term.
32
+ Setting it to 1 (default) recovers true variational inference
33
+ (as derived in `Scalable Variational Gaussian Process Classification`_).
34
+ Setting it to anything less than 1 reduces the regularization effect of the model
35
+ (similarly to what was proposed in `the beta-VAE paper`_).
36
+
37
+ Example:
38
+ >>> class MyVariationalQEP(gpytorch.models.PyroQEP):
39
+ >>> # implementation
40
+ >>>
41
+ >>> # variational_strategy = ...
42
+ >>> likelihood = gpytorch.likelihoods.QExponentialLikelihood()
43
+ >>> model = MyVariationalQEP(variational_strategy, likelihood, train_y.size())
44
+ >>>
45
+ >>> optimizer = pyro.optim.Adam({"lr": 0.01})
46
+ >>> elbo = pyro.infer.Trace_ELBO(num_particles=64, vectorize_particles=True)
47
+ >>> svi = pyro.infer.SVI(model.model, model.guide, optimizer, elbo)
48
+ >>>
49
+ >>> # Optimize variational parameters
50
+ >>> for _ in range(n_iter):
51
+ >>> loss = svi.step(train_x, train_y)
52
+
53
+ .. _Scalable Variational Gaussian Process Classification:
54
+ http://proceedings.mlr.press/v38/hensman15.pdf
55
+ .. _the beta-VAE paper:
56
+ https://openreview.net/pdf?id=Sy2fzU9gl
57
+ """
58
+
59
+ def __init__(self, variational_strategy, likelihood, num_data, name_prefix="", beta=1.0):
60
+ super().__init__()
61
+ self.variational_strategy = variational_strategy
62
+ self.name_prefix = name_prefix
63
+ self.likelihood = likelihood
64
+ self.num_data = num_data
65
+ self.beta = beta
66
+
67
+ # Set values for the likelihood
68
+ self.likelihood.num_data = num_data
69
+ self.likelihood.name_prefix = name_prefix
70
+
71
+ def guide(self, input, target, *args, **kwargs):
72
+ r"""
73
+ Guide function for Pyro inference.
74
+ Includes the guide for the QEP's likelihood function as well.
75
+
76
+ :param torch.Tensor input: :math:`\mathbf X` The input values values
77
+ :param torch.Tensor target: :math:`\mathbf y` The target values
78
+ :param args: Additional arguments passed to the likelihood's forward function.
79
+ :param kwargs: Additional keyword arguments passed to the likelihood's forward function.
80
+ """
81
+ # Get q(f)
82
+ function_dist = self.pyro_guide(input, beta=self.beta, name_prefix=self.name_prefix)
83
+ return self.likelihood.pyro_guide(function_dist, target, *args, **kwargs)
84
+
85
+ def model(self, input, target, *args, **kwargs):
86
+ r"""
87
+ Model function for Pyro inference.
88
+ Includes the model for the QEP's likelihood function as well.
89
+
90
+ :param torch.Tensor input: :math:`\mathbf X` The input values values
91
+ :param torch.Tensor target: :math:`\mathbf y` The target values
92
+ :param args: Additional arguments passed to the likelihood's forward function.
93
+ :param kwargs: Additional keyword arguments passed to the likelihood's forward function.
94
+ """
95
+ # Include module
96
+ pyro.module(self.name_prefix + ".qep", self)
97
+
98
+ # Get p(f)
99
+ function_dist = self.pyro_model(input, beta=self.beta, name_prefix=self.name_prefix)
100
+ return self.likelihood.pyro_model(function_dist, target, *args, **kwargs)
101
+
102
+ def __call__(self, inputs, prior=False):
103
+ if inputs.dim() == 1:
104
+ inputs = inputs.unsqueeze(-1)
105
+ return self.variational_strategy(inputs, prior=prior)
qpytorch/models/qep.py ADDED
@@ -0,0 +1,7 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from ..module import Module
4
+
5
+
6
+ class QEP(Module):
7
+ pass
@@ -0,0 +1,6 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from .bayesian_qeplvm import BayesianQEPLVM
4
+ from .latent_variable import MAPLatentVariable, PointLatentVariable, VariationalLatentVariable
5
+
6
+ __all__ = ["BayesianQEPLVM", "PointLatentVariable", "MAPLatentVariable", "VariationalLatentVariable"]