qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from linear_operator.operators import KroneckerProductLinearOperator
|
|
7
|
+
|
|
8
|
+
from gpytorch.kernels.matern_kernel import MaternKernel
|
|
9
|
+
|
|
10
|
+
sqrt3 = math.sqrt(3)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Matern32KernelGrad(MaternKernel):
|
|
14
|
+
r"""
|
|
15
|
+
Computes a covariance matrix of the Matern32 kernel that models the covariance
|
|
16
|
+
between the values and partial derivatives for inputs :math:`\mathbf{x_1}`
|
|
17
|
+
and :math:`\mathbf{x_2}`.
|
|
18
|
+
|
|
19
|
+
See :class:`qpytorch.kernels.Kernel` for descriptions of the lengthscale options.
|
|
20
|
+
|
|
21
|
+
.. note::
|
|
22
|
+
|
|
23
|
+
This kernel does not have an `outputscale` parameter. To add a scaling parameter,
|
|
24
|
+
decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`.
|
|
25
|
+
|
|
26
|
+
:param ard_num_dims: Set this if you want a separate lengthscale for each input
|
|
27
|
+
dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.)
|
|
28
|
+
:param batch_shape: Set this if you want a separate lengthscale for each batch of input
|
|
29
|
+
data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is
|
|
30
|
+
a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor.
|
|
31
|
+
:param active_dims: Set this if you want to compute the covariance of only
|
|
32
|
+
a few input dimensions. The ints corresponds to the indices of the
|
|
33
|
+
dimensions. (Default: `None`.)
|
|
34
|
+
:param lengthscale_prior: Set this if you want to apply a prior to the
|
|
35
|
+
lengthscale parameter. (Default: `None`)
|
|
36
|
+
:param lengthscale_constraint: Set this if you want to apply a constraint
|
|
37
|
+
to the lengthscale parameter. (Default: `Positive`.)
|
|
38
|
+
:param eps: The minimum value that the lengthscale can take (prevents
|
|
39
|
+
divide by zero errors). (Default: `1e-6`.)
|
|
40
|
+
|
|
41
|
+
:ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the
|
|
42
|
+
ard_num_dims and batch_shape arguments.
|
|
43
|
+
|
|
44
|
+
Example:
|
|
45
|
+
>>> x = torch.randn(10, 5)
|
|
46
|
+
>>> # Non-batch: Simple option
|
|
47
|
+
>>> covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.Matern32KernelGrad())
|
|
48
|
+
>>> covar = covar_module(x) # Output: LinearOperator of size (60 x 60), where 60 = n * (d + 1)
|
|
49
|
+
>>>
|
|
50
|
+
>>> batch_x = torch.randn(2, 10, 5)
|
|
51
|
+
>>> # Batch: Simple option
|
|
52
|
+
>>> covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.Matern32KernelGrad())
|
|
53
|
+
>>> # Batch: different lengthscale for each batch
|
|
54
|
+
>>> covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.Matern32KernelGrad(batch_shape=torch.Size([2]))) # noqa: E501
|
|
55
|
+
>>> covar = covar_module(x) # Output: LinearOperator of size (2 x 60 x 60)
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, **kwargs):
|
|
59
|
+
|
|
60
|
+
# remove nu in case it was set
|
|
61
|
+
kwargs.pop("nu", None)
|
|
62
|
+
super(Matern32KernelGrad, self).__init__(nu=1.5, **kwargs)
|
|
63
|
+
self._interleaved = kwargs.pop('interleaved', True)
|
|
64
|
+
|
|
65
|
+
def forward(self, x1, x2, diag=False, **params):
|
|
66
|
+
|
|
67
|
+
lengthscale = self.lengthscale
|
|
68
|
+
|
|
69
|
+
batch_shape = x1.shape[:-2]
|
|
70
|
+
n_batch_dims = len(batch_shape)
|
|
71
|
+
n1, d = x1.shape[-2:]
|
|
72
|
+
n2 = x2.shape[-2]
|
|
73
|
+
|
|
74
|
+
if not diag:
|
|
75
|
+
|
|
76
|
+
K = torch.zeros(*batch_shape, n1 * (d + 1), n2 * (d + 1), device=x1.device, dtype=x1.dtype)
|
|
77
|
+
|
|
78
|
+
distance_matrix = self.covar_dist(x1.div(lengthscale), x2.div(lengthscale), diag=diag, **params)
|
|
79
|
+
exp_neg_sqrt3r = torch.exp(-sqrt3 * distance_matrix)
|
|
80
|
+
|
|
81
|
+
# differences matrix in each dimension to be used for derivatives
|
|
82
|
+
# shape of n1 x n2 x d
|
|
83
|
+
outer = x1.view(*batch_shape, n1, 1, d) - x2.view(*batch_shape, 1, n2, d)
|
|
84
|
+
outer = outer / lengthscale.unsqueeze(-2) ** 2
|
|
85
|
+
# shape of n1 x d x n2
|
|
86
|
+
outer = torch.transpose(outer, -1, -2).contiguous()
|
|
87
|
+
|
|
88
|
+
# 1) Kernel block, cov(f^m, f^n)
|
|
89
|
+
# shape is n1 x n2
|
|
90
|
+
# exp_component = torch.exp(-sqrt3 * distance_matrix)
|
|
91
|
+
constant_component = (sqrt3 * distance_matrix).add(1)
|
|
92
|
+
|
|
93
|
+
K[..., :n1, :n2] = constant_component * exp_neg_sqrt3r #exp_component
|
|
94
|
+
|
|
95
|
+
# 2) First gradient block, cov(f^m, omega^n_i)
|
|
96
|
+
outer1 = outer.view(*batch_shape, n1, n2 * d)
|
|
97
|
+
# the - signs on -outer1 and -five_thirds cancel out
|
|
98
|
+
K[..., :n1, n2:] = 3 * outer1 * exp_neg_sqrt3r.repeat(
|
|
99
|
+
[*([1] * (n_batch_dims + 1)), d]
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# 3) Second gradient block, cov(omega^m_j, f^n)
|
|
103
|
+
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d)
|
|
104
|
+
outer2 = outer2.transpose(-1, -2)
|
|
105
|
+
K[..., n1:, :n2] = -3 * outer2 * exp_neg_sqrt3r.repeat(
|
|
106
|
+
[*([1] * n_batch_dims), d, 1]
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# 4) Hessian block, cov(omega^m_j, omega^n_i)
|
|
110
|
+
outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d])
|
|
111
|
+
kp = KroneckerProductLinearOperator(
|
|
112
|
+
torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / lengthscale**2,
|
|
113
|
+
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# part1 = -3 * exp_neg_sqrt3r
|
|
117
|
+
# part2 = sqrt3 * invrdd * outer3
|
|
118
|
+
invrdd = (distance_matrix+self.eps).pow(-1)
|
|
119
|
+
# invrdd[torch.arange(min(n1,n2)),torch.arange(min(n1,n2))] = distance_matrix.diagonal()
|
|
120
|
+
invrdd = invrdd.repeat([*([1] * (n_batch_dims)), d, d])
|
|
121
|
+
# invrdd = distance_matrix.pow(-1).fill_diagonal_(0).repeat([*([1] * (n_batch_dims)), d, d]).fill_diagonal_(1)
|
|
122
|
+
|
|
123
|
+
K[..., n1:, n2:] = -3 * exp_neg_sqrt3r.repeat([*([1] * n_batch_dims), d, d]).mul_(
|
|
124
|
+
(sqrt3*invrdd * outer3).sub_(kp.to_dense())
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Symmetrize for stability
|
|
128
|
+
if n1 == n2 and torch.eq(x1, x2).all():
|
|
129
|
+
K = 0.5 * (K.transpose(-1, -2) + K)
|
|
130
|
+
|
|
131
|
+
# Apply a perfect shuffle permutation to match the MutiTask ordering
|
|
132
|
+
if self._interleaved:
|
|
133
|
+
pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1)))
|
|
134
|
+
pi2 = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
|
|
135
|
+
K = K[..., pi1, :][..., :, pi2]
|
|
136
|
+
|
|
137
|
+
return K
|
|
138
|
+
else:
|
|
139
|
+
if not (n1 == n2 and torch.eq(x1, x2).all()):
|
|
140
|
+
raise RuntimeError("diag=True only works when x1 == x2")
|
|
141
|
+
|
|
142
|
+
# nu is set to 2.5
|
|
143
|
+
kernel_diag = super(Matern32KernelGrad, self).forward(x1, x2, diag=True)
|
|
144
|
+
grad_diag = (
|
|
145
|
+
3 * torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype)
|
|
146
|
+
) / lengthscale**2
|
|
147
|
+
grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d)
|
|
148
|
+
k_diag = torch.cat((kernel_diag, grad_diag), dim=-1)
|
|
149
|
+
if self._interleaved:
|
|
150
|
+
pi = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
|
|
151
|
+
k_diag = k_diag[..., pi]
|
|
152
|
+
return k_diag
|
|
153
|
+
|
|
154
|
+
def num_outputs_per_input(self, x1, x2):
|
|
155
|
+
return x1.size(-1) + 1
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from linear_operator.operators import KroneckerProductLinearOperator
|
|
7
|
+
|
|
8
|
+
from gpytorch.kernels.matern_kernel import MaternKernel
|
|
9
|
+
|
|
10
|
+
sqrt5 = math.sqrt(5)
|
|
11
|
+
five_thirds = 5.0 / 3.0
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Matern52KernelGrad(MaternKernel):
|
|
15
|
+
r"""
|
|
16
|
+
Computes a covariance matrix of the Matern52 kernel that models the covariance
|
|
17
|
+
between the values and partial derivatives for inputs :math:`\mathbf{x_1}`
|
|
18
|
+
and :math:`\mathbf{x_2}`.
|
|
19
|
+
|
|
20
|
+
See :class:`qpytorch.kernels.Kernel` for descriptions of the lengthscale options.
|
|
21
|
+
|
|
22
|
+
.. note::
|
|
23
|
+
|
|
24
|
+
This kernel does not have an `outputscale` parameter. To add a scaling parameter,
|
|
25
|
+
decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`.
|
|
26
|
+
|
|
27
|
+
.. note::
|
|
28
|
+
|
|
29
|
+
A perfect shuffle permutation is applied after the calculation of the matrix blocks
|
|
30
|
+
in order to match the MutiTask ordering.
|
|
31
|
+
|
|
32
|
+
The Matern52 kernel is defined as
|
|
33
|
+
|
|
34
|
+
.. math::
|
|
35
|
+
|
|
36
|
+
k(r) = (1 + \sqrt{5}r + \frac{5}{3} r^2) \exp(- \sqrt{5}r)
|
|
37
|
+
|
|
38
|
+
where :math:`r` is defined as
|
|
39
|
+
|
|
40
|
+
.. math::
|
|
41
|
+
|
|
42
|
+
r(\mathbf{x}^m , \mathbf{x}^n) = \sqrt{\sum_d{\frac{(x^m_d - x^n_d)^2}{l^2_d}}}
|
|
43
|
+
|
|
44
|
+
The first gradient block containing :math:`\frac{\partial k}{\partial x^n_i}` is defined as
|
|
45
|
+
|
|
46
|
+
.. math::
|
|
47
|
+
|
|
48
|
+
\frac{\partial k}{\partial x^n_i} = \frac{5}{3} \left( 1 + \sqrt{5}r \right) \exp(- \sqrt{5}r) \left(\frac{x^m_i - x^n_i}{l^2_i} \right)
|
|
49
|
+
|
|
50
|
+
The second gradient block containing :math:`\frac{\partial k}{\partial x^m_j}` is defined as
|
|
51
|
+
|
|
52
|
+
.. math::
|
|
53
|
+
|
|
54
|
+
\frac{\partial k}{\partial x^m_j} = - \frac{5}{3} \left( 1 + \sqrt{5}r \right) \exp(- \sqrt{5}r) \left(\frac{x^m_j - x^n_j}{l^2_j} \right)
|
|
55
|
+
|
|
56
|
+
The Hessian block containing :math:`\frac{\partial^2 k}{\partial x^m_j \partial x^n_i}` is defined as
|
|
57
|
+
|
|
58
|
+
.. math::
|
|
59
|
+
|
|
60
|
+
\frac{\partial^2 k}{\partial x^m_j \partial x^n_i} = - \frac{5}{3} \exp(- \sqrt{5}r) \left[ 5\left(\frac{x^m_i - x^n_i}{l^2_i} \right) \left( \frac{x^m_j - x^n_j}{l^2_j} \right) - \frac{\delta_{ij}}{l^2_i} \left( 1 + \sqrt{5}r \right) \right]
|
|
61
|
+
|
|
62
|
+
The derivations can be found `here <https://github.com/cornellius-gp/gpytorch/pull/2512>`__.
|
|
63
|
+
|
|
64
|
+
:param ard_num_dims: Set this if you want a separate lengthscale for each input
|
|
65
|
+
dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.)
|
|
66
|
+
:param batch_shape: Set this if you want a separate lengthscale for each batch of input
|
|
67
|
+
data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is
|
|
68
|
+
a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor.
|
|
69
|
+
:param active_dims: Set this if you want to compute the covariance of only
|
|
70
|
+
a few input dimensions. The ints corresponds to the indices of the
|
|
71
|
+
dimensions. (Default: `None`.)
|
|
72
|
+
:param lengthscale_prior: Set this if you want to apply a prior to the
|
|
73
|
+
lengthscale parameter. (Default: `None`)
|
|
74
|
+
:param lengthscale_constraint: Set this if you want to apply a constraint
|
|
75
|
+
to the lengthscale parameter. (Default: `Positive`.)
|
|
76
|
+
:param eps: The minimum value that the lengthscale can take (prevents
|
|
77
|
+
divide by zero errors). (Default: `1e-6`.)
|
|
78
|
+
|
|
79
|
+
:ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the
|
|
80
|
+
ard_num_dims and batch_shape arguments.
|
|
81
|
+
|
|
82
|
+
Example:
|
|
83
|
+
>>> x = torch.randn(10, 5)
|
|
84
|
+
>>> # Non-batch: Simple option
|
|
85
|
+
>>> covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.Matern52KernelGrad())
|
|
86
|
+
>>> covar = covar_module(x) # Output: LinearOperator of size (60 x 60), where 60 = n * (d + 1)
|
|
87
|
+
>>>
|
|
88
|
+
>>> batch_x = torch.randn(2, 10, 5)
|
|
89
|
+
>>> # Batch: Simple option
|
|
90
|
+
>>> covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.Matern52KernelGrad())
|
|
91
|
+
>>> # Batch: different lengthscale for each batch
|
|
92
|
+
>>> covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.Matern52KernelGrad(batch_shape=torch.Size([2]))) # noqa: E501
|
|
93
|
+
>>> covar = covar_module(x) # Output: LinearOperator of size (2 x 60 x 60)
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(self, **kwargs):
|
|
97
|
+
|
|
98
|
+
# remove nu in case it was set
|
|
99
|
+
kwargs.pop("nu", None)
|
|
100
|
+
super(Matern52KernelGrad, self).__init__(nu=2.5, **kwargs)
|
|
101
|
+
self._interleaved = kwargs.pop('interleaved', True)
|
|
102
|
+
|
|
103
|
+
def forward(self, x1, x2, diag=False, **params):
|
|
104
|
+
|
|
105
|
+
lengthscale = self.lengthscale
|
|
106
|
+
|
|
107
|
+
batch_shape = x1.shape[:-2]
|
|
108
|
+
n_batch_dims = len(batch_shape)
|
|
109
|
+
n1, d = x1.shape[-2:]
|
|
110
|
+
n2 = x2.shape[-2]
|
|
111
|
+
|
|
112
|
+
if not diag:
|
|
113
|
+
|
|
114
|
+
K = torch.zeros(*batch_shape, n1 * (d + 1), n2 * (d + 1), device=x1.device, dtype=x1.dtype)
|
|
115
|
+
|
|
116
|
+
distance_matrix = self.covar_dist(x1.div(lengthscale), x2.div(lengthscale), diag=diag, **params)
|
|
117
|
+
exp_neg_sqrt5r = torch.exp(-sqrt5 * distance_matrix)
|
|
118
|
+
one_plus_sqrt5r = 1 + sqrt5 * distance_matrix
|
|
119
|
+
|
|
120
|
+
# differences matrix in each dimension to be used for derivatives
|
|
121
|
+
# shape of n1 x n2 x d
|
|
122
|
+
outer = x1.view(*batch_shape, n1, 1, d) - x2.view(*batch_shape, 1, n2, d)
|
|
123
|
+
outer = outer / lengthscale.unsqueeze(-2) ** 2
|
|
124
|
+
# shape of n1 x d x n2
|
|
125
|
+
outer = torch.transpose(outer, -1, -2).contiguous()
|
|
126
|
+
|
|
127
|
+
# 1) Kernel block, cov(f^m, f^n)
|
|
128
|
+
# shape is n1 x n2
|
|
129
|
+
# exp_component = torch.exp(-sqrt5 * distance_matrix)
|
|
130
|
+
constant_component = one_plus_sqrt5r.add(five_thirds * distance_matrix**2)
|
|
131
|
+
|
|
132
|
+
K[..., :n1, :n2] = constant_component * exp_neg_sqrt5r #exp_component
|
|
133
|
+
|
|
134
|
+
# 2) First gradient block, cov(f^m, omega^n_i)
|
|
135
|
+
outer1 = outer.view(*batch_shape, n1, n2 * d)
|
|
136
|
+
# the - signs on -outer1 and -five_thirds cancel out
|
|
137
|
+
K[..., :n1, n2:] = five_thirds * outer1 * (one_plus_sqrt5r * exp_neg_sqrt5r).repeat(
|
|
138
|
+
[*([1] * (n_batch_dims + 1)), d]
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# 3) Second gradient block, cov(omega^m_j, f^n)
|
|
142
|
+
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d)
|
|
143
|
+
outer2 = outer2.transpose(-1, -2)
|
|
144
|
+
K[..., n1:, :n2] = -five_thirds * outer2 * (one_plus_sqrt5r * exp_neg_sqrt5r).repeat(
|
|
145
|
+
[*([1] * n_batch_dims), d, 1]
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# 4) Hessian block, cov(omega^m_j, omega^n_i)
|
|
149
|
+
outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d])
|
|
150
|
+
kp = KroneckerProductLinearOperator(
|
|
151
|
+
torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / lengthscale**2,
|
|
152
|
+
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# part1 = -five_thirds * exp_neg_sqrt5r
|
|
156
|
+
# part2 = 5 * outer3
|
|
157
|
+
# part3 = 1 + sqrt5 * distance_matrix
|
|
158
|
+
|
|
159
|
+
K[..., n1:, n2:] = -five_thirds * exp_neg_sqrt5r.repeat([*([1] * n_batch_dims), d, d]).mul_(
|
|
160
|
+
# need to use kp.to_dense().mul instead of kp.to_dense().mul_
|
|
161
|
+
# because otherwise a RuntimeError is raised due to how autograd works with
|
|
162
|
+
# view + inplace operations in the case of 1-dimensional input
|
|
163
|
+
(5 * outer3).sub_(kp.to_dense().mul(one_plus_sqrt5r.repeat([*([1] * n_batch_dims), d, d])))
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Symmetrize for stability
|
|
167
|
+
if n1 == n2 and torch.eq(x1, x2).all():
|
|
168
|
+
K = 0.5 * (K.transpose(-1, -2) + K)
|
|
169
|
+
|
|
170
|
+
# Apply a perfect shuffle permutation to match the MutiTask ordering
|
|
171
|
+
if self._interleaved:
|
|
172
|
+
pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1)))
|
|
173
|
+
pi2 = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
|
|
174
|
+
K = K[..., pi1, :][..., :, pi2]
|
|
175
|
+
|
|
176
|
+
return K
|
|
177
|
+
else:
|
|
178
|
+
if not (n1 == n2 and torch.eq(x1, x2).all()):
|
|
179
|
+
raise RuntimeError("diag=True only works when x1 == x2")
|
|
180
|
+
|
|
181
|
+
# nu is set to 2.5
|
|
182
|
+
kernel_diag = super(Matern52KernelGrad, self).forward(x1, x2, diag=True)
|
|
183
|
+
grad_diag = (
|
|
184
|
+
five_thirds * torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype)
|
|
185
|
+
) / lengthscale**2
|
|
186
|
+
grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d)
|
|
187
|
+
k_diag = torch.cat((kernel_diag, grad_diag), dim=-1)
|
|
188
|
+
if self._interleaved:
|
|
189
|
+
pi = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
|
|
190
|
+
k_diag = k_diag[..., pi]
|
|
191
|
+
return k_diag
|
|
192
|
+
|
|
193
|
+
def num_outputs_per_input(self, x1, x2):
|
|
194
|
+
return x1.size(-1) + 1
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from linear_operator.operators import KroneckerProductLinearOperator
|
|
7
|
+
|
|
8
|
+
from gpytorch.kernels.matern_kernel import MaternKernel
|
|
9
|
+
|
|
10
|
+
sqrt5 = math.sqrt(5)
|
|
11
|
+
five_thirds = 5.0 / 3.0
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Matern52KernelGradGrad(MaternKernel):
|
|
15
|
+
r"""
|
|
16
|
+
Computes a covariance matrix of the Matern52 kernel that models the covariance
|
|
17
|
+
between the values and first and second (non-mixed) partial derivatives for inputs :math:`\mathbf{x_1}`
|
|
18
|
+
and :math:`\mathbf{x_2}`.
|
|
19
|
+
|
|
20
|
+
See :class:`qpytorch.kernels.Kernel` for descriptions of the lengthscale options.
|
|
21
|
+
|
|
22
|
+
.. note::
|
|
23
|
+
|
|
24
|
+
This kernel does not have an `outputscale` parameter. To add a scaling parameter,
|
|
25
|
+
decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`.
|
|
26
|
+
|
|
27
|
+
:param ard_num_dims: Set this if you want a separate lengthscale for each input
|
|
28
|
+
dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.)
|
|
29
|
+
:param batch_shape: Set this if you want a separate lengthscale for each batch of input
|
|
30
|
+
data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is
|
|
31
|
+
a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor.
|
|
32
|
+
:param active_dims: Set this if you want to compute the covariance of only
|
|
33
|
+
a few input dimensions. The ints corresponds to the indices of the
|
|
34
|
+
dimensions. (Default: `None`.)
|
|
35
|
+
:param lengthscale_prior: Set this if you want to apply a prior to the
|
|
36
|
+
lengthscale parameter. (Default: `None`)
|
|
37
|
+
:param lengthscale_constraint: Set this if you want to apply a constraint
|
|
38
|
+
to the lengthscale parameter. (Default: `Positive`.)
|
|
39
|
+
:param eps: The minimum value that the lengthscale can take (prevents
|
|
40
|
+
divide by zero errors). (Default: `1e-6`.)
|
|
41
|
+
|
|
42
|
+
:ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the
|
|
43
|
+
ard_num_dims and batch_shape arguments.
|
|
44
|
+
|
|
45
|
+
Example:
|
|
46
|
+
>>> x = torch.randn(10, 5)
|
|
47
|
+
>>> # Non-batch: Simple option
|
|
48
|
+
>>> covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.Matern52KernelGradGrad())
|
|
49
|
+
>>> covar = covar_module(x) # Output: LinearOperator of size (110 x 110), where 110 = n * (2*d + 1)
|
|
50
|
+
>>>
|
|
51
|
+
>>> batch_x = torch.randn(2, 10, 5)
|
|
52
|
+
>>> # Batch: Simple option
|
|
53
|
+
>>> covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.Matern52KernelGradGrad())
|
|
54
|
+
>>> # Batch: different lengthscale for each batch
|
|
55
|
+
>>> covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.Matern52KernelGradGrad(batch_shape=torch.Size([2]))) # noqa: E501
|
|
56
|
+
>>> covar = covar_module(x) # Output: LinearOperator of size (2 x 110 x 110)
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, **kwargs):
|
|
60
|
+
|
|
61
|
+
# remove nu in case it was set
|
|
62
|
+
kwargs.pop("nu", None)
|
|
63
|
+
super(Matern52KernelGradGrad, self).__init__(nu=2.5, **kwargs)
|
|
64
|
+
self._interleaved = kwargs.pop('interleaved', True)
|
|
65
|
+
|
|
66
|
+
def forward(self, x1, x2, diag=False, **params):
|
|
67
|
+
|
|
68
|
+
lengthscale = self.lengthscale
|
|
69
|
+
|
|
70
|
+
batch_shape = x1.shape[:-2]
|
|
71
|
+
n_batch_dims = len(batch_shape)
|
|
72
|
+
n1, d = x1.shape[-2:]
|
|
73
|
+
n2 = x2.shape[-2]
|
|
74
|
+
|
|
75
|
+
mask_idx1 = params.pop('mask_idx1', None) # mask off-diagonal covariance
|
|
76
|
+
if mask_idx1 is not None:
|
|
77
|
+
mask_idx2 = params.pop('mask_idx2', None)
|
|
78
|
+
if mask_idx2 is None:
|
|
79
|
+
mask_idx2 = mask_idx1
|
|
80
|
+
else:
|
|
81
|
+
assert mask_idx1.shape[:-1] == mask_idx2.shape[:-1], 'Batch shapes of mask indices do not match!'
|
|
82
|
+
|
|
83
|
+
if not diag:
|
|
84
|
+
|
|
85
|
+
K = torch.zeros(*batch_shape, n1 * (2 * d + 1), n2 * (2 * d + 1), device=x1.device, dtype=x1.dtype)
|
|
86
|
+
|
|
87
|
+
distance_matrix = self.covar_dist(x1.div(lengthscale), x2.div(lengthscale), diag=diag, **params)
|
|
88
|
+
exp_neg_sqrt5r = torch.exp(-sqrt5 * distance_matrix)
|
|
89
|
+
one_plus_sqrt5r = 1 + sqrt5 * distance_matrix
|
|
90
|
+
|
|
91
|
+
# differences matrix in each dimension to be used for derivatives
|
|
92
|
+
# shape of n1 x n2 x d
|
|
93
|
+
outer = x1.view(*batch_shape, n1, 1, d) - x2.view(*batch_shape, 1, n2, d)
|
|
94
|
+
outer = outer / lengthscale.unsqueeze(-2) ** 2
|
|
95
|
+
# shape of n1 x d x n2
|
|
96
|
+
outer = torch.transpose(outer, -1, -2).contiguous()
|
|
97
|
+
|
|
98
|
+
# 1) Kernel block, cov(f^m, f^n)
|
|
99
|
+
# shape is n1 x n2
|
|
100
|
+
# exp_component = torch.exp(-sqrt5 * distance_matrix)
|
|
101
|
+
constant_component = one_plus_sqrt5r.add(five_thirds * distance_matrix**2)
|
|
102
|
+
|
|
103
|
+
K[..., :n1, :n2] = constant_component * exp_neg_sqrt5r #exp_component
|
|
104
|
+
|
|
105
|
+
# 2) First gradient block, cov(f^m, omega^n_i)
|
|
106
|
+
outer1 = outer.view(*batch_shape, n1, n2 * d)
|
|
107
|
+
# the - signs on -outer1 and -five_thirds cancel out
|
|
108
|
+
K[..., :n1, n2: (n2 * (d + 1))] = five_thirds * outer1 * (one_plus_sqrt5r * exp_neg_sqrt5r).repeat(
|
|
109
|
+
[*([1] * (n_batch_dims + 1)), d]
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# 3) Second gradient block, cov(omega^m_j, f^n)
|
|
113
|
+
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d)
|
|
114
|
+
outer2 = outer2.transpose(-1, -2)
|
|
115
|
+
K[..., n1: (n1 * (d + 1)), :n2] = -five_thirds * outer2 * (one_plus_sqrt5r * exp_neg_sqrt5r).repeat(
|
|
116
|
+
[*([1] * n_batch_dims), d, 1]
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# 4) Hessian block, cov(omega^m_j, omega^n_i)
|
|
120
|
+
outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d])
|
|
121
|
+
kp = KroneckerProductLinearOperator(
|
|
122
|
+
torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / lengthscale**2,
|
|
123
|
+
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# part1 = -five_thirds * exp_neg_sqrt5r
|
|
127
|
+
# part2 = 5 * outer3
|
|
128
|
+
# part3 = 1 + sqrt5 * distance_matrix
|
|
129
|
+
exp_neg_sqrt5rdd = exp_neg_sqrt5r.repeat([*([1] * (n_batch_dims)), d, d])
|
|
130
|
+
|
|
131
|
+
K[..., n1: (n1 * (d + 1)), n2: (n2 * (d + 1))] = -five_thirds * exp_neg_sqrt5rdd.mul(
|
|
132
|
+
# need to use kp.to_dense().mul instead of kp.to_dense().mul_
|
|
133
|
+
# because otherwise a RuntimeError is raised due to how autograd works with
|
|
134
|
+
# view + inplace operations in the case of 1-dimensional input
|
|
135
|
+
(5 * outer3).sub(kp.to_dense().mul(one_plus_sqrt5r.repeat([*([1] * n_batch_dims), d, d])))
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# 5) 1-3 block
|
|
139
|
+
douter1dx2 = KroneckerProductLinearOperator(
|
|
140
|
+
torch.ones(1, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2),
|
|
141
|
+
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
|
|
142
|
+
).to_dense()
|
|
143
|
+
|
|
144
|
+
K_13 = five_thirds * (-douter1dx2 * one_plus_sqrt5r.repeat([*([1] * (n_batch_dims + 1)), d]) + 5* outer1 * outer1) * exp_neg_sqrt5r.repeat(
|
|
145
|
+
[*([1] * (n_batch_dims + 1)), d]
|
|
146
|
+
) # verified for n1=n2=1 case
|
|
147
|
+
K[..., :n1, (n2 * (d + 1)) :] = K_13
|
|
148
|
+
|
|
149
|
+
if d>1:
|
|
150
|
+
douter1dx2 = KroneckerProductLinearOperator(
|
|
151
|
+
(torch.ones(1, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2)).transpose(-1, -2),
|
|
152
|
+
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
|
|
153
|
+
).to_dense()
|
|
154
|
+
K_31 = five_thirds * (-douter1dx2 * one_plus_sqrt5r.repeat([*([1] * n_batch_dims), d, 1]) + 5* outer2 * outer2) * exp_neg_sqrt5r.repeat(
|
|
155
|
+
[*([1] * n_batch_dims), d, 1]
|
|
156
|
+
) # verified for n1=n2=1 case
|
|
157
|
+
K[..., (n1 * (d + 1)) :, :n2] = K_31
|
|
158
|
+
|
|
159
|
+
# rest of the blocks are all of size (n1*d,n2*d)
|
|
160
|
+
outer1 = outer1.repeat([*([1] * n_batch_dims), d, 1])
|
|
161
|
+
outer2 = outer2.repeat([*([1] * (n_batch_dims + 1)), d])
|
|
162
|
+
# II = (torch.eye(d,d,device=x1.device,dtype=x1.dtype)/lengthscale.pow(2)).repeat(*batch_shape,n1,n2)
|
|
163
|
+
kp2 = KroneckerProductLinearOperator(
|
|
164
|
+
torch.ones(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2),
|
|
165
|
+
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
|
|
166
|
+
).to_dense()
|
|
167
|
+
|
|
168
|
+
# II may not be the correct thing to use. It might be more appropriate to use kp instead??
|
|
169
|
+
II = kp.to_dense()
|
|
170
|
+
# exp_neg_sqrt5rdd = exp_neg_sqrt5r.repeat([*([1] * (n_batch_dims)), d, d])
|
|
171
|
+
invrdd = (distance_matrix+self.eps).pow(-1)
|
|
172
|
+
# invrdd[torch.arange(min(n1,n2)),torch.arange(min(n1,n2))] = distance_matrix.diagonal()
|
|
173
|
+
invrdd = invrdd.repeat([*([1] * (n_batch_dims)), d, d])
|
|
174
|
+
# invrdd = distance_matrix.pow(-1).fill_diagonal_(0).repeat([*([1] * (n_batch_dims)), d, d]).fill_diagonal_(1)
|
|
175
|
+
|
|
176
|
+
K_23 = five_thirds * 5* ((kp2 - sqrt5*invrdd* outer1 * outer1) * outer2 + 2.0 * II * outer1) * exp_neg_sqrt5rdd # verified for n1=n2=1 case
|
|
177
|
+
|
|
178
|
+
K[..., n1 : (n1 * (d + 1)), (n2 * (d + 1)) :] = K_23
|
|
179
|
+
|
|
180
|
+
if d>1:
|
|
181
|
+
kp2t = KroneckerProductLinearOperator(
|
|
182
|
+
(torch.ones(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / self.lengthscale.pow(2)).transpose(-1, -2),
|
|
183
|
+
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
|
|
184
|
+
).to_dense()
|
|
185
|
+
K_32 = five_thirds * 5* (
|
|
186
|
+
(-(kp2t if d>1 else kp2) + sqrt5*invrdd* outer2 * outer2) * outer1 - 2.0 * II * outer2
|
|
187
|
+
) * exp_neg_sqrt5rdd # verified for n1=n2=1 case
|
|
188
|
+
|
|
189
|
+
K[..., (n1 * (d + 1)) :, n2 : (n2 * (d + 1))] = K_32
|
|
190
|
+
|
|
191
|
+
# K_33 = five_thirds * 5*(
|
|
192
|
+
# ((-(kp2t if d>1 else kp2) + sqrt5*invrdd*outer2 * outer2) * (-kp2) - 2.0 *sqrt5*invrdd * II * outer2 * outer1 + 2.0 * (II) ** 2
|
|
193
|
+
# ) + (
|
|
194
|
+
# (-(kp2t if d>1 else kp2)*sqrt5*invrdd + (5+sqrt5*invrdd)*invrdd**2*outer2 * outer2) * outer1 - 2.0 *sqrt5*invrdd * II * outer2
|
|
195
|
+
# ) * outer1) * exp_neg_sqrt5rdd # verified for n1=n2=1 case
|
|
196
|
+
K_33 = five_thirds * 5*(
|
|
197
|
+
(kp2 - sqrt5*invrdd*outer1 * outer1) * ((kp2t if d>1 else kp2)-sqrt5*invrdd* outer2 * outer2)
|
|
198
|
+
+ sqrt5*invrdd*(invrdd**2*outer3-4*(II))*outer3 + 2*(II)**2
|
|
199
|
+
) * exp_neg_sqrt5rdd
|
|
200
|
+
|
|
201
|
+
K[..., (n1 * (d + 1)) :, (n2 * (d + 1)) :] = K_33
|
|
202
|
+
|
|
203
|
+
# Symmetrize for stability
|
|
204
|
+
if n1 == n2 and torch.eq(x1, x2).all():
|
|
205
|
+
K = 0.5 * (K.transpose(-1, -2) + K)
|
|
206
|
+
|
|
207
|
+
# Apply a perfect shuffle permutation to match the MutiTask ordering
|
|
208
|
+
if self._interleaved:
|
|
209
|
+
pi1 = torch.arange(n1 * (2 * d + 1)).view(2 * d + 1, n1).t().reshape((n1 * (2 * d + 1)))
|
|
210
|
+
pi2 = torch.arange(n2 * (2 * d + 1)).view(2 * d + 1, n2).t().reshape((n2 * (2 * d + 1)))
|
|
211
|
+
K = K[..., pi1, :][..., :, pi2]
|
|
212
|
+
|
|
213
|
+
if mask_idx1 is not None:
|
|
214
|
+
if mask_idx1.ndim==1:
|
|
215
|
+
diag2keep = K[...,mask_idx1,mask_idx2]
|
|
216
|
+
K[...,mask_idx1,:] = 0; K[...,mask_idx2] = 0
|
|
217
|
+
K[...,mask_idx1,mask_idx2] = diag2keep * self.eps
|
|
218
|
+
elif mask_idx1.ndim==2:
|
|
219
|
+
for b in range(mask_idx1.shape[0]):
|
|
220
|
+
diag2keep = K[b,...,mask_idx1[b],mask_idx2[b]]
|
|
221
|
+
K[b,...,mask_idx1[b],:] = 0; K[b,...,mask_idx2[b]] = 0
|
|
222
|
+
K[b,...,mask_idx1[b],mask_idx2[b]] = diag2keep * self.eps
|
|
223
|
+
else:
|
|
224
|
+
raise NotImplementedError('Mask indices of batch dimension bigger than 1 not implemented!')
|
|
225
|
+
|
|
226
|
+
return K
|
|
227
|
+
else:
|
|
228
|
+
if not (n1 == n2 and torch.eq(x1, x2).all()):
|
|
229
|
+
raise RuntimeError("diag=True only works when x1 == x2")
|
|
230
|
+
|
|
231
|
+
# nu is set to 2.5
|
|
232
|
+
kernel_diag = super(Matern52KernelGradGrad, self).forward(x1, x2, diag=True)
|
|
233
|
+
grad_diag = (
|
|
234
|
+
five_thirds * torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype)
|
|
235
|
+
) / lengthscale**2
|
|
236
|
+
grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d)
|
|
237
|
+
gradgrad_diag = (
|
|
238
|
+
5**2 * torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) / lengthscale.pow(4)
|
|
239
|
+
)
|
|
240
|
+
gradgrad_diag = gradgrad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d)
|
|
241
|
+
k_diag = torch.cat((kernel_diag, grad_diag, gradgrad_diag), dim=-1)
|
|
242
|
+
if self._interleaved:
|
|
243
|
+
pi = torch.arange(n2 * (2 * d + 1)).view(2 * d + 1, n2).t().reshape((n2 * (2 * d + 1)))
|
|
244
|
+
k_diag = k_diag[..., pi]
|
|
245
|
+
return k_diag
|
|
246
|
+
|
|
247
|
+
def num_outputs_per_input(self, x1, x2):
|
|
248
|
+
return x1.size(-1) * 2 + 1
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from gpytorch.kernels.polynomial_kernel import PolynomialKernel
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PolynomialKernelGrad(PolynomialKernel):
|
|
11
|
+
|
|
12
|
+
def __init__(self, **kwargs):
|
|
13
|
+
super(PolynomialKernelGrad, self).__init__(**kwargs)
|
|
14
|
+
self._interleaved = kwargs.pop('interleaved', True)
|
|
15
|
+
|
|
16
|
+
def forward(
|
|
17
|
+
self,
|
|
18
|
+
x1: torch.Tensor,
|
|
19
|
+
x2: torch.Tensor,
|
|
20
|
+
diag: Optional[bool] = False,
|
|
21
|
+
last_dim_is_batch: Optional[bool] = False,
|
|
22
|
+
**params,
|
|
23
|
+
) -> torch.Tensor:
|
|
24
|
+
offset = self.offset.view(*self.batch_shape, 1, 1)
|
|
25
|
+
|
|
26
|
+
batch_shape = x1.shape[:-2]
|
|
27
|
+
n1, d = x1.shape[-2:]
|
|
28
|
+
n2 = x2.shape[-2]
|
|
29
|
+
|
|
30
|
+
if diag:
|
|
31
|
+
base_diag = (x1 * x2).sum(dim=-1) + self.offset
|
|
32
|
+
K11_diag = base_diag.pow(self.power)
|
|
33
|
+
|
|
34
|
+
all_outers_diag = (x1 * x2).transpose(-2, -1).reshape(*batch_shape, -1)
|
|
35
|
+
K22_base_diag = self.power * (self.power - 1) * base_diag.pow(self.power - 2)
|
|
36
|
+
K12_base_diag = self.power * base_diag.pow(self.power - 1)
|
|
37
|
+
|
|
38
|
+
K22_diag = torch.add(
|
|
39
|
+
all_outers_diag * K22_base_diag.repeat(*([1] * (K22_base_diag.dim() - 1)), d),
|
|
40
|
+
K12_base_diag.repeat(*([1] * (K12_base_diag.dim() - 1)), d),
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
K_diag = torch.cat([K11_diag, K22_diag], dim=-1)
|
|
44
|
+
# Apply perfect shuffle
|
|
45
|
+
if self._interleaved:
|
|
46
|
+
pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1)))
|
|
47
|
+
K_diag = K_diag[..., pi1]
|
|
48
|
+
return K_diag
|
|
49
|
+
|
|
50
|
+
else:
|
|
51
|
+
base_inner_prod = torch.matmul(x1, x2.transpose(-2, -1)) + offset
|
|
52
|
+
K11 = base_inner_prod.pow(self.power)
|
|
53
|
+
|
|
54
|
+
K12_base = self.power * base_inner_prod.pow(self.power - 1)
|
|
55
|
+
K12 = torch.zeros(*batch_shape, n1, n2 * d, dtype=x1.dtype, device=x1.device)
|
|
56
|
+
|
|
57
|
+
ones_ = torch.ones(*batch_shape, d, 1, n2, dtype=x1.dtype, device=x1.device)
|
|
58
|
+
K12_outer_prods = torch.matmul(x1.transpose(-2, -1).unsqueeze(-1), ones_)
|
|
59
|
+
K12 = (K12_base.unsqueeze(-3) * K12_outer_prods).transpose(-3, -2).reshape(*batch_shape, n1, d * n2)
|
|
60
|
+
|
|
61
|
+
ones_ = torch.ones(*batch_shape, d, n1, 1, dtype=x1.dtype, device=x1.device)
|
|
62
|
+
K21_outer_prods = torch.matmul(ones_, x2.transpose(-2, -1).unsqueeze(-2))
|
|
63
|
+
K21 = (K12_base.unsqueeze(-3) * K21_outer_prods).view(*batch_shape, d * n1, n2)
|
|
64
|
+
|
|
65
|
+
K22_base = self.power * (self.power - 1) * base_inner_prod.pow(self.power - 2)
|
|
66
|
+
K22 = torch.zeros(*batch_shape, n1 * d, n2 * d, dtype=x1.dtype, device=x1.device)
|
|
67
|
+
all_outers = x1.unsqueeze(-2).unsqueeze(-2).transpose(-2, -1).matmul(x2.unsqueeze(-3).unsqueeze(-2))
|
|
68
|
+
all_outers = all_outers.transpose(-4, -2).transpose(-3, -1)
|
|
69
|
+
K22 = K22_base.unsqueeze(-3).unsqueeze(-3) * all_outers # d x d x n1 x n2
|
|
70
|
+
|
|
71
|
+
# Can't avoid this for loop without unnecessary memory duplication, which is worse.
|
|
72
|
+
for i in range(d):
|
|
73
|
+
K22[..., i, i, :, :] = K22[..., i, i, :, :] + K12_base
|
|
74
|
+
|
|
75
|
+
K22 = K22.transpose(-4, -3).transpose(-3, -2).reshape(*batch_shape, n1 * d, n2 * d)
|
|
76
|
+
|
|
77
|
+
K = torch.cat([torch.cat([K11, K12], dim=-1), torch.cat([K21, K22], dim=-1)], dim=-2)
|
|
78
|
+
|
|
79
|
+
# Apply perfect shuffle
|
|
80
|
+
if self._interleaved:
|
|
81
|
+
pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1)))
|
|
82
|
+
pi2 = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
|
|
83
|
+
K = K[..., pi1, :][..., :, pi2]
|
|
84
|
+
|
|
85
|
+
return K
|
|
86
|
+
|
|
87
|
+
def num_outputs_per_input(self, x1, x2):
|
|
88
|
+
return x1.size(-1) + 1
|