qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,113 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ from ..module import Module
10
+ from ..constraints import Interval, Positive
11
+ from ..priors import Prior
12
+
13
+ class Power(Module):
14
+ """
15
+ Constructs a power parameter for the (multivariate) q-exponential distribution.
16
+ See :class:`qpytorch.distributions.QExponential` or :class:`qpytorch.distributions.MultivariateQExponential`
17
+ for description of the power parameter.
18
+
19
+ .. note::
20
+
21
+ This object works similarly as a hyperparameter of kernel, which can be imposed with a prior and optimized over.
22
+
23
+ :param power_init: initial value of power parameter of qep distribution. (Default: 1.0)
24
+ :param power_constraint: Set this if you want to apply a constraint to the power parameter.
25
+ (Default: :class:`~qpytorch.constraints.Positive`.)
26
+ :param power_prior: Set this if you want to apply a prior to the power parameter.
27
+ (Default: `None`.)
28
+
29
+ :ivar torch.Size shape:
30
+ The dimension of the power object.
31
+ :ivar torch.Tensor power:
32
+ The power parameter. The size/shape is the same as the `power_init` argument.
33
+ :ivar torch.Tensor data:
34
+ The data of the power object in :obj:`torch.tensor` format.
35
+
36
+ Example:
37
+ >>> power_init = torch.tensor(1.0)
38
+ >>> power_prior = qpytorch.priors.GammaPrior(4.0, 2.0)
39
+ >>> power = qpytorch.distributions.Power(power_init, power_prior=power_prior)
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ power_init: torch.Tensor = torch.tensor(1.0),
45
+ power_constraint: Optional[Interval] = None,
46
+ power_prior: Optional[Prior] = None
47
+ ):
48
+ super(Power, self).__init__()
49
+ if power_constraint is None:
50
+ power_constraint = Positive()
51
+
52
+ # set parameter
53
+ self.register_parameter(
54
+ name="raw_power",
55
+ parameter=torch.nn.Parameter(power_constraint.inverse_transform(power_init))
56
+ )
57
+ self.shape = self.raw_power.shape
58
+ # set constraint
59
+ self.register_constraint("raw_power", power_constraint)
60
+ # set prior
61
+ if power_prior is not None:
62
+ if not isinstance(power_prior, Prior):
63
+ raise TypeError("Expected qpytorch.priors.Prior but got " + type(power_prior).__name__)
64
+ self.register_prior("power_prior", power_prior, self._power_param, self._power_closure)
65
+
66
+ def _power_param(self, q: Power) -> torch.Tensor:
67
+ # Used by the raw_power
68
+ return q.power
69
+
70
+ def _power_closure(self, q: Power, v: torch.Tensor) -> torch.Tensor:
71
+ # Used by the raw_power
72
+ return q._set_power(v)
73
+
74
+ @property
75
+ def power(self) -> torch.Tensor:
76
+ return self.raw_power_constraint.transform(self.raw_power)
77
+
78
+ @power.setter
79
+ def power(self, value: torch.Tensor) -> torch.Tensor:
80
+ self._set_power(value)
81
+
82
+ def _set_power(self, value: torch.Tensor):
83
+ if not torch.is_tensor(value):
84
+ value = torch.as_tensor(value).to(self.raw_power)
85
+ self.initialize(raw_power=self.raw_power_constraint.inverse_transform(value))
86
+
87
+ @property
88
+ def data(self) -> torch.Tensor:
89
+ return self.power.data
90
+
91
+ def __truediv__(self, other):
92
+ return self.power/other
93
+
94
+ def __rtruediv__(self, other):
95
+ return other/self.power
96
+
97
+ def __rpow__(self, other):
98
+ return other**self.power
99
+
100
+ def __ne__(self, other):
101
+ return self.power!=other
102
+
103
+ def __lt__(self, other):
104
+ return self.power<other
105
+
106
+ def __gt__(self, other):
107
+ return self.power>other
108
+
109
+ def numel(self):
110
+ return self.power.numel()
111
+
112
+ def size(self):
113
+ return self.power.size()
@@ -0,0 +1,153 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import math
4
+ from numbers import Number, Real
5
+
6
+ import torch
7
+ from torch.distributions import constraints, Chi2
8
+ from torch.distributions.exp_family import ExponentialFamily
9
+ from torch.distributions.kl import register_kl
10
+ from torch.distributions.utils import _standard_normal, broadcast_all
11
+
12
+ from gpytorch.distributions.distribution import Distribution
13
+
14
+ __all__ = ["QExponential"]
15
+
16
+
17
+ class QExponential(ExponentialFamily, Distribution):
18
+ r"""
19
+ Creates a q-exponential distribution parameterized by
20
+ :attr:`loc`, :attr:`scale` and :attr:`power`, with the following density
21
+
22
+ .. math::
23
+
24
+ p(x; \mu, \sigma^2) = \frac{q}{2}(2\pi\sigma^2)^{-\frac{1}{2}}
25
+ \left|\frac{x-\mu}{\sigma}\right|^{\frac{q}{2}-1}
26
+ \exp\left\{-\frac{1}{2}\left|\frac{x-\mu}{\sigma}\right|^q\right\}
27
+
28
+ Example::
29
+
30
+ >>> # xdoctest: +IGNORE_WANT("non-deterministic")
31
+ >>> m = QExponential(torch.tensor([0.0]), torch.tensor([1.0]))
32
+ >>> m.sample() # q-exponentially distributed with loc=0, scale=1 and power=2
33
+ tensor([ 0.1046])
34
+
35
+ Args:
36
+ loc (float or Tensor): mean of the distribution (often referred to as mu)
37
+ scale (float or Tensor): standard deviation of the distribution
38
+ (often referred to as sigma)
39
+ power (float or Tensor): power of the distribution
40
+ """
41
+ arg_constraints = {"loc": constraints.real, "scale": constraints.positive, "power": constraints.positive}
42
+ support = constraints.real
43
+ has_rsample = True
44
+ _mean_carrier_measure = 0
45
+
46
+ @property
47
+ def mean(self):
48
+ return self.loc
49
+
50
+ @property
51
+ def mode(self):
52
+ return self.loc
53
+
54
+ @property
55
+ def stddev(self):
56
+ return self.scale
57
+
58
+ @property
59
+ def variance(self):
60
+ return self.stddev.pow(2)
61
+
62
+ @property
63
+ def rescalor(self):
64
+ return torch.exp((2./self.power*math.log(2) + torch.lgamma(0.5+2./self.power) - math.log(math.pi)/2.)/2.)
65
+
66
+ def __init__(self, loc, scale, power=torch.tensor(2.0), validate_args=None):
67
+ self.loc, self.scale = broadcast_all(loc, scale)
68
+ if isinstance(loc, Number) and isinstance(scale, Number):
69
+ batch_shape = torch.Size()
70
+ else:
71
+ batch_shape = self.loc.size()
72
+ self.power = power
73
+ super().__init__(batch_shape, validate_args=validate_args)
74
+
75
+ def confidence(self, alpha=0.05):
76
+ lower = self.icdf(torch.tensor(alpha/2))
77
+ upper = self.icdf(torch.tensor(1-alpha/2))
78
+ return lower, upper
79
+
80
+ def expand(self, batch_shape, _instance=None):
81
+ new = self._get_checked_instance(QExponential, _instance)
82
+ batch_shape = torch.Size(batch_shape)
83
+ new.loc = self.loc.expand(batch_shape)
84
+ new.scale = self.scale.expand(batch_shape)
85
+ super(QExponential, new).__init__(batch_shape, validate_args=False)
86
+ new._validate_args = self._validate_args
87
+ return new
88
+
89
+ def sample(self, sample_shape=torch.Size(), rescale=False):
90
+ shape = self._extended_shape(sample_shape)
91
+ with torch.no_grad():
92
+ eps = Chi2(1).sample(shape).to(self.loc.device)**(1./self.power) * _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device).sign()
93
+ if rescale: eps /= self.rescalor
94
+ return self.loc.expand(shape) + eps * self.scale.expand(shape)
95
+
96
+ def rsample(self, sample_shape=torch.Size(), rescale=False):
97
+ shape = self._extended_shape(sample_shape)
98
+ eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
99
+ if self.power!=2: eps = eps.abs()**(2./self.power-1) * eps
100
+ if rescale: eps /= self.rescalor
101
+ return self.loc + eps * self.scale
102
+
103
+ def log_prob(self, value):
104
+ if self._validate_args:
105
+ self._validate_sample(value)
106
+ log_scale = (
107
+ math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log()
108
+ )
109
+ scaled_diff = ((value - self.loc) / self.scale).abs()
110
+ res = -.5* ( scaled_diff**self.power + math.log(2 * math.pi) ) - log_scale
111
+ if self.power!=2: res += (self.power/2.-1)*scaled_diff.log() + torch.log(self.power/2.)
112
+ return res
113
+
114
+ def cdf(self, value):
115
+ if self._validate_args:
116
+ self._validate_sample(value)
117
+ scaled_diff = (value - self.loc) * self.scale.reciprocal()
118
+ if self.power!=2: scaled_diff *= scaled_diff.abs()**(self.power/2.-1)
119
+ return 0.5 * (
120
+ 1 + torch.erf(scaled_diff / math.sqrt(2))
121
+ )
122
+
123
+ def icdf(self, value):
124
+ erfinv = torch.erfinv(2 * value - 1) * math.sqrt(2)
125
+ if self.power!=2: erfinv *= erfinv.abs()**(2./self.power-1)
126
+ return self.loc + self.scale * erfinv
127
+
128
+ def entropy(self, exact=False):
129
+ res = 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale)
130
+ if self.power!=2: res += 0.5*(self.power/2.-1) *(2./self.power* Chi2(1).entropy() if exact else 0) - torch.log(self.power/2.)
131
+ return res
132
+
133
+ @property
134
+ def _natural_params(self):
135
+ if self.power!=2:
136
+ raise ValueError(f"Q-Exponential distribution with power {self.power} does not belong to exponential family!")
137
+ else:
138
+ return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())
139
+
140
+ def _log_normalizer(self, x, y):
141
+ if self.power!=2:
142
+ raise ValueError(f"Q-Exponential distribution with power {self.power} does not belong to exponential family!")
143
+ else:
144
+ return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)
145
+
146
+
147
+ @register_kl(QExponential, QExponential)
148
+ def _kl_qexponential_qexponential(p, q, exact=False):
149
+ var_ratio = (p.scale / q.scale).pow(2)
150
+ t1 = ((p.loc - q.loc) / q.scale).pow(2)
151
+ res = 0.5 * ((var_ratio + t1).pow(q.power/2.) - 1 - var_ratio.log())
152
+ if q.power!=2: res += 0.5 * ( -(q.power/2.-1)*torch.log(var_ratio + t1) + (p.power/2.-1) * (-2./p.power*Chi2(1).entropy() if exact else 0) )
153
+ return res
@@ -0,0 +1,58 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from __future__ import annotations
4
+
5
+ import warnings
6
+ from typing import Any
7
+
8
+ import linear_operator
9
+ import torch
10
+
11
+ from gpytorch.functions._log_normal_cdf import LogNormalCDF
12
+ from gpytorch.functions.matern_covariance import MaternCovariance
13
+ from gpytorch.functions.rbf_covariance import RBFCovariance
14
+
15
+
16
+ def log_normal_cdf(x):
17
+ """
18
+ Computes the element-wise log standard normal CDF of an input tensor x.
19
+
20
+ This function should always be preferred over calling normal_cdf and taking the log
21
+ manually, as it is more numerically stable.
22
+ """
23
+ return LogNormalCDF.apply(x)
24
+
25
+
26
+ def logdet(mat):
27
+ warnings.warn("gpytorch.logdet is deprecated. Use torch.logdet instead.", DeprecationWarning)
28
+ return torch.logdet(mat)
29
+
30
+
31
+ def matmul(mat, rhs):
32
+ warnings.warn("gpytorch.matmul is deprecated. Use torch.matmul instead.", DeprecationWarning)
33
+ return torch.matmul(mat, rhs)
34
+
35
+
36
+ def inv_matmul(mat, right_tensor, left_tensor=None):
37
+ warnings.warn("gpytorch.inv_matmul is deprecated. Use gpytorch.solve instead.", DeprecationWarning)
38
+ return linear_operator.solve(right_tensor, left_tensor=None)
39
+
40
+
41
+ __all__ = [
42
+ "MaternCovariance",
43
+ "RBFCovariance",
44
+ "inv_matmul",
45
+ "logdet",
46
+ "log_normal_cdf",
47
+ "matmul",
48
+ ]
49
+
50
+
51
+ def __getattr__(name: str) -> Any:
52
+ if hasattr(linear_operator.functions, name):
53
+ warnings.warn(
54
+ f"gpytorch.functions.{name} is deprecated. Use linear_operator.functions.{name} instead.",
55
+ DeprecationWarning,
56
+ )
57
+ return getattr(linear_operator.functions, name)
58
+ raise AttributeError(f"module gpytorch.functions has no attribute {name}.")
@@ -0,0 +1,80 @@
1
+ #!/usr/bin/env python3
2
+ from gpytorch.kernels import keops
3
+ from gpytorch.kernels.additive_structure_kernel import AdditiveStructureKernel
4
+ from gpytorch.kernels.arc_kernel import ArcKernel
5
+ from gpytorch.kernels.constant_kernel import ConstantKernel
6
+ from gpytorch.kernels.cosine_kernel import CosineKernel
7
+ from gpytorch.kernels.cylindrical_kernel import CylindricalKernel
8
+ from gpytorch.kernels.distributional_input_kernel import DistributionalInputKernel
9
+ from gpytorch.kernels.gaussian_symmetrized_kl_kernel import GaussianSymmetrizedKLKernel
10
+ from .grid_interpolation_kernel import GridInterpolationKernel
11
+ from gpytorch.kernels.grid_kernel import GridKernel
12
+ from gpytorch.kernels.hamming_kernel import HammingIMQKernel
13
+ from gpytorch.kernels.index_kernel import IndexKernel
14
+ from .inducing_point_kernel import InducingPointKernel
15
+ from .kernel import AdditiveKernel, Kernel, ProductKernel
16
+ from gpytorch.kernels.lcm_kernel import LCMKernel
17
+ from gpytorch.kernels.linear_kernel import LinearKernel
18
+ from .matern32_kernel_grad import Matern32KernelGrad
19
+ from .matern52_kernel_grad import Matern52KernelGrad
20
+ from .matern52_kernel_gradgrad import Matern52KernelGradGrad
21
+ from gpytorch.kernels.matern_kernel import MaternKernel
22
+ from gpytorch.kernels.multi_device_kernel import MultiDeviceKernel
23
+ from gpytorch.kernels.multitask_kernel import MultitaskKernel
24
+ from gpytorch.kernels.newton_girard_additive_kernel import NewtonGirardAdditiveKernel
25
+ from gpytorch.kernels.periodic_kernel import PeriodicKernel
26
+ from gpytorch.kernels.piecewise_polynomial_kernel import PiecewisePolynomialKernel
27
+ from gpytorch.kernels.polynomial_kernel import PolynomialKernel
28
+ from .polynomial_kernel_grad import PolynomialKernelGrad
29
+ from gpytorch.kernels.product_structure_kernel import ProductStructureKernel
30
+ from .qexponential_symmetrized_kl_kernel import QExponentialSymmetrizedKLKernel
31
+ from gpytorch.kernels.rbf_kernel import RBFKernel
32
+ from .rbf_kernel_grad import RBFKernelGrad
33
+ from .rbf_kernel_gradgrad import RBFKernelGradGrad
34
+ from .rff_kernel import RFFKernel
35
+ from gpytorch.kernels.rq_kernel import RQKernel
36
+ from gpytorch.kernels.scale_kernel import ScaleKernel
37
+ from gpytorch.kernels.spectral_delta_kernel import SpectralDeltaKernel
38
+ from gpytorch.kernels.spectral_mixture_kernel import SpectralMixtureKernel
39
+
40
+ __all__ = [
41
+ "keops",
42
+ "Kernel",
43
+ "ArcKernel",
44
+ "AdditiveKernel",
45
+ "AdditiveStructureKernel",
46
+ "ConstantKernel",
47
+ "CylindricalKernel",
48
+ "MultiDeviceKernel",
49
+ "CosineKernel",
50
+ "DistributionalInputKernel",
51
+ "GaussianSymmetrizedKLKernel",
52
+ "GridKernel",
53
+ "GridInterpolationKernel",
54
+ "HammingIMQKernel",
55
+ "IndexKernel",
56
+ "InducingPointKernel",
57
+ "LCMKernel",
58
+ "LinearKernel",
59
+ "MaternKernel",
60
+ "MultitaskKernel",
61
+ "NewtonGirardAdditiveKernel",
62
+ "PeriodicKernel",
63
+ "PiecewisePolynomialKernel",
64
+ "PolynomialKernel",
65
+ "PolynomialKernelGrad",
66
+ "ProductKernel",
67
+ "ProductStructureKernel",
68
+ "QExponentialSymmetrizedKLKernel",
69
+ "RBFKernel",
70
+ "RFFKernel",
71
+ "RBFKernelGrad",
72
+ "RBFKernelGradGrad",
73
+ "RQKernel",
74
+ "ScaleKernel",
75
+ "SpectralDeltaKernel",
76
+ "SpectralMixtureKernel",
77
+ "Matern32KernelGrad",
78
+ "Matern52KernelGrad",
79
+ "Matern52KernelGradGrad",
80
+ ]
@@ -0,0 +1,213 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from linear_operator import to_linear_operator
7
+ from linear_operator.operators import InterpolatedLinearOperator
8
+
9
+ from ..models.exact_prediction_strategies import InterpolatedPredictionStrategy
10
+ from gpytorch.utils.grid import create_grid
11
+ from gpytorch.utils.interpolation import Interpolation
12
+ from gpytorch.kernels.grid_kernel import GridKernel
13
+ from .kernel import Kernel
14
+
15
+
16
+ class GridInterpolationKernel(GridKernel):
17
+ r"""
18
+ Implements the KISS-QEP (or SKI) approximation for a given kernel.
19
+ It was proposed in `Kernel Interpolation for Scalable Structured Gaussian Processes`_,
20
+ and offers extremely fast and accurate Kernel approximations for large datasets.
21
+
22
+ Given a base kernel `k`, the covariance :math:`k(\mathbf{x_1}, \mathbf{x_2})` is approximated by
23
+ using a grid of regularly spaced *inducing points*:
24
+
25
+ .. math::
26
+
27
+ \begin{equation*}
28
+ k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{U,U} \mathbf{w_{x_2}}
29
+ \end{equation*}
30
+
31
+ where
32
+
33
+ * :math:`U` is the set of gridded inducing points
34
+
35
+ * :math:`K_{U,U}` is the kernel matrix between the inducing points
36
+
37
+ * :math:`\mathbf{w_{x_1}}` and :math:`\mathbf{w_{x_2}}` are sparse vectors based on
38
+ :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` that apply cubic interpolation.
39
+
40
+ The user should supply the size of the grid (using the grid_size attribute).
41
+ To choose a reasonable grid value, we highly recommend using the
42
+ :func:`gpytorch.utils.grid.choose_grid_size` helper function.
43
+ The bounds of the grid will automatically be determined by data.
44
+
45
+ (Alternatively, you can hard-code bounds using the grid_bounds, which
46
+ will speed up this kernel's computations.)
47
+
48
+ .. note::
49
+
50
+ `GridInterpolationKernel` can only wrap **stationary kernels** (such as RBF, Matern,
51
+ Periodic, Spectral Mixture, etc.)
52
+
53
+ Args:
54
+ base_kernel (Kernel):
55
+ The kernel to approximate with KISS-QEP
56
+ grid_size (Union[int, List[int]]):
57
+ The size of the grid in each dimension.
58
+ If a single int is provided, then every dimension will have the same grid size.
59
+ num_dims (int):
60
+ The dimension of the input data. Required if `grid_bounds=None`
61
+ grid_bounds (tuple(float, float), optional):
62
+ The bounds of the grid, if known (high performance mode).
63
+ The length of the tuple must match the number of dimensions.
64
+ The entries represent the min/max values for each dimension.
65
+ active_dims (tuple of ints, optional):
66
+ Passed down to the `base_kernel`.
67
+
68
+ .. _Kernel Interpolation for Scalable Structured Gaussian Processes:
69
+ http://proceedings.mlr.press/v37/wilson15.pdf
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ base_kernel: Kernel,
75
+ grid_size: Union[int, List[int]],
76
+ num_dims: Optional[int] = None,
77
+ grid_bounds: Optional[Tuple[float, float]] = None,
78
+ active_dims: Optional[Tuple[int, ...]] = None,
79
+ ):
80
+ has_initialized_grid = 0
81
+ grid_is_dynamic = True
82
+
83
+ # Make some temporary grid bounds, if none exist
84
+ if grid_bounds is None:
85
+ if num_dims is None:
86
+ raise RuntimeError("num_dims must be supplied if grid_bounds is None")
87
+ else:
88
+ # Create some temporary grid bounds - they'll be changed soon
89
+ grid_bounds = tuple((-1.0, 1.0) for _ in range(num_dims))
90
+ else:
91
+ has_initialized_grid = 1
92
+ grid_is_dynamic = False
93
+ if num_dims is None:
94
+ num_dims = len(grid_bounds)
95
+ elif num_dims != len(grid_bounds):
96
+ raise RuntimeError(
97
+ "num_dims ({}) disagrees with the number of supplied "
98
+ "grid_bounds ({})".format(num_dims, len(grid_bounds))
99
+ )
100
+
101
+ if isinstance(grid_size, int):
102
+ grid_sizes = [grid_size for _ in range(num_dims)]
103
+ else:
104
+ grid_sizes = list(grid_size)
105
+
106
+ if len(grid_sizes) != num_dims:
107
+ raise RuntimeError("The number of grid sizes provided through grid_size do not match num_dims.")
108
+
109
+ # Initialize values and the grid
110
+ self.grid_is_dynamic = grid_is_dynamic
111
+ self.num_dims = num_dims
112
+ self.grid_sizes = grid_sizes
113
+ self.grid_bounds = grid_bounds
114
+ grid = create_grid(self.grid_sizes, self.grid_bounds)
115
+
116
+ super(GridInterpolationKernel, self).__init__(
117
+ base_kernel=base_kernel,
118
+ grid=grid,
119
+ interpolation_mode=True,
120
+ active_dims=active_dims,
121
+ )
122
+ self.register_buffer("has_initialized_grid", torch.tensor(has_initialized_grid, dtype=torch.bool))
123
+
124
+ @property
125
+ def _tight_grid_bounds(self):
126
+ grid_spacings = tuple((bound[1] - bound[0]) / self.grid_sizes[i] for i, bound in enumerate(self.grid_bounds))
127
+ return tuple(
128
+ (bound[0] + 2.01 * spacing, bound[1] - 2.01 * spacing)
129
+ for bound, spacing in zip(self.grid_bounds, grid_spacings)
130
+ )
131
+
132
+ def _compute_grid(self, inputs, last_dim_is_batch=False):
133
+ n_data, n_dimensions = inputs.size(-2), inputs.size(-1)
134
+ if last_dim_is_batch:
135
+ inputs = inputs.transpose(-1, -2).unsqueeze(-1)
136
+ n_dimensions = 1
137
+ batch_shape = inputs.shape[:-2]
138
+
139
+ inputs = inputs.reshape(-1, n_dimensions)
140
+ interp_indices, interp_values = Interpolation().interpolate(self.grid, inputs)
141
+ interp_indices = interp_indices.view(*batch_shape, n_data, -1)
142
+ interp_values = interp_values.view(*batch_shape, n_data, -1)
143
+ return interp_indices, interp_values
144
+
145
+ def _inducing_forward(self, last_dim_is_batch, **params):
146
+ return super().forward(self.grid, self.grid, last_dim_is_batch=last_dim_is_batch, **params)
147
+
148
+ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
149
+ # See if we need to update the grid or not
150
+ if self.grid_is_dynamic: # This is true if a grid_bounds wasn't passed in
151
+ if torch.equal(x1, x2):
152
+ x = x1.reshape(-1, self.num_dims)
153
+ else:
154
+ x = torch.cat([x1.reshape(-1, self.num_dims), x2.reshape(-1, self.num_dims)])
155
+ x_maxs = x.max(0)[0].tolist()
156
+ x_mins = x.min(0)[0].tolist()
157
+
158
+ # We need to update the grid if
159
+ # 1) it hasn't ever been initialized, or
160
+ # 2) if any of the grid points are "out of bounds"
161
+ update_grid = (not self.has_initialized_grid.item()) or any(
162
+ x_min < bound[0] or x_max > bound[1]
163
+ for x_min, x_max, bound in zip(x_mins, x_maxs, self._tight_grid_bounds)
164
+ )
165
+
166
+ # Update the grid if needed
167
+ if update_grid:
168
+ grid_spacings = tuple(
169
+ (x_max - x_min) / (gs - 4.02) for gs, x_min, x_max in zip(self.grid_sizes, x_mins, x_maxs)
170
+ )
171
+ self.grid_bounds = tuple(
172
+ (x_min - 2.01 * spacing, x_max + 2.01 * spacing)
173
+ for x_min, x_max, spacing in zip(x_mins, x_maxs, grid_spacings)
174
+ )
175
+ grid = create_grid(
176
+ self.grid_sizes,
177
+ self.grid_bounds,
178
+ dtype=self.grid[0].dtype,
179
+ device=self.grid[0].device,
180
+ )
181
+ self.update_grid(grid)
182
+
183
+ base_lazy_tsr = to_linear_operator(self._inducing_forward(last_dim_is_batch=last_dim_is_batch, **params))
184
+ if last_dim_is_batch and base_lazy_tsr.size(-3) == 1:
185
+ base_lazy_tsr = base_lazy_tsr.repeat(*x1.shape[:-2], x1.size(-1), 1, 1)
186
+
187
+ left_interp_indices, left_interp_values = self._compute_grid(x1, last_dim_is_batch)
188
+ if torch.equal(x1, x2):
189
+ right_interp_indices = left_interp_indices
190
+ right_interp_values = left_interp_values
191
+ else:
192
+ right_interp_indices, right_interp_values = self._compute_grid(x2, last_dim_is_batch)
193
+
194
+ batch_shape = torch.broadcast_shapes(
195
+ base_lazy_tsr.batch_shape,
196
+ left_interp_indices.shape[:-2],
197
+ right_interp_indices.shape[:-2],
198
+ )
199
+ res = InterpolatedLinearOperator(
200
+ base_lazy_tsr.expand(*batch_shape, *base_lazy_tsr.matrix_shape),
201
+ left_interp_indices.detach().expand(*batch_shape, *left_interp_indices.shape[-2:]),
202
+ left_interp_values.expand(*batch_shape, *left_interp_values.shape[-2:]),
203
+ right_interp_indices.detach().expand(*batch_shape, *right_interp_indices.shape[-2:]),
204
+ right_interp_values.expand(*batch_shape, *right_interp_values.shape[-2:]),
205
+ )
206
+
207
+ if diag:
208
+ return res.diagonal(dim1=-1, dim2=-2)
209
+ else:
210
+ return res
211
+
212
+ def prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood):
213
+ return InterpolatedPredictionStrategy(train_inputs, train_prior_dist, train_labels, likelihood)