qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,256 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ from linear_operator.operators import LinearOperator, MatmulLinearOperator, SumLinearOperator
7
+ from torch import Tensor
8
+ from torch.distributions.kl import kl_divergence
9
+
10
+ from ..distributions import Delta, MultivariateNormal, MultivariateQExponential
11
+ from ..models import ApproximateGP, ApproximateQEP
12
+ from gpytorch.utils.errors import CachingError
13
+ from gpytorch.utils.memoize import pop_from_cache_ignore_args
14
+ from ._variational_distribution import _VariationalDistribution
15
+ from .delta_variational_distribution import DeltaVariationalDistribution
16
+ from .variational_strategy import VariationalStrategy
17
+
18
+
19
+ class BatchDecoupledVariationalStrategy(VariationalStrategy):
20
+ r"""
21
+ A VariationalStrategy that uses a different set of inducing points for the
22
+ variational mean and variational covar. It follows the "decoupled" model
23
+ proposed by `Jankowiak et al. (2020)`_ (which is roughly based on the strategies
24
+ proposed by `Cheng et al. (2017)`_.
25
+
26
+ Let :math:`\mathbf Z_\mu` and :math:`\mathbf Z_\sigma` be the mean/variance
27
+ inducing points. The variational distribution for an input :math:`\mathbf
28
+ x` is given by:
29
+
30
+ .. math::
31
+
32
+ \begin{align*}
33
+ \mathbb E[ f(\mathbf x) ] &= \mathbf k_{\mathbf Z_\mu \mathbf x}^\top
34
+ \mathbf K_{\mathbf Z_\mu \mathbf Z_\mu}^{-1} \mathbf m
35
+ \\
36
+ \text{Var}[ f(\mathbf x) ] &= k_{\mathbf x \mathbf x} - \mathbf k_{\mathbf Z_\sigma \mathbf x}^\top
37
+ \mathbf K_{\mathbf Z_\sigma \mathbf Z_\sigma}^{-1}
38
+ \left( \mathbf K_{\mathbf Z_\sigma} - \mathbf S \right)
39
+ \mathbf K_{\mathbf Z_\sigma \mathbf Z_\sigma}^{-1}
40
+ \mathbf k_{\mathbf Z_\sigma \mathbf x}
41
+ \end{align*}
42
+
43
+ where :math:`\mathbf m` and :math:`\mathbf S` are the variational parameters.
44
+ Unlike the original proposed implementation, :math:`\mathbf Z_\mu` and :math:`\mathbf Z_\sigma`
45
+ have **the same number of inducing points**, which allows us to perform batched operations.
46
+
47
+ Additionally, you can use a different set of kernel hyperparameters for the mean and the variance function.
48
+ We recommend using this feature only with the :obj:`~qpytorch.mlls.PredictiveLogLikelihood` objective function
49
+ as proposed in "Parametric Gaussian Process Regressors" (`Jankowiak et al. (2020)`_).
50
+ Use the mean_var_batch_dim to indicate which batch dimension corresponds to the different mean/var
51
+ kernels.
52
+
53
+ .. note::
54
+ We recommend using the "right-most" batch dimension (i.e. ``mean_var_batch_dim=-1``) for the dimension
55
+ that corresponds to the different mean/variance kernel parameters.
56
+
57
+ Assuming you want `b1` many independent GPs (uncorrelated QEPs), the :obj:`~qpytorch.variational._VariationalDistribution`
58
+ objects should have a batch shape of `b1`, and the mean/covar modules
59
+ of the GP (QEP) should have a batch shape of `b1 x 2`.
60
+ (The 2 corresponds to the mean/variance hyperparameters.)
61
+
62
+ .. seealso::
63
+ :obj:`~qpytorch.variational.OrthogonallyDecoupledVariationalStrategy` (a variant proposed by
64
+ `Salimbeni et al. (2018)`_ that uses orthogonal projections.)
65
+
66
+ :param model: Model this strategy is applied to.
67
+ Typically passed in when the VariationalStrategy is created in the
68
+ __init__ method of the user defined model.
69
+ It should contain power if Q-Exponential distribution is involved in.
70
+ :param inducing_points: Tensor containing a set of inducing
71
+ points to use for variational inference.
72
+ :param variational_distribution: A
73
+ VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
74
+ :param learn_inducing_locations: (Default True): Whether or not
75
+ the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they
76
+ parameters of the model).
77
+ :param mean_var_batch_dim: (Default `None`):
78
+ Set this parameter (ideally to `-1`) to indicate which dimension corresponds to different
79
+ kernel hyperparameters for the mean/variance functions.
80
+ :param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability
81
+
82
+ .. _Cheng et al. (2017):
83
+ https://arxiv.org/abs/1711.10127
84
+
85
+ .. _Salimbeni et al. (2018):
86
+ https://arxiv.org/abs/1809.08820
87
+
88
+ .. _Jankowiak et al. (2020):
89
+ https://arxiv.org/abs/1910.07123
90
+
91
+ Example (**different** hypers for mean/variance):
92
+ >>> class MeanFieldDecoupledModel(qpytorch.models.ApproximateGP or qpytorch.models.ApproximateQEP):
93
+ >>> '''
94
+ >>> A batch of 3 independent MeanFieldDecoupled PPGPR (PPQEP) models.
95
+ >>> '''
96
+ >>> def __init__(self, inducing_points):
97
+ >>> # The variational parameters have a batch_shape of [3]
98
+ >>> variational_distribution = qpytorch.variational.MeanFieldVariationalDistribution(
99
+ >>> inducing_points.size(-1), batch_shape=torch.Size([3]),
100
+ >>> )
101
+ >>> variational_strategy = qpytorch.variational.BatchDecoupledVariationalStrategy(
102
+ >>> self, inducing_points, variational_distribution, learn_inducing_locations=True,
103
+ >>> mean_var_batch_dim=-1
104
+ >>> )
105
+ >>>
106
+ >>> # The mean/covar modules have a batch_shape of [3, 2]
107
+ >>> # where the last batch dim corresponds to the mean & variance hyperparameters
108
+ >>> super().__init__(variational_strategy)
109
+ >>> self.mean_module = qpytorch.means.ConstantMean(batch_shape=torch.Size([3, 2]))
110
+ >>> self.covar_module = qpytorch.kernels.ScaleKernel(
111
+ >>> qpytorch.kernels.RBFKernel(batch_shape=torch.Size([3, 2])),
112
+ >>> batch_shape=torch.Size([3, 2]),
113
+ >>> )
114
+
115
+ Example (**shared** hypers for mean/variance):
116
+ >>> class MeanFieldDecoupledModel(qpytorch.models.ApproximateGP or qpytorch.models.ApproximateQEP):
117
+ >>> '''
118
+ >>> A batch of 3 independent MeanFieldDecoupled PPGPR (PPQEP) models.
119
+ >>> '''
120
+ >>> def __init__(self, inducing_points):
121
+ >>> # The variational parameters have a batch_shape of [3]
122
+ >>> variational_distribution = qpytorch.variational.MeanFieldVariationalDistribution(
123
+ >>> inducing_points.size(-1), batch_shape=torch.Size([3]),
124
+ >>> )
125
+ >>> variational_strategy = qpytorch.variational.BatchDecoupledVariationalStrategy(
126
+ >>> self, inducing_points, variational_distribution, learn_inducing_locations=True,
127
+ >>> )
128
+ >>>
129
+ >>> # The mean/covar modules have a batch_shape of [3, 1]
130
+ >>> # where the singleton dimension corresponds to the shared mean/variance hyperparameters
131
+ >>> super().__init__(variational_strategy)
132
+ >>> self.mean_module = qpytorch.means.ConstantMean(batch_shape=torch.Size([3, 1]))
133
+ >>> self.covar_module = qpytorch.kernels.ScaleKernel(
134
+ >>> qpytorch.kernels.RBFKernel(batch_shape=torch.Size([3, 1])),
135
+ >>> batch_shape=torch.Size([3, 1]),
136
+ >>> )
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ model: Union[ApproximateGP, ApproximateQEP],
142
+ inducing_points: Tensor,
143
+ variational_distribution: _VariationalDistribution,
144
+ learn_inducing_locations: bool = True,
145
+ mean_var_batch_dim: Optional[int] = None,
146
+ jitter_val: Optional[float] = None,
147
+ ):
148
+ if isinstance(variational_distribution, DeltaVariationalDistribution):
149
+ raise NotImplementedError(
150
+ "BatchDecoupledVariationalStrategy does not work with DeltaVariationalDistribution"
151
+ )
152
+
153
+ if mean_var_batch_dim is not None and mean_var_batch_dim >= 0:
154
+ raise ValueError(f"mean_var_batch_dim should be negative indexed, got {mean_var_batch_dim}")
155
+ self.mean_var_batch_dim = mean_var_batch_dim
156
+
157
+ # Maybe unsqueeze inducing points
158
+ if inducing_points.dim() == 1:
159
+ inducing_points = inducing_points.unsqueeze(-1)
160
+
161
+ # We're going to create two set of inducing points
162
+ # One set for computing the mean, one set for computing the variance
163
+ if self.mean_var_batch_dim is not None:
164
+ inducing_points = torch.stack([inducing_points, inducing_points], dim=(self.mean_var_batch_dim - 2))
165
+ else:
166
+ inducing_points = torch.stack([inducing_points, inducing_points], dim=-3)
167
+ super().__init__(
168
+ model, inducing_points, variational_distribution, learn_inducing_locations, jitter_val=jitter_val
169
+ )
170
+
171
+ def _expand_inputs(self, x: Tensor, inducing_points: Tensor) -> Tuple[Tensor, Tensor]:
172
+ # If we haven't explicitly marked a dimension as batch, add the corresponding batch dimension to the input
173
+ if self.mean_var_batch_dim is None:
174
+ x = x.unsqueeze(-3)
175
+ else:
176
+ x = x.unsqueeze(self.mean_var_batch_dim - 2)
177
+ return super()._expand_inputs(x, inducing_points)
178
+
179
+ def forward(
180
+ self,
181
+ x: Tensor,
182
+ inducing_points: Tensor,
183
+ inducing_values: Tensor,
184
+ variational_inducing_covar: Optional[LinearOperator] = None,
185
+ **kwargs,
186
+ ) -> Union[MultivariateNormal, MultivariateQExponential]:
187
+ # We'll compute the covariance, and cross-covariance terms for both the
188
+ # pred-mean and pred-covar, using their different inducing points (and maybe kernel hypers)
189
+
190
+ mean_var_batch_dim = self.mean_var_batch_dim or -1
191
+
192
+ # Compute full prior distribution
193
+ full_inputs = torch.cat([inducing_points, x], dim=-2)
194
+ full_output = self.model.forward(full_inputs, **kwargs)
195
+ full_covar = full_output.lazy_covariance_matrix
196
+
197
+ # Covariance terms
198
+ num_induc = inducing_points.size(-2)
199
+ test_mean = full_output.mean[..., num_induc:]
200
+ induc_induc_covar = full_covar[..., :num_induc, :num_induc].add_jitter(self.jitter_val)
201
+ induc_data_covar = full_covar[..., :num_induc, num_induc:].to_dense()
202
+ data_data_covar = full_covar[..., num_induc:, num_induc:]
203
+
204
+ # Compute interpolation terms
205
+ # K_ZZ^{-1/2} K_ZX
206
+ # K_ZZ^{-1/2} \mu_Z
207
+ L = self._cholesky_factor(induc_induc_covar)
208
+ if L.shape != induc_induc_covar.shape:
209
+ # Aggressive caching can cause nasty shape incompatibilities when evaluating with different batch shapes
210
+ # TODO: Use a hook to make this cleaner
211
+ try:
212
+ pop_from_cache_ignore_args(self, "cholesky_factor")
213
+ except CachingError:
214
+ pass
215
+ L = self._cholesky_factor(induc_induc_covar)
216
+ interp_term = L.solve(induc_data_covar.double()).to(full_inputs.dtype)
217
+ mean_interp_term = interp_term.select(mean_var_batch_dim - 2, 0)
218
+ var_interp_term = interp_term.select(mean_var_batch_dim - 2, 1)
219
+
220
+ # Compute the mean of q(f)
221
+ # k_XZ K_ZZ^{-1/2} m + \mu_X
222
+ # Here we're using the terms that correspond to the mean's inducing points
223
+ predictive_mean = torch.add(
224
+ torch.matmul(mean_interp_term.transpose(-1, -2), inducing_values.unsqueeze(-1)).squeeze(-1),
225
+ test_mean.select(mean_var_batch_dim - 1, 0),
226
+ )
227
+
228
+ # Compute the covariance of q(f)
229
+ # K_XX + k_XZ K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} k_ZX
230
+ middle_term = self.prior_distribution.lazy_covariance_matrix.mul(-1)
231
+ if variational_inducing_covar is not None:
232
+ middle_term = SumLinearOperator(variational_inducing_covar, middle_term)
233
+ predictive_covar = SumLinearOperator(
234
+ data_data_covar.add_jitter(self.jitter_val).to_dense().select(mean_var_batch_dim - 2, 1),
235
+ MatmulLinearOperator(var_interp_term.transpose(-1, -2), middle_term @ var_interp_term),
236
+ )
237
+
238
+ if hasattr(self.model, 'power'):
239
+ return MultivariateQExponential(predictive_mean, predictive_covar, power=self.model.power)
240
+ else:
241
+ return MultivariateNormal(predictive_mean, predictive_covar)
242
+
243
+ def kl_divergence(self) -> Tensor:
244
+ variational_dist = self.variational_distribution
245
+ prior_dist = self.prior_distribution
246
+
247
+ mean_dist = Delta(variational_dist.mean)
248
+ if hasattr(self.model, 'power'):
249
+ covar_dist = MultivariateQExponential(
250
+ torch.zeros_like(variational_dist.mean), variational_dist.lazy_covariance_matrix, power=self.model.power
251
+ )
252
+ else:
253
+ covar_dist = MultivariateNormal(
254
+ torch.zeros_like(variational_dist.mean), variational_dist.lazy_covariance_matrix
255
+ )
256
+ return kl_divergence(mean_dist, prior_dist) + kl_divergence(covar_dist, prior_dist)
@@ -0,0 +1,65 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Union
4
+
5
+ import torch
6
+ from linear_operator.operators import CholLinearOperator, TriangularLinearOperator
7
+
8
+ from ..distributions import MultivariateNormal, MultivariateQExponential
9
+ from ._variational_distribution import _VariationalDistribution
10
+
11
+
12
+ class CholeskyVariationalDistribution(_VariationalDistribution):
13
+ """
14
+ A :obj:`~qpytorch.variational._VariationalDistribution` that is defined to be a multivariate normal (q-exponential) distribution
15
+ with a full covariance matrix.
16
+
17
+ The most common way this distribution is defined is to parameterize it in terms of a mean vector and a covariance
18
+ matrix. In order to ensure that the covariance matrix remains positive definite, we only consider the lower
19
+ triangle.
20
+
21
+ :param num_inducing_points: Size of the variational distribution. This implies that the variational mean
22
+ should be this size, and the variational covariance matrix should have this many rows and columns.
23
+ :param batch_shape: Specifies an optional batch size
24
+ for the variational parameters. This is useful for example when doing additive variational inference.
25
+ :param mean_init_std: (Default: 1e-3) Standard deviation of gaussian (q-exponential) noise to add to the mean initialization.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ num_inducing_points: int,
31
+ batch_shape: torch.Size = torch.Size([]),
32
+ mean_init_std: float = 1e-3,
33
+ **kwargs,
34
+ ):
35
+ super().__init__(num_inducing_points=num_inducing_points, batch_shape=batch_shape, mean_init_std=mean_init_std)
36
+ mean_init = torch.zeros(num_inducing_points)
37
+ covar_init = torch.eye(num_inducing_points, num_inducing_points)
38
+ mean_init = mean_init.repeat(*batch_shape, 1)
39
+ covar_init = covar_init.repeat(*batch_shape, 1, 1)
40
+
41
+ self.register_parameter(name="variational_mean", parameter=torch.nn.Parameter(mean_init))
42
+ self.register_parameter(name="chol_variational_covar", parameter=torch.nn.Parameter(covar_init))
43
+
44
+ if 'power' in kwargs: self.power = kwargs.pop('power')
45
+
46
+ def forward(self) -> Union[MultivariateNormal, MultivariateQExponential]:
47
+ chol_variational_covar = self.chol_variational_covar
48
+ dtype = chol_variational_covar.dtype
49
+ device = chol_variational_covar.device
50
+
51
+ # First make the cholesky factor is upper triangular
52
+ lower_mask = torch.ones(self.chol_variational_covar.shape[-2:], dtype=dtype, device=device).tril(0)
53
+ chol_variational_covar = TriangularLinearOperator(chol_variational_covar.mul(lower_mask))
54
+
55
+ # Now construct the actual matrix
56
+ variational_covar = CholLinearOperator(chol_variational_covar)
57
+ if not hasattr(self, 'power'):
58
+ return MultivariateNormal(self.variational_mean, variational_covar)
59
+ else:
60
+ return MultivariateQExponential(self.variational_mean, variational_covar, power=self.power)
61
+
62
+ def initialize_variational_distribution(self, prior_dist: Union[MultivariateNormal, MultivariateQExponential]) -> None:
63
+ self.variational_mean.data.copy_(prior_dist.mean)
64
+ self.variational_mean.data.add_(torch.randn_like(prior_dist.mean), alpha=self.mean_init_std)
65
+ self.chol_variational_covar.data.copy_(prior_dist.lazy_covariance_matrix.cholesky().to_dense())
@@ -0,0 +1,352 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ from linear_operator import to_linear_operator
7
+ from linear_operator.operators import DiagLinearOperator, LinearOperator, MatmulLinearOperator, SumLinearOperator
8
+ from linear_operator.utils import linear_cg
9
+ from torch import Tensor
10
+ from torch.autograd.function import FunctionCtx
11
+
12
+ from .. import settings
13
+ from ..distributions import Delta, Distribution, MultivariateNormal, MultivariateQExponential
14
+ from ..module import Module
15
+ from gpytorch.utils.memoize import cached
16
+ from ._variational_strategy import _VariationalStrategy
17
+ from .natural_variational_distribution import NaturalVariationalDistribution
18
+
19
+
20
+ class _NgdInterpTerms(torch.autograd.Function):
21
+ """
22
+ This function takes in
23
+
24
+ - the kernel interpolation term K_ZZ^{-1/2} k_ZX
25
+ - the natural parameters of the variational distribution
26
+
27
+ and returns
28
+
29
+ - the predictive distribution mean/covariance
30
+ - the inducing KL divergence KL( q(u) || p(u))
31
+
32
+ However, the gradients will be with respect to the **cannonical parameters**
33
+ of the variational distribution, rather than the **natural parameters**.
34
+ This corresponds to performing natural gradient descent on the variational distribution.
35
+ """
36
+
37
+ @staticmethod
38
+ def forward(
39
+ ctx: FunctionCtx,
40
+ interp_term: torch.Tensor,
41
+ natural_vec: torch.Tensor,
42
+ natural_mat: torch.Tensor,
43
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
44
+ # Compute precision
45
+ prec = natural_mat.mul(-2.0)
46
+ diag = prec.diagonal(dim1=-1, dim2=-2).unsqueeze(-1)
47
+
48
+ # Make sure that interp_term and natural_vec are the same batch shape
49
+ batch_shape = torch.broadcast_shapes(interp_term.shape[:-2], natural_vec.shape[:-1])
50
+ expanded_interp_term = interp_term.expand(*batch_shape, *interp_term.shape[-2:])
51
+ expanded_natural_vec = natural_vec.expand(*batch_shape, natural_vec.size(-1))
52
+
53
+ # Compute necessary solves with the precision. We need
54
+ # m = expec_vec = S * natural_vec
55
+ # S K^{-1/2} k
56
+ solves = linear_cg(
57
+ prec.matmul,
58
+ torch.cat([expanded_natural_vec.unsqueeze(-1), expanded_interp_term], dim=-1),
59
+ n_tridiag=0,
60
+ max_iter=settings.max_cg_iterations.value(),
61
+ tolerance=min(settings.eval_cg_tolerance.value(), settings.cg_tolerance.value()),
62
+ max_tridiag_iter=settings.max_lanczos_quadrature_iterations.value(),
63
+ preconditioner=lambda x: x / diag,
64
+ )
65
+ expec_vec = solves[..., 0]
66
+ s_times_interp_term = solves[..., 1:]
67
+
68
+ # Compute the interpolated mean
69
+ # k^T K^{-1/2} m
70
+ interp_mean = (s_times_interp_term.transpose(-1, -2) @ natural_vec.unsqueeze(-1)).squeeze(-1)
71
+
72
+ # Compute the interpolated variance
73
+ # k^T K^{-1/2} S K^{-1/2} k = k^T K^{-1/2} (expec_mat - expec_vec expec_vec^T) K^{-1/2} k
74
+ interp_var = (s_times_interp_term * interp_term).sum(dim=-2)
75
+
76
+ # Let's not bother actually computing the KL-div in the foward pass
77
+ # 1/2 ( -log | S | + tr(S) + m^T m - len(m) )
78
+ # = 1/2 ( -log | expec_mat - expec_vec expec_vec^T | + tr(expec_mat) - len(m) )
79
+ kl_div = torch.zeros_like(interp_mean[..., 0])
80
+
81
+ # We're done!
82
+ ctx.save_for_backward(interp_term, s_times_interp_term, interp_mean, natural_vec, expec_vec, prec)
83
+ return interp_mean, interp_var, kl_div
84
+
85
+ @staticmethod
86
+ def backward(
87
+ ctx: FunctionCtx, interp_mean_grad: torch.Tensor, interp_var_grad: torch.Tensor, kl_div_grad: torch.Tensor
88
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None]:
89
+ # Get the saved terms
90
+ interp_term, s_times_interp_term, interp_mean, natural_vec, expec_vec, prec = ctx.saved_tensors
91
+
92
+ # Expand data-depenedent gradients
93
+ interp_mean_grad = interp_mean_grad.unsqueeze(-2)
94
+ interp_var_grad = interp_var_grad.unsqueeze(-2)
95
+
96
+ # Compute gradient of interp term (K^{-1/2} k)
97
+ # interp_mean component: m
98
+ # interp_var component: S K^{-1/2} k
99
+ # kl component: 0
100
+ interp_term_grad = (interp_var_grad * s_times_interp_term).mul(2.0) + (
101
+ interp_mean_grad * expec_vec.unsqueeze(-1)
102
+ )
103
+
104
+ # Compute gradient of expected vector (m)
105
+ # interp_mean component: K^{-1/2} k
106
+ # interp_var component: (k^T K^{-1/2} m) K^{-1/2} k
107
+ # kl component: S^{-1} m
108
+ expec_vec_grad = (
109
+ (interp_var_grad * interp_mean.unsqueeze(-2) * interp_term).sum(dim=-1).mul(-2)
110
+ + (interp_mean_grad * interp_term).sum(dim=-1)
111
+ + (kl_div_grad.unsqueeze(-1) * natural_vec)
112
+ )
113
+
114
+ # Compute gradient of expected matrix (mm^T + S)
115
+ # interp_mean component: 0
116
+ # interp_var component: K^{-1/2} k k^T K^{-1/2}
117
+ # kl component: 1/2 ( I - S^{-1} )
118
+ eye = torch.eye(expec_vec.size(-1), device=expec_vec.device, dtype=expec_vec.dtype)
119
+ expec_mat_grad = torch.add(
120
+ (interp_var_grad * interp_term) @ interp_term.transpose(-1, -2),
121
+ (kl_div_grad.unsqueeze(-1).unsqueeze(-1) * (eye - prec).mul(0.5)),
122
+ )
123
+
124
+ # We're done!
125
+ return interp_term_grad, expec_vec_grad, expec_mat_grad, None # Extra "None" for the kwarg
126
+
127
+
128
+ class CiqVariationalStrategy(_VariationalStrategy):
129
+ r"""
130
+ Similar to :class:`~qpytorch.variational.VariationalStrategy`,
131
+ except the whitening operation is performed using Contour Integral Quadrature
132
+ rather than Cholesky (see `Pleiss et al. (2020)`_ for more info).
133
+ See the `CIQ-SVGP tutorial`_ for an example.
134
+
135
+ Contour Integral Quadrature uses iterative matrix-vector multiplication to approximate
136
+ the :math:`\mathbf K_{\mathbf Z \mathbf Z}^{-1/2}` matrix used for the whitening operation.
137
+ This can be more efficient than the standard variational strategy for large numbers
138
+ of inducing points (e.g. :math:`M > 1000`) or when the inducing points have structure
139
+ (e.g. they lie on an evenly-spaced grid).
140
+
141
+ .. note::
142
+
143
+ It is recommended that this object is used in conjunction with
144
+ :obj:`~qpytorch.variational.NaturalVariationalDistribution` and
145
+ `natural gradient descent`_.
146
+
147
+ :param model: Model this strategy is applied to.
148
+ Typically passed in when the VariationalStrategy is created in the
149
+ __init__ method of the user defined model.
150
+ It should contain power if Q-Exponential distribution is involved in.
151
+ :param inducing_points: Tensor containing a set of inducing
152
+ points to use for variational inference.
153
+ :param variational_distribution: A
154
+ VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
155
+ :param learn_inducing_locations: (Default True): Whether or not
156
+ the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they
157
+ parameters of the model).
158
+ :param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability
159
+
160
+ .. _Pleiss et al. (2020):
161
+ https://arxiv.org/pdf/2006.11267.pdf
162
+ .. _CIQ-SVGP tutorial:
163
+ examples/04_Variational_and_Approximate_GPs/SVGP_CIQ.html
164
+ .. _natural gradient descent:
165
+ examples/04_Variational_and_Approximate_GPs/Natural_Gradient_Descent.html
166
+ """
167
+
168
+ def _ngd(self) -> bool:
169
+ return isinstance(self._variational_distribution, NaturalVariationalDistribution)
170
+
171
+ @property
172
+ @cached(name="prior_distribution_memo")
173
+ def prior_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
174
+ zeros = torch.zeros(
175
+ self._variational_distribution.shape(),
176
+ dtype=self._variational_distribution.dtype,
177
+ device=self._variational_distribution.device,
178
+ )
179
+ ones = torch.ones_like(zeros)
180
+ if hasattr(self.model, 'power'):
181
+ res = MultivariateQExponential(zeros, DiagLinearOperator(ones), power=self.model.power)
182
+ else:
183
+ res = MultivariateNormal(zeros, DiagLinearOperator(ones))
184
+ return res
185
+
186
+ @property
187
+ @cached(name="variational_distribution_memo")
188
+ def variational_distribution(self) -> Distribution:
189
+ if self._ngd():
190
+ raise RuntimeError(
191
+ "Variational distribution for NGD-CIQ should be computed during forward calls. "
192
+ "This is probably a bug in GPyTorch."
193
+ )
194
+ return super().variational_distribution
195
+
196
+ def forward(
197
+ self,
198
+ x: torch.Tensor,
199
+ inducing_points: torch.Tensor,
200
+ inducing_values: torch.Tensor,
201
+ variational_inducing_covar: Optional[LinearOperator] = None,
202
+ *params,
203
+ **kwargs,
204
+ ) -> Union[MultivariateNormal, MultivariateQExponential]:
205
+ # Compute full prior distribution
206
+ full_inputs = torch.cat([inducing_points, x], dim=-2)
207
+ full_output = self.model.forward(full_inputs, *params, **kwargs)
208
+ full_covar = full_output.lazy_covariance_matrix
209
+
210
+ # Covariance terms
211
+ num_induc = inducing_points.size(-2)
212
+ test_mean = full_output.mean[..., num_induc:]
213
+ induc_induc_covar = full_covar[..., :num_induc, :num_induc].evaluate_kernel().add_jitter(self.jitter_val)
214
+ induc_data_covar = full_covar[..., :num_induc, num_induc:].to_dense()
215
+ data_data_covar = full_covar[..., num_induc:, num_induc:].add_jitter(self.jitter_val)
216
+
217
+ # Compute interpolation terms
218
+ # K_XZ K_ZZ^{-1} \mu_z
219
+ # K_XZ K_ZZ^{-1/2} \mu_Z
220
+ with settings.max_preconditioner_size(0): # Turn off preconditioning for CIQ
221
+ interp_term = to_linear_operator(induc_induc_covar).sqrt_inv_matmul(induc_data_covar)
222
+
223
+ # Compute interpolated mean and variance terms
224
+ # We have separate computation rules for NGD versus standard GD
225
+ if self._ngd():
226
+ interp_mean, interp_var, kl_div = _NgdInterpTerms.apply(
227
+ interp_term,
228
+ self._variational_distribution.natural_vec,
229
+ self._variational_distribution.natural_mat,
230
+ )
231
+
232
+ # Compute the covariance of q(f)
233
+ predictive_var = data_data_covar.diagonal(dim1=-1, dim2=-2) - interp_term.pow(2).sum(dim=-2) + interp_var
234
+ predictive_var = torch.clamp_min(predictive_var, settings.min_variance.value(predictive_var.dtype))
235
+ predictive_covar = DiagLinearOperator(predictive_var)
236
+
237
+ # Also compute and cache the KL divergence
238
+ if not hasattr(self, "_memoize_cache"):
239
+ self._memoize_cache = dict()
240
+ self._memoize_cache["kl"] = kl_div
241
+
242
+ else:
243
+ # Compute interpolated mean term
244
+ interp_mean = torch.matmul(
245
+ interp_term.transpose(-1, -2), (inducing_values - self.prior_distribution.mean).unsqueeze(-1)
246
+ ).squeeze(-1)
247
+
248
+ # Compute the covariance of q(f)
249
+ middle_term = self.prior_distribution.lazy_covariance_matrix.mul(-1)
250
+ if variational_inducing_covar is not None:
251
+ middle_term = SumLinearOperator(variational_inducing_covar, middle_term)
252
+ predictive_covar = SumLinearOperator(
253
+ data_data_covar.add_jitter(self.jitter_val),
254
+ MatmulLinearOperator(interp_term.transpose(-1, -2), middle_term @ interp_term),
255
+ )
256
+
257
+ # Compute the mean of q(f)
258
+ # k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z) + \mu_X
259
+ predictive_mean = interp_mean + test_mean
260
+
261
+ # Return the distribution
262
+ if hasattr(self.model, 'power'):
263
+ return MultivariateQExponential(predictive_mean, predictive_covar, power=self.model.power)
264
+ else:
265
+ return MultivariateNormal(predictive_mean, predictive_covar)
266
+
267
+ def kl_divergence(self) -> Tensor:
268
+ r"""
269
+ Compute the KL divergence between the variational inducing distribution :math:`q(\mathbf u)`
270
+ and the prior inducing distribution :math:`p(\mathbf u)`.
271
+
272
+ :rtype: torch.Tensor
273
+ """
274
+ if self._ngd():
275
+ if hasattr(self, "_memoize_cache") and "kl" in self._memoize_cache:
276
+ return self._memoize_cache["kl"]
277
+ else:
278
+ raise RuntimeError(
279
+ "KL divergence for NGD-CIQ should be computed during forward calls."
280
+ "This is probably a bug in GPyTorch."
281
+ )
282
+ else:
283
+ return super().kl_divergence()
284
+
285
+ def __call__(self, x: torch.Tensor, prior: bool = False, *params, **kwargs) -> Union[MultivariateNormal, MultivariateQExponential]:
286
+ # This is mostly the same as _VariationalStrategy.__call__()
287
+ # but with special rules for natural gradient descent (to prevent O(M^3) computation)
288
+
289
+ # If we're in prior mode, then we're done!
290
+ if prior:
291
+ return self.model.forward(x)
292
+
293
+ # Delete previously cached items from the training distribution
294
+ if self.training:
295
+ self._clear_cache()
296
+
297
+ # (Maybe) initialize variational distribution
298
+ if not self.variational_params_initialized.item():
299
+ if self._ngd():
300
+ noise = torch.randn_like(self.prior_distribution.mean).mul_(1e-3)
301
+ eye = torch.eye(noise.size(-1), dtype=noise.dtype, device=noise.device).mul(-0.5)
302
+ self._variational_distribution.natural_vec.data.copy_(noise)
303
+ self._variational_distribution.natural_mat.data.copy_(eye)
304
+ self.variational_params_initialized.fill_(1)
305
+ else:
306
+ prior_dist = self.prior_distribution
307
+ self._variational_distribution.initialize_variational_distribution(prior_dist)
308
+ self.variational_params_initialized.fill_(1)
309
+
310
+ # Ensure inducing_points and x are the same size
311
+ inducing_points = self.inducing_points
312
+ if inducing_points.shape[:-2] != x.shape[:-2]:
313
+ x, inducing_points = self._expand_inputs(x, inducing_points)
314
+
315
+ # Get q(f)
316
+ if self._ngd():
317
+ return Module.__call__(
318
+ self,
319
+ x,
320
+ inducing_points,
321
+ inducing_values=None,
322
+ variational_inducing_covar=None,
323
+ *params,
324
+ **kwargs,
325
+ )
326
+ else:
327
+ # Get p(u)/q(u)
328
+ variational_dist_u = self.variational_distribution
329
+
330
+ if isinstance(variational_dist_u, (MultivariateNormal, MultivariateQExponential)):
331
+ return Module.__call__(
332
+ self,
333
+ x,
334
+ inducing_points,
335
+ inducing_values=variational_dist_u.mean,
336
+ variational_inducing_covar=variational_dist_u.lazy_covariance_matrix,
337
+ **kwargs,
338
+ )
339
+ elif isinstance(variational_dist_u, Delta):
340
+ return Module.__call__(
341
+ self,
342
+ x,
343
+ inducing_points,
344
+ inducing_values=variational_dist_u.mean,
345
+ variational_inducing_covar=None,
346
+ **kwargs,
347
+ )
348
+ else:
349
+ raise RuntimeError(
350
+ f"Invalid variational distribuition ({type(variational_dist_u)}). "
351
+ "Expected a multivariate normal (q-exponential) or a delta distribution."
352
+ )