qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch.distributions import Chi2
|
|
5
|
+
from gpytorch.kernels.distributional_input_kernel import DistributionalInputKernel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _symmetrized_kl(dist1, dist2, eps=1e-8, **kwargs):
|
|
9
|
+
"""
|
|
10
|
+
Symmetrized KL distance between two q-exponential distributions. We assume that
|
|
11
|
+
the first half of the distribution tensors are the mean, and the second half
|
|
12
|
+
are the log variances.
|
|
13
|
+
Args:
|
|
14
|
+
dist1 (torch.Tensor) has shapes batch x n x dimensions. The first half
|
|
15
|
+
of the last dimensions are the means, while the second half are the log-variances.
|
|
16
|
+
dist2 (torch.Tensor) has shapes batch x n x dimensions. The first half
|
|
17
|
+
of the last dimensions are the means, while the second half are the log-variances.
|
|
18
|
+
eps (float) jitter term for the noise variance
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
power = kwargs.pop('power', torch.tensor(1.0))
|
|
22
|
+
exact = kwargs.pop('exact', False)
|
|
23
|
+
num_dims = int(dist1.shape[-1] / 2)
|
|
24
|
+
|
|
25
|
+
dist1_mean = dist1[..., :num_dims].unsqueeze(-3)
|
|
26
|
+
dist1_logvar = dist1[..., num_dims:].unsqueeze(-3)
|
|
27
|
+
dist1_var = eps + dist1_logvar.exp()
|
|
28
|
+
|
|
29
|
+
dist2_mean = dist2[..., :num_dims].unsqueeze(-2)
|
|
30
|
+
dist2_logvar = dist2[..., num_dims:].unsqueeze(-2)
|
|
31
|
+
dist2_var = eps + dist2_logvar.exp()
|
|
32
|
+
|
|
33
|
+
var_ratio12 = dist1_var / dist2_var
|
|
34
|
+
# log_var_ratio12 = var_ratio12.log()
|
|
35
|
+
# note that the log variance ratio cancels because of the summed KL.
|
|
36
|
+
loc_sqdiffs = (dist1_mean - dist2_mean).pow(2)
|
|
37
|
+
kl1 = 0.5 * ((var_ratio12 + loc_sqdiffs / dist2_var).pow(power/2.) - 1)
|
|
38
|
+
kl2 = 0.5 * ((var_ratio12.reciprocal() + loc_sqdiffs / dist1_var).pow(power/2.) - 1)
|
|
39
|
+
if power!=2:
|
|
40
|
+
kl1 += 0.5* (-(1-2./power)*torch.log(2*kl1+1) + (power/2.-1) * (-2./power*Chi2(1).entropy() if exact else 0) )
|
|
41
|
+
kl2 += 0.5* (-(1-2./power)*torch.log(2*kl2+1) + (power/2.-1) * (-2./power*Chi2(1).entropy() if exact else 0) )
|
|
42
|
+
symmetrized_kl = kl1 + kl2
|
|
43
|
+
return symmetrized_kl.sum(-1).transpose(-1, -2)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class QExponentialSymmetrizedKLKernel(DistributionalInputKernel):
|
|
47
|
+
r"""
|
|
48
|
+
Computes a kernel based on the symmetrized KL divergence, assuming that two q-exponential
|
|
49
|
+
distributions are inputted. Inputs are assumed to be `batch x N x 2d` tensors where `d` is the
|
|
50
|
+
dimension of the distribution. The first `d` dimensions are the mean parameters of the
|
|
51
|
+
`batch x N` distributions, while the second `d` dimensions are the log variances.
|
|
52
|
+
|
|
53
|
+
Original citation is Moreno et al, '04
|
|
54
|
+
(https://papers.nips.cc/paper/2351-a-kullback-leibler-divergence-based-kernel-for-svm-\
|
|
55
|
+
classification-in-multimedia-applications.pdf) for the symmetrized KL divergence kernel between
|
|
56
|
+
two Gaussian distributions.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, **kwargs):
|
|
60
|
+
distance_function = lambda dist1, dist2: _symmetrized_kl(dist1, dist2, **kwargs)
|
|
61
|
+
super(QExponentialSymmetrizedKLKernel, self).__init__(distance_function=distance_function, **kwargs)
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from linear_operator.operators import KroneckerProductLinearOperator
|
|
5
|
+
|
|
6
|
+
from gpytorch.kernels.rbf_kernel import postprocess_rbf, RBFKernel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class RBFKernelGrad(RBFKernel):
|
|
10
|
+
r"""
|
|
11
|
+
Computes a covariance matrix of the RBF kernel that models the covariance
|
|
12
|
+
between the values and partial derivatives for inputs :math:`\mathbf{x_1}`
|
|
13
|
+
and :math:`\mathbf{x_2}`.
|
|
14
|
+
|
|
15
|
+
See :class:`qpytorch.kernels.Kernel` for descriptions of the lengthscale options.
|
|
16
|
+
|
|
17
|
+
.. note::
|
|
18
|
+
|
|
19
|
+
This kernel does not have an `outputscale` parameter. To add a scaling parameter,
|
|
20
|
+
decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`.
|
|
21
|
+
|
|
22
|
+
:param ard_num_dims: Set this if you want a separate lengthscale for each input
|
|
23
|
+
dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.)
|
|
24
|
+
:param batch_shape: Set this if you want a separate lengthscale for each batch of input
|
|
25
|
+
data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is
|
|
26
|
+
a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor.
|
|
27
|
+
:param active_dims: Set this if you want to compute the covariance of only
|
|
28
|
+
a few input dimensions. The ints corresponds to the indices of the
|
|
29
|
+
dimensions. (Default: `None`.)
|
|
30
|
+
:param lengthscale_prior: Set this if you want to apply a prior to the
|
|
31
|
+
lengthscale parameter. (Default: `None`)
|
|
32
|
+
:param lengthscale_constraint: Set this if you want to apply a constraint
|
|
33
|
+
to the lengthscale parameter. (Default: `Positive`.)
|
|
34
|
+
:param eps: The minimum value that the lengthscale can take (prevents
|
|
35
|
+
divide by zero errors). (Default: `1e-6`.)
|
|
36
|
+
|
|
37
|
+
:ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the
|
|
38
|
+
ard_num_dims and batch_shape arguments.
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
>>> x = torch.randn(10, 5)
|
|
42
|
+
>>> # Non-batch: Simple option
|
|
43
|
+
>>> covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernelGrad())
|
|
44
|
+
>>> covar = covar_module(x) # Output: LinearOperator of size (60 x 60), where 60 = n * (d + 1)
|
|
45
|
+
>>>
|
|
46
|
+
>>> batch_x = torch.randn(2, 10, 5)
|
|
47
|
+
>>> # Batch: Simple option
|
|
48
|
+
>>> covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernelGrad())
|
|
49
|
+
>>> # Batch: different lengthscale for each batch
|
|
50
|
+
>>> covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernelGrad(batch_shape=torch.Size([2])))
|
|
51
|
+
>>> covar = covar_module(x) # Output: LinearOperator of size (2 x 60 x 60)
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, **kwargs):
|
|
55
|
+
super(RBFKernelGrad, self).__init__(**kwargs)
|
|
56
|
+
self._interleaved = kwargs.pop('interleaved', True)
|
|
57
|
+
|
|
58
|
+
def forward(self, x1, x2, diag=False, **params):
|
|
59
|
+
batch_shape = x1.shape[:-2]
|
|
60
|
+
n_batch_dims = len(batch_shape)
|
|
61
|
+
n1, d = x1.shape[-2:]
|
|
62
|
+
n2 = x2.shape[-2]
|
|
63
|
+
|
|
64
|
+
if not diag:
|
|
65
|
+
K = torch.zeros(*batch_shape, n1 * (d + 1), n2 * (d + 1), device=x1.device, dtype=x1.dtype)
|
|
66
|
+
|
|
67
|
+
# Scale the inputs by the lengthscale (for stability)
|
|
68
|
+
x1_ = x1.div(self.lengthscale)
|
|
69
|
+
x2_ = x2.div(self.lengthscale)
|
|
70
|
+
|
|
71
|
+
# Form all possible rank-1 products for the gradient and Hessian blocks
|
|
72
|
+
outer = x1_.view(*batch_shape, n1, 1, d) - x2_.view(*batch_shape, 1, n2, d)
|
|
73
|
+
outer = outer / self.lengthscale.unsqueeze(-2)
|
|
74
|
+
outer = torch.transpose(outer, -1, -2).contiguous()
|
|
75
|
+
|
|
76
|
+
# 1) Kernel block
|
|
77
|
+
diff = self.covar_dist(x1_, x2_, square_dist=True, **params)
|
|
78
|
+
K_11 = postprocess_rbf(diff)
|
|
79
|
+
K[..., :n1, :n2] = K_11
|
|
80
|
+
|
|
81
|
+
# 2) First gradient block
|
|
82
|
+
outer1 = outer.view(*batch_shape, n1, n2 * d)
|
|
83
|
+
K[..., :n1, n2:] = outer1 * K_11.repeat([*([1] * (n_batch_dims + 1)), d])
|
|
84
|
+
|
|
85
|
+
# 3) Second gradient block
|
|
86
|
+
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d)
|
|
87
|
+
outer2 = outer2.transpose(-1, -2)
|
|
88
|
+
K[..., n1:, :n2] = -outer2 * K_11.repeat([*([1] * n_batch_dims), d, 1])
|
|
89
|
+
|
|
90
|
+
# 4) Hessian block
|
|
91
|
+
outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d])
|
|
92
|
+
kp = KroneckerProductLinearOperator(
|
|
93
|
+
torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2),
|
|
94
|
+
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
|
|
95
|
+
)
|
|
96
|
+
chain_rule = kp.to_dense() - outer3
|
|
97
|
+
K[..., n1:, n2:] = chain_rule * K_11.repeat([*([1] * n_batch_dims), d, d])
|
|
98
|
+
|
|
99
|
+
# Symmetrize for stability
|
|
100
|
+
if n1 == n2 and torch.eq(x1, x2).all():
|
|
101
|
+
K = 0.5 * (K.transpose(-1, -2) + K)
|
|
102
|
+
|
|
103
|
+
# Apply a perfect shuffle permutation to match the MutiTask ordering
|
|
104
|
+
if self._interleaved:
|
|
105
|
+
pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1)))
|
|
106
|
+
pi2 = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
|
|
107
|
+
K = K[..., pi1, :][..., :, pi2]
|
|
108
|
+
|
|
109
|
+
return K
|
|
110
|
+
|
|
111
|
+
else:
|
|
112
|
+
if not (n1 == n2 and torch.eq(x1, x2).all()):
|
|
113
|
+
raise RuntimeError("diag=True only works when x1 == x2")
|
|
114
|
+
|
|
115
|
+
kernel_diag = super(RBFKernelGrad, self).forward(x1, x2, diag=True)
|
|
116
|
+
grad_diag = torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) / self.lengthscale.pow(2)
|
|
117
|
+
grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d)
|
|
118
|
+
k_diag = torch.cat((kernel_diag, grad_diag), dim=-1)
|
|
119
|
+
if self._interleaved:
|
|
120
|
+
pi = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
|
|
121
|
+
k_diag = k_diag[..., pi]
|
|
122
|
+
return k_diag
|
|
123
|
+
|
|
124
|
+
def num_outputs_per_input(self, x1, x2):
|
|
125
|
+
return x1.size(-1) + 1
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from linear_operator.operators import KroneckerProductLinearOperator
|
|
5
|
+
|
|
6
|
+
from gpytorch.kernels.rbf_kernel import postprocess_rbf, RBFKernel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class RBFKernelGradGrad(RBFKernel):
|
|
10
|
+
r"""
|
|
11
|
+
Computes a covariance matrix of the RBF kernel that models the covariance
|
|
12
|
+
between the values and first and second (non-mixed) partial derivatives for inputs :math:`\mathbf{x_1}`
|
|
13
|
+
and :math:`\mathbf{x_2}`.
|
|
14
|
+
|
|
15
|
+
See :class:`qpytorch.kernels.Kernel` for descriptions of the lengthscale options.
|
|
16
|
+
|
|
17
|
+
.. note::
|
|
18
|
+
|
|
19
|
+
This kernel does not have an `outputscale` parameter. To add a scaling parameter,
|
|
20
|
+
decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`.
|
|
21
|
+
|
|
22
|
+
:param ard_num_dims: Set this if you want a separate lengthscale for each input
|
|
23
|
+
dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.)
|
|
24
|
+
:param batch_shape: Set this if you want a separate lengthscale for each batch of input
|
|
25
|
+
data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is
|
|
26
|
+
a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor.
|
|
27
|
+
:param active_dims: Set this if you want to compute the covariance of only
|
|
28
|
+
a few input dimensions. The ints corresponds to the indices of the
|
|
29
|
+
dimensions. (Default: `None`.)
|
|
30
|
+
:param lengthscale_prior: Set this if you want to apply a prior to the
|
|
31
|
+
lengthscale parameter. (Default: `None`)
|
|
32
|
+
:param lengthscale_constraint: Set this if you want to apply a constraint
|
|
33
|
+
to the lengthscale parameter. (Default: `Positive`.)
|
|
34
|
+
:param eps: The minimum value that the lengthscale can take (prevents
|
|
35
|
+
divide by zero errors). (Default: `1e-6`.)
|
|
36
|
+
|
|
37
|
+
:ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the
|
|
38
|
+
ard_num_dims and batch_shape arguments.
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
>>> x = torch.randn(10, 5)
|
|
42
|
+
>>> # Non-batch: Simple option
|
|
43
|
+
>>> covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernelGradGrad())
|
|
44
|
+
>>> covar = covar_module(x) # Output: LinearOperator of size (110 x 110), where 110 = n * (2*d + 1)
|
|
45
|
+
>>>
|
|
46
|
+
>>> batch_x = torch.randn(2, 10, 5)
|
|
47
|
+
>>> # Batch: Simple option
|
|
48
|
+
>>> covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernelGradGrad())
|
|
49
|
+
>>> # Batch: different lengthscale for each batch
|
|
50
|
+
>>> covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernelGradGrad(batch_shape=torch.Size([2])))
|
|
51
|
+
>>> covar = covar_module(x) # Output: LinearOperator of size (2 x 110 x 110)
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, **kwargs):
|
|
55
|
+
super(RBFKernelGradGrad, self).__init__(**kwargs)
|
|
56
|
+
self._interleaved = kwargs.pop('interleaved', True)
|
|
57
|
+
|
|
58
|
+
def forward(self, x1, x2, diag=False, **params):
|
|
59
|
+
batch_shape = x1.shape[:-2]
|
|
60
|
+
n_batch_dims = len(batch_shape)
|
|
61
|
+
n1, d = x1.shape[-2:]
|
|
62
|
+
n2 = x2.shape[-2]
|
|
63
|
+
|
|
64
|
+
if not diag:
|
|
65
|
+
K = torch.zeros(*batch_shape, n1 * (2 * d + 1), n2 * (2 * d + 1), device=x1.device, dtype=x1.dtype)
|
|
66
|
+
|
|
67
|
+
# Scale the inputs by the lengthscale (for stability)
|
|
68
|
+
x1_ = x1.div(self.lengthscale)
|
|
69
|
+
x2_ = x2.div(self.lengthscale)
|
|
70
|
+
|
|
71
|
+
# Form all possible rank-1 products for the gradient and Hessian blocks
|
|
72
|
+
outer = x1_.view(*batch_shape, n1, 1, d) - x2_.view(*batch_shape, 1, n2, d)
|
|
73
|
+
outer = outer / self.lengthscale.unsqueeze(-2)
|
|
74
|
+
outer = torch.transpose(outer, -1, -2).contiguous()
|
|
75
|
+
|
|
76
|
+
# 1) Kernel block
|
|
77
|
+
diff = self.covar_dist(x1_, x2_, square_dist=True, **params)
|
|
78
|
+
K_11 = postprocess_rbf(diff)
|
|
79
|
+
K[..., :n1, :n2] = K_11
|
|
80
|
+
|
|
81
|
+
# 2) First gradient block
|
|
82
|
+
outer1 = outer.view(*batch_shape, n1, n2 * d)
|
|
83
|
+
K[..., :n1, n2 : (n2 * (d + 1))] = outer1 * K_11.repeat([*([1] * (n_batch_dims + 1)), d])
|
|
84
|
+
|
|
85
|
+
# 3) Second gradient block
|
|
86
|
+
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d)
|
|
87
|
+
outer2 = outer2.transpose(-1, -2)
|
|
88
|
+
K[..., n1 : (n1 * (d + 1)), :n2] = -outer2 * K_11.repeat([*([1] * n_batch_dims), d, 1])
|
|
89
|
+
|
|
90
|
+
# 4) Hessian block
|
|
91
|
+
outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d])
|
|
92
|
+
kp = KroneckerProductLinearOperator(
|
|
93
|
+
torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2),
|
|
94
|
+
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
|
|
95
|
+
)
|
|
96
|
+
chain_rule = kp.to_dense() - outer3
|
|
97
|
+
K[..., n1 : (n1 * (d + 1)), n2 : (n2 * (d + 1))] = chain_rule * K_11.repeat([*([1] * n_batch_dims), d, d])
|
|
98
|
+
|
|
99
|
+
# 5) 1-3 block
|
|
100
|
+
douter1dx2 = KroneckerProductLinearOperator(
|
|
101
|
+
torch.ones(1, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2),
|
|
102
|
+
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
|
|
103
|
+
).to_dense()
|
|
104
|
+
|
|
105
|
+
K_13 = (-douter1dx2 + outer1 * outer1) * K_11.repeat(
|
|
106
|
+
[*([1] * (n_batch_dims + 1)), d]
|
|
107
|
+
) # verified for n1=n2=1 case
|
|
108
|
+
K[..., :n1, (n2 * (d + 1)) :] = K_13
|
|
109
|
+
|
|
110
|
+
if d>1:
|
|
111
|
+
douter1dx2 = KroneckerProductLinearOperator(
|
|
112
|
+
(torch.ones(1, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2)).transpose(-1, -2),
|
|
113
|
+
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
|
|
114
|
+
).to_dense()
|
|
115
|
+
K_31 = (-douter1dx2 + outer2 * outer2) * K_11.repeat(
|
|
116
|
+
[*([1] * n_batch_dims), d, 1]
|
|
117
|
+
) # verified for n1=n2=1 case
|
|
118
|
+
K[..., (n1 * (d + 1)) :, :n2] = K_31
|
|
119
|
+
|
|
120
|
+
# rest of the blocks are all of size (n1*d,n2*d)
|
|
121
|
+
outer1 = outer1.repeat([*([1] * n_batch_dims), d, 1])
|
|
122
|
+
outer2 = outer2.repeat([*([1] * (n_batch_dims + 1)), d])
|
|
123
|
+
# II = (torch.eye(d,d,device=x1.device,dtype=x1.dtype)/lengthscale.pow(2)).repeat(*batch_shape,n1,n2)
|
|
124
|
+
kp2 = KroneckerProductLinearOperator(
|
|
125
|
+
torch.ones(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2),
|
|
126
|
+
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
|
|
127
|
+
).to_dense()
|
|
128
|
+
|
|
129
|
+
# II may not be the correct thing to use. It might be more appropriate to use kp instead??
|
|
130
|
+
II = kp.to_dense()
|
|
131
|
+
K_11dd = K_11.repeat([*([1] * (n_batch_dims)), d, d])
|
|
132
|
+
|
|
133
|
+
K_23 = ((-kp2 + outer1 * outer1) * (-outer2) + 2.0 * II * outer1) * K_11dd # verified for n1=n2=1 case
|
|
134
|
+
|
|
135
|
+
K[..., n1 : (n1 * (d + 1)), (n2 * (d + 1)) :] = K_23
|
|
136
|
+
|
|
137
|
+
if d>1:
|
|
138
|
+
kp2t = KroneckerProductLinearOperator(
|
|
139
|
+
(torch.ones(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2)).transpose(-1, -2),
|
|
140
|
+
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
|
|
141
|
+
).to_dense()
|
|
142
|
+
K_32 = (
|
|
143
|
+
(-(kp2t if d>1 else kp2) + outer2 * outer2) * outer1 - 2.0 * II * outer2
|
|
144
|
+
) * K_11dd # verified for n1=n2=1 case
|
|
145
|
+
|
|
146
|
+
K[..., (n1 * (d + 1)) :, n2 : (n2 * (d + 1))] = K_32
|
|
147
|
+
|
|
148
|
+
K_33 = (
|
|
149
|
+
(-(kp2t if d>1 else kp2) + outer2 * outer2) * (-kp2) - 2.0 * II * outer2 * outer1 + 2.0 * (II) ** 2
|
|
150
|
+
) * K_11dd + (
|
|
151
|
+
(-(kp2t if d>1 else kp2) + outer2 * outer2) * outer1 - 2.0 * II * outer2
|
|
152
|
+
) * outer1 * K_11dd # verified for n1=n2=1 case
|
|
153
|
+
|
|
154
|
+
K[..., (n1 * (d + 1)) :, (n2 * (d + 1)) :] = K_33
|
|
155
|
+
|
|
156
|
+
# Symmetrize for stability
|
|
157
|
+
if n1 == n2 and torch.eq(x1, x2).all():
|
|
158
|
+
K = 0.5 * (K.transpose(-1, -2) + K)
|
|
159
|
+
|
|
160
|
+
# Apply a perfect shuffle permutation to match the MutiTask ordering
|
|
161
|
+
if self._interleaved:
|
|
162
|
+
pi1 = torch.arange(n1 * (2 * d + 1)).view(2 * d + 1, n1).t().reshape((n1 * (2 * d + 1)))
|
|
163
|
+
pi2 = torch.arange(n2 * (2 * d + 1)).view(2 * d + 1, n2).t().reshape((n2 * (2 * d + 1)))
|
|
164
|
+
K = K[..., pi1, :][..., :, pi2]
|
|
165
|
+
|
|
166
|
+
return K
|
|
167
|
+
|
|
168
|
+
else:
|
|
169
|
+
if not (n1 == n2 and torch.eq(x1, x2).all()):
|
|
170
|
+
raise RuntimeError("diag=True only works when x1 == x2")
|
|
171
|
+
|
|
172
|
+
kernel_diag = super(RBFKernelGradGrad, self).forward(x1, x2, diag=True)
|
|
173
|
+
grad_diag = torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) / self.lengthscale.pow(2)
|
|
174
|
+
grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d)
|
|
175
|
+
gradgrad_diag = (
|
|
176
|
+
3 * torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) / self.lengthscale.pow(4)
|
|
177
|
+
)
|
|
178
|
+
gradgrad_diag = gradgrad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d)
|
|
179
|
+
k_diag = torch.cat((kernel_diag, grad_diag, gradgrad_diag), dim=-1)
|
|
180
|
+
if self._interleaved:
|
|
181
|
+
pi = torch.arange(n2 * (2 * d + 1)).view(2 * d + 1, n2).t().reshape((n2 * (2 * d + 1)))
|
|
182
|
+
k_diag = k_diag[..., pi]
|
|
183
|
+
return k_diag
|
|
184
|
+
|
|
185
|
+
def num_outputs_per_input(self, x1, x2):
|
|
186
|
+
return x1.size(-1) * 2 + 1
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from linear_operator.operators import LowRankRootLinearOperator, MatmulLinearOperator, RootLinearOperator
|
|
8
|
+
from torch import Tensor
|
|
9
|
+
|
|
10
|
+
from ..models import exact_prediction_strategies
|
|
11
|
+
from .kernel import Kernel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RFFKernel(Kernel):
|
|
15
|
+
r"""
|
|
16
|
+
Computes a covariance matrix based on Random Fourier Features with the RBFKernel.
|
|
17
|
+
|
|
18
|
+
Random Fourier features was originally proposed in
|
|
19
|
+
'Random Features for Large-Scale Kernel Machines' by Rahimi and Recht (2008).
|
|
20
|
+
Instead of the shifted cosine features from Rahimi and Recht (2008), we use
|
|
21
|
+
the sine and cosine features which is a lower-variance estimator --- see
|
|
22
|
+
'On the Error of Random Fourier Features' by Sutherland and Schneider (2015).
|
|
23
|
+
|
|
24
|
+
By Bochner's theorem, any continuous kernel :math:`k` is positive definite
|
|
25
|
+
if and only if it is the Fourier transform of a non-negative measure :math:`p(\omega)`, i.e.
|
|
26
|
+
|
|
27
|
+
.. math::
|
|
28
|
+
\begin{equation}
|
|
29
|
+
k(x, x') = k(x - x') = \int p(\omega) e^{i(\omega^\top (x - x'))} d\omega.
|
|
30
|
+
\end{equation}
|
|
31
|
+
|
|
32
|
+
where :math:`p(\omega)` is a normalized probability measure if :math:`k(0)=1`.
|
|
33
|
+
|
|
34
|
+
For the RBF kernel,
|
|
35
|
+
|
|
36
|
+
.. math::
|
|
37
|
+
\begin{equation}
|
|
38
|
+
k(\Delta) = \exp{(-\frac{\Delta^2}{2\sigma^2})} \text{ and } p(\omega) = \exp{(-\frac{\sigma^2\omega^2}{2})}
|
|
39
|
+
\end{equation}
|
|
40
|
+
|
|
41
|
+
where :math:`\Delta = x - x'`.
|
|
42
|
+
|
|
43
|
+
Given datapoint :math:`x\in \mathbb{R}^d`, we can construct its random Fourier features
|
|
44
|
+
:math:`z(x) \in \mathbb{R}^{2D}` by
|
|
45
|
+
|
|
46
|
+
.. math::
|
|
47
|
+
\begin{equation}
|
|
48
|
+
z(x) = \sqrt{\frac{1}{D}}
|
|
49
|
+
\begin{bmatrix}
|
|
50
|
+
\cos(\omega_1^\top x)\\
|
|
51
|
+
\sin(\omega_1^\top x)\\
|
|
52
|
+
\cdots \\
|
|
53
|
+
\cos(\omega_D^\top x)\\
|
|
54
|
+
\sin(\omega_D^\top x)
|
|
55
|
+
\end{bmatrix}, \omega_1, \ldots, \omega_D \sim p(\omega)
|
|
56
|
+
\end{equation}
|
|
57
|
+
|
|
58
|
+
such that we have an unbiased Monte Carlo estimator
|
|
59
|
+
|
|
60
|
+
.. math::
|
|
61
|
+
\begin{equation}
|
|
62
|
+
k(x, x') = k(x - x') \approx z(x)^\top z(x') = \frac{1}{D}\sum_{i=1}^D \cos(\omega_i^\top (x - x')).
|
|
63
|
+
\end{equation}
|
|
64
|
+
|
|
65
|
+
.. note::
|
|
66
|
+
When this kernel is used in batch mode, the random frequencies are drawn
|
|
67
|
+
independently across the batch dimension as well by default.
|
|
68
|
+
|
|
69
|
+
:param num_samples: Number of random frequencies to draw. This is :math:`D` in the above
|
|
70
|
+
papers. This will produce :math:`D` sine features and :math:`D` cosine
|
|
71
|
+
features for a total of :math:`2D` random Fourier features.
|
|
72
|
+
:type num_samples: int
|
|
73
|
+
:param num_dims: (Default `None`.) Dimensionality of the data space.
|
|
74
|
+
This is :math:`d` in the above papers. Note that if you want an
|
|
75
|
+
independent lengthscale for each dimension, set `ard_num_dims` equal to
|
|
76
|
+
`num_dims`. If unspecified, it will be inferred the first time `forward`
|
|
77
|
+
is called.
|
|
78
|
+
:type num_dims: int, optional
|
|
79
|
+
|
|
80
|
+
:var torch.Tensor randn_weights: The random frequencies that are drawn once and then fixed.
|
|
81
|
+
|
|
82
|
+
Example:
|
|
83
|
+
|
|
84
|
+
>>> # This will infer `num_dims` automatically
|
|
85
|
+
>>> kernel= qpytorch.kernels.RFFKernel(num_samples=5)
|
|
86
|
+
>>> x = torch.randn(10, 3)
|
|
87
|
+
>>> kxx = kernel(x, x).to_dense()
|
|
88
|
+
>>> print(kxx.randn_weights.size())
|
|
89
|
+
torch.Size([3, 5])
|
|
90
|
+
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
has_lengthscale = True
|
|
94
|
+
|
|
95
|
+
def __init__(self, num_samples: int, num_dims: Optional[int] = None, **kwargs):
|
|
96
|
+
super().__init__(**kwargs)
|
|
97
|
+
self.num_samples = num_samples
|
|
98
|
+
if num_dims is not None:
|
|
99
|
+
self._init_weights(num_dims, num_samples)
|
|
100
|
+
|
|
101
|
+
def _init_weights(
|
|
102
|
+
self, num_dims: Optional[int] = None, num_samples: Optional[int] = None, randn_weights: Optional[Tensor] = None
|
|
103
|
+
):
|
|
104
|
+
if num_dims is not None and num_samples is not None:
|
|
105
|
+
d = num_dims
|
|
106
|
+
D = num_samples
|
|
107
|
+
if randn_weights is None:
|
|
108
|
+
randn_shape = torch.Size([*self._batch_shape, d, D])
|
|
109
|
+
randn_weights = torch.randn(
|
|
110
|
+
randn_shape, dtype=self.raw_lengthscale.dtype, device=self.raw_lengthscale.device
|
|
111
|
+
)
|
|
112
|
+
self.register_buffer("randn_weights", randn_weights)
|
|
113
|
+
|
|
114
|
+
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: bool = False, **kwargs) -> Tensor:
|
|
115
|
+
if last_dim_is_batch:
|
|
116
|
+
x1 = x1.transpose(-1, -2).unsqueeze(-1)
|
|
117
|
+
x2 = x2.transpose(-1, -2).unsqueeze(-1)
|
|
118
|
+
num_dims = x1.size(-1)
|
|
119
|
+
if not hasattr(self, "randn_weights"):
|
|
120
|
+
self._init_weights(num_dims, self.num_samples)
|
|
121
|
+
x1_eq_x2 = torch.equal(x1, x2)
|
|
122
|
+
z1 = self._featurize(x1, normalize=False)
|
|
123
|
+
if not x1_eq_x2:
|
|
124
|
+
z2 = self._featurize(x2, normalize=False)
|
|
125
|
+
else:
|
|
126
|
+
z2 = z1
|
|
127
|
+
D = float(self.num_samples)
|
|
128
|
+
if diag:
|
|
129
|
+
return (z1 * z2).sum(-1) / D
|
|
130
|
+
if x1_eq_x2:
|
|
131
|
+
# Exploit low rank structure, if there are fewer features than data points
|
|
132
|
+
if z1.size(-1) < z2.size(-2):
|
|
133
|
+
return LowRankRootLinearOperator(z1 / math.sqrt(D))
|
|
134
|
+
else:
|
|
135
|
+
return RootLinearOperator(z1 / math.sqrt(D))
|
|
136
|
+
else:
|
|
137
|
+
return MatmulLinearOperator(z1 / D, z2.transpose(-1, -2))
|
|
138
|
+
|
|
139
|
+
def _featurize(self, x: Tensor, normalize: bool = False) -> Tensor:
|
|
140
|
+
# Recompute division each time to allow backprop through lengthscale
|
|
141
|
+
# Transpose lengthscale to allow for ARD
|
|
142
|
+
x = x.matmul(self.randn_weights / self.lengthscale.transpose(-1, -2))
|
|
143
|
+
z = torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
|
|
144
|
+
if normalize:
|
|
145
|
+
D = self.num_samples
|
|
146
|
+
z = z / math.sqrt(D)
|
|
147
|
+
return z
|
|
148
|
+
|
|
149
|
+
def prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood):
|
|
150
|
+
# Allow for fast sampling
|
|
151
|
+
return exact_prediction_strategies.RFFPredictionStrategy(
|
|
152
|
+
train_inputs, train_prior_dist, train_labels, likelihood
|
|
153
|
+
)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from .bernoulli_likelihood import BernoulliLikelihood
|
|
4
|
+
from .beta_likelihood import BetaLikelihood
|
|
5
|
+
from .gaussian_likelihood import (
|
|
6
|
+
_GaussianLikelihoodBase,
|
|
7
|
+
GaussianLikelihood,
|
|
8
|
+
GaussianLikelihoodWithMissingObs,
|
|
9
|
+
FixedNoiseGaussianLikelihood,
|
|
10
|
+
DirichletClassificationLikelihood,
|
|
11
|
+
)
|
|
12
|
+
from .qexponential_likelihood import (
|
|
13
|
+
_QExponentialLikelihoodBase,
|
|
14
|
+
QExponentialLikelihood,
|
|
15
|
+
QExponentialLikelihoodWithMissingObs,
|
|
16
|
+
FixedNoiseQExponentialLikelihood,
|
|
17
|
+
QExponentialDirichletClassificationLikelihood,
|
|
18
|
+
)
|
|
19
|
+
from .laplace_likelihood import LaplaceLikelihood
|
|
20
|
+
from .likelihood import _OneDimensionalLikelihood, Likelihood
|
|
21
|
+
from .likelihood_list import LikelihoodList
|
|
22
|
+
from .multitask_gaussian_likelihood import (
|
|
23
|
+
_MultitaskGaussianLikelihoodBase,
|
|
24
|
+
MultitaskGaussianLikelihood,
|
|
25
|
+
MultitaskFixedNoiseGaussianLikelihood,
|
|
26
|
+
MultitaskDirichletClassificationLikelihood,
|
|
27
|
+
)
|
|
28
|
+
from .multitask_qexponential_likelihood import (
|
|
29
|
+
_MultitaskQExponentialLikelihoodBase,
|
|
30
|
+
MultitaskQExponentialLikelihood,
|
|
31
|
+
MultitaskFixedNoiseQExponentialLikelihood,
|
|
32
|
+
MultitaskQExponentialDirichletClassificationLikelihood,
|
|
33
|
+
)
|
|
34
|
+
from .noise_models import HeteroskedasticNoise
|
|
35
|
+
from .softmax_likelihood import SoftmaxLikelihood
|
|
36
|
+
from .student_t_likelihood import StudentTLikelihood
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"_GaussianLikelihoodBase",
|
|
40
|
+
"_QExponentialLikelihoodBase",
|
|
41
|
+
"_OneDimensionalLikelihood",
|
|
42
|
+
"_MultitaskGaussianLikelihoodBase",
|
|
43
|
+
"_MultitaskQExponentialLikelihoodBase",
|
|
44
|
+
"BernoulliLikelihood",
|
|
45
|
+
"BetaLikelihood",
|
|
46
|
+
"DirichletClassificationLikelihood",
|
|
47
|
+
"QExponentialDirichletClassificationLikelihood",
|
|
48
|
+
"FixedNoiseGaussianLikelihood",
|
|
49
|
+
"FixedNoiseQExponentialLikelihood",
|
|
50
|
+
"GaussianLikelihood",
|
|
51
|
+
"QExponentialLikelihood",
|
|
52
|
+
"GaussianLikelihoodWithMissingObs",
|
|
53
|
+
"QExponentialLikelihoodWithMissingObs",
|
|
54
|
+
"HeteroskedasticNoise",
|
|
55
|
+
"LaplaceLikelihood",
|
|
56
|
+
"Likelihood",
|
|
57
|
+
"LikelihoodList",
|
|
58
|
+
"MultitaskGaussianLikelihood",
|
|
59
|
+
"MultitaskFixedNoiseGaussianLikelihood",
|
|
60
|
+
"MultitaskDirichletClassificationLikelihood",
|
|
61
|
+
"MultitaskQExponentialLikelihood",
|
|
62
|
+
"MultitaskFixedNoiseQExponentialLikelihood",
|
|
63
|
+
"MultitaskQExponentialDirichletClassificationLikelihood",
|
|
64
|
+
"SoftmaxLikelihood",
|
|
65
|
+
"StudentTLikelihood",
|
|
66
|
+
]
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Any, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from torch.distributions import Bernoulli
|
|
9
|
+
|
|
10
|
+
from ..distributions import base_distributions, MultivariateNormal, MultivariateQExponential
|
|
11
|
+
from ..functions import log_normal_cdf
|
|
12
|
+
from .likelihood import _OneDimensionalLikelihood
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BernoulliLikelihood(_OneDimensionalLikelihood):
|
|
16
|
+
r"""
|
|
17
|
+
Implements the Bernoulli likelihood used for GP/QEP classification, using
|
|
18
|
+
Probit regression (i.e., the latent function is warped to be in [0,1]
|
|
19
|
+
using the standard Normal CDF :math:`\Phi(x)`). Given the identity
|
|
20
|
+
:math:`\Phi(-x) = 1-\Phi(x)`, we can write the likelihood compactly as:
|
|
21
|
+
|
|
22
|
+
.. math::
|
|
23
|
+
\begin{equation*}
|
|
24
|
+
p(Y=y|f)=\Phi((2y - 1)f)
|
|
25
|
+
\end{equation*}
|
|
26
|
+
|
|
27
|
+
.. note::
|
|
28
|
+
BernoulliLikelihood has an analytic marginal distribution.
|
|
29
|
+
|
|
30
|
+
.. note::
|
|
31
|
+
The labels should take values in {0, 1}.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
has_analytic_marginal: bool = True
|
|
35
|
+
|
|
36
|
+
def __init__(self) -> None:
|
|
37
|
+
return super().__init__()
|
|
38
|
+
|
|
39
|
+
def forward(self, function_samples: Tensor, *args: Any, **kwargs: Any) -> Bernoulli:
|
|
40
|
+
output_probs = base_distributions.Normal(0, 1).cdf(function_samples)
|
|
41
|
+
return base_distributions.Bernoulli(probs=output_probs)
|
|
42
|
+
|
|
43
|
+
def log_marginal(
|
|
44
|
+
self, observations: Tensor, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any
|
|
45
|
+
) -> Tensor:
|
|
46
|
+
marginal = self.marginal(function_dist, *args, **kwargs)
|
|
47
|
+
return marginal.log_prob(observations)
|
|
48
|
+
|
|
49
|
+
def marginal(self, function_dist: Union[MultivariateNormal, MultivariateQExponential], *args: Any, **kwargs: Any) -> Bernoulli:
|
|
50
|
+
r"""
|
|
51
|
+
:return: Analytic marginal :math:`p(\mathbf y)`.
|
|
52
|
+
"""
|
|
53
|
+
mean = function_dist.mean
|
|
54
|
+
var = function_dist.variance
|
|
55
|
+
link = mean.div(torch.sqrt(1 + var))
|
|
56
|
+
output_probs = base_distributions.Normal(0, 1).cdf(link)
|
|
57
|
+
return base_distributions.Bernoulli(probs=output_probs)
|
|
58
|
+
|
|
59
|
+
def expected_log_prob(
|
|
60
|
+
self, observations: Tensor, function_dist: Union[MultivariateNormal, MultivariateQExponential], *params: Any, **kwargs: Any
|
|
61
|
+
) -> Tensor:
|
|
62
|
+
if torch.any(observations.eq(-1)):
|
|
63
|
+
# Remove after 1.0
|
|
64
|
+
warnings.warn(
|
|
65
|
+
"BernoulliLikelihood.expected_log_prob expects observations with labels in {0, 1}. "
|
|
66
|
+
"Observations with labels in {-1, 1} are deprecated.",
|
|
67
|
+
DeprecationWarning,
|
|
68
|
+
)
|
|
69
|
+
else:
|
|
70
|
+
observations = observations.mul(2).sub(1)
|
|
71
|
+
# Custom function here so we can use log_normal_cdf rather than Normal.cdf
|
|
72
|
+
# This is going to be less prone to overflow errors
|
|
73
|
+
log_prob_lambda = lambda function_samples: log_normal_cdf(function_samples.mul(observations))
|
|
74
|
+
log_prob = self.quadrature(log_prob_lambda, function_dist)
|
|
75
|
+
return log_prob
|