qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ..module import Module
|
|
10
|
+
from ..constraints import Interval, Positive
|
|
11
|
+
from ..priors import Prior
|
|
12
|
+
|
|
13
|
+
class Power(Module):
|
|
14
|
+
"""
|
|
15
|
+
Constructs a power parameter for the (multivariate) q-exponential distribution.
|
|
16
|
+
See :class:`qpytorch.distributions.QExponential` or :class:`qpytorch.distributions.MultivariateQExponential`
|
|
17
|
+
for description of the power parameter.
|
|
18
|
+
|
|
19
|
+
.. note::
|
|
20
|
+
|
|
21
|
+
This object works similarly as a hyperparameter of kernel, which can be imposed with a prior and optimized over.
|
|
22
|
+
|
|
23
|
+
:param power_init: initial value of power parameter of qep distribution. (Default: 1.0)
|
|
24
|
+
:param power_constraint: Set this if you want to apply a constraint to the power parameter.
|
|
25
|
+
(Default: :class:`~qpytorch.constraints.Positive`.)
|
|
26
|
+
:param power_prior: Set this if you want to apply a prior to the power parameter.
|
|
27
|
+
(Default: `None`.)
|
|
28
|
+
|
|
29
|
+
:ivar torch.Size shape:
|
|
30
|
+
The dimension of the power object.
|
|
31
|
+
:ivar torch.Tensor power:
|
|
32
|
+
The power parameter. The size/shape is the same as the `power_init` argument.
|
|
33
|
+
:ivar torch.Tensor data:
|
|
34
|
+
The data of the power object in :obj:`torch.tensor` format.
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
>>> power_init = torch.tensor(1.0)
|
|
38
|
+
>>> power_prior = qpytorch.priors.GammaPrior(4.0, 2.0)
|
|
39
|
+
>>> power = qpytorch.distributions.Power(power_init, power_prior=power_prior)
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
power_init: torch.Tensor = torch.tensor(1.0),
|
|
45
|
+
power_constraint: Optional[Interval] = None,
|
|
46
|
+
power_prior: Optional[Prior] = None
|
|
47
|
+
):
|
|
48
|
+
super(Power, self).__init__()
|
|
49
|
+
if power_constraint is None:
|
|
50
|
+
power_constraint = Positive()
|
|
51
|
+
|
|
52
|
+
# set parameter
|
|
53
|
+
self.register_parameter(
|
|
54
|
+
name="raw_power",
|
|
55
|
+
parameter=torch.nn.Parameter(power_constraint.inverse_transform(power_init))
|
|
56
|
+
)
|
|
57
|
+
self.shape = self.raw_power.shape
|
|
58
|
+
# set constraint
|
|
59
|
+
self.register_constraint("raw_power", power_constraint)
|
|
60
|
+
# set prior
|
|
61
|
+
if power_prior is not None:
|
|
62
|
+
if not isinstance(power_prior, Prior):
|
|
63
|
+
raise TypeError("Expected qpytorch.priors.Prior but got " + type(power_prior).__name__)
|
|
64
|
+
self.register_prior("power_prior", power_prior, self._power_param, self._power_closure)
|
|
65
|
+
|
|
66
|
+
def _power_param(self, q: Power) -> torch.Tensor:
|
|
67
|
+
# Used by the raw_power
|
|
68
|
+
return q.power
|
|
69
|
+
|
|
70
|
+
def _power_closure(self, q: Power, v: torch.Tensor) -> torch.Tensor:
|
|
71
|
+
# Used by the raw_power
|
|
72
|
+
return q._set_power(v)
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def power(self) -> torch.Tensor:
|
|
76
|
+
return self.raw_power_constraint.transform(self.raw_power)
|
|
77
|
+
|
|
78
|
+
@power.setter
|
|
79
|
+
def power(self, value: torch.Tensor) -> torch.Tensor:
|
|
80
|
+
self._set_power(value)
|
|
81
|
+
|
|
82
|
+
def _set_power(self, value: torch.Tensor):
|
|
83
|
+
if not torch.is_tensor(value):
|
|
84
|
+
value = torch.as_tensor(value).to(self.raw_power)
|
|
85
|
+
self.initialize(raw_power=self.raw_power_constraint.inverse_transform(value))
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def data(self) -> torch.Tensor:
|
|
89
|
+
return self.power.data
|
|
90
|
+
|
|
91
|
+
def __truediv__(self, other):
|
|
92
|
+
return self.power/other
|
|
93
|
+
|
|
94
|
+
def __rtruediv__(self, other):
|
|
95
|
+
return other/self.power
|
|
96
|
+
|
|
97
|
+
def __rpow__(self, other):
|
|
98
|
+
return other**self.power
|
|
99
|
+
|
|
100
|
+
def __ne__(self, other):
|
|
101
|
+
return self.power!=other
|
|
102
|
+
|
|
103
|
+
def __lt__(self, other):
|
|
104
|
+
return self.power<other
|
|
105
|
+
|
|
106
|
+
def __gt__(self, other):
|
|
107
|
+
return self.power>other
|
|
108
|
+
|
|
109
|
+
def numel(self):
|
|
110
|
+
return self.power.numel()
|
|
111
|
+
|
|
112
|
+
def size(self):
|
|
113
|
+
return self.power.size()
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from numbers import Number, Real
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch.distributions import constraints, Chi2
|
|
8
|
+
from torch.distributions.exp_family import ExponentialFamily
|
|
9
|
+
from torch.distributions.kl import register_kl
|
|
10
|
+
from torch.distributions.utils import _standard_normal, broadcast_all
|
|
11
|
+
|
|
12
|
+
from gpytorch.distributions.distribution import Distribution
|
|
13
|
+
|
|
14
|
+
__all__ = ["QExponential"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class QExponential(ExponentialFamily, Distribution):
|
|
18
|
+
r"""
|
|
19
|
+
Creates a q-exponential distribution parameterized by
|
|
20
|
+
:attr:`loc`, :attr:`scale` and :attr:`power`, with the following density
|
|
21
|
+
|
|
22
|
+
.. math::
|
|
23
|
+
|
|
24
|
+
p(x; \mu, \sigma^2) = \frac{q}{2}(2\pi\sigma^2)^{-\frac{1}{2}}
|
|
25
|
+
\left|\frac{x-\mu}{\sigma}\right|^{\frac{q}{2}-1}
|
|
26
|
+
\exp\left\{-\frac{1}{2}\left|\frac{x-\mu}{\sigma}\right|^q\right\}
|
|
27
|
+
|
|
28
|
+
Example::
|
|
29
|
+
|
|
30
|
+
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
|
31
|
+
>>> m = QExponential(torch.tensor([0.0]), torch.tensor([1.0]))
|
|
32
|
+
>>> m.sample() # q-exponentially distributed with loc=0, scale=1 and power=2
|
|
33
|
+
tensor([ 0.1046])
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
loc (float or Tensor): mean of the distribution (often referred to as mu)
|
|
37
|
+
scale (float or Tensor): standard deviation of the distribution
|
|
38
|
+
(often referred to as sigma)
|
|
39
|
+
power (float or Tensor): power of the distribution
|
|
40
|
+
"""
|
|
41
|
+
arg_constraints = {"loc": constraints.real, "scale": constraints.positive, "power": constraints.positive}
|
|
42
|
+
support = constraints.real
|
|
43
|
+
has_rsample = True
|
|
44
|
+
_mean_carrier_measure = 0
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def mean(self):
|
|
48
|
+
return self.loc
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def mode(self):
|
|
52
|
+
return self.loc
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def stddev(self):
|
|
56
|
+
return self.scale
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def variance(self):
|
|
60
|
+
return self.stddev.pow(2)
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def rescalor(self):
|
|
64
|
+
return torch.exp((2./self.power*math.log(2) + torch.lgamma(0.5+2./self.power) - math.log(math.pi)/2.)/2.)
|
|
65
|
+
|
|
66
|
+
def __init__(self, loc, scale, power=torch.tensor(2.0), validate_args=None):
|
|
67
|
+
self.loc, self.scale = broadcast_all(loc, scale)
|
|
68
|
+
if isinstance(loc, Number) and isinstance(scale, Number):
|
|
69
|
+
batch_shape = torch.Size()
|
|
70
|
+
else:
|
|
71
|
+
batch_shape = self.loc.size()
|
|
72
|
+
self.power = power
|
|
73
|
+
super().__init__(batch_shape, validate_args=validate_args)
|
|
74
|
+
|
|
75
|
+
def confidence(self, alpha=0.05):
|
|
76
|
+
lower = self.icdf(torch.tensor(alpha/2))
|
|
77
|
+
upper = self.icdf(torch.tensor(1-alpha/2))
|
|
78
|
+
return lower, upper
|
|
79
|
+
|
|
80
|
+
def expand(self, batch_shape, _instance=None):
|
|
81
|
+
new = self._get_checked_instance(QExponential, _instance)
|
|
82
|
+
batch_shape = torch.Size(batch_shape)
|
|
83
|
+
new.loc = self.loc.expand(batch_shape)
|
|
84
|
+
new.scale = self.scale.expand(batch_shape)
|
|
85
|
+
super(QExponential, new).__init__(batch_shape, validate_args=False)
|
|
86
|
+
new._validate_args = self._validate_args
|
|
87
|
+
return new
|
|
88
|
+
|
|
89
|
+
def sample(self, sample_shape=torch.Size(), rescale=False):
|
|
90
|
+
shape = self._extended_shape(sample_shape)
|
|
91
|
+
with torch.no_grad():
|
|
92
|
+
eps = Chi2(1).sample(shape).to(self.loc.device)**(1./self.power) * _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device).sign()
|
|
93
|
+
if rescale: eps /= self.rescalor
|
|
94
|
+
return self.loc.expand(shape) + eps * self.scale.expand(shape)
|
|
95
|
+
|
|
96
|
+
def rsample(self, sample_shape=torch.Size(), rescale=False):
|
|
97
|
+
shape = self._extended_shape(sample_shape)
|
|
98
|
+
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
|
|
99
|
+
if self.power!=2: eps = eps.abs()**(2./self.power-1) * eps
|
|
100
|
+
if rescale: eps /= self.rescalor
|
|
101
|
+
return self.loc + eps * self.scale
|
|
102
|
+
|
|
103
|
+
def log_prob(self, value):
|
|
104
|
+
if self._validate_args:
|
|
105
|
+
self._validate_sample(value)
|
|
106
|
+
log_scale = (
|
|
107
|
+
math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log()
|
|
108
|
+
)
|
|
109
|
+
scaled_diff = ((value - self.loc) / self.scale).abs()
|
|
110
|
+
res = -.5* ( scaled_diff**self.power + math.log(2 * math.pi) ) - log_scale
|
|
111
|
+
if self.power!=2: res += (self.power/2.-1)*scaled_diff.log() + torch.log(self.power/2.)
|
|
112
|
+
return res
|
|
113
|
+
|
|
114
|
+
def cdf(self, value):
|
|
115
|
+
if self._validate_args:
|
|
116
|
+
self._validate_sample(value)
|
|
117
|
+
scaled_diff = (value - self.loc) * self.scale.reciprocal()
|
|
118
|
+
if self.power!=2: scaled_diff *= scaled_diff.abs()**(self.power/2.-1)
|
|
119
|
+
return 0.5 * (
|
|
120
|
+
1 + torch.erf(scaled_diff / math.sqrt(2))
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def icdf(self, value):
|
|
124
|
+
erfinv = torch.erfinv(2 * value - 1) * math.sqrt(2)
|
|
125
|
+
if self.power!=2: erfinv *= erfinv.abs()**(2./self.power-1)
|
|
126
|
+
return self.loc + self.scale * erfinv
|
|
127
|
+
|
|
128
|
+
def entropy(self, exact=False):
|
|
129
|
+
res = 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale)
|
|
130
|
+
if self.power!=2: res += 0.5*(self.power/2.-1) *(2./self.power* Chi2(1).entropy() if exact else 0) - torch.log(self.power/2.)
|
|
131
|
+
return res
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def _natural_params(self):
|
|
135
|
+
if self.power!=2:
|
|
136
|
+
raise ValueError(f"Q-Exponential distribution with power {self.power} does not belong to exponential family!")
|
|
137
|
+
else:
|
|
138
|
+
return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())
|
|
139
|
+
|
|
140
|
+
def _log_normalizer(self, x, y):
|
|
141
|
+
if self.power!=2:
|
|
142
|
+
raise ValueError(f"Q-Exponential distribution with power {self.power} does not belong to exponential family!")
|
|
143
|
+
else:
|
|
144
|
+
return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@register_kl(QExponential, QExponential)
|
|
148
|
+
def _kl_qexponential_qexponential(p, q, exact=False):
|
|
149
|
+
var_ratio = (p.scale / q.scale).pow(2)
|
|
150
|
+
t1 = ((p.loc - q.loc) / q.scale).pow(2)
|
|
151
|
+
res = 0.5 * ((var_ratio + t1).pow(q.power/2.) - 1 - var_ratio.log())
|
|
152
|
+
if q.power!=2: res += 0.5 * ( -(q.power/2.-1)*torch.log(var_ratio + t1) + (p.power/2.-1) * (-2./p.power*Chi2(1).entropy() if exact else 0) )
|
|
153
|
+
return res
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import warnings
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import linear_operator
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from gpytorch.functions._log_normal_cdf import LogNormalCDF
|
|
12
|
+
from gpytorch.functions.matern_covariance import MaternCovariance
|
|
13
|
+
from gpytorch.functions.rbf_covariance import RBFCovariance
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def log_normal_cdf(x):
|
|
17
|
+
"""
|
|
18
|
+
Computes the element-wise log standard normal CDF of an input tensor x.
|
|
19
|
+
|
|
20
|
+
This function should always be preferred over calling normal_cdf and taking the log
|
|
21
|
+
manually, as it is more numerically stable.
|
|
22
|
+
"""
|
|
23
|
+
return LogNormalCDF.apply(x)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def logdet(mat):
|
|
27
|
+
warnings.warn("gpytorch.logdet is deprecated. Use torch.logdet instead.", DeprecationWarning)
|
|
28
|
+
return torch.logdet(mat)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def matmul(mat, rhs):
|
|
32
|
+
warnings.warn("gpytorch.matmul is deprecated. Use torch.matmul instead.", DeprecationWarning)
|
|
33
|
+
return torch.matmul(mat, rhs)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def inv_matmul(mat, right_tensor, left_tensor=None):
|
|
37
|
+
warnings.warn("gpytorch.inv_matmul is deprecated. Use gpytorch.solve instead.", DeprecationWarning)
|
|
38
|
+
return linear_operator.solve(right_tensor, left_tensor=None)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
__all__ = [
|
|
42
|
+
"MaternCovariance",
|
|
43
|
+
"RBFCovariance",
|
|
44
|
+
"inv_matmul",
|
|
45
|
+
"logdet",
|
|
46
|
+
"log_normal_cdf",
|
|
47
|
+
"matmul",
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def __getattr__(name: str) -> Any:
|
|
52
|
+
if hasattr(linear_operator.functions, name):
|
|
53
|
+
warnings.warn(
|
|
54
|
+
f"gpytorch.functions.{name} is deprecated. Use linear_operator.functions.{name} instead.",
|
|
55
|
+
DeprecationWarning,
|
|
56
|
+
)
|
|
57
|
+
return getattr(linear_operator.functions, name)
|
|
58
|
+
raise AttributeError(f"module gpytorch.functions has no attribute {name}.")
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
from gpytorch.kernels import keops
|
|
3
|
+
from gpytorch.kernels.additive_structure_kernel import AdditiveStructureKernel
|
|
4
|
+
from gpytorch.kernels.arc_kernel import ArcKernel
|
|
5
|
+
from gpytorch.kernels.constant_kernel import ConstantKernel
|
|
6
|
+
from gpytorch.kernels.cosine_kernel import CosineKernel
|
|
7
|
+
from gpytorch.kernels.cylindrical_kernel import CylindricalKernel
|
|
8
|
+
from gpytorch.kernels.distributional_input_kernel import DistributionalInputKernel
|
|
9
|
+
from gpytorch.kernels.gaussian_symmetrized_kl_kernel import GaussianSymmetrizedKLKernel
|
|
10
|
+
from .grid_interpolation_kernel import GridInterpolationKernel
|
|
11
|
+
from gpytorch.kernels.grid_kernel import GridKernel
|
|
12
|
+
from gpytorch.kernels.hamming_kernel import HammingIMQKernel
|
|
13
|
+
from gpytorch.kernels.index_kernel import IndexKernel
|
|
14
|
+
from .inducing_point_kernel import InducingPointKernel
|
|
15
|
+
from .kernel import AdditiveKernel, Kernel, ProductKernel
|
|
16
|
+
from gpytorch.kernels.lcm_kernel import LCMKernel
|
|
17
|
+
from gpytorch.kernels.linear_kernel import LinearKernel
|
|
18
|
+
from .matern32_kernel_grad import Matern32KernelGrad
|
|
19
|
+
from .matern52_kernel_grad import Matern52KernelGrad
|
|
20
|
+
from .matern52_kernel_gradgrad import Matern52KernelGradGrad
|
|
21
|
+
from gpytorch.kernels.matern_kernel import MaternKernel
|
|
22
|
+
from gpytorch.kernels.multi_device_kernel import MultiDeviceKernel
|
|
23
|
+
from gpytorch.kernels.multitask_kernel import MultitaskKernel
|
|
24
|
+
from gpytorch.kernels.newton_girard_additive_kernel import NewtonGirardAdditiveKernel
|
|
25
|
+
from gpytorch.kernels.periodic_kernel import PeriodicKernel
|
|
26
|
+
from gpytorch.kernels.piecewise_polynomial_kernel import PiecewisePolynomialKernel
|
|
27
|
+
from gpytorch.kernels.polynomial_kernel import PolynomialKernel
|
|
28
|
+
from .polynomial_kernel_grad import PolynomialKernelGrad
|
|
29
|
+
from gpytorch.kernels.product_structure_kernel import ProductStructureKernel
|
|
30
|
+
from .qexponential_symmetrized_kl_kernel import QExponentialSymmetrizedKLKernel
|
|
31
|
+
from gpytorch.kernels.rbf_kernel import RBFKernel
|
|
32
|
+
from .rbf_kernel_grad import RBFKernelGrad
|
|
33
|
+
from .rbf_kernel_gradgrad import RBFKernelGradGrad
|
|
34
|
+
from .rff_kernel import RFFKernel
|
|
35
|
+
from gpytorch.kernels.rq_kernel import RQKernel
|
|
36
|
+
from gpytorch.kernels.scale_kernel import ScaleKernel
|
|
37
|
+
from gpytorch.kernels.spectral_delta_kernel import SpectralDeltaKernel
|
|
38
|
+
from gpytorch.kernels.spectral_mixture_kernel import SpectralMixtureKernel
|
|
39
|
+
|
|
40
|
+
__all__ = [
|
|
41
|
+
"keops",
|
|
42
|
+
"Kernel",
|
|
43
|
+
"ArcKernel",
|
|
44
|
+
"AdditiveKernel",
|
|
45
|
+
"AdditiveStructureKernel",
|
|
46
|
+
"ConstantKernel",
|
|
47
|
+
"CylindricalKernel",
|
|
48
|
+
"MultiDeviceKernel",
|
|
49
|
+
"CosineKernel",
|
|
50
|
+
"DistributionalInputKernel",
|
|
51
|
+
"GaussianSymmetrizedKLKernel",
|
|
52
|
+
"GridKernel",
|
|
53
|
+
"GridInterpolationKernel",
|
|
54
|
+
"HammingIMQKernel",
|
|
55
|
+
"IndexKernel",
|
|
56
|
+
"InducingPointKernel",
|
|
57
|
+
"LCMKernel",
|
|
58
|
+
"LinearKernel",
|
|
59
|
+
"MaternKernel",
|
|
60
|
+
"MultitaskKernel",
|
|
61
|
+
"NewtonGirardAdditiveKernel",
|
|
62
|
+
"PeriodicKernel",
|
|
63
|
+
"PiecewisePolynomialKernel",
|
|
64
|
+
"PolynomialKernel",
|
|
65
|
+
"PolynomialKernelGrad",
|
|
66
|
+
"ProductKernel",
|
|
67
|
+
"ProductStructureKernel",
|
|
68
|
+
"QExponentialSymmetrizedKLKernel",
|
|
69
|
+
"RBFKernel",
|
|
70
|
+
"RFFKernel",
|
|
71
|
+
"RBFKernelGrad",
|
|
72
|
+
"RBFKernelGradGrad",
|
|
73
|
+
"RQKernel",
|
|
74
|
+
"ScaleKernel",
|
|
75
|
+
"SpectralDeltaKernel",
|
|
76
|
+
"SpectralMixtureKernel",
|
|
77
|
+
"Matern32KernelGrad",
|
|
78
|
+
"Matern52KernelGrad",
|
|
79
|
+
"Matern52KernelGradGrad",
|
|
80
|
+
]
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from linear_operator import to_linear_operator
|
|
7
|
+
from linear_operator.operators import InterpolatedLinearOperator
|
|
8
|
+
|
|
9
|
+
from ..models.exact_prediction_strategies import InterpolatedPredictionStrategy
|
|
10
|
+
from gpytorch.utils.grid import create_grid
|
|
11
|
+
from gpytorch.utils.interpolation import Interpolation
|
|
12
|
+
from gpytorch.kernels.grid_kernel import GridKernel
|
|
13
|
+
from .kernel import Kernel
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GridInterpolationKernel(GridKernel):
|
|
17
|
+
r"""
|
|
18
|
+
Implements the KISS-QEP (or SKI) approximation for a given kernel.
|
|
19
|
+
It was proposed in `Kernel Interpolation for Scalable Structured Gaussian Processes`_,
|
|
20
|
+
and offers extremely fast and accurate Kernel approximations for large datasets.
|
|
21
|
+
|
|
22
|
+
Given a base kernel `k`, the covariance :math:`k(\mathbf{x_1}, \mathbf{x_2})` is approximated by
|
|
23
|
+
using a grid of regularly spaced *inducing points*:
|
|
24
|
+
|
|
25
|
+
.. math::
|
|
26
|
+
|
|
27
|
+
\begin{equation*}
|
|
28
|
+
k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{U,U} \mathbf{w_{x_2}}
|
|
29
|
+
\end{equation*}
|
|
30
|
+
|
|
31
|
+
where
|
|
32
|
+
|
|
33
|
+
* :math:`U` is the set of gridded inducing points
|
|
34
|
+
|
|
35
|
+
* :math:`K_{U,U}` is the kernel matrix between the inducing points
|
|
36
|
+
|
|
37
|
+
* :math:`\mathbf{w_{x_1}}` and :math:`\mathbf{w_{x_2}}` are sparse vectors based on
|
|
38
|
+
:math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` that apply cubic interpolation.
|
|
39
|
+
|
|
40
|
+
The user should supply the size of the grid (using the grid_size attribute).
|
|
41
|
+
To choose a reasonable grid value, we highly recommend using the
|
|
42
|
+
:func:`gpytorch.utils.grid.choose_grid_size` helper function.
|
|
43
|
+
The bounds of the grid will automatically be determined by data.
|
|
44
|
+
|
|
45
|
+
(Alternatively, you can hard-code bounds using the grid_bounds, which
|
|
46
|
+
will speed up this kernel's computations.)
|
|
47
|
+
|
|
48
|
+
.. note::
|
|
49
|
+
|
|
50
|
+
`GridInterpolationKernel` can only wrap **stationary kernels** (such as RBF, Matern,
|
|
51
|
+
Periodic, Spectral Mixture, etc.)
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
base_kernel (Kernel):
|
|
55
|
+
The kernel to approximate with KISS-QEP
|
|
56
|
+
grid_size (Union[int, List[int]]):
|
|
57
|
+
The size of the grid in each dimension.
|
|
58
|
+
If a single int is provided, then every dimension will have the same grid size.
|
|
59
|
+
num_dims (int):
|
|
60
|
+
The dimension of the input data. Required if `grid_bounds=None`
|
|
61
|
+
grid_bounds (tuple(float, float), optional):
|
|
62
|
+
The bounds of the grid, if known (high performance mode).
|
|
63
|
+
The length of the tuple must match the number of dimensions.
|
|
64
|
+
The entries represent the min/max values for each dimension.
|
|
65
|
+
active_dims (tuple of ints, optional):
|
|
66
|
+
Passed down to the `base_kernel`.
|
|
67
|
+
|
|
68
|
+
.. _Kernel Interpolation for Scalable Structured Gaussian Processes:
|
|
69
|
+
http://proceedings.mlr.press/v37/wilson15.pdf
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
base_kernel: Kernel,
|
|
75
|
+
grid_size: Union[int, List[int]],
|
|
76
|
+
num_dims: Optional[int] = None,
|
|
77
|
+
grid_bounds: Optional[Tuple[float, float]] = None,
|
|
78
|
+
active_dims: Optional[Tuple[int, ...]] = None,
|
|
79
|
+
):
|
|
80
|
+
has_initialized_grid = 0
|
|
81
|
+
grid_is_dynamic = True
|
|
82
|
+
|
|
83
|
+
# Make some temporary grid bounds, if none exist
|
|
84
|
+
if grid_bounds is None:
|
|
85
|
+
if num_dims is None:
|
|
86
|
+
raise RuntimeError("num_dims must be supplied if grid_bounds is None")
|
|
87
|
+
else:
|
|
88
|
+
# Create some temporary grid bounds - they'll be changed soon
|
|
89
|
+
grid_bounds = tuple((-1.0, 1.0) for _ in range(num_dims))
|
|
90
|
+
else:
|
|
91
|
+
has_initialized_grid = 1
|
|
92
|
+
grid_is_dynamic = False
|
|
93
|
+
if num_dims is None:
|
|
94
|
+
num_dims = len(grid_bounds)
|
|
95
|
+
elif num_dims != len(grid_bounds):
|
|
96
|
+
raise RuntimeError(
|
|
97
|
+
"num_dims ({}) disagrees with the number of supplied "
|
|
98
|
+
"grid_bounds ({})".format(num_dims, len(grid_bounds))
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
if isinstance(grid_size, int):
|
|
102
|
+
grid_sizes = [grid_size for _ in range(num_dims)]
|
|
103
|
+
else:
|
|
104
|
+
grid_sizes = list(grid_size)
|
|
105
|
+
|
|
106
|
+
if len(grid_sizes) != num_dims:
|
|
107
|
+
raise RuntimeError("The number of grid sizes provided through grid_size do not match num_dims.")
|
|
108
|
+
|
|
109
|
+
# Initialize values and the grid
|
|
110
|
+
self.grid_is_dynamic = grid_is_dynamic
|
|
111
|
+
self.num_dims = num_dims
|
|
112
|
+
self.grid_sizes = grid_sizes
|
|
113
|
+
self.grid_bounds = grid_bounds
|
|
114
|
+
grid = create_grid(self.grid_sizes, self.grid_bounds)
|
|
115
|
+
|
|
116
|
+
super(GridInterpolationKernel, self).__init__(
|
|
117
|
+
base_kernel=base_kernel,
|
|
118
|
+
grid=grid,
|
|
119
|
+
interpolation_mode=True,
|
|
120
|
+
active_dims=active_dims,
|
|
121
|
+
)
|
|
122
|
+
self.register_buffer("has_initialized_grid", torch.tensor(has_initialized_grid, dtype=torch.bool))
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def _tight_grid_bounds(self):
|
|
126
|
+
grid_spacings = tuple((bound[1] - bound[0]) / self.grid_sizes[i] for i, bound in enumerate(self.grid_bounds))
|
|
127
|
+
return tuple(
|
|
128
|
+
(bound[0] + 2.01 * spacing, bound[1] - 2.01 * spacing)
|
|
129
|
+
for bound, spacing in zip(self.grid_bounds, grid_spacings)
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def _compute_grid(self, inputs, last_dim_is_batch=False):
|
|
133
|
+
n_data, n_dimensions = inputs.size(-2), inputs.size(-1)
|
|
134
|
+
if last_dim_is_batch:
|
|
135
|
+
inputs = inputs.transpose(-1, -2).unsqueeze(-1)
|
|
136
|
+
n_dimensions = 1
|
|
137
|
+
batch_shape = inputs.shape[:-2]
|
|
138
|
+
|
|
139
|
+
inputs = inputs.reshape(-1, n_dimensions)
|
|
140
|
+
interp_indices, interp_values = Interpolation().interpolate(self.grid, inputs)
|
|
141
|
+
interp_indices = interp_indices.view(*batch_shape, n_data, -1)
|
|
142
|
+
interp_values = interp_values.view(*batch_shape, n_data, -1)
|
|
143
|
+
return interp_indices, interp_values
|
|
144
|
+
|
|
145
|
+
def _inducing_forward(self, last_dim_is_batch, **params):
|
|
146
|
+
return super().forward(self.grid, self.grid, last_dim_is_batch=last_dim_is_batch, **params)
|
|
147
|
+
|
|
148
|
+
def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
|
|
149
|
+
# See if we need to update the grid or not
|
|
150
|
+
if self.grid_is_dynamic: # This is true if a grid_bounds wasn't passed in
|
|
151
|
+
if torch.equal(x1, x2):
|
|
152
|
+
x = x1.reshape(-1, self.num_dims)
|
|
153
|
+
else:
|
|
154
|
+
x = torch.cat([x1.reshape(-1, self.num_dims), x2.reshape(-1, self.num_dims)])
|
|
155
|
+
x_maxs = x.max(0)[0].tolist()
|
|
156
|
+
x_mins = x.min(0)[0].tolist()
|
|
157
|
+
|
|
158
|
+
# We need to update the grid if
|
|
159
|
+
# 1) it hasn't ever been initialized, or
|
|
160
|
+
# 2) if any of the grid points are "out of bounds"
|
|
161
|
+
update_grid = (not self.has_initialized_grid.item()) or any(
|
|
162
|
+
x_min < bound[0] or x_max > bound[1]
|
|
163
|
+
for x_min, x_max, bound in zip(x_mins, x_maxs, self._tight_grid_bounds)
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Update the grid if needed
|
|
167
|
+
if update_grid:
|
|
168
|
+
grid_spacings = tuple(
|
|
169
|
+
(x_max - x_min) / (gs - 4.02) for gs, x_min, x_max in zip(self.grid_sizes, x_mins, x_maxs)
|
|
170
|
+
)
|
|
171
|
+
self.grid_bounds = tuple(
|
|
172
|
+
(x_min - 2.01 * spacing, x_max + 2.01 * spacing)
|
|
173
|
+
for x_min, x_max, spacing in zip(x_mins, x_maxs, grid_spacings)
|
|
174
|
+
)
|
|
175
|
+
grid = create_grid(
|
|
176
|
+
self.grid_sizes,
|
|
177
|
+
self.grid_bounds,
|
|
178
|
+
dtype=self.grid[0].dtype,
|
|
179
|
+
device=self.grid[0].device,
|
|
180
|
+
)
|
|
181
|
+
self.update_grid(grid)
|
|
182
|
+
|
|
183
|
+
base_lazy_tsr = to_linear_operator(self._inducing_forward(last_dim_is_batch=last_dim_is_batch, **params))
|
|
184
|
+
if last_dim_is_batch and base_lazy_tsr.size(-3) == 1:
|
|
185
|
+
base_lazy_tsr = base_lazy_tsr.repeat(*x1.shape[:-2], x1.size(-1), 1, 1)
|
|
186
|
+
|
|
187
|
+
left_interp_indices, left_interp_values = self._compute_grid(x1, last_dim_is_batch)
|
|
188
|
+
if torch.equal(x1, x2):
|
|
189
|
+
right_interp_indices = left_interp_indices
|
|
190
|
+
right_interp_values = left_interp_values
|
|
191
|
+
else:
|
|
192
|
+
right_interp_indices, right_interp_values = self._compute_grid(x2, last_dim_is_batch)
|
|
193
|
+
|
|
194
|
+
batch_shape = torch.broadcast_shapes(
|
|
195
|
+
base_lazy_tsr.batch_shape,
|
|
196
|
+
left_interp_indices.shape[:-2],
|
|
197
|
+
right_interp_indices.shape[:-2],
|
|
198
|
+
)
|
|
199
|
+
res = InterpolatedLinearOperator(
|
|
200
|
+
base_lazy_tsr.expand(*batch_shape, *base_lazy_tsr.matrix_shape),
|
|
201
|
+
left_interp_indices.detach().expand(*batch_shape, *left_interp_indices.shape[-2:]),
|
|
202
|
+
left_interp_values.expand(*batch_shape, *left_interp_values.shape[-2:]),
|
|
203
|
+
right_interp_indices.detach().expand(*batch_shape, *right_interp_indices.shape[-2:]),
|
|
204
|
+
right_interp_values.expand(*batch_shape, *right_interp_values.shape[-2:]),
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
if diag:
|
|
208
|
+
return res.diagonal(dim1=-1, dim2=-2)
|
|
209
|
+
else:
|
|
210
|
+
return res
|
|
211
|
+
|
|
212
|
+
def prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood):
|
|
213
|
+
return InterpolatedPredictionStrategy(train_inputs, train_prior_dist, train_labels, likelihood)
|