qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,317 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import warnings
4
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from linear_operator import to_dense
8
+ from linear_operator.operators import (
9
+ CholLinearOperator,
10
+ DiagLinearOperator,
11
+ LinearOperator,
12
+ MatmulLinearOperator,
13
+ RootLinearOperator,
14
+ SumLinearOperator,
15
+ TriangularLinearOperator,
16
+ BlockDiagLinearOperator,
17
+ KroneckerProductLinearOperator
18
+ )
19
+ from linear_operator.utils.cholesky import psd_safe_cholesky
20
+ from linear_operator.utils.errors import NotPSDError
21
+ from torch import Tensor
22
+
23
+ from ._variational_strategy import _VariationalStrategy
24
+ from .cholesky_variational_distribution import CholeskyVariationalDistribution
25
+
26
+ from ..distributions import MultivariateNormal, MultivariateQExponential, MultitaskMultivariateNormal, MultitaskMultivariateQExponential
27
+ from ..models import ApproximateGP, ApproximateQEP
28
+ from gpytorch.settings import _linalg_dtype_cholesky, trace_mode
29
+ from gpytorch.utils.errors import CachingError
30
+ from gpytorch.utils.memoize import cached, clear_cache_hook, pop_from_cache_ignore_args
31
+ from ..utils.warnings import OldVersionWarning
32
+ from . import _VariationalDistribution
33
+
34
+
35
+ def _ensure_updated_strategy_flag_set(
36
+ state_dict: Dict[str, Tensor],
37
+ prefix: str,
38
+ local_metadata: Dict[str, Any],
39
+ strict: bool,
40
+ missing_keys: Iterable[str],
41
+ unexpected_keys: Iterable[str],
42
+ error_msgs: Iterable[str],
43
+ ):
44
+ device = state_dict[list(state_dict.keys())[0]].device
45
+ if prefix + "updated_strategy" not in state_dict:
46
+ state_dict[prefix + "updated_strategy"] = torch.tensor(False, device=device)
47
+ warnings.warn(
48
+ "You have loaded a variational GP (QEP) model (using `VariationalStrategy`) from a previous version of "
49
+ "GPyTorch. We have updated the parameters of your model to work with the new version of "
50
+ "`VariationalStrategy` that uses whitened parameters.\nYour model will work as expected, but we "
51
+ "recommend that you re-save your model.",
52
+ OldVersionWarning,
53
+ )
54
+
55
+
56
+ class MultitaskVariationalStrategy(_VariationalStrategy):
57
+ r"""
58
+ The modified variational strategy, as defined by `Hensman et al. (2015)`_.
59
+ This strategy takes a set of :math:`m \ll n` inducing points :math:`\mathbf Z`
60
+ and applies an approximate distribution :math:`q( \mathbf u)` over their function values.
61
+ (Here, we use the common notation :math:`\mathbf u = f(\mathbf Z)`.
62
+ The approximate function distribution for any abitrary input :math:`\mathbf X` is given by:
63
+
64
+ .. math::
65
+
66
+ q( f(\mathbf X) ) = \int p( f(\mathbf X) \mid \mathbf u) q(\mathbf u) \: d\mathbf u
67
+
68
+ This variational strategy uses "whitening" to accelerate the optimization of the variational
69
+ parameters. See `Matthews (2017)`_ for more info.
70
+
71
+ :param model: Model this strategy is applied to.
72
+ Typically passed in when the VariationalStrategy is created in the
73
+ __init__ method of the user defined model.
74
+ It should contain power if Q-Exponential distribution is involved in.
75
+ It contain forward that outputs a MultitaskMultivariateNormal (MultitaskMultivariateQExponential) distribution.
76
+ :param inducing_points: Tensor containing a set of inducing
77
+ points to use for variational inference.
78
+ :param variational_distribution: A
79
+ VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
80
+ :param learn_inducing_locations: (Default True): Whether or not
81
+ the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they
82
+ parameters of the model).
83
+ :param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability
84
+
85
+ .. _Hensman et al. (2015):
86
+ http://proceedings.mlr.press/v38/hensman15.pdf
87
+ .. _Matthews (2017):
88
+ https://www.repository.cam.ac.uk/handle/1810/278022
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ model: Union[ApproximateGP, ApproximateQEP],
94
+ inducing_points: Tensor,
95
+ variational_distribution: _VariationalDistribution,
96
+ learn_inducing_locations: bool = True,
97
+ jitter_val: Optional[float] = None,
98
+ ):
99
+ super().__init__(
100
+ model, inducing_points, variational_distribution, learn_inducing_locations, jitter_val=jitter_val
101
+ )
102
+ self.register_buffer("updated_strategy", torch.tensor(True))
103
+ self._register_load_state_dict_pre_hook(_ensure_updated_strategy_flag_set)
104
+ self.has_fantasy_strategy = True
105
+
106
+ @cached(name="cholesky_factor", ignore_args=True)
107
+ def _cholesky_factor(self, induc_induc_covar: LinearOperator) -> TriangularLinearOperator:
108
+ L = psd_safe_cholesky(to_dense(induc_induc_covar).type(_linalg_dtype_cholesky.value()))
109
+ return TriangularLinearOperator(L)
110
+
111
+ @property
112
+ @cached(name="prior_distribution_memo")
113
+ def prior_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
114
+ zeros = torch.zeros(
115
+ self._variational_distribution.shape(),
116
+ dtype=self._variational_distribution.dtype,
117
+ device=self._variational_distribution.device,
118
+ )
119
+ ones = torch.ones_like(zeros)
120
+ if hasattr(self.model, 'power'):
121
+ res = MultivariateQExponential(zeros, DiagLinearOperator(ones), power=self.model.power)
122
+ else:
123
+ res = MultivariateNormal(zeros, DiagLinearOperator(ones))
124
+ return res
125
+
126
+ @property
127
+ @cached(name="pseudo_points_memo")
128
+ def pseudo_points(self) -> Tuple[Tensor, Tensor]:
129
+ # TODO: have var_mean, var_cov come from a method of _variational_distribution
130
+ # while having Kmm_root be a root decomposition to enable CIQVariationalDistribution support.
131
+
132
+ # retrieve the variational mean, m and covariance matrix, S.
133
+ if not isinstance(self._variational_distribution, CholeskyVariationalDistribution):
134
+ raise NotImplementedError(
135
+ "Only CholeskyVariationalDistribution has pseudo-point support currently, ",
136
+ "but your _variational_distribution is a ",
137
+ self._variational_distribution.__name__,
138
+ )
139
+
140
+ var_cov_root = TriangularLinearOperator(self._variational_distribution.chol_variational_covar)
141
+ var_cov = CholLinearOperator(var_cov_root)
142
+ var_mean = self.variational_distribution.mean
143
+ if var_mean.shape[-1] != 1:
144
+ var_mean = var_mean.unsqueeze(-1)
145
+
146
+ # compute R = I - S
147
+ cov_diff = var_cov.add_jitter(-1.0)
148
+ cov_diff = -1.0 * cov_diff
149
+
150
+ # K^{1/2}
151
+ Kmm = self.model.covar_module(self.inducing_points)
152
+ Kmm_root = Kmm.cholesky()
153
+
154
+ # D_a = (S^{-1} - K^{-1})^{-1} = S + S R^{-1} S
155
+ # note that in the whitened case R = I - S, unwhitened R = K - S
156
+ # we compute (R R^{T})^{-1} R^T S for stability reasons as R is probably not PSD.
157
+ eval_var_cov = var_cov.to_dense()
158
+ eval_rhs = cov_diff.transpose(-1, -2).matmul(eval_var_cov)
159
+ inner_term = cov_diff.matmul(cov_diff.transpose(-1, -2))
160
+ # TODO: flag the jitter here
161
+ inner_solve = inner_term.add_jitter(self.jitter_val).solve(eval_rhs, eval_var_cov.transpose(-1, -2))
162
+ inducing_covar = var_cov + inner_solve
163
+
164
+ inducing_covar = Kmm_root.matmul(inducing_covar).matmul(Kmm_root.transpose(-1, -2))
165
+
166
+ # mean term: D_a S^{-1} m
167
+ # unwhitened: (S - S R^{-1} S) S^{-1} m = (I - S R^{-1}) m
168
+ rhs = cov_diff.transpose(-1, -2).matmul(var_mean)
169
+ # TODO: this jitter too
170
+ inner_rhs_mean_solve = inner_term.add_jitter(self.jitter_val).solve(rhs)
171
+ pseudo_target_mean = Kmm_root.matmul(inner_rhs_mean_solve)
172
+
173
+ # ensure inducing covar is psd
174
+ # TODO: make this be an explicit root decomposition
175
+ try:
176
+ pseudo_target_covar = CholLinearOperator(inducing_covar.add_jitter(self.jitter_val).cholesky()).to_dense()
177
+ except NotPSDError:
178
+ from linear_operator.operators import DiagLinearOperator
179
+
180
+ evals, evecs = torch.linalg.eigh(inducing_covar)
181
+ pseudo_target_covar = (
182
+ evecs.matmul(DiagLinearOperator(evals + self.jitter_val)).matmul(evecs.transpose(-1, -2)).to_dense()
183
+ )
184
+
185
+ return pseudo_target_covar, pseudo_target_mean
186
+
187
+ def forward(
188
+ self,
189
+ x: Tensor,
190
+ inducing_points: Tensor,
191
+ inducing_values: Tensor,
192
+ variational_inducing_covar: Optional[LinearOperator] = None,
193
+ **kwargs,
194
+ ) -> Union[MultitaskMultivariateNormal, MultitaskMultivariateQExponential]:
195
+ # Compute full prior distribution
196
+ full_inputs = torch.cat([inducing_points, x], dim=-2)
197
+ full_output = self.model.forward(full_inputs, **kwargs) # MultitaskMultivariateNormal or MultitaskMultivariateQExponential
198
+ if not type(full_output) in (MultitaskMultivariateNormal, MultitaskMultivariateQExponential):
199
+ raise TypeError(
200
+ "The type of model forward p(f(X)) is ",
201
+ full_output.__class__.__name__,
202
+ ", not multitask. Please use regular VariationalStrategy instead.")
203
+ full_covar = full_output.lazy_covariance_matrix
204
+
205
+ num_tasks = full_output.num_tasks#.event_shape[-1]
206
+ _interleaved = full_output._interleaved
207
+ # Covariance terms
208
+ num_induc = inducing_points.size(-2)
209
+ test_mean = full_output.mean[..., num_induc:, :]
210
+ if _interleaved:
211
+ induc_induc_covar = full_covar[..., :(num_induc*num_tasks), :(num_induc*num_tasks)].add_jitter(self.jitter_val) # interleaved
212
+ induc_data_covar = full_covar[..., :(num_induc*num_tasks), (num_induc*num_tasks):].to_dense()
213
+ data_data_covar = full_covar[..., (num_induc*num_tasks):, (num_induc*num_tasks):]
214
+ else:
215
+ induc_idx = (torch.arange(num_induc, device=full_covar.device)+torch.arange(num_tasks, device=full_covar.device)[:,None]*full_output.event_shape[0]).flatten()
216
+ data_idx = (torch.arange(num_induc, full_output.event_shape[0], device=full_covar.device)+torch.arange(num_tasks, device=full_covar.device)[:,None]*full_output.event_shape[0]).flatten()
217
+ induc_induc_covar = full_covar[..., induc_idx, :][..., induc_idx].add_jitter(self.jitter_val) # not interleaved
218
+ induc_data_covar = full_covar[..., induc_idx, :][..., data_idx].to_dense()
219
+ data_data_covar = full_covar[..., data_idx, :][..., data_idx]
220
+
221
+ # Compute interpolation terms
222
+ # K_ZZ^{-1/2} K_ZX
223
+ # K_ZZ^{-1/2} \mu_Z
224
+ L = self._cholesky_factor(induc_induc_covar)
225
+ if L.shape != induc_induc_covar.shape:
226
+ # Aggressive caching can cause nasty shape incompatibilies when evaluating with different batch shapes
227
+ # TODO: Use a hook fo this
228
+ try:
229
+ pop_from_cache_ignore_args(self, "cholesky_factor")
230
+ except CachingError:
231
+ pass
232
+ L = self._cholesky_factor(induc_induc_covar)
233
+ interp_term = L.solve(induc_data_covar.type(_linalg_dtype_cholesky.value())).to(full_inputs.dtype)
234
+
235
+ # Compute the mean of q(f)
236
+ # k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z) + \mu_X
237
+ if len(self.variational_distribution.batch_shape) > 0:
238
+ if _interleaved: inducing_values = inducing_values.transpose(-1, -2)
239
+ inducing_values = inducing_values.reshape(*inducing_values.shape[:-2], -1)
240
+ else:
241
+ inducing_values = inducing_values.repeat_interleave(num_tasks,-1) if _interleaved else inducing_values.tile(num_tasks)
242
+ predictive_mean = (interp_term.transpose(-1, -2) @ inducing_values.unsqueeze(-1)).squeeze(-1)
243
+ if _interleaved:
244
+ predictive_mean = predictive_mean.reshape_as(test_mean) + test_mean
245
+ else:
246
+ new_shape = test_mean.shape[:-2] + test_mean.shape[:-3:-1]
247
+ predictive_mean = predictive_mean.view(new_shape).transpose(-1, -2).contiguous() + test_mean
248
+
249
+ # Compute the covariance of q(f)
250
+ # K_XX + k_XZ K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} k_ZX
251
+ middle_term = self.prior_distribution.lazy_covariance_matrix.mul(-1)
252
+ if variational_inducing_covar is not None:
253
+ middle_term = SumLinearOperator(variational_inducing_covar, middle_term)
254
+ if len(self.variational_distribution.batch_shape) > 0:
255
+ middle_term = BlockDiagLinearOperator(middle_term)
256
+ if _interleaved:
257
+ pi = torch.arange(num_induc * num_tasks, device=middle_term.device).view(num_tasks, num_induc).t().reshape((num_induc * num_tasks))
258
+ middle_term = middle_term[..., pi, :][..., :, pi]
259
+ else:
260
+ if _interleaved:
261
+ middle_term = KroneckerProductLinearOperator(middle_term, DiagLinearOperator(torch.ones(num_tasks, device=middle_term.device)))
262
+ else:
263
+ middle_term = KroneckerProductLinearOperator(DiagLinearOperator(torch.ones(num_tasks, device=middle_term.device)), middle_term)
264
+
265
+ if trace_mode.on():
266
+ predictive_covar = (
267
+ data_data_covar.add_jitter(self.jitter_val).to_dense()
268
+ + interp_term.transpose(-1, -2) @ middle_term.to_dense() @ interp_term
269
+ )
270
+ else:
271
+ predictive_covar = SumLinearOperator(
272
+ data_data_covar.add_jitter(self.jitter_val),
273
+ MatmulLinearOperator(interp_term.transpose(-1, -2), middle_term @ interp_term),
274
+ )
275
+
276
+ # Return the distribution
277
+ if hasattr(self.model, 'power'):
278
+ return MultitaskMultivariateQExponential(predictive_mean, predictive_covar, power=self.model.power, interleaved=_interleaved)
279
+ else:
280
+ return MultitaskMultivariateNormal(predictive_mean, predictive_covar, interleaved=_interleaved)
281
+
282
+ def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> Union[MultivariateNormal, MultivariateQExponential]:
283
+ if not self.updated_strategy.item() and not prior:
284
+ with torch.no_grad():
285
+ # Get unwhitened p(u)
286
+ prior_function_dist = self(self.inducing_points, prior=True)
287
+ prior_mean = prior_function_dist.loc
288
+ L = self._cholesky_factor(prior_function_dist.lazy_covariance_matrix.add_jitter(self.jitter_val))
289
+
290
+ # Temporarily turn off noise that's added to the mean
291
+ orig_mean_init_std = self._variational_distribution.mean_init_std
292
+ self._variational_distribution.mean_init_std = 0.0
293
+
294
+ # Change the variational parameters to be whitened
295
+ variational_dist = self.variational_distribution
296
+ if isinstance(variational_dist, (MultivariateNormal, MultivariateQExponential)):
297
+ mean_diff = (variational_dist.loc - prior_mean).unsqueeze(-1).type(_linalg_dtype_cholesky.value())
298
+ whitened_mean = L.solve(mean_diff).squeeze(-1).to(variational_dist.loc.dtype)
299
+ covar_root = variational_dist.lazy_covariance_matrix.root_decomposition().root.to_dense()
300
+ covar_root = covar_root.type(_linalg_dtype_cholesky.value())
301
+ whitened_covar = RootLinearOperator(L.solve(covar_root).to(variational_dist.loc.dtype))
302
+ whitened_variational_distribution = variational_dist.__class__(whitened_mean, whitened_covar)
303
+ if isinstance(variational_dist, MultivariateQExponential): whitened_variational_distribution.power = variational_dist.power
304
+ self._variational_distribution.initialize_variational_distribution(
305
+ whitened_variational_distribution
306
+ )
307
+
308
+ # Reset the random noise parameter of the model
309
+ self._variational_distribution.mean_init_std = orig_mean_init_std
310
+
311
+ # Reset the cache
312
+ clear_cache_hook(self)
313
+
314
+ # Mark that we have updated the variational strategy
315
+ self.updated_strategy.fill_(True)
316
+
317
+ return super().__call__(x, prior=prior, **kwargs)
@@ -0,0 +1,152 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import abc
4
+
5
+ import torch
6
+ from linear_operator.operators import CholLinearOperator, TriangularLinearOperator
7
+ from linear_operator.utils.cholesky import psd_safe_cholesky
8
+
9
+ from ..distributions import MultivariateNormal, MultivariateQExponential
10
+ from ._variational_distribution import _VariationalDistribution
11
+
12
+
13
+ class _NaturalVariationalDistribution(_VariationalDistribution, abc.ABC):
14
+ r"""Any :obj:`~qpytorch.variational._VariationalDistribution` which calculates
15
+ natural gradients with respect to its parameters.
16
+ """
17
+ pass
18
+
19
+
20
+ class NaturalVariationalDistribution(_NaturalVariationalDistribution):
21
+ r"""A multivariate normal :obj:`~qpytorch.variational._VariationalDistribution`,
22
+ parameterized by **natural** parameters.
23
+
24
+ .. note::
25
+ The :obj:`~qpytorch.variational.NaturalVariationalDistribution` can only
26
+ be used with :obj:`gpytorch.optim.NGD`, or other optimizers
27
+ that follow exactly the gradient direction. Failure to do so will cause
28
+ the natural matrix :math:`\mathbf \Theta_\text{mat}` to stop being
29
+ positive definite, and a :obj:`~RuntimeError` will be raised.
30
+
31
+ .. seealso::
32
+ The `natural gradient descent tutorial
33
+ <examples/04_Variational_and_Approximate_GPs/Natural_Gradient_Descent.ipynb>`_
34
+ for use instructions.
35
+
36
+ The :obj:`~qpytorch.variational.TrilNaturalVariationalDistribution` for
37
+ a more numerically stable parameterization, at the cost of needing more
38
+ iterations to make variational regression converge.
39
+
40
+ :param int num_inducing_points: Size of the variational distribution. This implies that the variational mean
41
+ should be this size, and the variational covariance matrix should have this many rows and columns.
42
+ :param batch_shape: Specifies an optional batch size
43
+ for the variational parameters. This is useful for example when doing additive variational inference.
44
+ :type batch_shape: :obj:`torch.Size`, optional
45
+ :param float mean_init_std: (Default: 1e-3) Standard deviation of gaussian (q-exponential) noise to add to the mean initialization.
46
+
47
+ """
48
+
49
+ def __init__(self, num_inducing_points, batch_shape=torch.Size([]), mean_init_std=1e-3, **kwargs):
50
+ super().__init__(num_inducing_points=num_inducing_points, batch_shape=batch_shape, mean_init_std=mean_init_std)
51
+ scaled_mean_init = torch.zeros(num_inducing_points)
52
+ neg_prec_init = torch.eye(num_inducing_points, num_inducing_points).mul(-0.5)
53
+ scaled_mean_init = scaled_mean_init.repeat(*batch_shape, 1)
54
+ neg_prec_init = neg_prec_init.repeat(*batch_shape, 1, 1)
55
+
56
+ # eta1 and eta2 parameterization of the variational distribution
57
+ self.register_parameter(name="natural_vec", parameter=torch.nn.Parameter(scaled_mean_init))
58
+ self.register_parameter(name="natural_mat", parameter=torch.nn.Parameter(neg_prec_init))
59
+
60
+ if 'power' in kwargs: self.power = kwargs.pop('power')
61
+
62
+ def forward(self):
63
+ mean, chol_covar = _NaturalToMuVarSqrt.apply(self.natural_vec, self.natural_mat)
64
+ covar = CholLinearOperator(TriangularLinearOperator(chol_covar))
65
+ if not hasattr(self, 'power'):
66
+ res = MultivariateNormal(mean, covar)
67
+ else:
68
+ res = MultivariateQExponential(mean, covar, power=self.power)
69
+ return res
70
+
71
+ def initialize_variational_distribution(self, prior_dist):
72
+ prior_prec = prior_dist.covariance_matrix.inverse()
73
+ prior_mean = prior_dist.mean
74
+ noise = torch.randn_like(prior_mean).mul_(self.mean_init_std)
75
+
76
+ self.natural_vec.data.copy_((prior_prec @ prior_mean.unsqueeze(-1)).squeeze(-1).add_(noise))
77
+ self.natural_mat.data.copy_(prior_prec.mul(-0.5))
78
+
79
+
80
+ def _triangular_inverse(A, upper=False):
81
+ eye = torch.eye(A.size(-1), dtype=A.dtype, device=A.device)
82
+ return torch.linalg.solve_triangular(A, eye, upper=upper)
83
+
84
+
85
+ def _phi_for_cholesky_(A):
86
+ "Modifies A to be the phi function used in differentiating through Cholesky"
87
+ A.tril_().diagonal(offset=0, dim1=-2, dim2=-1).mul_(0.5)
88
+ return A
89
+
90
+
91
+ def _cholesky_backward(dout_dL, L, L_inverse):
92
+ # c.f. https://github.com/pytorch/pytorch/blob/25ba802ce4cbdeaebcad4a03cec8502f0de9b7b3/
93
+ # tools/autograd/templates/Functions.cpp
94
+ A = L.transpose(-1, -2) @ dout_dL
95
+ phi = _phi_for_cholesky_(A)
96
+ grad_input = (L_inverse.transpose(-1, -2) @ phi) @ L_inverse
97
+ # Symmetrize gradient
98
+ return grad_input.add(grad_input.transpose(-1, -2)).mul_(0.5)
99
+
100
+
101
+ class _NaturalToMuVarSqrt(torch.autograd.Function):
102
+ @staticmethod
103
+ def _forward(nat_mean, nat_covar):
104
+ try:
105
+ L_inv = psd_safe_cholesky(-2.0 * nat_covar, upper=False)
106
+ except RuntimeError as e:
107
+ if str(e).startswith("cholesky"):
108
+ raise RuntimeError(
109
+ "Non-negative-definite natural covariance. You probably "
110
+ "updated it using an optimizer other than gpytorch.optim.NGD (such as Adam). "
111
+ "This is not supported."
112
+ )
113
+ else:
114
+ raise e
115
+ L = _triangular_inverse(L_inv, upper=False)
116
+ S = L.transpose(-1, -2) @ L
117
+ mu = (S @ nat_mean.unsqueeze(-1)).squeeze(-1)
118
+ # Two choleskys are annoying, but we don't have good support for a
119
+ # LinearOperator of form L.T @ L
120
+ return mu, psd_safe_cholesky(S, upper=False)
121
+
122
+ @staticmethod
123
+ def forward(ctx, nat_mean, nat_covar):
124
+ mu, L = _NaturalToMuVarSqrt._forward(nat_mean, nat_covar)
125
+ ctx.save_for_backward(mu, L)
126
+ return mu, L
127
+
128
+ @staticmethod
129
+ def _backward(dout_dmu, dout_dL, mu, L, C):
130
+ """Calculate dout/d(eta1, eta2), which are:
131
+ eta1 = mu
132
+ eta2 = mu*mu^T + LL^T = mu*mu^T + Sigma
133
+
134
+ Thus:
135
+ dout/deta1 = dout/dmu + dout/dL dL/deta1
136
+ dout/deta2 = dout/dL dL/deta1
137
+
138
+ For L = chol(eta2 - eta1*eta1^T).
139
+ dout/dSigma = _cholesky_backward(dout/dL, L)
140
+ dout/deta2 = dout/dSigma
141
+ dSigma/deta1 = -2* (dout/dSigma) mu
142
+ """
143
+ dout_dSigma = _cholesky_backward(dout_dL, L, C)
144
+ dout_deta1 = dout_dmu - 2 * (dout_dSigma @ mu.unsqueeze(-1)).squeeze(-1)
145
+ return dout_deta1, dout_dSigma
146
+
147
+ @staticmethod
148
+ def backward(ctx, dout_dmu, dout_dL):
149
+ "Calculates the natural gradient with respect to nat_mean, nat_covar"
150
+ mu, L = ctx.saved_tensors
151
+ C = _triangular_inverse(L, upper=False)
152
+ return _NaturalToMuVarSqrt._backward(dout_dmu, dout_dL, mu, L, C)