qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,487 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Any, Optional, Union
4
+
5
+ import torch
6
+ from jaxtyping import Float
7
+ from linear_operator import to_dense
8
+ from linear_operator.operators import DiagLinearOperator, LinearOperator, TriangularLinearOperator
9
+ from linear_operator.utils.cholesky import psd_safe_cholesky
10
+ from torch import LongTensor, Tensor
11
+
12
+ from ..distributions import MultivariateNormal, MultivariateQExponential
13
+ from ..models import ApproximateGP, ApproximateQEP, ExactGP, ExactQEP
14
+ from ..module import Module
15
+ from gpytorch.utils.errors import CachingError
16
+ from gpytorch.utils.memoize import add_to_cache, cached, pop_from_cache
17
+ from gpytorch.utils.nearest_neighbors import NNUtil
18
+ from ._variational_distribution import _VariationalDistribution
19
+ from .mean_field_variational_distribution import MeanFieldVariationalDistribution
20
+ from .unwhitened_variational_strategy import UnwhitenedVariationalStrategy
21
+
22
+
23
+ class NNVariationalStrategy(UnwhitenedVariationalStrategy):
24
+ r"""
25
+ This strategy sets all inducing point locations to observed inputs,
26
+ and employs a :math:`k`-nearest-neighbor approximation. It was introduced as the
27
+ `Variational Nearest Neighbor Gaussian Processes (VNNGP)` in `Wu et al (2022)`_.
28
+ See the `VNNGP tutorial`_ for an example.
29
+
30
+ VNNGP assumes a k-nearest-neighbor generative process for inducing points :math:`\mathbf u`,
31
+ :math:`\mathbf q(\mathbf u) = \prod_{j=1}^M q(u_j | \mathbf u_{n(j)})`
32
+ where :math:`n(j)` denotes the indices of :math:`k` nearest neighbors for :math:`u_j` among
33
+ :math:`u_1, \cdots, u_{j-1}`. For any test observation :math:`\mathbf f`,
34
+ VNNGP makes predictive inference conditioned on its :math:`k` nearest inducing points
35
+ :math:`\mathbf u_{n(f)}`, i.e. :math:`p(f|\mathbf u_{n(f)})`.
36
+
37
+ VNNGP's objective factorizes over inducing points and observations, making stochastic optimization over both
38
+ immediately available. After a one-time cost of computing the :math:`k`-nearest neighbor structure,
39
+ the training and inference complexity is :math:`O(k^3)`.
40
+ Since VNNGP uses observations as inducing points, it is a user choice to either (1)
41
+ use the same mini-batch of inducing points and observations (recommended),
42
+ or (2) use different mini-batches of inducing points and observations. See the `VNNGP tutorial`_ for
43
+ implementation and comparison.
44
+
45
+
46
+ .. note::
47
+
48
+ The current implementation only supports :obj:`~qpytorch.variational.MeanFieldVariationalDistribution`.
49
+
50
+ We recommend installing the `faiss`_ library (requiring separate package installment)
51
+ for nearest neighbor search, which is significantly faster than the `scikit-learn` nearest neighbor search.
52
+ GPyTorch will automatically use `faiss` if it is installed, but will revert to `scikit-learn` otherwise.
53
+
54
+ Different inducing point orderings will produce in different nearest neighbor approximations.
55
+
56
+
57
+ :param ~gpytorch.models.ApproximateGP (~qpytorch.models.ApproximateQEP) model: Model this strategy is applied to.
58
+ Typically passed in when the VariationalStrategy is created in the
59
+ __init__ method of the user defined model.
60
+ It should contain power if Q-Exponential distribution is involved in.
61
+ :param inducing_points: Tensor containing a set of inducing
62
+ points to use for variational inference.
63
+ :param variational_distribution: A
64
+ VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
65
+ :param k: Number of nearest neighbors.
66
+ :param training_batch_size: The number of data points that will be in the training batch size.
67
+ :param jitter_val: Amount of diagonal jitter to add for covariance matrix numerical stability.
68
+ :param compute_full_kl: Whether to compute full kl divergence or stochastic estimate.
69
+
70
+ .. _Wu et al (2022):
71
+ https://arxiv.org/pdf/2202.01694.pdf
72
+ .. _VNNGP tutorial:
73
+ examples/04_Variational_and_Approximate_GPs/VNNGP.html
74
+ .. _faiss:
75
+ https://github.com/facebookresearch/faiss
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ model: Union[ApproximateGP, ApproximateQEP],
81
+ inducing_points: Float[Tensor, "... M D"],
82
+ variational_distribution: Float[_VariationalDistribution, "... M"],
83
+ k: int,
84
+ training_batch_size: Optional[int] = None,
85
+ jitter_val: Optional[float] = 1e-3,
86
+ compute_full_kl: Optional[bool] = False,
87
+ ):
88
+ assert isinstance(
89
+ variational_distribution, MeanFieldVariationalDistribution
90
+ ), "Currently, NNVariationalStrategy only supports MeanFieldVariationalDistribution."
91
+
92
+ super().__init__(
93
+ model, inducing_points, variational_distribution, learn_inducing_locations=False, jitter_val=jitter_val
94
+ )
95
+
96
+ # Model
97
+ object.__setattr__(self, "model", model)
98
+
99
+ self.inducing_points = inducing_points
100
+ self.M, self.D = inducing_points.shape[-2:]
101
+ self.k = k
102
+ assert self.k < self.M, (
103
+ f"Number of nearest neighbors k must be smaller than the number of inducing points, "
104
+ f"but got k = {k}, M = {self.M}."
105
+ )
106
+
107
+ self._inducing_batch_shape: torch.Size = inducing_points.shape[:-2]
108
+ self._model_batch_shape: torch.Size = self._variational_distribution.variational_mean.shape[:-1]
109
+ self._batch_shape: torch.Size = torch.broadcast_shapes(self._inducing_batch_shape, self._model_batch_shape)
110
+
111
+ self.nn_util: NNUtil = NNUtil(
112
+ k, dim=self.D, batch_shape=self._inducing_batch_shape, device=inducing_points.device
113
+ )
114
+ self._compute_nn()
115
+ # otherwise, no nearest neighbor approximation is used
116
+
117
+ self.training_batch_size = training_batch_size if training_batch_size is not None else self.M
118
+ self._set_training_iterator()
119
+
120
+ self.compute_full_kl = compute_full_kl
121
+
122
+ @property
123
+ @cached(name="prior_distribution_memo")
124
+ def prior_distribution(self) -> Union[Float[MultivariateNormal, "... M"], Float[MultivariateQExponential, "... M"]]:
125
+ out = self.model.forward(self.inducing_points)
126
+ if hasattr(self.model, 'power'):
127
+ res = MultivariateQExponential(out.mean, out.lazy_covariance_matrix.add_jitter(self.jitter_val), power=self.model.power)
128
+ else:
129
+ res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter(self.jitter_val))
130
+ return res
131
+
132
+ def _cholesky_factor(
133
+ self, induc_induc_covar: Float[LinearOperator, "... M M"]
134
+ ) -> Float[TriangularLinearOperator, "... M M"]:
135
+ # Uncached version
136
+ L = psd_safe_cholesky(to_dense(induc_induc_covar))
137
+ return TriangularLinearOperator(L)
138
+
139
+ def __call__(
140
+ self, x: Float[Tensor, "... N D"], prior: bool = False, **kwargs: Any
141
+ ) -> Union[Float[MultivariateNormal, "... N"], Float[MultivariateQExponential, "... N"]]:
142
+ # If we're in prior mode, then we're done!
143
+ if prior:
144
+ return self.model.forward(x, **kwargs)
145
+
146
+ if x is not None:
147
+ # Make sure x and inducing points have the same batch shape
148
+ if not (self.inducing_points.shape[:-2] == x.shape[:-2]):
149
+ try:
150
+ x = x.expand(*self.inducing_points.shape[:-2], *x.shape[-2:]).contiguous()
151
+ except RuntimeError:
152
+ raise RuntimeError(
153
+ f"x batch shape must match or broadcast with the inducing points' batch shape, "
154
+ f"but got x batch shape = {x.shape[:-2]}, "
155
+ f"inducing points batch shape = {self.inducing_points.shape[:-2]}."
156
+ )
157
+
158
+ # Delete previously cached items from the training distribution
159
+ if self.training:
160
+ self._clear_cache()
161
+
162
+ # (Maybe) initialize variational distribution
163
+ if not self.variational_params_initialized.item():
164
+ prior_dist = self.prior_distribution
165
+ self._variational_distribution.variational_mean.data.copy_(prior_dist.mean)
166
+ self._variational_distribution.variational_mean.data.add_(
167
+ torch.randn_like(prior_dist.mean), alpha=self._variational_distribution.mean_init_std
168
+ )
169
+ # initialize with a small variational stddev for quicker conv. of kl divergence
170
+ self._variational_distribution._variational_stddev.data.copy_(torch.tensor(1e-2))
171
+ self.variational_params_initialized.fill_(1)
172
+
173
+ return self.forward(
174
+ x, self.inducing_points, inducing_values=None, variational_inducing_covar=None, **kwargs
175
+ )
176
+ else:
177
+ # Ensure inducing_points and x are the same size
178
+ inducing_points = self.inducing_points
179
+ return self.forward(x, inducing_points, inducing_values=None, variational_inducing_covar=None, **kwargs)
180
+
181
+ def forward(
182
+ self,
183
+ x: Float[Tensor, "... N D"],
184
+ inducing_points: Float[Tensor, "... M D"],
185
+ inducing_values: Float[Tensor, "... M"],
186
+ variational_inducing_covar: Optional[Float[LinearOperator, "... M M"]] = None,
187
+ **kwargs: Any,
188
+ ) -> Union[Float[MultivariateNormal, "... N"], Float[MultivariateQExponential, "... N"]]:
189
+ if self.training:
190
+ # In training mode, note that the full inducing points set = full training dataset
191
+ # Users have the option to choose input None or a tensor of training data for x
192
+ # If x is None, will sample training data from inducing points
193
+ # Otherwise, will find the indices of inducing points that are equal to x
194
+ if x is None:
195
+ x_indices = self._get_training_indices()
196
+ kl_indices = x_indices
197
+
198
+ predictive_mean = self._variational_distribution.variational_mean[..., x_indices]
199
+ predictive_var = self._variational_distribution._variational_stddev[..., x_indices] ** 2
200
+
201
+ else:
202
+ # find the indices of inducing points that correspond to x
203
+ x_indices = self.nn_util.find_nn_idx(x.float(), k=1).squeeze(-1) # (*inducing_batch_shape, batch_size)
204
+
205
+ expanded_x_indices = x_indices.expand(*self._batch_shape, x_indices.shape[-1])
206
+ expanded_variational_mean = self._variational_distribution.variational_mean.expand(
207
+ *self._batch_shape, self.M
208
+ )
209
+ expanded_variational_var = (
210
+ self._variational_distribution._variational_stddev.expand(*self._batch_shape, self.M) ** 2
211
+ )
212
+
213
+ predictive_mean = expanded_variational_mean.gather(-1, expanded_x_indices)
214
+ predictive_var = expanded_variational_var.gather(-1, expanded_x_indices)
215
+
216
+ # sample a different indices for stochastic estimation of kl
217
+ kl_indices = self._get_training_indices()
218
+
219
+ kl = self._kl_divergence(kl_indices)
220
+ add_to_cache(self, "kl_divergence_memo", kl)
221
+
222
+ # if hasattr(self.model, 'power'):
223
+ # return MultivariateQExponential(predictive_mean, DiagLinearOperator(predictive_var), power=self.model.power)
224
+ # else:
225
+ # return MultivariateNormal(predictive_mean, DiagLinearOperator(predictive_var))
226
+ else:
227
+ nn_indices = self.nn_util.find_nn_idx(x.float())
228
+
229
+ x_batch_shape = x.shape[:-2]
230
+ batch_shape = torch.broadcast_shapes(self._batch_shape, x_batch_shape)
231
+ x_bsz = x.shape[-2]
232
+ assert nn_indices.shape == (*x_batch_shape, x_bsz, self.k), nn_indices.shape
233
+
234
+ # select K nearest neighbors from inducing points for test point x
235
+ expanded_nn_indices = nn_indices.unsqueeze(-1).expand(*x_batch_shape, x_bsz, self.k, self.D)
236
+ expanded_inducing_points = inducing_points.unsqueeze(-2).expand(*x_batch_shape, self.M, self.k, self.D)
237
+ inducing_points = expanded_inducing_points.gather(-3, expanded_nn_indices)
238
+ assert inducing_points.shape == (*x_batch_shape, x_bsz, self.k, self.D)
239
+
240
+ # get variational mean and covar for nearest neighbors
241
+ inducing_values = self._variational_distribution.variational_mean
242
+ expanded_inducing_values = inducing_values.unsqueeze(-1).expand(*batch_shape, self.M, self.k)
243
+ expanded_nn_indices = nn_indices.expand(*batch_shape, x_bsz, self.k)
244
+ inducing_values = expanded_inducing_values.gather(-2, expanded_nn_indices)
245
+ assert inducing_values.shape == (*batch_shape, x_bsz, self.k)
246
+
247
+ variational_stddev = self._variational_distribution._variational_stddev
248
+ assert variational_stddev.shape == (*self._model_batch_shape, self.M)
249
+ expanded_variational_stddev = variational_stddev.unsqueeze(-1).expand(*batch_shape, self.M, self.k)
250
+ variational_inducing_covar = expanded_variational_stddev.gather(-2, expanded_nn_indices) ** 2
251
+ assert variational_inducing_covar.shape == (*batch_shape, x_bsz, self.k)
252
+ variational_inducing_covar = DiagLinearOperator(variational_inducing_covar)
253
+ assert variational_inducing_covar.shape == (*batch_shape, x_bsz, self.k, self.k)
254
+
255
+ # Make everything batch mode
256
+ x = x.unsqueeze(-2)
257
+ assert x.shape == (*x_batch_shape, x_bsz, 1, self.D)
258
+ x = x.expand(*batch_shape, x_bsz, 1, self.D)
259
+
260
+ # Compute forward mode in the standard way
261
+ _batch_dims = tuple(range(len(batch_shape)))
262
+ _x = x.permute((-3,) + _batch_dims + (-2, -1)) # (x_bsz, *batch_shape, 1, D)
263
+
264
+ # inducing_points.shape (*x_batch_shape, x_bsz, self.k, self.D)
265
+ inducing_points = inducing_points.expand(*batch_shape, x_bsz, self.k, self.D)
266
+ _inducing_points = inducing_points.permute((-3,) + _batch_dims + (-2, -1)) # (x_bsz, *batch_shape, k, D)
267
+ _inducing_values = inducing_values.permute((-2,) + _batch_dims + (-1,))
268
+ _variational_inducing_covar = variational_inducing_covar.permute((-3,) + _batch_dims + (-2, -1))
269
+ dist = super().forward(_x, _inducing_points, _inducing_values, _variational_inducing_covar, **kwargs)
270
+
271
+ _x_batch_dims = tuple(range(1, 1 + len(batch_shape)))
272
+ predictive_mean = dist.mean # (x_bsz, *x_batch_shape, 1)
273
+ predictive_covar = dist.covariance_matrix # (x_bsz, *x_batch_shape, 1, 1)
274
+ predictive_mean = predictive_mean.permute(_x_batch_dims + (0, -1))
275
+ predictive_covar = predictive_covar.permute(_x_batch_dims + (0, -2, -1))
276
+
277
+ # Undo batch mode
278
+ predictive_mean = predictive_mean.squeeze(-1)
279
+ predictive_var = predictive_covar.squeeze(-2).squeeze(-1)
280
+ assert predictive_var.shape == predictive_covar.shape[:-2]
281
+ assert predictive_mean.shape == predictive_covar.shape[:-2]
282
+
283
+ # Return the distribution
284
+ if hasattr(self.model, 'power'):
285
+ return MultivariateQExponential(predictive_mean, DiagLinearOperator(predictive_var), power=self.model.power)
286
+ else:
287
+ return MultivariateNormal(predictive_mean, DiagLinearOperator(predictive_var))
288
+
289
+ def get_fantasy_model(
290
+ self,
291
+ inputs: Float[Tensor, "... N D"],
292
+ targets: Float[Tensor, "... N"],
293
+ mean_module: Optional[Module] = None,
294
+ covar_module: Optional[Module] = None,
295
+ **kwargs,
296
+ ) -> Union[ExactGP, ExactQEP]:
297
+ raise NotImplementedError(
298
+ f"No fantasy model support for {self.__class__.__name__}. "
299
+ "Only VariationalStrategy and UnwhitenedVariationalStrategy are currently supported."
300
+ )
301
+
302
+ def _set_training_iterator(self) -> None:
303
+ self._training_indices_iter = 0
304
+ if self.training_batch_size == self.M:
305
+ self._training_indices_iterator = (torch.arange(self.M, device=self.inducing_points.device),)
306
+ else:
307
+ # The first training batch always contains the first k inducing points
308
+ # This is because computing the KL divergence for the first k inducing points is special-cased
309
+ # (since the first k inducing points have < k neighbors)
310
+ # Note that there is a special function _firstk_kl_helper for this
311
+ training_indices = torch.randperm(self.M - self.k, device=self.inducing_points.device) + self.k
312
+ self._training_indices_iterator = (torch.arange(self.k),) + training_indices.split(self.training_batch_size)
313
+ self._total_training_batches = len(self._training_indices_iterator)
314
+
315
+ def _get_training_indices(self) -> LongTensor:
316
+ self.current_training_indices = self._training_indices_iterator[self._training_indices_iter]
317
+ self._training_indices_iter += 1
318
+ if self._training_indices_iter == self._total_training_batches:
319
+ self._set_training_iterator()
320
+ return self.current_training_indices
321
+
322
+ def _firstk_kl_helper(self) -> Float[Tensor, "..."]:
323
+ # Compute the KL divergence for first k inducing points
324
+ train_x_firstk = self.inducing_points[..., : self.k, :]
325
+ full_output = self.model.forward(train_x_firstk)
326
+
327
+ induc_mean, induc_induc_covar = full_output.mean, full_output.lazy_covariance_matrix
328
+
329
+ induc_induc_covar = induc_induc_covar.add_jitter(self.jitter_val)
330
+ if hasattr(self.model, 'power'):
331
+ prior_dist = MultivariateQExponential(induc_mean, induc_induc_covar, power=self.model.power)
332
+ else:
333
+ prior_dist = MultivariateNormal(induc_mean, induc_induc_covar)
334
+
335
+ inducing_values = self._variational_distribution.variational_mean[..., : self.k]
336
+ variational_covar_fisrtk = self._variational_distribution._variational_stddev[..., : self.k] ** 2
337
+ variational_inducing_covar = DiagLinearOperator(variational_covar_fisrtk)
338
+
339
+ if hasattr(self.model, 'power'):
340
+ variational_distribution = MultivariateQExponential(inducing_values, variational_inducing_covar, power=self.model.power)
341
+ else:
342
+ variational_distribution = MultivariateNormal(inducing_values, variational_inducing_covar)
343
+ kl = torch.distributions.kl.kl_divergence(variational_distribution, prior_dist) # model_batch_shape
344
+ return kl
345
+
346
+ def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[Tensor, "..."]: # noqa: F821
347
+ # Compute the KL divergence for a mini batch of the rest M-k inducing points
348
+ # See paper appendix for kl breakdown
349
+ kl_bs = len(kl_indices) # training_batch_size
350
+ variational_mean = self._variational_distribution.variational_mean # (*model_bs, M)
351
+ variational_stddev = self._variational_distribution._variational_stddev
352
+
353
+ # (1) compute logdet_q
354
+ inducing_point_log_variational_covar = (variational_stddev[..., kl_indices] ** 2).log()
355
+ logdet_q = torch.sum(inducing_point_log_variational_covar, dim=-1) # model_bs
356
+
357
+ # (2) compute lodet_p
358
+ # Select a mini-batch of inducing points according to kl_indices
359
+ inducing_points = self.inducing_points[..., kl_indices, :].expand(*self._batch_shape, kl_bs, self.D)
360
+ # (*bs, kl_bs, D)
361
+ # Select their K nearest neighbors
362
+ nearest_neighbor_indices = self.nn_xinduce_idx[..., kl_indices - self.k, :].to(inducing_points.device)
363
+ # (*bs, kl_bs, K)
364
+ expanded_inducing_points_all = self.inducing_points.unsqueeze(-2).expand(
365
+ *self._batch_shape, self.M, self.k, self.D
366
+ )
367
+ expanded_nearest_neighbor_indices = nearest_neighbor_indices.unsqueeze(-1).expand(
368
+ *self._batch_shape, kl_bs, self.k, self.D
369
+ )
370
+ nearest_neighbors = expanded_inducing_points_all.gather(-3, expanded_nearest_neighbor_indices)
371
+ # (*bs, kl_bs, K, D)
372
+
373
+ # Compute prior distribution
374
+ # Move the kl_bs dimension to the first dimension to enable batch covar_module computation
375
+ nearest_neighbors_ = nearest_neighbors.permute((-3,) + tuple(range(len(self._batch_shape))) + (-2, -1))
376
+ # (kl_bs, *bs, K, D)
377
+ inducing_points_ = inducing_points.permute((-2,) + tuple(range(len(self._batch_shape))) + (-1,))
378
+ # (kl_bs, *bs, D)
379
+ full_output = self.model.forward(torch.cat([nearest_neighbors_, inducing_points_.unsqueeze(-2)], dim=-2))
380
+ full_mean, full_covar = full_output.mean, full_output.covariance_matrix
381
+
382
+ # Mean terms
383
+ _undo_permute_dims = tuple(range(1, 1 + len(self._batch_shape))) + (0, -1)
384
+ nearest_neighbors_prior_mean = full_mean[..., : self.k].permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K)
385
+ inducing_prior_mean = full_mean[..., self.k :].permute(_undo_permute_dims).squeeze(-1) # (*inducing_bs, kl_bs)
386
+ # Covar terms
387
+ nearest_neighbors_prior_cov = full_covar[..., : self.k, : self.k]
388
+ nearest_neighbors_inducing_prior_cross_cov = full_covar[..., : self.k, self.k :]
389
+ inducing_prior_cov = full_covar[..., self.k :, self.k :]
390
+ inducing_prior_cov = (
391
+ inducing_prior_cov.squeeze(-1).squeeze(-1).permute((-1,) + tuple(range(len(self._batch_shape))))
392
+ )
393
+
394
+ # Interpolation term K_nn^{-1} k_{nu}
395
+ interp_term = torch.linalg.solve(
396
+ nearest_neighbors_prior_cov + self.jitter_val * torch.eye(self.k, device=self.inducing_points.device),
397
+ nearest_neighbors_inducing_prior_cross_cov,
398
+ ).squeeze(
399
+ -1
400
+ ) # (kl_bs, *inducing_bs, K)
401
+ interp_term = interp_term.permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K)
402
+ nearest_neighbors_inducing_prior_cross_cov = nearest_neighbors_inducing_prior_cross_cov.squeeze(-1).permute(
403
+ _undo_permute_dims
404
+ ) # k_{n(j),j}, (*inducing_bs, kl_bs, K)
405
+
406
+ invquad_term_for_F = torch.sum(
407
+ interp_term * nearest_neighbors_inducing_prior_cross_cov, dim=-1
408
+ ) # (*inducing_bs, kl_bs)
409
+
410
+ inducing_prior_cov = self.model.covar_module.forward(
411
+ inducing_points, inducing_points, diag=True
412
+ ) # (*inducing_bs, kl_bs)
413
+
414
+ F = inducing_prior_cov - invquad_term_for_F
415
+ F = F + self.jitter_val
416
+ # K_uu - k_un K_nn^{-1} k_nu
417
+ logdet_p = F.log().sum(dim=-1) # shape: inducing_bs
418
+
419
+ # (3) compute trace_term
420
+ expanded_variational_stddev = variational_stddev.unsqueeze(-1).expand(*self._batch_shape, self.M, self.k)
421
+ expanded_variational_mean = variational_mean.unsqueeze(-1).expand(*self._batch_shape, self.M, self.k)
422
+ expanded_nearest_neighbor_indices = nearest_neighbor_indices.expand(*self._batch_shape, kl_bs, self.k)
423
+ nearest_neighbor_variational_covar = (
424
+ expanded_variational_stddev.gather(-2, expanded_nearest_neighbor_indices) ** 2
425
+ ) # (*batch_shape, kl_bs, k)
426
+ bjsquared_s_nearest_neighbors = torch.sum(
427
+ interp_term**2 * nearest_neighbor_variational_covar, dim=-1
428
+ ) # (*batch_shape, kl_bs)
429
+ inducing_point_variational_covar = variational_stddev[..., kl_indices] ** 2 # (model_bs, kl_bs)
430
+ trace_term = (1.0 / F * (bjsquared_s_nearest_neighbors + inducing_point_variational_covar)).sum(
431
+ dim=-1
432
+ ) # batch_shape
433
+
434
+ # (4) compute invquad_term
435
+ nearest_neighbors_variational_mean = expanded_variational_mean.gather(-2, expanded_nearest_neighbor_indices)
436
+ Bj_m_nearest_neighbors = torch.sum(
437
+ interp_term * (nearest_neighbors_variational_mean - nearest_neighbors_prior_mean), dim=-1
438
+ )
439
+ inducing_variational_mean = variational_mean[..., kl_indices]
440
+ invquad_term = torch.sum(
441
+ (inducing_variational_mean - inducing_prior_mean - Bj_m_nearest_neighbors) ** 2 / F, dim=-1
442
+ )
443
+
444
+ trace_plus_invquad_form = trace_term + invquad_term
445
+ if hasattr(self.model, 'power'): trace_plus_invquad_form = trace_plus_invquad_form**(self.model.power/2.)
446
+ kl = (logdet_p - logdet_q - kl_bs + trace_plus_invquad_form) * (1.0 / 2)
447
+ if hasattr(self.model, 'power') and self.model.power!=2:
448
+ kl -= kl_bs*(1.0/2-1./self.model.power)*(torch.log(trace_plus_invquad_form)+torch.distributions.Chi2(kl_bs).entropy())
449
+ assert kl.shape == self._batch_shape, kl.shape
450
+
451
+ return kl
452
+
453
+ def _kl_divergence(
454
+ self, kl_indices: Optional[LongTensor] = None, batch_size: Optional[int] = None
455
+ ) -> Float[Tensor, "..."]:
456
+ if self.compute_full_kl or (self._total_training_batches == 1):
457
+ if batch_size is None:
458
+ batch_size = self.training_batch_size
459
+ kl = self._firstk_kl_helper()
460
+ for kl_indices in torch.split(torch.arange(self.k, self.M), batch_size):
461
+ kl += self._stochastic_kl_helper(kl_indices)
462
+ else:
463
+ # compute a stochastic estimate
464
+ assert kl_indices is not None
465
+ if self._training_indices_iter == 1:
466
+ assert len(kl_indices) == self.k, (
467
+ f"kl_indices sould be the first batch data of length k, "
468
+ f"but got len(kl_indices) = {len(kl_indices)} and k = {self.k}."
469
+ )
470
+ kl = self._firstk_kl_helper() * self.M / self.k
471
+ else:
472
+ kl = self._stochastic_kl_helper(kl_indices) * self.M / len(kl_indices)
473
+ return kl
474
+
475
+ def kl_divergence(self) -> Float[Tensor, "..."]:
476
+ try:
477
+ return pop_from_cache(self, "kl_divergence_memo")
478
+ except CachingError:
479
+ raise RuntimeError("KL Divergence of variational strategy was called before nearest neighbors were set.")
480
+
481
+ def _compute_nn(self) -> "NNVariationalStrategy":
482
+ with torch.no_grad():
483
+ inducing_points_fl = self.inducing_points.data.float()
484
+ self.nn_util.set_nn_idx(inducing_points_fl)
485
+ self.nn_xinduce_idx = self.nn_util.build_sequential_nn_idx(inducing_points_fl)
486
+ # shape (*_inducing_batch_shape, M-k, k)
487
+ return self
@@ -0,0 +1,128 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ from linear_operator.operators import LinearOperator
7
+ from torch import Tensor
8
+
9
+ from ..distributions import MultivariateNormal, MultivariateQExponential
10
+ from gpytorch.utils.memoize import add_to_cache, cached
11
+ from ._variational_distribution import _VariationalDistribution
12
+ from ._variational_strategy import _VariationalStrategy
13
+ from .delta_variational_distribution import DeltaVariationalDistribution
14
+
15
+
16
+ class OrthogonallyDecoupledVariationalStrategy(_VariationalStrategy):
17
+ r"""
18
+ Implements orthogonally decoupled VGPs as defined in `Salimbeni et al. (2018)`_.
19
+ This variational strategy uses a different set of inducing points for the mean and covariance functions.
20
+ The idea is to use more inducing points for the (computationally efficient) mean and fewer inducing points for the
21
+ (computationally expensive) covaraince.
22
+
23
+ This variational strategy defines the inducing points/:obj:`~qpytorch.variational._VariationalDistribution`
24
+ for the mean function.
25
+ It then wraps a different :obj:`~qpytorch.variational._VariationalStrategy` which
26
+ defines the covariance inducing points.
27
+
28
+ :param covar_variational_strategy:
29
+ The variational strategy for the covariance term.
30
+ :param inducing_points: Tensor containing a set of inducing
31
+ points to use for variational inference.
32
+ :param variational_distribution: A
33
+ VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
34
+ :param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability
35
+
36
+ Example:
37
+ >>> mean_inducing_points = torch.randn(1000, train_x.size(-1), dtype=train_x.dtype, device=train_x.device)
38
+ >>> covar_inducing_points = torch.randn(100, train_x.size(-1), dtype=train_x.dtype, device=train_x.device)
39
+ >>>
40
+ >>> covar_variational_strategy = qpytorch.variational.VariationalStrategy(
41
+ >>> model, covar_inducing_points,
42
+ >>> qpytorch.variational.CholeskyVariationalDistribution(covar_inducing_points.size(-2)),
43
+ >>> learn_inducing_locations=True
44
+ >>> )
45
+ >>>
46
+ >>> variational_strategy = qpytorch.variational.OrthogonallyDecoupledVariationalStrategy(
47
+ >>> covar_variational_strategy, mean_inducing_points,
48
+ >>> qpytorch.variational.DeltaVariationalDistribution(mean_inducing_points.size(-2)),
49
+ >>> )
50
+
51
+ .. _Salimbeni et al. (2018):
52
+ https://arxiv.org/abs/1809.08820
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ covar_variational_strategy: _VariationalStrategy,
58
+ inducing_points: Tensor,
59
+ variational_distribution: _VariationalDistribution,
60
+ jitter_val: Optional[float] = None,
61
+ ):
62
+ if not isinstance(variational_distribution, DeltaVariationalDistribution):
63
+ raise NotImplementedError(
64
+ "OrthogonallyDecoupledVariationalStrategy currently works with DeltaVariationalDistribution"
65
+ )
66
+
67
+ super().__init__(
68
+ covar_variational_strategy,
69
+ inducing_points,
70
+ variational_distribution,
71
+ learn_inducing_locations=True,
72
+ jitter_val=jitter_val,
73
+ )
74
+ self.base_variational_strategy = covar_variational_strategy
75
+
76
+ @property
77
+ @cached(name="prior_distribution_memo")
78
+ def prior_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
79
+ out = self.model(self.inducing_points)
80
+ if isinstance(out, MultivariateNormal):
81
+ res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter(self.jitter_val))
82
+ elif isinstance(out, MultivariateQExponential):
83
+ res = MultivariateQExponential(out.mean, out.lazy_covariance_matrix.add_jitter(self.jitter_val), power=out.power)
84
+ return res
85
+
86
+ def forward(
87
+ self,
88
+ x: Tensor,
89
+ inducing_points: Tensor,
90
+ inducing_values: Tensor,
91
+ variational_inducing_covar: Optional[LinearOperator] = None,
92
+ **kwargs,
93
+ ) -> Union[MultivariateNormal, MultivariateQExponential]:
94
+ if variational_inducing_covar is not None:
95
+ raise NotImplementedError(
96
+ "OrthogonallyDecoupledVariationalStrategy currently works with DeltaVariationalDistribution"
97
+ )
98
+
99
+ num_data = x.size(-2)
100
+ full_output = self.model(torch.cat([x, inducing_points], dim=-2), **kwargs)
101
+ full_mean = full_output.mean
102
+ full_covar = full_output.lazy_covariance_matrix
103
+
104
+ if self.training:
105
+ induc_mean = full_mean[..., num_data:]
106
+ induc_induc_covar = full_covar[..., num_data:, num_data:]
107
+ if isinstance(full_output, MultivariateNormal):
108
+ prior_dist = MultivariateNormal(induc_mean, induc_induc_covar)
109
+ if isinstance(full_output, MultivariateQExponential):
110
+ prior_dist = MultivariateQExponential(induc_mean, induc_induc_covar, power=full_output.power)
111
+ add_to_cache(self, "prior_distribution_memo", prior_dist)
112
+
113
+ test_mean = full_mean[..., :num_data]
114
+ data_induc_covar = full_covar[..., :num_data, num_data:]
115
+ predictive_mean = (data_induc_covar @ inducing_values.unsqueeze(-1)).squeeze(-1).add(test_mean)
116
+ predictive_covar = full_covar[..., :num_data, :num_data]
117
+
118
+ # Return the distribution
119
+ if isinstance(full_output, MultivariateNormal):
120
+ return MultivariateNormal(predictive_mean, predictive_covar)
121
+ elif isinstance(full_output, MultivariateQExponential):
122
+ return MultivariateQExponential(predictive_mean, predictive_covar, power=full_output.power)
123
+
124
+ def kl_divergence(self) -> Tensor:
125
+ mean = self.variational_distribution.mean
126
+ induc_induc_covar = self.prior_distribution.lazy_covariance_matrix
127
+ kl = self.model.kl_divergence() + ((induc_induc_covar @ mean.unsqueeze(-1)).squeeze(-1) * mean).sum(-1).mul(0.5)
128
+ return kl