qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,695 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from __future__ import annotations
4
+
5
+ import warnings
6
+ from abc import abstractmethod
7
+ from copy import deepcopy
8
+ from typing import Callable, Dict, Iterable, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from linear_operator import to_dense, to_linear_operator
12
+ from linear_operator.operators import LinearOperator, ZeroLinearOperator
13
+ from torch import Tensor
14
+ from torch.nn import ModuleList
15
+
16
+ from .. import settings
17
+ from ..constraints import Interval, Positive
18
+ from ..distributions import MultivariateNormal, MultivariateQExponential
19
+ from ..lazy import LazyEvaluatedKernelTensor
20
+ from ..likelihoods import GaussianLikelihood, QExponentialLikelihood
21
+ from ..models import exact_prediction_strategies
22
+ from ..module import Module
23
+ from ..priors import Prior
24
+
25
+
26
+ def sq_dist(x1, x2, x1_eq_x2=False):
27
+ """Equivalent to the square of `torch.cdist` with p=2."""
28
+ # TODO: use torch squared cdist once implemented: https://github.com/pytorch/pytorch/pull/25799
29
+ adjustment = x1.mean(-2, keepdim=True)
30
+ x1 = x1 - adjustment
31
+
32
+ # Compute squared distance matrix using quadratic expansion
33
+ x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
34
+ x1_pad = torch.ones_like(x1_norm)
35
+ if x1_eq_x2 and not x1.requires_grad and not x2.requires_grad:
36
+ x2, x2_norm, x2_pad = x1, x1_norm, x1_pad
37
+ else:
38
+ x2 = x2 - adjustment # x1 and x2 should be identical in all dims except -2 at this point
39
+ x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
40
+ x2_pad = torch.ones_like(x2_norm)
41
+ x1_ = torch.cat([-2.0 * x1, x1_norm, x1_pad], dim=-1)
42
+ x2_ = torch.cat([x2, x2_pad, x2_norm], dim=-1)
43
+ res = x1_.matmul(x2_.transpose(-2, -1))
44
+
45
+ if x1_eq_x2 and not x1.requires_grad and not x2.requires_grad:
46
+ res.diagonal(dim1=-2, dim2=-1).fill_(0)
47
+
48
+ # Zero out negative values
49
+ return res.clamp_min_(0)
50
+
51
+
52
+ def dist(x1, x2, x1_eq_x2=False):
53
+ """
54
+ Equivalent to `torch.cdist` with p=2, but clamps the minimum element to 1e-15.
55
+ """
56
+ if not x1_eq_x2:
57
+ res = torch.cdist(x1, x2)
58
+ return res.clamp_min(1e-15)
59
+ res = sq_dist(x1, x2, x1_eq_x2=x1_eq_x2)
60
+ return res.clamp_min_(1e-30).sqrt_()
61
+
62
+
63
+ # only necessary for legacy purposes
64
+ class Distance(torch.nn.Module):
65
+ def __init__(self, postprocess: Optional[Callable] = None):
66
+ super().__init__()
67
+ if postprocess is not None:
68
+ warnings.warn(
69
+ "The `postprocess` argument is deprecated. "
70
+ "See https://github.com/cornellius-gp/gpytorch/pull/2205 for details.",
71
+ DeprecationWarning,
72
+ )
73
+ self._postprocess = postprocess
74
+
75
+ def _sq_dist(self, x1, x2, x1_eq_x2=False, postprocess=False):
76
+ res = sq_dist(x1, x2, x1_eq_x2=x1_eq_x2)
77
+ return self._postprocess(res) if postprocess else res
78
+
79
+ def _dist(self, x1, x2, x1_eq_x2=False, postprocess=False):
80
+ res = dist(x1, x2, x1_eq_x2=x1_eq_x2)
81
+ return self._postprocess(res) if postprocess else res
82
+
83
+
84
+ class Kernel(Module):
85
+ r"""
86
+ Kernels in GPyTorch/QPyTorch are implemented as a :class:`gpytorch.Module` that, when called on two :class:`torch.Tensor`
87
+ objects :math:`\mathbf x_1` and :math:`\mathbf x_2` returns either a :obj:`torch.Tensor` or a
88
+ :obj:`~linear_operator.operators.LinearOperator` that represents the
89
+ covariance matrix between :math:`\mathbf x_1` and :math:`\mathbf x_2`.
90
+
91
+ In the typical use case, extend this class simply requires implementing a
92
+ :py:meth:`~qpytorch.kernels.Kernel.forward` method.
93
+
94
+ .. note::
95
+ The :py:meth:`~qpytorch.kernels.Kernel.__call__` method does some additional internal work. In particular,
96
+ all kernels are lazily evaluated so that we can index in to the kernel matrix before actually
97
+ computing it. Furthermore, many built-in kernel modules return
98
+ :class:`~linear_operator.LinearOperators` that allow for more efficient
99
+ inference than if we explicitly computed the kernel matrix itself.
100
+
101
+ As a result, if you want to get an actual
102
+ :obj:`torch.tensor` representing the covariance matrix, you may need to call the
103
+ :func:`~linear_operator.operators.LinearOperator.to_dense` method on the output.
104
+
105
+ This base :class:`Kernel` class includes a lengthscale parameter
106
+ :math:`\Theta`, which is used by many common kernel functions.
107
+ There are a few options for the lengthscale:
108
+
109
+ * Default: No lengthscale (i.e. :math:`\Theta` is the identity matrix).
110
+
111
+ * Single lengthscale: One lengthscale can be applied to all input dimensions/batches
112
+ (i.e. :math:`\Theta` is a constant diagonal matrix).
113
+ This is controlled by setting the attribute `has_lengthscale=True`.
114
+
115
+ * ARD: Each input dimension gets its own separate lengthscale
116
+ (i.e. :math:`\Theta` is a non-constant diagonal matrix).
117
+ This is controlled by the `ard_num_dims` keyword argument (as well as `has_lengthscale=True`).
118
+
119
+ In batch mode (i.e. when :math:`\mathbf x_1` and :math:`\mathbf x_2` are batches of input matrices), each
120
+ batch of data can have its own lengthscale parameter by setting the `batch_shape`
121
+ keyword argument to the appropriate number of batches.
122
+
123
+ .. note::
124
+
125
+ You can set a prior on the lengthscale parameter using the lengthscale_prior argument.
126
+
127
+ :param ard_num_dims: Set this if you want a separate lengthscale for each input
128
+ dimension. It should be `D` if :math:`\mathbf x` is a `... x N x D` matrix. (Default: `None`.)
129
+ :param batch_shape: Set this if you want a separate lengthscale for each batch of input
130
+ data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x_1` is
131
+ a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor.
132
+ :param active_dims: Set this if you want to compute the covariance of only
133
+ a few input dimensions. The ints corresponds to the indices of the
134
+ dimensions. (Default: `None`.)
135
+ :param lengthscale_prior: Set this if you want to apply a prior to the
136
+ lengthscale parameter. (Default: `None`.)
137
+ :param lengthscale_constraint: Set this if you want to apply a constraint
138
+ to the lengthscale parameter. (Default: :class:`~gpytorch.constraints.Positive`.)
139
+ :param eps: A small positive value added to the lengthscale to prevent
140
+ divide by zero errors. (Default: `1e-6`.)
141
+
142
+ :ivar torch.Size batch_shape:
143
+ The (minimum) number of batch dimensions supported by this kernel.
144
+ Typically, this captures the batch shape of the lengthscale and other parameters,
145
+ and is usually set by the `batch_shape` argument in the constructor.
146
+ :ivar torch.dtype dtype:
147
+ The dtype supported by this kernel.
148
+ Typically, this depends on the dtype of the lengthscale and other parameters.
149
+ :ivar bool is_stationary:
150
+ Set to True if the Kernel represents a stationary function
151
+ (one that depends only on :math:`\mathbf x_1 - \mathbf x_2`).
152
+ :ivar torch.Tensor lengthscale:
153
+ The lengthscale parameter. Size/shape of parameter depends on the
154
+ `ard_num_dims` and `batch_shape` arguments.
155
+
156
+ Example:
157
+ >>> covar_module = qpytorch.kernels.LinearKernel()
158
+ >>> x1 = torch.randn(50, 3)
159
+ >>> lazy_covar_matrix = covar_module(x1) # Returns a RootLinearOperator
160
+ >>> tensor_covar_matrix = lazy_covar_matrix.to_dense() # Gets the actual tensor for this kernel matrix
161
+ """
162
+
163
+ has_lengthscale = False
164
+
165
+ def __init__(
166
+ self,
167
+ ard_num_dims: Optional[int] = None,
168
+ batch_shape: Optional[torch.Size] = None,
169
+ active_dims: Optional[Tuple[int, ...]] = None,
170
+ lengthscale_prior: Optional[Prior] = None,
171
+ lengthscale_constraint: Optional[Interval] = None,
172
+ eps: float = 1e-6,
173
+ **kwargs,
174
+ ):
175
+ super(Kernel, self).__init__()
176
+ self._batch_shape = torch.Size([]) if batch_shape is None else batch_shape
177
+ if active_dims is not None and not torch.is_tensor(active_dims):
178
+ active_dims = torch.tensor(active_dims, dtype=torch.long)
179
+ self.register_buffer("active_dims", active_dims)
180
+ self.ard_num_dims = ard_num_dims
181
+
182
+ self.eps = eps
183
+
184
+ param_transform = kwargs.get("param_transform")
185
+
186
+ if lengthscale_constraint is None:
187
+ lengthscale_constraint = Positive()
188
+
189
+ if param_transform is not None:
190
+ warnings.warn(
191
+ "The 'param_transform' argument is now deprecated. If you want to use a different "
192
+ "transformation, specify a different 'lengthscale_constraint' instead.",
193
+ DeprecationWarning,
194
+ )
195
+
196
+ if self.has_lengthscale:
197
+ lengthscale_num_dims = 1 if ard_num_dims is None else ard_num_dims
198
+ self.register_parameter(
199
+ name="raw_lengthscale",
200
+ parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, lengthscale_num_dims)),
201
+ )
202
+ if lengthscale_prior is not None:
203
+ if not isinstance(lengthscale_prior, Prior):
204
+ raise TypeError("Expected qpytorch.priors.Prior but got " + type(lengthscale_prior).__name__)
205
+ self.register_prior(
206
+ "lengthscale_prior", lengthscale_prior, self._lengthscale_param, self._lengthscale_closure
207
+ )
208
+
209
+ self.register_constraint("raw_lengthscale", lengthscale_constraint)
210
+
211
+ self.distance_module = None
212
+ # TODO: Remove this on next official PyTorch release.
213
+ self.__pdist_supports_batch = True
214
+
215
+ def _lengthscale_param(self, m: Kernel) -> Tensor:
216
+ # Used by the lengthscale_prior
217
+ return m.lengthscale
218
+
219
+ def _lengthscale_closure(self, m: Kernel, v: Tensor) -> Tensor:
220
+ # Used by the lengthscale_prior
221
+ return m._set_lengthscale(v)
222
+
223
+ def _set_lengthscale(self, value: Tensor):
224
+ # Used by the lengthscale_prior
225
+ if not self.has_lengthscale:
226
+ raise RuntimeError("Kernel has no lengthscale.")
227
+
228
+ if not torch.is_tensor(value):
229
+ value = torch.as_tensor(value).to(self.raw_lengthscale)
230
+
231
+ self.initialize(raw_lengthscale=self.raw_lengthscale_constraint.inverse_transform(value))
232
+
233
+ @abstractmethod
234
+ def forward(
235
+ self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: bool = False, **params
236
+ ) -> Union[Tensor, LinearOperator]:
237
+ r"""
238
+ Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`.
239
+ This method should be implemented by all Kernel subclasses.
240
+
241
+ :param x1: First set of data (... x N x D).
242
+ :param x2: Second set of data (... x M x D).
243
+ :param diag: Should the Kernel compute the whole kernel, or just the diag?
244
+ If True, it must be the case that `x1 == x2`. (Default: False.)
245
+ :param last_dim_is_batch: If True, treat the last dimension
246
+ of `x1` and `x2` as another batch dimension.
247
+ (Useful for additive structure over the dimensions). (Default: False.)
248
+
249
+ :return: The kernel matrix or vector. The shape depends on the kernel's evaluation mode:
250
+
251
+ * `full_covar`: `... x N x M`
252
+ * `full_covar` with `last_dim_is_batch=True`: `... x K x N x M`
253
+ * `diag`: `... x N`
254
+ * `diag` with `last_dim_is_batch=True`: `... x K x N`
255
+ """
256
+ raise NotImplementedError()
257
+
258
+ @property
259
+ def batch_shape(self) -> torch.Size:
260
+ kernels = list(self.sub_kernels())
261
+ if len(kernels):
262
+ return torch.broadcast_shapes(self._batch_shape, *[k.batch_shape for k in kernels])
263
+ else:
264
+ return self._batch_shape
265
+
266
+ @batch_shape.setter
267
+ def batch_shape(self, val: torch.Size):
268
+ self._batch_shape = val
269
+
270
+ @property
271
+ def device(self) -> Optional[torch.device]:
272
+ if self.has_lengthscale:
273
+ return self.lengthscale.device
274
+ devices = {param.device for param in self.parameters()}
275
+ if len(devices) > 1:
276
+ raise RuntimeError(f"The kernel's parameters are on multiple devices: {devices}.")
277
+ elif devices:
278
+ return devices.pop()
279
+ return None
280
+
281
+ @property
282
+ def dtype(self) -> torch.dtype:
283
+ if self.has_lengthscale:
284
+ return self.lengthscale.dtype
285
+ dtypes = {param.dtype for param in self.parameters()}
286
+ if len(dtypes) > 1:
287
+ raise RuntimeError(f"The kernel's parameters have multiple dtypes: {dtypes}.")
288
+ elif dtypes:
289
+ return dtypes.pop()
290
+ return torch.get_default_dtype()
291
+
292
+ @property
293
+ def lengthscale(self) -> Tensor:
294
+ if self.has_lengthscale:
295
+ return self.raw_lengthscale_constraint.transform(self.raw_lengthscale)
296
+ else:
297
+ return None
298
+
299
+ @lengthscale.setter
300
+ def lengthscale(self, value: Tensor):
301
+ self._set_lengthscale(value)
302
+
303
+ @property
304
+ def is_stationary(self) -> bool:
305
+ return self.has_lengthscale
306
+
307
+ def local_load_samples(self, samples_dict: Dict[str, Tensor], memo: set, prefix: str):
308
+ num_samples = next(iter(samples_dict.values())).size(0)
309
+ self.batch_shape = torch.Size([num_samples]) + self.batch_shape
310
+ super().local_load_samples(samples_dict, memo, prefix)
311
+
312
+ def covar_dist(
313
+ self,
314
+ x1: Tensor,
315
+ x2: Tensor,
316
+ diag: bool = False,
317
+ last_dim_is_batch: bool = False,
318
+ square_dist: bool = False,
319
+ **params,
320
+ ) -> Tensor:
321
+ r"""
322
+ This is a helper method for computing the Euclidean distance between
323
+ all pairs of points in :math:`\mathbf x_1` and :math:`\mathbf x_2`.
324
+
325
+ :param x1: First set of data (... x N x D).
326
+ :param x2: Second set of data (... x M x D).
327
+ :param diag: Should the Kernel compute the whole kernel, or just the diag?
328
+ If True, it must be the case that `x1 == x2`. (Default: False.)
329
+ :param last_dim_is_batch: If True, treat the last dimension
330
+ of `x1` and `x2` as another batch dimension.
331
+ (Useful for additive structure over the dimensions). (Default: False.)
332
+ :param square_dist:
333
+ If True, returns the squared distance rather than the standard distance. (Default: False.)
334
+ :return: The kernel matrix or vector. The shape depends on the kernel's evaluation mode:
335
+
336
+ * `full_covar`: `... x N x M`
337
+ * `full_covar` with `last_dim_is_batch=True`: `... x K x N x M`
338
+ * `diag`: `... x N`
339
+ * `diag` with `last_dim_is_batch=True`: `... x K x N`
340
+ """
341
+ if last_dim_is_batch:
342
+ x1 = x1.transpose(-1, -2).unsqueeze(-1)
343
+ x2 = x2.transpose(-1, -2).unsqueeze(-1)
344
+
345
+ x1_eq_x2 = torch.equal(x1, x2)
346
+ res = None
347
+
348
+ if diag:
349
+ # Special case the diagonal because we can return all zeros most of the time.
350
+ if x1_eq_x2:
351
+ return torch.zeros(*x1.shape[:-2], x1.shape[-2], dtype=x1.dtype, device=x1.device)
352
+ else:
353
+ res = torch.linalg.norm(x1 - x2, dim=-1) # 2-norm by default
354
+ return res.pow(2) if square_dist else res
355
+ else:
356
+ dist_func = sq_dist if square_dist else dist
357
+ return dist_func(x1, x2, x1_eq_x2)
358
+
359
+ def expand_batch(self, *sizes: Union[torch.Size, Tuple[int, ...]]) -> Kernel:
360
+ r"""
361
+ Constructs a new kernel where the lengthscale (and other kernel parameters)
362
+ are expanded to match the batch dimension determined by `sizes`.
363
+
364
+ :param sizes: The batch shape of the new tensor
365
+ """
366
+ # Type checking
367
+ if len(sizes) == 1 and hasattr(sizes, "__iter__"):
368
+ new_batch_shape = torch.Size(sizes[0])
369
+ elif all(isinstance(size, int) for size in sizes):
370
+ new_batch_shape = torch.Size(sizes)
371
+ else:
372
+ raise RuntimeError("Invalid arguments {} to expand_batch.".format(sizes))
373
+
374
+ # Check for easy case:
375
+ orig_batch_shape = self.batch_shape
376
+ if new_batch_shape == orig_batch_shape:
377
+ return self
378
+
379
+ # Ensure that the expansion size is compatible with the given batch shape
380
+ try:
381
+ torch.broadcast_shapes(new_batch_shape, orig_batch_shape)
382
+ except RuntimeError:
383
+ raise RuntimeError(
384
+ f"Cannot expand a kernel with batch shape {self.batch_shape} to new shape {new_batch_shape}"
385
+ )
386
+
387
+ # Create a new kernel with updated batch shape
388
+ new_kernel = deepcopy(self)
389
+ new_kernel._batch_shape = new_batch_shape
390
+
391
+ # Reshape the parameters of the kernel
392
+ for param_name, param in self.named_parameters(recurse=False):
393
+ # For a given parameter, get the number of dimensions that do not correspond to the batch shape
394
+ non_batch_shape = param.shape[len(orig_batch_shape) :]
395
+ new_param_shape = torch.Size([*new_batch_shape, *non_batch_shape])
396
+ new_kernel.__getattr__(param_name).data = param.expand(new_param_shape)
397
+
398
+ # Reshape the buffers of the kernel
399
+ for buffr_name, buffr in self.named_buffers(recurse=False):
400
+ # For a given buffer, get the number of dimensions that do not correspond to the batch shape
401
+ non_batch_shape = buffr.shape[len(orig_batch_shape) :]
402
+ new_buffer_shape = torch.Size([*new_batch_shape, *non_batch_shape])
403
+ new_kernel.__getattr__(buffr_name).data = buffr.expand(new_buffer_shape)
404
+
405
+ # Recurse, if necessary
406
+ for sub_module_name, sub_module in self.named_sub_kernels():
407
+ new_kernel.__setattr__(sub_module_name, sub_module.expand_batch(new_batch_shape))
408
+
409
+ return new_kernel
410
+
411
+ def named_sub_kernels(self) -> Iterable[Tuple[str, Kernel]]:
412
+ """
413
+ For compositional Kernel classes (e.g. :class:`~qpytorch.kernels.AdditiveKernel`
414
+ or :class:`~qpytorch.kernels.ProductKernel`).
415
+
416
+ :return: An iterator over the component kernel objects,
417
+ along with the name of each component kernel.
418
+ """
419
+ for name, module in self.named_modules():
420
+ if module is not self and isinstance(module, Kernel):
421
+ yield name, module
422
+
423
+ def num_outputs_per_input(self, x1: Tensor, x2: Tensor) -> int:
424
+ """
425
+ For most kernels, `num_outputs_per_input = 1`.
426
+
427
+ However, some kernels (e.g. multitask kernels or interdomain kernels) return a
428
+ `num_outputs_per_input x num_outputs_per_input` matrix of covariance values for
429
+ every pair of data points.
430
+
431
+ I.e. if `x1` is size `... x N x D` and `x2` is size `... x M x D`, then the size of the kernel
432
+ will be `... x (N * num_outputs_per_input) x (M * num_outputs_per_input)`.
433
+
434
+ :return: `num_outputs_per_input` (usually 1).
435
+ """
436
+ return 1
437
+
438
+ def prediction_strategy(
439
+ self,
440
+ train_inputs: Tensor,
441
+ train_prior_dist: Union[MultivariateNormal, MultivariateQExponential],
442
+ train_labels: Tensor,
443
+ likelihood: Union[GaussianLikelihood, QExponentialLikelihood],
444
+ ) -> exact_prediction_strategies.PredictionStrategy:
445
+ return exact_prediction_strategies.DefaultPredictionStrategy(
446
+ train_inputs, train_prior_dist, train_labels, likelihood
447
+ )
448
+
449
+ def sub_kernels(self) -> Iterable[Kernel]:
450
+ """
451
+ For compositional Kernel classes (e.g. :class:`~qpytorch.kernels.AdditiveKernel`
452
+ or :class:`~qpytorch.kernels.ProductKernel`).
453
+
454
+ :return: An iterator over the component kernel objects.
455
+ """
456
+ for _, kernel in self.named_sub_kernels():
457
+ yield kernel
458
+
459
+ def __call__(
460
+ self, x1: Tensor, x2: Optional[Tensor] = None, diag: bool = False, last_dim_is_batch: bool = False, **params
461
+ ) -> Union[LazyEvaluatedKernelTensor, LinearOperator, Tensor]:
462
+ r"""
463
+ Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`.
464
+
465
+ .. note::
466
+ Following PyTorch convention, all :class:`~gpytorch.models.GP` (:class:`~qpytorch.models.QEP`) objects should use `__call__`
467
+ rather than :py:meth:`~qpytorch.kernels.Kernel.forward`.
468
+ The `__call__` method applies additional pre- and post-processing to the `forward` method,
469
+ and additionally employs a lazy evaluation scheme to reduce memory and computational costs.
470
+
471
+ :param x1: First set of data (... x N x D).
472
+ :param x2: Second set of data (... x M x D).
473
+ (If `None`, then `x2` is set to `x1`.)
474
+ :param diag: Should the Kernel compute the whole kernel, or just the diag?
475
+ If True, it must be the case that `x1 == x2`. (Default: False.)
476
+ :param last_dim_is_batch: If True, treat the last dimension
477
+ of `x1` and `x2` as another batch dimension.
478
+ (Useful for additive structure over the dimensions). (Default: False.)
479
+
480
+ :return: An object that will lazily evaluate to the kernel matrix or vector.
481
+ The shape depends on the kernel's evaluation mode:
482
+
483
+ * `full_covar`: `... x N x M`
484
+ * `full_covar` with `last_dim_is_batch=True`: `... x K x N x M`
485
+ * `diag`: `... x N`
486
+ * `diag` with `last_dim_is_batch=True`: `... x K x N`
487
+ """
488
+ if last_dim_is_batch:
489
+ warnings.warn(
490
+ "The last_dim_is_batch argument is deprecated, and will be removed in GPyTorch 2.0. "
491
+ "If you are using it as part of AdditiveStructureKernel or ProductStructureKernel, "
492
+ 'please update your code according to the "Kernels with Additive or Product Structure" '
493
+ "tutorial in the GPyTorch docs.",
494
+ DeprecationWarning,
495
+ )
496
+
497
+ x1_, x2_ = x1, x2
498
+
499
+ # Select the active dimensions
500
+ if self.active_dims is not None:
501
+ x1_ = x1_.index_select(-1, self.active_dims)
502
+ if x2_ is not None:
503
+ x2_ = x2_.index_select(-1, self.active_dims)
504
+
505
+ # Give x1_ and x2_ a last dimension, if necessary
506
+ if x1_.ndimension() == 1:
507
+ x1_ = x1_.unsqueeze(1)
508
+ if x2_ is not None:
509
+ if x2_.ndimension() == 1:
510
+ x2_ = x2_.unsqueeze(1)
511
+ if not x1_.size(-1) == x2_.size(-1):
512
+ raise RuntimeError("x1_ and x2_ must have the same number of dimensions!")
513
+
514
+ if x2_ is None:
515
+ x2_ = x1_
516
+
517
+ # Check that ard_num_dims matches the supplied number of dimensions
518
+ if settings.debug.on():
519
+ if self.ard_num_dims is not None and self.ard_num_dims != x1_.size(-1):
520
+ raise RuntimeError(
521
+ "Expected the input to have {} dimensionality "
522
+ "(based on the ard_num_dims argument). Got {}.".format(self.ard_num_dims, x1_.size(-1))
523
+ )
524
+
525
+ if diag:
526
+ res = super(Kernel, self).__call__(x1_, x2_, diag=True, last_dim_is_batch=last_dim_is_batch, **params)
527
+ # Did this Kernel eat the diag option?
528
+ # If it does not return a LazyEvaluatedKernelTensor, we can call diag on the output
529
+ if not isinstance(res, LazyEvaluatedKernelTensor):
530
+ if res.dim() == x1_.dim() and res.shape[-2:] == torch.Size((x1_.size(-2), x2_.size(-2))):
531
+ res = res.diagonal(dim1=-1, dim2=-2)
532
+ return res
533
+
534
+ else:
535
+ if settings.lazily_evaluate_kernels.on():
536
+ res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, last_dim_is_batch=last_dim_is_batch, **params)
537
+ else:
538
+ res = to_linear_operator(
539
+ super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)
540
+ )
541
+ return res
542
+
543
+ def __getstate__(self):
544
+ # JIT ScriptModules cannot be pickled
545
+ self.distance_module = None
546
+ return self.__dict__
547
+
548
+ def __add__(self, other: Kernel) -> Kernel:
549
+ kernels = []
550
+ kernels += self.kernels if isinstance(self, AdditiveKernel) else [self]
551
+ kernels += other.kernels if isinstance(other, AdditiveKernel) else [other]
552
+ return AdditiveKernel(*kernels)
553
+
554
+ def __mul__(self, other: Kernel) -> Kernel:
555
+ kernels = []
556
+ kernels += self.kernels if isinstance(self, ProductKernel) else [self]
557
+ kernels += other.kernels if isinstance(other, ProductKernel) else [other]
558
+ return ProductKernel(*kernels)
559
+
560
+ def __setstate__(self, d):
561
+ self.__dict__ = d
562
+
563
+ def __getitem__(self, index) -> Kernel:
564
+ r"""
565
+ Constructs a new kernel where the lengthscale (and other kernel parameters)
566
+ are modified by an indexing operation.
567
+
568
+ :param index: Index to apply to all parameters.
569
+ """
570
+
571
+ if len(self.batch_shape) == 0:
572
+ return self
573
+
574
+ new_kernel = deepcopy(self)
575
+ # Process the index
576
+ index = index if isinstance(index, tuple) else (index,)
577
+
578
+ for param_name, param in self.named_parameters(recurse=False):
579
+ new_param = new_kernel.__getattr__(param_name)
580
+ new_param.data = new_param.__getitem__(index)
581
+ ndim_removed = len(param.shape) - len(new_param.shape)
582
+ new_batch_shape_len = len(self.batch_shape) - ndim_removed
583
+ new_kernel.batch_shape = new_param.shape[:new_batch_shape_len]
584
+
585
+ for buffr_name, buffr in self.named_buffers(recurse=False):
586
+ # For a given buffer, get the number of dimensions that do not correspond to the batch shape
587
+ new_buffr = new_kernel.__getattr__(buffr_name)
588
+ new_buffr.data = new_buffr.__getitem__(index)
589
+ ndim_removed = len(buffr.shape) - len(new_buffr.shape)
590
+ new_batch_shape_len = len(self.batch_shape) - ndim_removed
591
+ new_kernel.batch_shape = new_buffr.shape[:new_batch_shape_len]
592
+
593
+ for sub_module_name, sub_module in self.named_sub_kernels():
594
+ new_kernel.__setattr__(sub_module_name, sub_module.__getitem__(index))
595
+
596
+ return new_kernel
597
+
598
+
599
+ class AdditiveKernel(Kernel):
600
+ """
601
+ A Kernel that supports summing over multiple component kernels.
602
+
603
+ Example:
604
+ >>> covar_module = RBFKernel(active_dims=torch.tensor([1])) + RBFKernel(active_dims=torch.tensor([2]))
605
+ >>> x1 = torch.randn(50, 2)
606
+ >>> additive_kernel_matrix = covar_module(x1)
607
+
608
+ :param kernels: Kernels to add together.
609
+ """
610
+
611
+ @property
612
+ def is_stationary(self) -> bool:
613
+ return all(k.is_stationary for k in self.kernels)
614
+
615
+ def __init__(self, *kernels: Iterable[Kernel]):
616
+ super(AdditiveKernel, self).__init__()
617
+ self.kernels = ModuleList(kernels)
618
+
619
+ def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]:
620
+ res = ZeroLinearOperator() if not diag else 0
621
+ for kern in self.kernels:
622
+ next_term = kern(x1, x2, diag=diag, **params)
623
+ if not diag:
624
+ res = res + to_linear_operator(next_term)
625
+ else:
626
+ res = res + next_term
627
+
628
+ return res
629
+
630
+ def num_outputs_per_input(self, x1, x2):
631
+ return self.kernels[0].num_outputs_per_input(x1, x2)
632
+
633
+ def __getitem__(self, index) -> Kernel:
634
+ new_kernel = deepcopy(self)
635
+ for i, kernel in enumerate(self.kernels):
636
+ new_kernel.kernels[i] = kernel.__getitem__(index)
637
+
638
+ return new_kernel
639
+
640
+
641
+ class ProductKernel(Kernel):
642
+ """
643
+ A Kernel that supports elementwise multiplying multiple component kernels together.
644
+
645
+ Example:
646
+ >>> covar_module = RBFKernel(active_dims=torch.tensor([1])) * RBFKernel(active_dims=torch.tensor([2]))
647
+ >>> x1 = torch.randn(50, 2)
648
+ >>> kernel_matrix = covar_module(x1) # The RBF Kernel already decomposes multiplicatively, so this is foolish!
649
+
650
+ :param kernels: Kernels to multiply together.
651
+ """
652
+
653
+ @property
654
+ def is_stationary(self) -> bool:
655
+ return all(k.is_stationary for k in self.kernels)
656
+
657
+ def __init__(self, *kernels: Iterable[Kernel]):
658
+ super(ProductKernel, self).__init__()
659
+ self.kernels = ModuleList(kernels)
660
+
661
+ def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]:
662
+ x1_eq_x2 = torch.equal(x1, x2)
663
+
664
+ if not x1_eq_x2:
665
+ # If x1 != x2, then we can't make a MulLinearOperator because the kernel won't necessarily be
666
+ # square/symmetric
667
+ res = to_dense(self.kernels[0](x1, x2, diag=diag, **params))
668
+ else:
669
+ res = self.kernels[0](x1, x2, diag=diag, **params)
670
+
671
+ if not diag:
672
+ res = to_linear_operator(res)
673
+
674
+ for kern in self.kernels[1:]:
675
+ next_term = kern(x1, x2, diag=diag, **params)
676
+ if not x1_eq_x2:
677
+ # Again to_dense if x1 != x2
678
+ res = res * to_dense(next_term)
679
+ else:
680
+ if not diag:
681
+ res = res * to_linear_operator(next_term)
682
+ else:
683
+ res = res * next_term
684
+
685
+ return res
686
+
687
+ def num_outputs_per_input(self, x1: Tensor, x2: Tensor) -> int:
688
+ return self.kernels[0].num_outputs_per_input(x1, x2)
689
+
690
+ def __getitem__(self, index) -> Kernel:
691
+ new_kernel = deepcopy(self)
692
+ for i, kernel in enumerate(self.kernels):
693
+ new_kernel.kernels[i] = kernel.__getitem__(index)
694
+
695
+ return new_kernel