qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,437 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import math
4
+ import warnings
5
+ from abc import ABC, abstractmethod
6
+ from copy import deepcopy
7
+ from typing import Any, Dict, Optional, Union
8
+
9
+ import torch
10
+ from torch import Tensor
11
+ from torch.distributions import Distribution as _Distribution
12
+
13
+ from .. import settings
14
+ from ..distributions import base_distributions, MultivariateNormal, QExponential, MultivariateQExponential
15
+ from ..module import Module
16
+ from gpytorch.utils.quadrature import GaussHermiteQuadrature1D
17
+ from ..utils.warnings import GPInputWarning
18
+
19
+
20
+ class _Likelihood(Module, ABC):
21
+ has_analytic_marginal: bool = False
22
+
23
+ def __init__(self, max_plate_nesting: int = 1) -> None:
24
+ super().__init__()
25
+ self.max_plate_nesting: int = max_plate_nesting
26
+
27
+ def _draw_likelihood_samples(
28
+ self, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, sample_shape: Optional[torch.Size] = None, **kwargs: Any
29
+ ) -> _Distribution:
30
+ if sample_shape is None:
31
+ sample_shape = torch.Size(
32
+ [settings.num_likelihood_samples.value()]
33
+ + [1] * (self.max_plate_nesting - len(function_dist.batch_shape) - 1)
34
+ )
35
+ else:
36
+ sample_shape = sample_shape[: -len(function_dist.batch_shape) - 1]
37
+ if self.training:
38
+ num_event_dims = len(function_dist.event_shape)
39
+ if isinstance(function_dist, MultivariateNormal):
40
+ function_dist = base_distributions.Normal(function_dist.mean, function_dist.variance.sqrt())
41
+ elif isinstance(function_dist, MultivariateQExponential):
42
+ function_dist = QExponential(function_dist.mean, function_dist.variance.sqrt(), function_dist.power)
43
+ function_dist = base_distributions.Independent(function_dist, num_event_dims - 1)
44
+ function_samples = function_dist.rsample(sample_shape)
45
+ return self.forward(function_samples, *args, **kwargs)
46
+
47
+ def expected_log_prob(
48
+ self, observations: Tensor, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any
49
+ ) -> Tensor:
50
+ likelihood_samples = self._draw_likelihood_samples(function_dist, *args, **kwargs)
51
+ res = likelihood_samples.log_prob(observations, *args, **kwargs).mean(dim=0)
52
+ return res
53
+
54
+ @abstractmethod
55
+ def forward(self, function_samples: Tensor, *args: Any, **kwargs: Any) -> _Distribution:
56
+ raise NotImplementedError
57
+
58
+ def get_fantasy_likelihood(self, **kwargs: Any) -> "_Likelihood":
59
+ return deepcopy(self)
60
+
61
+ def log_marginal(
62
+ self, observations: Tensor, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any
63
+ ) -> Tensor:
64
+ likelihood_samples = self._draw_likelihood_samples(function_dist, *args, **kwargs)
65
+ log_probs = likelihood_samples.log_prob(observations)
66
+ res = log_probs.sub(math.log(log_probs.size(0))).logsumexp(dim=0)
67
+ return res
68
+
69
+ def marginal(self, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any) -> _Distribution:
70
+ res = self._draw_likelihood_samples(function_dist, *args, **kwargs)
71
+ return res
72
+
73
+ def __call__(self, input: Union[Tensor, MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any) -> _Distribution:
74
+ # Conditional
75
+ if torch.is_tensor(input):
76
+ return super().__call__(input, *args, **kwargs) # pyre-ignore[7]
77
+ # Marginal
78
+ elif isinstance(input, (MultivariateNormal, MultivariateQExponential)):
79
+ return self.marginal(input, *args, **kwargs)
80
+ # Error
81
+ else:
82
+ raise RuntimeError(
83
+ "Likelihoods expects a MultivariateNormal or MultivariateQExponential input to make marginal predictions, or a "
84
+ "torch.Tensor for conditional predictions. Got a {}".format(input.__class__.__name__)
85
+ )
86
+
87
+
88
+ try:
89
+ import pyro
90
+
91
+ class Likelihood(_Likelihood):
92
+ r"""
93
+ A Likelihood in GPyTorch specifies the mapping from latent function values
94
+ :math:`f(\mathbf X)` to observed labels :math:`y`.
95
+
96
+ For example, in the case of regression this might be a Gaussian or Q-Exponential
97
+ distribution, as :math:`y(\mathbf x)` is equal to :math:`f(\mathbf x)` plus Gaussian (Q-Exponential) noise:
98
+
99
+ .. math::
100
+ y(\mathbf x) = f(\mathbf x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2}_{n} \mathbf I) or Q-EP(0,\sigma^{2}_{n} \mathbf I)
101
+
102
+ In the case of classification, this might be a Bernoulli distribution,
103
+ where the probability that :math:`y=1` is given by the latent function
104
+ passed through some sigmoid or probit function:
105
+
106
+ .. math::
107
+ y(\mathbf x) = \begin{cases}
108
+ 1 & \text{w/ probability} \:\: \sigma(f(\mathbf x)) \\
109
+ 0 & \text{w/ probability} \:\: 1-\sigma(f(\mathbf x))
110
+ \end{cases}
111
+
112
+ In either case, to implement a likelihood function, GPyTorch only
113
+ requires a forward method that computes the conditional distribution
114
+ :math:`p(y \mid f(\mathbf x))`.
115
+
116
+ :param bool has_analytic_marginal: Whether or not the marginal distribution :math:`p(\mathbf y)`
117
+ can be computed in closed form. (See :meth:`~qpytorch.likelihoods.Likelihood.__call__` docstring.)
118
+ :param max_plate_nesting: (For Pyro integration only.) How many batch dimensions are in the function.
119
+ This should be modified if the likelihood uses plated random variables. (Default = 1)
120
+ :param str name_prefix: (For Pyro integration only.) Prefix to assign to named Pyro latent variables.
121
+ :param int num_data: (For Pyro integration only.) Total amount of observations.
122
+ """
123
+
124
+ @property
125
+ def num_data(self) -> int:
126
+ if hasattr(self, "_num_data"):
127
+ return self._num_data
128
+ else:
129
+ warnings.warn(
130
+ "likelihood.num_data isn't set. This might result in incorrect ELBO scaling.", GPInputWarning
131
+ )
132
+ return ""
133
+
134
+ @num_data.setter
135
+ def num_data(self, val: int) -> None:
136
+ self._num_data = val
137
+
138
+ @property
139
+ def name_prefix(self) -> str:
140
+ if hasattr(self, "_name_prefix"):
141
+ return self._name_prefix
142
+ else:
143
+ return ""
144
+
145
+ @name_prefix.setter
146
+ def name_prefix(self, val: str) -> None:
147
+ self._name_prefix = val
148
+
149
+ def _draw_likelihood_samples(
150
+ self, function_dist: Union[_Distribution, MultivariateQExponential], *args: Any, sample_shape: Optional[torch.Size] = None, **kwargs: Any
151
+ ) -> _Distribution:
152
+ if self.training:
153
+ num_event_dims = len(function_dist.event_shape)
154
+ if isinstance(function_dist, _Distribution):
155
+ function_dist = base_distributions.Normal(function_dist.mean, function_dist.variance.sqrt())
156
+ elif isinstance(function_dist, MultivariateQExponential):
157
+ function_dist = QExponential(function_dist.mean, function_dist.variance.sqrt(), function_dist.power)
158
+ function_dist = base_distributions.Independent(function_dist, num_event_dims - 1)
159
+
160
+ plate_name = self.name_prefix + ".num_particles_vectorized"
161
+ num_samples = settings.num_likelihood_samples.value()
162
+ max_plate_nesting = max(self.max_plate_nesting, len(function_dist.batch_shape))
163
+ with pyro.plate(plate_name, size=num_samples, dim=(-max_plate_nesting - 1)):
164
+ if sample_shape is None:
165
+ function_samples = pyro.sample(self.name_prefix, function_dist.mask(False))
166
+ # Deal with the fact that we're not assuming conditional independence over data points here
167
+ function_samples = function_samples.squeeze(-len(function_dist.event_shape) - 1)
168
+ else:
169
+ sample_shape = sample_shape[: -len(function_dist.batch_shape)]
170
+ function_samples = function_dist(sample_shape)
171
+
172
+ if not self.training:
173
+ function_samples = function_samples.squeeze(-len(function_dist.event_shape) - 1)
174
+ return self.forward(function_samples, *args, **kwargs)
175
+
176
+ def expected_log_prob(
177
+ self, observations: Tensor, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any
178
+ ) -> Tensor:
179
+ r"""
180
+ (Used by :obj:`~qpytorch.mlls.VariationalELBO` for variational inference.)
181
+
182
+ Computes the expected log likelihood, where the expectation is over the GP (QEP) variational distribution.
183
+
184
+ .. math::
185
+ \sum_{\mathbf x, y} \mathbb{E}_{q\left( f(\mathbf x) \right)}
186
+ \left[ \log p \left( y \mid f(\mathbf x) \right) \right]
187
+
188
+ :param observations: Values of :math:`y`.
189
+ :param function_dist: Distribution for :math:`f(x)`.
190
+ :param args: Additional args (passed to the forward function).
191
+ :param kwargs: Additional kwargs (passed to the forward function).
192
+ """
193
+ return super().expected_log_prob(observations, function_dist, *args, **kwargs)
194
+
195
+ @abstractmethod
196
+ def forward(
197
+ self, function_samples: Tensor, *args: Any, data: Dict[str, Tensor] = {}, **kwargs: Any
198
+ ) -> _Distribution:
199
+ r"""
200
+ Computes the conditional distribution :math:`p(\mathbf y \mid
201
+ \mathbf f, \ldots)` that defines the likelihood.
202
+
203
+ :param function_samples: Samples from the function (:math:`\mathbf f`)
204
+ :param data: (Pyro integration only.) Additional variables that the likelihood needs to condition
205
+ on. The keys of the dictionary will correspond to Pyro sample sites
206
+ in the likelihood's model/guide.
207
+ :param args: Additional args
208
+ :param kwargs: Additional kwargs
209
+ """
210
+ raise NotImplementedError
211
+
212
+ def get_fantasy_likelihood(self, **kwargs: Any) -> "_Likelihood":
213
+ """"""
214
+ return super().get_fantasy_likelihood(**kwargs)
215
+
216
+ def log_marginal(
217
+ self, observations: Tensor, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any
218
+ ) -> Tensor:
219
+ r"""
220
+ (Used by :obj:`~qpytorch.mlls.PredictiveLogLikelihood` for approximate inference.)
221
+
222
+ Computes the log marginal likelihood of the approximate predictive distribution
223
+
224
+ .. math::
225
+ \sum_{\mathbf x, y} \log \mathbb{E}_{q\left( f(\mathbf x) \right)}
226
+ \left[ p \left( y \mid f(\mathbf x) \right) \right]
227
+
228
+ Note that this differs from :meth:`expected_log_prob` because the :math:`log` is on the outside
229
+ of the expectation.
230
+
231
+ :param observations: Values of :math:`y`.
232
+ :param function_dist: Distribution for :math:`f(x)`.
233
+ :param args: Additional args (passed to the forward function).
234
+ :param kwargs: Additional kwargs (passed to the forward function).
235
+ """
236
+ return super().log_marginal(observations, function_dist, *args, **kwargs)
237
+
238
+ def marginal(self, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any) -> _Distribution:
239
+ r"""
240
+ Computes a predictive distribution :math:`p(y^* | \mathbf x^*)` given either a posterior
241
+ distribution :math:`p(\mathbf f | \mathcal D, \mathbf x)` or a
242
+ prior distribution :math:`p(\mathbf f|\mathbf x)` as input.
243
+
244
+ With both exact inference and variational inference, the form of
245
+ :math:`p(\mathbf f|\mathcal D, \mathbf x)` or :math:`p(\mathbf f|
246
+ \mathbf x)` should usually be Gaussian or Q-Exponential. As a result, function_dist
247
+ should usually be a :obj:`~gpytorch.distributions.MultivariateNormal`
248
+ or :obj:`~qpytorch.distributions.MultivariateQExponential` specified by the mean and
249
+ (co)variance of :math:`p(\mathbf f|...)`.
250
+
251
+ :param function_dist: Distribution for :math:`f(x)`.
252
+ :param args: Additional args (passed to the forward function).
253
+ :param kwargs: Additional kwargs (passed to the forward function).
254
+ :return: The marginal distribution, or samples from it.
255
+ """
256
+ return super().marginal(function_dist, *args, **kwargs)
257
+
258
+ def pyro_guide(self, function_dist: Union[MultivariateNormal, MultivariateQExponential], target: Tensor, *args: Any, **kwargs: Any) -> None:
259
+ r"""
260
+ (For Pyro integration only).
261
+
262
+ Part of the guide function for the likelihood.
263
+ This should be re-defined if the likelihood contains any latent variables that need to be infered.
264
+
265
+ :param function_dist: Distribution of latent function
266
+ :math:`q(\mathbf f)`.
267
+ :param target: Observed :math:`\mathbf y`.
268
+ :param args: Additional args (passed to the forward function).
269
+ :param kwargs: Additional kwargs (passed to the forward function).
270
+ """
271
+ with pyro.plate(self.name_prefix + ".data_plate", dim=-1):
272
+ pyro.sample(self.name_prefix + ".f", function_dist)
273
+
274
+ def pyro_model(self, function_dist: Union[MultivariateNormal, MultivariateQExponential], target: Tensor, *args: Any, **kwargs: Any) -> Tensor:
275
+ r"""
276
+ (For Pyro integration only).
277
+
278
+ Part of the model function for the likelihood.
279
+ It should return the
280
+ This should be re-defined if the likelihood contains any latent variables that need to be infered.
281
+
282
+ :param function_dist: Distribution of latent function
283
+ :math:`p(\mathbf f)`.
284
+ :param target: Observed :math:`\mathbf y`.
285
+ :param args: Additional args (passed to the forward function).
286
+ :param kwargs: Additional kwargs (passed to the forward function).
287
+ """
288
+ with pyro.plate(self.name_prefix + ".data_plate", dim=-1):
289
+ function_samples = pyro.sample(self.name_prefix + ".f", function_dist)
290
+ output_dist = self(function_samples, *args, **kwargs)
291
+ return self.sample_target(output_dist, target)
292
+
293
+ def sample_target(self, output_dist: Union[MultivariateNormal, MultivariateQExponential], target: Tensor) -> Tensor:
294
+ scale = (self.num_data or output_dist.batch_shape[-1]) / output_dist.batch_shape[-1]
295
+ with pyro.poutine.scale(scale=scale): # pyre-ignore[16]
296
+ return pyro.sample(self.name_prefix + ".y", output_dist, obs=target)
297
+
298
+ def __call__(self, input: Union[Tensor, MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any) -> _Distribution:
299
+ r"""
300
+ Calling this object does one of two things:
301
+
302
+ 1. If likelihood is called with a :class:`torch.Tensor` object, then it is
303
+ assumed that the input is samples from :math:`f(\mathbf x)`. This
304
+ returns the *conditional* distribution :math:`p(y|f(\mathbf x))`.
305
+
306
+ .. code-block:: python
307
+
308
+ f = torch.randn(20)
309
+ likelihood = qpytorch.likelihoods.GaussianLikelihood() #or qpytorch.likelihoods.QExponentialLikelihood()
310
+ conditional = likelihood(f)
311
+ print(type(conditional), conditional.batch_shape, conditional.event_shape)
312
+ # >>> <class 'torch.distributions.normal.Normal'> torch.Size([20]) torch.Size([])
313
+ # or >>> <class 'qpytorch.distributions.qexponential.QExponential'> torch.Size([20]) torch.Size([])
314
+
315
+ 2. If likelihood is called with a :class:`~gpytorch.distributions.MultivariateNormal`
316
+ or :class:`~qpytorch.distributions.MultivariateQExponential` object,
317
+ then it is assumed that the input is the distribution :math:`f(\mathbf x)`.
318
+ This returns the *marginal* distribution :math:`p(y|\mathbf x)`.
319
+
320
+ The form of the marginal distribution depends on the likelihood.
321
+ For :class:`~qpytorch.likelihoods.BernoulliLikelihood` and
322
+ :class:`~qpytorch.likelihoods.GaussianLikelihood` and
323
+ :class:`~qpytorch.likelihoods.QExponentialLikelihood` objects, the marginal distribution
324
+ can be computed analytically, and the likelihood returns the analytic distribution.
325
+ For most other likelihoods, there is no analytic form for the marginal,
326
+ and so the likelihood instead returns a batch of Monte Carlo samples from the marginal.
327
+
328
+ .. code-block:: python
329
+
330
+ mean = torch.randn(20)
331
+ covar = linear_operator.operators.DiagLinearOperator(torch.ones(20))
332
+ f = qpytorch.distributions.MultivariateNormal(mean, covar) or
333
+ power = torch.tensor(1.0)
334
+ f = qpytorch.distributions.MultivariateQExponential(mean, covar, power)
335
+
336
+ # Analytic marginal computation - Bernoulli and Gaussian and Q-Exponential likelihoods only
337
+ analytic_marginal_likelihood = qpytorch.likelihoods.GaussianLikelihood()
338
+ #or qpytorch.likelihoods.QExponentialLikelihood()
339
+ marginal = analytic_marginal_likelihood(f)
340
+ print(type(marginal), marginal.batch_shape, marginal.event_shape)
341
+ # >>> <class 'gpytorch.distributions.multivariate_normal.MultivariateNormal'> torch.Size([]) torch.Size([20]) # noqa: E501
342
+ # or >>> <class 'qpytorch.distributions.multivariate_qexponential.MultivariateQExponential'> torch.Size([]) torch.Size([20])
343
+
344
+ # MC marginal computation - all other likelihoods
345
+ mc_marginal_likelihood = qpytorch.likelihoods.BetaLikelihood()
346
+ with qpytorch.settings.num_likelihood_samples(15):
347
+ marginal = mc_marginal_likelihood(f)
348
+ print(type(marginal), marginal.batch_shape, marginal.event_shape)
349
+ # >>> <class 'torch.distributions.beta.Beta'> torch.Size([15, 20]) torch.Size([])
350
+ # The batch_shape torch.Size([15, 20]) represents 15 MC samples for 20 data points.
351
+
352
+ .. note::
353
+
354
+ If a Likelihood supports analytic marginals, the :attr:`has_analytic_marginal` property will be True.
355
+ If a Likelihood does not support analytic marginals, you can set the number of Monte Carlo
356
+ samples using the :class:`gpytorch.settings.num_likelihood_samples` context manager.
357
+
358
+ :param input: Either a (... x N) sample from :math:`\mathbf f`
359
+ or a (... x N) MVN (QEP) distribution of :math:`\mathbf f`.
360
+ :param args: Additional args (passed to the forward function).
361
+ :param kwargs: Additional kwargs (passed to the forward function).
362
+ :return: Either a conditional :math:`p(\mathbf y \mid \mathbf f)`
363
+ or marginal :math:`p(\mathbf y)`
364
+ based on whether :attr:`input` is a Tensor or a MultivariateNormal or MultivariateQExponential (see above).
365
+ """
366
+ # Conditional
367
+ if torch.is_tensor(input):
368
+ return super().__call__(input, *args, **kwargs)
369
+ # Marginal
370
+ elif any(
371
+ [
372
+ isinstance(input, (MultivariateNormal, MultivariateQExponential)),
373
+ isinstance(input, (pyro.distributions.Normal, QExponential)), # pyre-ignore[16]
374
+ (
375
+ isinstance(input, pyro.distributions.Independent) # pyre-ignore[16]
376
+ and isinstance(input.base_dist, pyro.distributions.Normal) # pyre-ignore[16]
377
+ ),
378
+ ]
379
+ ):
380
+ return self.marginal(input, *args, **kwargs) # pyre-ignore[6]
381
+ # Error
382
+ else:
383
+ raise RuntimeError(
384
+ "Likelihoods expects a MultivariateNormal (MultivariateQExponential) or Normal (QExponential) input to make marginal predictions, or a "
385
+ "torch.Tensor for conditional predictions. Got a {}".format(input.__class__.__name__)
386
+ )
387
+
388
+ except ImportError:
389
+
390
+ class Likelihood(_Likelihood):
391
+ @property
392
+ def num_data(self) -> int:
393
+ warnings.warn("num_data is only used for likelihoods that are integrated with Pyro.", RuntimeWarning)
394
+ return 0
395
+
396
+ @num_data.setter
397
+ def num_data(self, val: int) -> None:
398
+ warnings.warn("num_data is only used for likelihoods that are integrated with Pyro.", RuntimeWarning)
399
+
400
+ @property
401
+ def name_prefix(self) -> str:
402
+ warnings.warn("name_prefix is only used for likelihoods that are integrated with Pyro.", RuntimeWarning)
403
+ return ""
404
+
405
+ @name_prefix.setter
406
+ def name_prefix(self, val: str) -> None:
407
+ warnings.warn("name_prefix is only used for likelihoods that are integrated with Pyro.", RuntimeWarning)
408
+
409
+
410
+ class _OneDimensionalLikelihood(Likelihood, ABC):
411
+ r"""
412
+ A specific case of :obj:`~qpytorch.likelihoods.Likelihood` when the GP (QEP) represents a one-dimensional
413
+ output. (I.e. for a specific :math:`\mathbf x`, :math:`f(\mathbf x) \in \mathbb{R}`.)
414
+
415
+ Inheriting from this likelihood reduces the variance when computing approximate GP (QEP) objective functions
416
+ by using 1D Gauss-Hermite quadrature.
417
+ """
418
+
419
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
420
+ super().__init__(*args, **kwargs)
421
+ self.quadrature = GaussHermiteQuadrature1D()
422
+
423
+ def expected_log_prob(
424
+ self, observations: Tensor, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any
425
+ ) -> Tensor:
426
+ log_prob_lambda = lambda function_samples: self.forward(function_samples, *args, **kwargs).log_prob(
427
+ observations
428
+ )
429
+ log_prob = self.quadrature(log_prob_lambda, function_dist)
430
+ return log_prob
431
+
432
+ def log_marginal(
433
+ self, observations: Tensor, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any
434
+ ) -> Tensor:
435
+ prob_lambda = lambda function_samples: self.forward(function_samples).log_prob(observations).exp()
436
+ prob = self.quadrature(prob_lambda, function_dist)
437
+ return prob.log()
@@ -0,0 +1,60 @@
1
+ #! /usr/bin/env python3
2
+
3
+ from torch.nn import ModuleList
4
+
5
+ from . import Likelihood
6
+ from gpytorch.utils.generic import length_safe_zip
7
+
8
+
9
+ def _get_tuple_args_(*args):
10
+ for arg in args:
11
+ if isinstance(arg, tuple):
12
+ yield arg
13
+ else:
14
+ yield (arg,)
15
+
16
+
17
+ class LikelihoodList(Likelihood):
18
+ def __init__(self, *likelihoods):
19
+ super().__init__()
20
+ self.likelihoods = ModuleList(likelihoods)
21
+
22
+ def expected_log_prob(self, *args, **kwargs):
23
+ return [
24
+ likelihood.expected_log_prob(*args_, **kwargs)
25
+ for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
26
+ ]
27
+
28
+ def forward(self, *args, **kwargs):
29
+ if "noise" in kwargs:
30
+ noise = kwargs.pop("noise")
31
+ # if noise kwarg is passed, assume it's an iterable of noise tensors
32
+ return [
33
+ likelihood.forward(*args_, {**kwargs, "noise": noise_})
34
+ for likelihood, args_, noise_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args), noise)
35
+ ]
36
+ else:
37
+ return [
38
+ likelihood.forward(*args_, **kwargs)
39
+ for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
40
+ ]
41
+
42
+ def pyro_sample_output(self, *args, **kwargs):
43
+ return [
44
+ likelihood.pyro_sample_output(*args_, **kwargs)
45
+ for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
46
+ ]
47
+
48
+ def __call__(self, *args, **kwargs):
49
+ if "noise" in kwargs:
50
+ noise = kwargs.pop("noise")
51
+ # if noise kwarg is passed, assume it's an iterable of noise tensors
52
+ return [
53
+ likelihood(*args_, {**kwargs, "noise": noise_})
54
+ for likelihood, args_, noise_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args), noise)
55
+ ]
56
+ else:
57
+ return [
58
+ likelihood(*args_, **kwargs)
59
+ for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
60
+ ]