qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,437 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import warnings
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
from typing import Any, Dict, Optional, Union
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
from torch.distributions import Distribution as _Distribution
|
|
12
|
+
|
|
13
|
+
from .. import settings
|
|
14
|
+
from ..distributions import base_distributions, MultivariateNormal, QExponential, MultivariateQExponential
|
|
15
|
+
from ..module import Module
|
|
16
|
+
from gpytorch.utils.quadrature import GaussHermiteQuadrature1D
|
|
17
|
+
from ..utils.warnings import GPInputWarning
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class _Likelihood(Module, ABC):
|
|
21
|
+
has_analytic_marginal: bool = False
|
|
22
|
+
|
|
23
|
+
def __init__(self, max_plate_nesting: int = 1) -> None:
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.max_plate_nesting: int = max_plate_nesting
|
|
26
|
+
|
|
27
|
+
def _draw_likelihood_samples(
|
|
28
|
+
self, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, sample_shape: Optional[torch.Size] = None, **kwargs: Any
|
|
29
|
+
) -> _Distribution:
|
|
30
|
+
if sample_shape is None:
|
|
31
|
+
sample_shape = torch.Size(
|
|
32
|
+
[settings.num_likelihood_samples.value()]
|
|
33
|
+
+ [1] * (self.max_plate_nesting - len(function_dist.batch_shape) - 1)
|
|
34
|
+
)
|
|
35
|
+
else:
|
|
36
|
+
sample_shape = sample_shape[: -len(function_dist.batch_shape) - 1]
|
|
37
|
+
if self.training:
|
|
38
|
+
num_event_dims = len(function_dist.event_shape)
|
|
39
|
+
if isinstance(function_dist, MultivariateNormal):
|
|
40
|
+
function_dist = base_distributions.Normal(function_dist.mean, function_dist.variance.sqrt())
|
|
41
|
+
elif isinstance(function_dist, MultivariateQExponential):
|
|
42
|
+
function_dist = QExponential(function_dist.mean, function_dist.variance.sqrt(), function_dist.power)
|
|
43
|
+
function_dist = base_distributions.Independent(function_dist, num_event_dims - 1)
|
|
44
|
+
function_samples = function_dist.rsample(sample_shape)
|
|
45
|
+
return self.forward(function_samples, *args, **kwargs)
|
|
46
|
+
|
|
47
|
+
def expected_log_prob(
|
|
48
|
+
self, observations: Tensor, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any
|
|
49
|
+
) -> Tensor:
|
|
50
|
+
likelihood_samples = self._draw_likelihood_samples(function_dist, *args, **kwargs)
|
|
51
|
+
res = likelihood_samples.log_prob(observations, *args, **kwargs).mean(dim=0)
|
|
52
|
+
return res
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def forward(self, function_samples: Tensor, *args: Any, **kwargs: Any) -> _Distribution:
|
|
56
|
+
raise NotImplementedError
|
|
57
|
+
|
|
58
|
+
def get_fantasy_likelihood(self, **kwargs: Any) -> "_Likelihood":
|
|
59
|
+
return deepcopy(self)
|
|
60
|
+
|
|
61
|
+
def log_marginal(
|
|
62
|
+
self, observations: Tensor, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any
|
|
63
|
+
) -> Tensor:
|
|
64
|
+
likelihood_samples = self._draw_likelihood_samples(function_dist, *args, **kwargs)
|
|
65
|
+
log_probs = likelihood_samples.log_prob(observations)
|
|
66
|
+
res = log_probs.sub(math.log(log_probs.size(0))).logsumexp(dim=0)
|
|
67
|
+
return res
|
|
68
|
+
|
|
69
|
+
def marginal(self, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any) -> _Distribution:
|
|
70
|
+
res = self._draw_likelihood_samples(function_dist, *args, **kwargs)
|
|
71
|
+
return res
|
|
72
|
+
|
|
73
|
+
def __call__(self, input: Union[Tensor, MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any) -> _Distribution:
|
|
74
|
+
# Conditional
|
|
75
|
+
if torch.is_tensor(input):
|
|
76
|
+
return super().__call__(input, *args, **kwargs) # pyre-ignore[7]
|
|
77
|
+
# Marginal
|
|
78
|
+
elif isinstance(input, (MultivariateNormal, MultivariateQExponential)):
|
|
79
|
+
return self.marginal(input, *args, **kwargs)
|
|
80
|
+
# Error
|
|
81
|
+
else:
|
|
82
|
+
raise RuntimeError(
|
|
83
|
+
"Likelihoods expects a MultivariateNormal or MultivariateQExponential input to make marginal predictions, or a "
|
|
84
|
+
"torch.Tensor for conditional predictions. Got a {}".format(input.__class__.__name__)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
import pyro
|
|
90
|
+
|
|
91
|
+
class Likelihood(_Likelihood):
|
|
92
|
+
r"""
|
|
93
|
+
A Likelihood in GPyTorch specifies the mapping from latent function values
|
|
94
|
+
:math:`f(\mathbf X)` to observed labels :math:`y`.
|
|
95
|
+
|
|
96
|
+
For example, in the case of regression this might be a Gaussian or Q-Exponential
|
|
97
|
+
distribution, as :math:`y(\mathbf x)` is equal to :math:`f(\mathbf x)` plus Gaussian (Q-Exponential) noise:
|
|
98
|
+
|
|
99
|
+
.. math::
|
|
100
|
+
y(\mathbf x) = f(\mathbf x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2}_{n} \mathbf I) or Q-EP(0,\sigma^{2}_{n} \mathbf I)
|
|
101
|
+
|
|
102
|
+
In the case of classification, this might be a Bernoulli distribution,
|
|
103
|
+
where the probability that :math:`y=1` is given by the latent function
|
|
104
|
+
passed through some sigmoid or probit function:
|
|
105
|
+
|
|
106
|
+
.. math::
|
|
107
|
+
y(\mathbf x) = \begin{cases}
|
|
108
|
+
1 & \text{w/ probability} \:\: \sigma(f(\mathbf x)) \\
|
|
109
|
+
0 & \text{w/ probability} \:\: 1-\sigma(f(\mathbf x))
|
|
110
|
+
\end{cases}
|
|
111
|
+
|
|
112
|
+
In either case, to implement a likelihood function, GPyTorch only
|
|
113
|
+
requires a forward method that computes the conditional distribution
|
|
114
|
+
:math:`p(y \mid f(\mathbf x))`.
|
|
115
|
+
|
|
116
|
+
:param bool has_analytic_marginal: Whether or not the marginal distribution :math:`p(\mathbf y)`
|
|
117
|
+
can be computed in closed form. (See :meth:`~qpytorch.likelihoods.Likelihood.__call__` docstring.)
|
|
118
|
+
:param max_plate_nesting: (For Pyro integration only.) How many batch dimensions are in the function.
|
|
119
|
+
This should be modified if the likelihood uses plated random variables. (Default = 1)
|
|
120
|
+
:param str name_prefix: (For Pyro integration only.) Prefix to assign to named Pyro latent variables.
|
|
121
|
+
:param int num_data: (For Pyro integration only.) Total amount of observations.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def num_data(self) -> int:
|
|
126
|
+
if hasattr(self, "_num_data"):
|
|
127
|
+
return self._num_data
|
|
128
|
+
else:
|
|
129
|
+
warnings.warn(
|
|
130
|
+
"likelihood.num_data isn't set. This might result in incorrect ELBO scaling.", GPInputWarning
|
|
131
|
+
)
|
|
132
|
+
return ""
|
|
133
|
+
|
|
134
|
+
@num_data.setter
|
|
135
|
+
def num_data(self, val: int) -> None:
|
|
136
|
+
self._num_data = val
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def name_prefix(self) -> str:
|
|
140
|
+
if hasattr(self, "_name_prefix"):
|
|
141
|
+
return self._name_prefix
|
|
142
|
+
else:
|
|
143
|
+
return ""
|
|
144
|
+
|
|
145
|
+
@name_prefix.setter
|
|
146
|
+
def name_prefix(self, val: str) -> None:
|
|
147
|
+
self._name_prefix = val
|
|
148
|
+
|
|
149
|
+
def _draw_likelihood_samples(
|
|
150
|
+
self, function_dist: Union[_Distribution, MultivariateQExponential], *args: Any, sample_shape: Optional[torch.Size] = None, **kwargs: Any
|
|
151
|
+
) -> _Distribution:
|
|
152
|
+
if self.training:
|
|
153
|
+
num_event_dims = len(function_dist.event_shape)
|
|
154
|
+
if isinstance(function_dist, _Distribution):
|
|
155
|
+
function_dist = base_distributions.Normal(function_dist.mean, function_dist.variance.sqrt())
|
|
156
|
+
elif isinstance(function_dist, MultivariateQExponential):
|
|
157
|
+
function_dist = QExponential(function_dist.mean, function_dist.variance.sqrt(), function_dist.power)
|
|
158
|
+
function_dist = base_distributions.Independent(function_dist, num_event_dims - 1)
|
|
159
|
+
|
|
160
|
+
plate_name = self.name_prefix + ".num_particles_vectorized"
|
|
161
|
+
num_samples = settings.num_likelihood_samples.value()
|
|
162
|
+
max_plate_nesting = max(self.max_plate_nesting, len(function_dist.batch_shape))
|
|
163
|
+
with pyro.plate(plate_name, size=num_samples, dim=(-max_plate_nesting - 1)):
|
|
164
|
+
if sample_shape is None:
|
|
165
|
+
function_samples = pyro.sample(self.name_prefix, function_dist.mask(False))
|
|
166
|
+
# Deal with the fact that we're not assuming conditional independence over data points here
|
|
167
|
+
function_samples = function_samples.squeeze(-len(function_dist.event_shape) - 1)
|
|
168
|
+
else:
|
|
169
|
+
sample_shape = sample_shape[: -len(function_dist.batch_shape)]
|
|
170
|
+
function_samples = function_dist(sample_shape)
|
|
171
|
+
|
|
172
|
+
if not self.training:
|
|
173
|
+
function_samples = function_samples.squeeze(-len(function_dist.event_shape) - 1)
|
|
174
|
+
return self.forward(function_samples, *args, **kwargs)
|
|
175
|
+
|
|
176
|
+
def expected_log_prob(
|
|
177
|
+
self, observations: Tensor, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any
|
|
178
|
+
) -> Tensor:
|
|
179
|
+
r"""
|
|
180
|
+
(Used by :obj:`~qpytorch.mlls.VariationalELBO` for variational inference.)
|
|
181
|
+
|
|
182
|
+
Computes the expected log likelihood, where the expectation is over the GP (QEP) variational distribution.
|
|
183
|
+
|
|
184
|
+
.. math::
|
|
185
|
+
\sum_{\mathbf x, y} \mathbb{E}_{q\left( f(\mathbf x) \right)}
|
|
186
|
+
\left[ \log p \left( y \mid f(\mathbf x) \right) \right]
|
|
187
|
+
|
|
188
|
+
:param observations: Values of :math:`y`.
|
|
189
|
+
:param function_dist: Distribution for :math:`f(x)`.
|
|
190
|
+
:param args: Additional args (passed to the forward function).
|
|
191
|
+
:param kwargs: Additional kwargs (passed to the forward function).
|
|
192
|
+
"""
|
|
193
|
+
return super().expected_log_prob(observations, function_dist, *args, **kwargs)
|
|
194
|
+
|
|
195
|
+
@abstractmethod
|
|
196
|
+
def forward(
|
|
197
|
+
self, function_samples: Tensor, *args: Any, data: Dict[str, Tensor] = {}, **kwargs: Any
|
|
198
|
+
) -> _Distribution:
|
|
199
|
+
r"""
|
|
200
|
+
Computes the conditional distribution :math:`p(\mathbf y \mid
|
|
201
|
+
\mathbf f, \ldots)` that defines the likelihood.
|
|
202
|
+
|
|
203
|
+
:param function_samples: Samples from the function (:math:`\mathbf f`)
|
|
204
|
+
:param data: (Pyro integration only.) Additional variables that the likelihood needs to condition
|
|
205
|
+
on. The keys of the dictionary will correspond to Pyro sample sites
|
|
206
|
+
in the likelihood's model/guide.
|
|
207
|
+
:param args: Additional args
|
|
208
|
+
:param kwargs: Additional kwargs
|
|
209
|
+
"""
|
|
210
|
+
raise NotImplementedError
|
|
211
|
+
|
|
212
|
+
def get_fantasy_likelihood(self, **kwargs: Any) -> "_Likelihood":
|
|
213
|
+
""""""
|
|
214
|
+
return super().get_fantasy_likelihood(**kwargs)
|
|
215
|
+
|
|
216
|
+
def log_marginal(
|
|
217
|
+
self, observations: Tensor, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any
|
|
218
|
+
) -> Tensor:
|
|
219
|
+
r"""
|
|
220
|
+
(Used by :obj:`~qpytorch.mlls.PredictiveLogLikelihood` for approximate inference.)
|
|
221
|
+
|
|
222
|
+
Computes the log marginal likelihood of the approximate predictive distribution
|
|
223
|
+
|
|
224
|
+
.. math::
|
|
225
|
+
\sum_{\mathbf x, y} \log \mathbb{E}_{q\left( f(\mathbf x) \right)}
|
|
226
|
+
\left[ p \left( y \mid f(\mathbf x) \right) \right]
|
|
227
|
+
|
|
228
|
+
Note that this differs from :meth:`expected_log_prob` because the :math:`log` is on the outside
|
|
229
|
+
of the expectation.
|
|
230
|
+
|
|
231
|
+
:param observations: Values of :math:`y`.
|
|
232
|
+
:param function_dist: Distribution for :math:`f(x)`.
|
|
233
|
+
:param args: Additional args (passed to the forward function).
|
|
234
|
+
:param kwargs: Additional kwargs (passed to the forward function).
|
|
235
|
+
"""
|
|
236
|
+
return super().log_marginal(observations, function_dist, *args, **kwargs)
|
|
237
|
+
|
|
238
|
+
def marginal(self, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any) -> _Distribution:
|
|
239
|
+
r"""
|
|
240
|
+
Computes a predictive distribution :math:`p(y^* | \mathbf x^*)` given either a posterior
|
|
241
|
+
distribution :math:`p(\mathbf f | \mathcal D, \mathbf x)` or a
|
|
242
|
+
prior distribution :math:`p(\mathbf f|\mathbf x)` as input.
|
|
243
|
+
|
|
244
|
+
With both exact inference and variational inference, the form of
|
|
245
|
+
:math:`p(\mathbf f|\mathcal D, \mathbf x)` or :math:`p(\mathbf f|
|
|
246
|
+
\mathbf x)` should usually be Gaussian or Q-Exponential. As a result, function_dist
|
|
247
|
+
should usually be a :obj:`~gpytorch.distributions.MultivariateNormal`
|
|
248
|
+
or :obj:`~qpytorch.distributions.MultivariateQExponential` specified by the mean and
|
|
249
|
+
(co)variance of :math:`p(\mathbf f|...)`.
|
|
250
|
+
|
|
251
|
+
:param function_dist: Distribution for :math:`f(x)`.
|
|
252
|
+
:param args: Additional args (passed to the forward function).
|
|
253
|
+
:param kwargs: Additional kwargs (passed to the forward function).
|
|
254
|
+
:return: The marginal distribution, or samples from it.
|
|
255
|
+
"""
|
|
256
|
+
return super().marginal(function_dist, *args, **kwargs)
|
|
257
|
+
|
|
258
|
+
def pyro_guide(self, function_dist: Union[MultivariateNormal, MultivariateQExponential], target: Tensor, *args: Any, **kwargs: Any) -> None:
|
|
259
|
+
r"""
|
|
260
|
+
(For Pyro integration only).
|
|
261
|
+
|
|
262
|
+
Part of the guide function for the likelihood.
|
|
263
|
+
This should be re-defined if the likelihood contains any latent variables that need to be infered.
|
|
264
|
+
|
|
265
|
+
:param function_dist: Distribution of latent function
|
|
266
|
+
:math:`q(\mathbf f)`.
|
|
267
|
+
:param target: Observed :math:`\mathbf y`.
|
|
268
|
+
:param args: Additional args (passed to the forward function).
|
|
269
|
+
:param kwargs: Additional kwargs (passed to the forward function).
|
|
270
|
+
"""
|
|
271
|
+
with pyro.plate(self.name_prefix + ".data_plate", dim=-1):
|
|
272
|
+
pyro.sample(self.name_prefix + ".f", function_dist)
|
|
273
|
+
|
|
274
|
+
def pyro_model(self, function_dist: Union[MultivariateNormal, MultivariateQExponential], target: Tensor, *args: Any, **kwargs: Any) -> Tensor:
|
|
275
|
+
r"""
|
|
276
|
+
(For Pyro integration only).
|
|
277
|
+
|
|
278
|
+
Part of the model function for the likelihood.
|
|
279
|
+
It should return the
|
|
280
|
+
This should be re-defined if the likelihood contains any latent variables that need to be infered.
|
|
281
|
+
|
|
282
|
+
:param function_dist: Distribution of latent function
|
|
283
|
+
:math:`p(\mathbf f)`.
|
|
284
|
+
:param target: Observed :math:`\mathbf y`.
|
|
285
|
+
:param args: Additional args (passed to the forward function).
|
|
286
|
+
:param kwargs: Additional kwargs (passed to the forward function).
|
|
287
|
+
"""
|
|
288
|
+
with pyro.plate(self.name_prefix + ".data_plate", dim=-1):
|
|
289
|
+
function_samples = pyro.sample(self.name_prefix + ".f", function_dist)
|
|
290
|
+
output_dist = self(function_samples, *args, **kwargs)
|
|
291
|
+
return self.sample_target(output_dist, target)
|
|
292
|
+
|
|
293
|
+
def sample_target(self, output_dist: Union[MultivariateNormal, MultivariateQExponential], target: Tensor) -> Tensor:
|
|
294
|
+
scale = (self.num_data or output_dist.batch_shape[-1]) / output_dist.batch_shape[-1]
|
|
295
|
+
with pyro.poutine.scale(scale=scale): # pyre-ignore[16]
|
|
296
|
+
return pyro.sample(self.name_prefix + ".y", output_dist, obs=target)
|
|
297
|
+
|
|
298
|
+
def __call__(self, input: Union[Tensor, MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any) -> _Distribution:
|
|
299
|
+
r"""
|
|
300
|
+
Calling this object does one of two things:
|
|
301
|
+
|
|
302
|
+
1. If likelihood is called with a :class:`torch.Tensor` object, then it is
|
|
303
|
+
assumed that the input is samples from :math:`f(\mathbf x)`. This
|
|
304
|
+
returns the *conditional* distribution :math:`p(y|f(\mathbf x))`.
|
|
305
|
+
|
|
306
|
+
.. code-block:: python
|
|
307
|
+
|
|
308
|
+
f = torch.randn(20)
|
|
309
|
+
likelihood = qpytorch.likelihoods.GaussianLikelihood() #or qpytorch.likelihoods.QExponentialLikelihood()
|
|
310
|
+
conditional = likelihood(f)
|
|
311
|
+
print(type(conditional), conditional.batch_shape, conditional.event_shape)
|
|
312
|
+
# >>> <class 'torch.distributions.normal.Normal'> torch.Size([20]) torch.Size([])
|
|
313
|
+
# or >>> <class 'qpytorch.distributions.qexponential.QExponential'> torch.Size([20]) torch.Size([])
|
|
314
|
+
|
|
315
|
+
2. If likelihood is called with a :class:`~gpytorch.distributions.MultivariateNormal`
|
|
316
|
+
or :class:`~qpytorch.distributions.MultivariateQExponential` object,
|
|
317
|
+
then it is assumed that the input is the distribution :math:`f(\mathbf x)`.
|
|
318
|
+
This returns the *marginal* distribution :math:`p(y|\mathbf x)`.
|
|
319
|
+
|
|
320
|
+
The form of the marginal distribution depends on the likelihood.
|
|
321
|
+
For :class:`~qpytorch.likelihoods.BernoulliLikelihood` and
|
|
322
|
+
:class:`~qpytorch.likelihoods.GaussianLikelihood` and
|
|
323
|
+
:class:`~qpytorch.likelihoods.QExponentialLikelihood` objects, the marginal distribution
|
|
324
|
+
can be computed analytically, and the likelihood returns the analytic distribution.
|
|
325
|
+
For most other likelihoods, there is no analytic form for the marginal,
|
|
326
|
+
and so the likelihood instead returns a batch of Monte Carlo samples from the marginal.
|
|
327
|
+
|
|
328
|
+
.. code-block:: python
|
|
329
|
+
|
|
330
|
+
mean = torch.randn(20)
|
|
331
|
+
covar = linear_operator.operators.DiagLinearOperator(torch.ones(20))
|
|
332
|
+
f = qpytorch.distributions.MultivariateNormal(mean, covar) or
|
|
333
|
+
power = torch.tensor(1.0)
|
|
334
|
+
f = qpytorch.distributions.MultivariateQExponential(mean, covar, power)
|
|
335
|
+
|
|
336
|
+
# Analytic marginal computation - Bernoulli and Gaussian and Q-Exponential likelihoods only
|
|
337
|
+
analytic_marginal_likelihood = qpytorch.likelihoods.GaussianLikelihood()
|
|
338
|
+
#or qpytorch.likelihoods.QExponentialLikelihood()
|
|
339
|
+
marginal = analytic_marginal_likelihood(f)
|
|
340
|
+
print(type(marginal), marginal.batch_shape, marginal.event_shape)
|
|
341
|
+
# >>> <class 'gpytorch.distributions.multivariate_normal.MultivariateNormal'> torch.Size([]) torch.Size([20]) # noqa: E501
|
|
342
|
+
# or >>> <class 'qpytorch.distributions.multivariate_qexponential.MultivariateQExponential'> torch.Size([]) torch.Size([20])
|
|
343
|
+
|
|
344
|
+
# MC marginal computation - all other likelihoods
|
|
345
|
+
mc_marginal_likelihood = qpytorch.likelihoods.BetaLikelihood()
|
|
346
|
+
with qpytorch.settings.num_likelihood_samples(15):
|
|
347
|
+
marginal = mc_marginal_likelihood(f)
|
|
348
|
+
print(type(marginal), marginal.batch_shape, marginal.event_shape)
|
|
349
|
+
# >>> <class 'torch.distributions.beta.Beta'> torch.Size([15, 20]) torch.Size([])
|
|
350
|
+
# The batch_shape torch.Size([15, 20]) represents 15 MC samples for 20 data points.
|
|
351
|
+
|
|
352
|
+
.. note::
|
|
353
|
+
|
|
354
|
+
If a Likelihood supports analytic marginals, the :attr:`has_analytic_marginal` property will be True.
|
|
355
|
+
If a Likelihood does not support analytic marginals, you can set the number of Monte Carlo
|
|
356
|
+
samples using the :class:`gpytorch.settings.num_likelihood_samples` context manager.
|
|
357
|
+
|
|
358
|
+
:param input: Either a (... x N) sample from :math:`\mathbf f`
|
|
359
|
+
or a (... x N) MVN (QEP) distribution of :math:`\mathbf f`.
|
|
360
|
+
:param args: Additional args (passed to the forward function).
|
|
361
|
+
:param kwargs: Additional kwargs (passed to the forward function).
|
|
362
|
+
:return: Either a conditional :math:`p(\mathbf y \mid \mathbf f)`
|
|
363
|
+
or marginal :math:`p(\mathbf y)`
|
|
364
|
+
based on whether :attr:`input` is a Tensor or a MultivariateNormal or MultivariateQExponential (see above).
|
|
365
|
+
"""
|
|
366
|
+
# Conditional
|
|
367
|
+
if torch.is_tensor(input):
|
|
368
|
+
return super().__call__(input, *args, **kwargs)
|
|
369
|
+
# Marginal
|
|
370
|
+
elif any(
|
|
371
|
+
[
|
|
372
|
+
isinstance(input, (MultivariateNormal, MultivariateQExponential)),
|
|
373
|
+
isinstance(input, (pyro.distributions.Normal, QExponential)), # pyre-ignore[16]
|
|
374
|
+
(
|
|
375
|
+
isinstance(input, pyro.distributions.Independent) # pyre-ignore[16]
|
|
376
|
+
and isinstance(input.base_dist, pyro.distributions.Normal) # pyre-ignore[16]
|
|
377
|
+
),
|
|
378
|
+
]
|
|
379
|
+
):
|
|
380
|
+
return self.marginal(input, *args, **kwargs) # pyre-ignore[6]
|
|
381
|
+
# Error
|
|
382
|
+
else:
|
|
383
|
+
raise RuntimeError(
|
|
384
|
+
"Likelihoods expects a MultivariateNormal (MultivariateQExponential) or Normal (QExponential) input to make marginal predictions, or a "
|
|
385
|
+
"torch.Tensor for conditional predictions. Got a {}".format(input.__class__.__name__)
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
except ImportError:
|
|
389
|
+
|
|
390
|
+
class Likelihood(_Likelihood):
|
|
391
|
+
@property
|
|
392
|
+
def num_data(self) -> int:
|
|
393
|
+
warnings.warn("num_data is only used for likelihoods that are integrated with Pyro.", RuntimeWarning)
|
|
394
|
+
return 0
|
|
395
|
+
|
|
396
|
+
@num_data.setter
|
|
397
|
+
def num_data(self, val: int) -> None:
|
|
398
|
+
warnings.warn("num_data is only used for likelihoods that are integrated with Pyro.", RuntimeWarning)
|
|
399
|
+
|
|
400
|
+
@property
|
|
401
|
+
def name_prefix(self) -> str:
|
|
402
|
+
warnings.warn("name_prefix is only used for likelihoods that are integrated with Pyro.", RuntimeWarning)
|
|
403
|
+
return ""
|
|
404
|
+
|
|
405
|
+
@name_prefix.setter
|
|
406
|
+
def name_prefix(self, val: str) -> None:
|
|
407
|
+
warnings.warn("name_prefix is only used for likelihoods that are integrated with Pyro.", RuntimeWarning)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
class _OneDimensionalLikelihood(Likelihood, ABC):
|
|
411
|
+
r"""
|
|
412
|
+
A specific case of :obj:`~qpytorch.likelihoods.Likelihood` when the GP (QEP) represents a one-dimensional
|
|
413
|
+
output. (I.e. for a specific :math:`\mathbf x`, :math:`f(\mathbf x) \in \mathbb{R}`.)
|
|
414
|
+
|
|
415
|
+
Inheriting from this likelihood reduces the variance when computing approximate GP (QEP) objective functions
|
|
416
|
+
by using 1D Gauss-Hermite quadrature.
|
|
417
|
+
"""
|
|
418
|
+
|
|
419
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
420
|
+
super().__init__(*args, **kwargs)
|
|
421
|
+
self.quadrature = GaussHermiteQuadrature1D()
|
|
422
|
+
|
|
423
|
+
def expected_log_prob(
|
|
424
|
+
self, observations: Tensor, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any
|
|
425
|
+
) -> Tensor:
|
|
426
|
+
log_prob_lambda = lambda function_samples: self.forward(function_samples, *args, **kwargs).log_prob(
|
|
427
|
+
observations
|
|
428
|
+
)
|
|
429
|
+
log_prob = self.quadrature(log_prob_lambda, function_dist)
|
|
430
|
+
return log_prob
|
|
431
|
+
|
|
432
|
+
def log_marginal(
|
|
433
|
+
self, observations: Tensor, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any
|
|
434
|
+
) -> Tensor:
|
|
435
|
+
prob_lambda = lambda function_samples: self.forward(function_samples).log_prob(observations).exp()
|
|
436
|
+
prob = self.quadrature(prob_lambda, function_dist)
|
|
437
|
+
return prob.log()
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
#! /usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from torch.nn import ModuleList
|
|
4
|
+
|
|
5
|
+
from . import Likelihood
|
|
6
|
+
from gpytorch.utils.generic import length_safe_zip
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _get_tuple_args_(*args):
|
|
10
|
+
for arg in args:
|
|
11
|
+
if isinstance(arg, tuple):
|
|
12
|
+
yield arg
|
|
13
|
+
else:
|
|
14
|
+
yield (arg,)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LikelihoodList(Likelihood):
|
|
18
|
+
def __init__(self, *likelihoods):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.likelihoods = ModuleList(likelihoods)
|
|
21
|
+
|
|
22
|
+
def expected_log_prob(self, *args, **kwargs):
|
|
23
|
+
return [
|
|
24
|
+
likelihood.expected_log_prob(*args_, **kwargs)
|
|
25
|
+
for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
def forward(self, *args, **kwargs):
|
|
29
|
+
if "noise" in kwargs:
|
|
30
|
+
noise = kwargs.pop("noise")
|
|
31
|
+
# if noise kwarg is passed, assume it's an iterable of noise tensors
|
|
32
|
+
return [
|
|
33
|
+
likelihood.forward(*args_, {**kwargs, "noise": noise_})
|
|
34
|
+
for likelihood, args_, noise_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args), noise)
|
|
35
|
+
]
|
|
36
|
+
else:
|
|
37
|
+
return [
|
|
38
|
+
likelihood.forward(*args_, **kwargs)
|
|
39
|
+
for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
def pyro_sample_output(self, *args, **kwargs):
|
|
43
|
+
return [
|
|
44
|
+
likelihood.pyro_sample_output(*args_, **kwargs)
|
|
45
|
+
for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
def __call__(self, *args, **kwargs):
|
|
49
|
+
if "noise" in kwargs:
|
|
50
|
+
noise = kwargs.pop("noise")
|
|
51
|
+
# if noise kwarg is passed, assume it's an iterable of noise tensors
|
|
52
|
+
return [
|
|
53
|
+
likelihood(*args_, {**kwargs, "noise": noise_})
|
|
54
|
+
for likelihood, args_, noise_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args), noise)
|
|
55
|
+
]
|
|
56
|
+
else:
|
|
57
|
+
return [
|
|
58
|
+
likelihood(*args_, **kwargs)
|
|
59
|
+
for likelihood, args_ in length_safe_zip(self.likelihoods, _get_tuple_args_(*args))
|
|
60
|
+
]
|