qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from linear_operator import to_dense
|
|
8
|
+
from linear_operator.operators import (
|
|
9
|
+
CholLinearOperator,
|
|
10
|
+
DiagLinearOperator,
|
|
11
|
+
LinearOperator,
|
|
12
|
+
MatmulLinearOperator,
|
|
13
|
+
RootLinearOperator,
|
|
14
|
+
SumLinearOperator,
|
|
15
|
+
TriangularLinearOperator,
|
|
16
|
+
BlockDiagLinearOperator,
|
|
17
|
+
KroneckerProductLinearOperator
|
|
18
|
+
)
|
|
19
|
+
from linear_operator.utils.cholesky import psd_safe_cholesky
|
|
20
|
+
from linear_operator.utils.errors import NotPSDError
|
|
21
|
+
from torch import Tensor
|
|
22
|
+
|
|
23
|
+
from ._variational_strategy import _VariationalStrategy
|
|
24
|
+
from .cholesky_variational_distribution import CholeskyVariationalDistribution
|
|
25
|
+
|
|
26
|
+
from ..distributions import MultivariateNormal, MultivariateQExponential, MultitaskMultivariateNormal, MultitaskMultivariateQExponential
|
|
27
|
+
from ..models import ApproximateGP, ApproximateQEP
|
|
28
|
+
from gpytorch.settings import _linalg_dtype_cholesky, trace_mode
|
|
29
|
+
from gpytorch.utils.errors import CachingError
|
|
30
|
+
from gpytorch.utils.memoize import cached, clear_cache_hook, pop_from_cache_ignore_args
|
|
31
|
+
from ..utils.warnings import OldVersionWarning
|
|
32
|
+
from . import _VariationalDistribution
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _ensure_updated_strategy_flag_set(
|
|
36
|
+
state_dict: Dict[str, Tensor],
|
|
37
|
+
prefix: str,
|
|
38
|
+
local_metadata: Dict[str, Any],
|
|
39
|
+
strict: bool,
|
|
40
|
+
missing_keys: Iterable[str],
|
|
41
|
+
unexpected_keys: Iterable[str],
|
|
42
|
+
error_msgs: Iterable[str],
|
|
43
|
+
):
|
|
44
|
+
device = state_dict[list(state_dict.keys())[0]].device
|
|
45
|
+
if prefix + "updated_strategy" not in state_dict:
|
|
46
|
+
state_dict[prefix + "updated_strategy"] = torch.tensor(False, device=device)
|
|
47
|
+
warnings.warn(
|
|
48
|
+
"You have loaded a variational GP (QEP) model (using `VariationalStrategy`) from a previous version of "
|
|
49
|
+
"GPyTorch. We have updated the parameters of your model to work with the new version of "
|
|
50
|
+
"`VariationalStrategy` that uses whitened parameters.\nYour model will work as expected, but we "
|
|
51
|
+
"recommend that you re-save your model.",
|
|
52
|
+
OldVersionWarning,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class MultitaskVariationalStrategy(_VariationalStrategy):
|
|
57
|
+
r"""
|
|
58
|
+
The modified variational strategy, as defined by `Hensman et al. (2015)`_.
|
|
59
|
+
This strategy takes a set of :math:`m \ll n` inducing points :math:`\mathbf Z`
|
|
60
|
+
and applies an approximate distribution :math:`q( \mathbf u)` over their function values.
|
|
61
|
+
(Here, we use the common notation :math:`\mathbf u = f(\mathbf Z)`.
|
|
62
|
+
The approximate function distribution for any abitrary input :math:`\mathbf X` is given by:
|
|
63
|
+
|
|
64
|
+
.. math::
|
|
65
|
+
|
|
66
|
+
q( f(\mathbf X) ) = \int p( f(\mathbf X) \mid \mathbf u) q(\mathbf u) \: d\mathbf u
|
|
67
|
+
|
|
68
|
+
This variational strategy uses "whitening" to accelerate the optimization of the variational
|
|
69
|
+
parameters. See `Matthews (2017)`_ for more info.
|
|
70
|
+
|
|
71
|
+
:param model: Model this strategy is applied to.
|
|
72
|
+
Typically passed in when the VariationalStrategy is created in the
|
|
73
|
+
__init__ method of the user defined model.
|
|
74
|
+
It should contain power if Q-Exponential distribution is involved in.
|
|
75
|
+
It contain forward that outputs a MultitaskMultivariateNormal (MultitaskMultivariateQExponential) distribution.
|
|
76
|
+
:param inducing_points: Tensor containing a set of inducing
|
|
77
|
+
points to use for variational inference.
|
|
78
|
+
:param variational_distribution: A
|
|
79
|
+
VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
|
|
80
|
+
:param learn_inducing_locations: (Default True): Whether or not
|
|
81
|
+
the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they
|
|
82
|
+
parameters of the model).
|
|
83
|
+
:param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability
|
|
84
|
+
|
|
85
|
+
.. _Hensman et al. (2015):
|
|
86
|
+
http://proceedings.mlr.press/v38/hensman15.pdf
|
|
87
|
+
.. _Matthews (2017):
|
|
88
|
+
https://www.repository.cam.ac.uk/handle/1810/278022
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
model: Union[ApproximateGP, ApproximateQEP],
|
|
94
|
+
inducing_points: Tensor,
|
|
95
|
+
variational_distribution: _VariationalDistribution,
|
|
96
|
+
learn_inducing_locations: bool = True,
|
|
97
|
+
jitter_val: Optional[float] = None,
|
|
98
|
+
):
|
|
99
|
+
super().__init__(
|
|
100
|
+
model, inducing_points, variational_distribution, learn_inducing_locations, jitter_val=jitter_val
|
|
101
|
+
)
|
|
102
|
+
self.register_buffer("updated_strategy", torch.tensor(True))
|
|
103
|
+
self._register_load_state_dict_pre_hook(_ensure_updated_strategy_flag_set)
|
|
104
|
+
self.has_fantasy_strategy = True
|
|
105
|
+
|
|
106
|
+
@cached(name="cholesky_factor", ignore_args=True)
|
|
107
|
+
def _cholesky_factor(self, induc_induc_covar: LinearOperator) -> TriangularLinearOperator:
|
|
108
|
+
L = psd_safe_cholesky(to_dense(induc_induc_covar).type(_linalg_dtype_cholesky.value()))
|
|
109
|
+
return TriangularLinearOperator(L)
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
@cached(name="prior_distribution_memo")
|
|
113
|
+
def prior_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
114
|
+
zeros = torch.zeros(
|
|
115
|
+
self._variational_distribution.shape(),
|
|
116
|
+
dtype=self._variational_distribution.dtype,
|
|
117
|
+
device=self._variational_distribution.device,
|
|
118
|
+
)
|
|
119
|
+
ones = torch.ones_like(zeros)
|
|
120
|
+
if hasattr(self.model, 'power'):
|
|
121
|
+
res = MultivariateQExponential(zeros, DiagLinearOperator(ones), power=self.model.power)
|
|
122
|
+
else:
|
|
123
|
+
res = MultivariateNormal(zeros, DiagLinearOperator(ones))
|
|
124
|
+
return res
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
@cached(name="pseudo_points_memo")
|
|
128
|
+
def pseudo_points(self) -> Tuple[Tensor, Tensor]:
|
|
129
|
+
# TODO: have var_mean, var_cov come from a method of _variational_distribution
|
|
130
|
+
# while having Kmm_root be a root decomposition to enable CIQVariationalDistribution support.
|
|
131
|
+
|
|
132
|
+
# retrieve the variational mean, m and covariance matrix, S.
|
|
133
|
+
if not isinstance(self._variational_distribution, CholeskyVariationalDistribution):
|
|
134
|
+
raise NotImplementedError(
|
|
135
|
+
"Only CholeskyVariationalDistribution has pseudo-point support currently, ",
|
|
136
|
+
"but your _variational_distribution is a ",
|
|
137
|
+
self._variational_distribution.__name__,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
var_cov_root = TriangularLinearOperator(self._variational_distribution.chol_variational_covar)
|
|
141
|
+
var_cov = CholLinearOperator(var_cov_root)
|
|
142
|
+
var_mean = self.variational_distribution.mean
|
|
143
|
+
if var_mean.shape[-1] != 1:
|
|
144
|
+
var_mean = var_mean.unsqueeze(-1)
|
|
145
|
+
|
|
146
|
+
# compute R = I - S
|
|
147
|
+
cov_diff = var_cov.add_jitter(-1.0)
|
|
148
|
+
cov_diff = -1.0 * cov_diff
|
|
149
|
+
|
|
150
|
+
# K^{1/2}
|
|
151
|
+
Kmm = self.model.covar_module(self.inducing_points)
|
|
152
|
+
Kmm_root = Kmm.cholesky()
|
|
153
|
+
|
|
154
|
+
# D_a = (S^{-1} - K^{-1})^{-1} = S + S R^{-1} S
|
|
155
|
+
# note that in the whitened case R = I - S, unwhitened R = K - S
|
|
156
|
+
# we compute (R R^{T})^{-1} R^T S for stability reasons as R is probably not PSD.
|
|
157
|
+
eval_var_cov = var_cov.to_dense()
|
|
158
|
+
eval_rhs = cov_diff.transpose(-1, -2).matmul(eval_var_cov)
|
|
159
|
+
inner_term = cov_diff.matmul(cov_diff.transpose(-1, -2))
|
|
160
|
+
# TODO: flag the jitter here
|
|
161
|
+
inner_solve = inner_term.add_jitter(self.jitter_val).solve(eval_rhs, eval_var_cov.transpose(-1, -2))
|
|
162
|
+
inducing_covar = var_cov + inner_solve
|
|
163
|
+
|
|
164
|
+
inducing_covar = Kmm_root.matmul(inducing_covar).matmul(Kmm_root.transpose(-1, -2))
|
|
165
|
+
|
|
166
|
+
# mean term: D_a S^{-1} m
|
|
167
|
+
# unwhitened: (S - S R^{-1} S) S^{-1} m = (I - S R^{-1}) m
|
|
168
|
+
rhs = cov_diff.transpose(-1, -2).matmul(var_mean)
|
|
169
|
+
# TODO: this jitter too
|
|
170
|
+
inner_rhs_mean_solve = inner_term.add_jitter(self.jitter_val).solve(rhs)
|
|
171
|
+
pseudo_target_mean = Kmm_root.matmul(inner_rhs_mean_solve)
|
|
172
|
+
|
|
173
|
+
# ensure inducing covar is psd
|
|
174
|
+
# TODO: make this be an explicit root decomposition
|
|
175
|
+
try:
|
|
176
|
+
pseudo_target_covar = CholLinearOperator(inducing_covar.add_jitter(self.jitter_val).cholesky()).to_dense()
|
|
177
|
+
except NotPSDError:
|
|
178
|
+
from linear_operator.operators import DiagLinearOperator
|
|
179
|
+
|
|
180
|
+
evals, evecs = torch.linalg.eigh(inducing_covar)
|
|
181
|
+
pseudo_target_covar = (
|
|
182
|
+
evecs.matmul(DiagLinearOperator(evals + self.jitter_val)).matmul(evecs.transpose(-1, -2)).to_dense()
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return pseudo_target_covar, pseudo_target_mean
|
|
186
|
+
|
|
187
|
+
def forward(
|
|
188
|
+
self,
|
|
189
|
+
x: Tensor,
|
|
190
|
+
inducing_points: Tensor,
|
|
191
|
+
inducing_values: Tensor,
|
|
192
|
+
variational_inducing_covar: Optional[LinearOperator] = None,
|
|
193
|
+
**kwargs,
|
|
194
|
+
) -> Union[MultitaskMultivariateNormal, MultitaskMultivariateQExponential]:
|
|
195
|
+
# Compute full prior distribution
|
|
196
|
+
full_inputs = torch.cat([inducing_points, x], dim=-2)
|
|
197
|
+
full_output = self.model.forward(full_inputs, **kwargs) # MultitaskMultivariateNormal or MultitaskMultivariateQExponential
|
|
198
|
+
if not type(full_output) in (MultitaskMultivariateNormal, MultitaskMultivariateQExponential):
|
|
199
|
+
raise TypeError(
|
|
200
|
+
"The type of model forward p(f(X)) is ",
|
|
201
|
+
full_output.__class__.__name__,
|
|
202
|
+
", not multitask. Please use regular VariationalStrategy instead.")
|
|
203
|
+
full_covar = full_output.lazy_covariance_matrix
|
|
204
|
+
|
|
205
|
+
num_tasks = full_output.num_tasks#.event_shape[-1]
|
|
206
|
+
_interleaved = full_output._interleaved
|
|
207
|
+
# Covariance terms
|
|
208
|
+
num_induc = inducing_points.size(-2)
|
|
209
|
+
test_mean = full_output.mean[..., num_induc:, :]
|
|
210
|
+
if _interleaved:
|
|
211
|
+
induc_induc_covar = full_covar[..., :(num_induc*num_tasks), :(num_induc*num_tasks)].add_jitter(self.jitter_val) # interleaved
|
|
212
|
+
induc_data_covar = full_covar[..., :(num_induc*num_tasks), (num_induc*num_tasks):].to_dense()
|
|
213
|
+
data_data_covar = full_covar[..., (num_induc*num_tasks):, (num_induc*num_tasks):]
|
|
214
|
+
else:
|
|
215
|
+
induc_idx = (torch.arange(num_induc, device=full_covar.device)+torch.arange(num_tasks, device=full_covar.device)[:,None]*full_output.event_shape[0]).flatten()
|
|
216
|
+
data_idx = (torch.arange(num_induc, full_output.event_shape[0], device=full_covar.device)+torch.arange(num_tasks, device=full_covar.device)[:,None]*full_output.event_shape[0]).flatten()
|
|
217
|
+
induc_induc_covar = full_covar[..., induc_idx, :][..., induc_idx].add_jitter(self.jitter_val) # not interleaved
|
|
218
|
+
induc_data_covar = full_covar[..., induc_idx, :][..., data_idx].to_dense()
|
|
219
|
+
data_data_covar = full_covar[..., data_idx, :][..., data_idx]
|
|
220
|
+
|
|
221
|
+
# Compute interpolation terms
|
|
222
|
+
# K_ZZ^{-1/2} K_ZX
|
|
223
|
+
# K_ZZ^{-1/2} \mu_Z
|
|
224
|
+
L = self._cholesky_factor(induc_induc_covar)
|
|
225
|
+
if L.shape != induc_induc_covar.shape:
|
|
226
|
+
# Aggressive caching can cause nasty shape incompatibilies when evaluating with different batch shapes
|
|
227
|
+
# TODO: Use a hook fo this
|
|
228
|
+
try:
|
|
229
|
+
pop_from_cache_ignore_args(self, "cholesky_factor")
|
|
230
|
+
except CachingError:
|
|
231
|
+
pass
|
|
232
|
+
L = self._cholesky_factor(induc_induc_covar)
|
|
233
|
+
interp_term = L.solve(induc_data_covar.type(_linalg_dtype_cholesky.value())).to(full_inputs.dtype)
|
|
234
|
+
|
|
235
|
+
# Compute the mean of q(f)
|
|
236
|
+
# k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z) + \mu_X
|
|
237
|
+
if len(self.variational_distribution.batch_shape) > 0:
|
|
238
|
+
if _interleaved: inducing_values = inducing_values.transpose(-1, -2)
|
|
239
|
+
inducing_values = inducing_values.reshape(*inducing_values.shape[:-2], -1)
|
|
240
|
+
else:
|
|
241
|
+
inducing_values = inducing_values.repeat_interleave(num_tasks,-1) if _interleaved else inducing_values.tile(num_tasks)
|
|
242
|
+
predictive_mean = (interp_term.transpose(-1, -2) @ inducing_values.unsqueeze(-1)).squeeze(-1)
|
|
243
|
+
if _interleaved:
|
|
244
|
+
predictive_mean = predictive_mean.reshape_as(test_mean) + test_mean
|
|
245
|
+
else:
|
|
246
|
+
new_shape = test_mean.shape[:-2] + test_mean.shape[:-3:-1]
|
|
247
|
+
predictive_mean = predictive_mean.view(new_shape).transpose(-1, -2).contiguous() + test_mean
|
|
248
|
+
|
|
249
|
+
# Compute the covariance of q(f)
|
|
250
|
+
# K_XX + k_XZ K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} k_ZX
|
|
251
|
+
middle_term = self.prior_distribution.lazy_covariance_matrix.mul(-1)
|
|
252
|
+
if variational_inducing_covar is not None:
|
|
253
|
+
middle_term = SumLinearOperator(variational_inducing_covar, middle_term)
|
|
254
|
+
if len(self.variational_distribution.batch_shape) > 0:
|
|
255
|
+
middle_term = BlockDiagLinearOperator(middle_term)
|
|
256
|
+
if _interleaved:
|
|
257
|
+
pi = torch.arange(num_induc * num_tasks, device=middle_term.device).view(num_tasks, num_induc).t().reshape((num_induc * num_tasks))
|
|
258
|
+
middle_term = middle_term[..., pi, :][..., :, pi]
|
|
259
|
+
else:
|
|
260
|
+
if _interleaved:
|
|
261
|
+
middle_term = KroneckerProductLinearOperator(middle_term, DiagLinearOperator(torch.ones(num_tasks, device=middle_term.device)))
|
|
262
|
+
else:
|
|
263
|
+
middle_term = KroneckerProductLinearOperator(DiagLinearOperator(torch.ones(num_tasks, device=middle_term.device)), middle_term)
|
|
264
|
+
|
|
265
|
+
if trace_mode.on():
|
|
266
|
+
predictive_covar = (
|
|
267
|
+
data_data_covar.add_jitter(self.jitter_val).to_dense()
|
|
268
|
+
+ interp_term.transpose(-1, -2) @ middle_term.to_dense() @ interp_term
|
|
269
|
+
)
|
|
270
|
+
else:
|
|
271
|
+
predictive_covar = SumLinearOperator(
|
|
272
|
+
data_data_covar.add_jitter(self.jitter_val),
|
|
273
|
+
MatmulLinearOperator(interp_term.transpose(-1, -2), middle_term @ interp_term),
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Return the distribution
|
|
277
|
+
if hasattr(self.model, 'power'):
|
|
278
|
+
return MultitaskMultivariateQExponential(predictive_mean, predictive_covar, power=self.model.power, interleaved=_interleaved)
|
|
279
|
+
else:
|
|
280
|
+
return MultitaskMultivariateNormal(predictive_mean, predictive_covar, interleaved=_interleaved)
|
|
281
|
+
|
|
282
|
+
def __call__(self, x: Tensor, prior: bool = False, **kwargs) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
283
|
+
if not self.updated_strategy.item() and not prior:
|
|
284
|
+
with torch.no_grad():
|
|
285
|
+
# Get unwhitened p(u)
|
|
286
|
+
prior_function_dist = self(self.inducing_points, prior=True)
|
|
287
|
+
prior_mean = prior_function_dist.loc
|
|
288
|
+
L = self._cholesky_factor(prior_function_dist.lazy_covariance_matrix.add_jitter(self.jitter_val))
|
|
289
|
+
|
|
290
|
+
# Temporarily turn off noise that's added to the mean
|
|
291
|
+
orig_mean_init_std = self._variational_distribution.mean_init_std
|
|
292
|
+
self._variational_distribution.mean_init_std = 0.0
|
|
293
|
+
|
|
294
|
+
# Change the variational parameters to be whitened
|
|
295
|
+
variational_dist = self.variational_distribution
|
|
296
|
+
if isinstance(variational_dist, (MultivariateNormal, MultivariateQExponential)):
|
|
297
|
+
mean_diff = (variational_dist.loc - prior_mean).unsqueeze(-1).type(_linalg_dtype_cholesky.value())
|
|
298
|
+
whitened_mean = L.solve(mean_diff).squeeze(-1).to(variational_dist.loc.dtype)
|
|
299
|
+
covar_root = variational_dist.lazy_covariance_matrix.root_decomposition().root.to_dense()
|
|
300
|
+
covar_root = covar_root.type(_linalg_dtype_cholesky.value())
|
|
301
|
+
whitened_covar = RootLinearOperator(L.solve(covar_root).to(variational_dist.loc.dtype))
|
|
302
|
+
whitened_variational_distribution = variational_dist.__class__(whitened_mean, whitened_covar)
|
|
303
|
+
if isinstance(variational_dist, MultivariateQExponential): whitened_variational_distribution.power = variational_dist.power
|
|
304
|
+
self._variational_distribution.initialize_variational_distribution(
|
|
305
|
+
whitened_variational_distribution
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Reset the random noise parameter of the model
|
|
309
|
+
self._variational_distribution.mean_init_std = orig_mean_init_std
|
|
310
|
+
|
|
311
|
+
# Reset the cache
|
|
312
|
+
clear_cache_hook(self)
|
|
313
|
+
|
|
314
|
+
# Mark that we have updated the variational strategy
|
|
315
|
+
self.updated_strategy.fill_(True)
|
|
316
|
+
|
|
317
|
+
return super().__call__(x, prior=prior, **kwargs)
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from linear_operator.operators import CholLinearOperator, TriangularLinearOperator
|
|
7
|
+
from linear_operator.utils.cholesky import psd_safe_cholesky
|
|
8
|
+
|
|
9
|
+
from ..distributions import MultivariateNormal, MultivariateQExponential
|
|
10
|
+
from ._variational_distribution import _VariationalDistribution
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _NaturalVariationalDistribution(_VariationalDistribution, abc.ABC):
|
|
14
|
+
r"""Any :obj:`~qpytorch.variational._VariationalDistribution` which calculates
|
|
15
|
+
natural gradients with respect to its parameters.
|
|
16
|
+
"""
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class NaturalVariationalDistribution(_NaturalVariationalDistribution):
|
|
21
|
+
r"""A multivariate normal :obj:`~qpytorch.variational._VariationalDistribution`,
|
|
22
|
+
parameterized by **natural** parameters.
|
|
23
|
+
|
|
24
|
+
.. note::
|
|
25
|
+
The :obj:`~qpytorch.variational.NaturalVariationalDistribution` can only
|
|
26
|
+
be used with :obj:`gpytorch.optim.NGD`, or other optimizers
|
|
27
|
+
that follow exactly the gradient direction. Failure to do so will cause
|
|
28
|
+
the natural matrix :math:`\mathbf \Theta_\text{mat}` to stop being
|
|
29
|
+
positive definite, and a :obj:`~RuntimeError` will be raised.
|
|
30
|
+
|
|
31
|
+
.. seealso::
|
|
32
|
+
The `natural gradient descent tutorial
|
|
33
|
+
<examples/04_Variational_and_Approximate_GPs/Natural_Gradient_Descent.ipynb>`_
|
|
34
|
+
for use instructions.
|
|
35
|
+
|
|
36
|
+
The :obj:`~qpytorch.variational.TrilNaturalVariationalDistribution` for
|
|
37
|
+
a more numerically stable parameterization, at the cost of needing more
|
|
38
|
+
iterations to make variational regression converge.
|
|
39
|
+
|
|
40
|
+
:param int num_inducing_points: Size of the variational distribution. This implies that the variational mean
|
|
41
|
+
should be this size, and the variational covariance matrix should have this many rows and columns.
|
|
42
|
+
:param batch_shape: Specifies an optional batch size
|
|
43
|
+
for the variational parameters. This is useful for example when doing additive variational inference.
|
|
44
|
+
:type batch_shape: :obj:`torch.Size`, optional
|
|
45
|
+
:param float mean_init_std: (Default: 1e-3) Standard deviation of gaussian (q-exponential) noise to add to the mean initialization.
|
|
46
|
+
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, num_inducing_points, batch_shape=torch.Size([]), mean_init_std=1e-3, **kwargs):
|
|
50
|
+
super().__init__(num_inducing_points=num_inducing_points, batch_shape=batch_shape, mean_init_std=mean_init_std)
|
|
51
|
+
scaled_mean_init = torch.zeros(num_inducing_points)
|
|
52
|
+
neg_prec_init = torch.eye(num_inducing_points, num_inducing_points).mul(-0.5)
|
|
53
|
+
scaled_mean_init = scaled_mean_init.repeat(*batch_shape, 1)
|
|
54
|
+
neg_prec_init = neg_prec_init.repeat(*batch_shape, 1, 1)
|
|
55
|
+
|
|
56
|
+
# eta1 and eta2 parameterization of the variational distribution
|
|
57
|
+
self.register_parameter(name="natural_vec", parameter=torch.nn.Parameter(scaled_mean_init))
|
|
58
|
+
self.register_parameter(name="natural_mat", parameter=torch.nn.Parameter(neg_prec_init))
|
|
59
|
+
|
|
60
|
+
if 'power' in kwargs: self.power = kwargs.pop('power')
|
|
61
|
+
|
|
62
|
+
def forward(self):
|
|
63
|
+
mean, chol_covar = _NaturalToMuVarSqrt.apply(self.natural_vec, self.natural_mat)
|
|
64
|
+
covar = CholLinearOperator(TriangularLinearOperator(chol_covar))
|
|
65
|
+
if not hasattr(self, 'power'):
|
|
66
|
+
res = MultivariateNormal(mean, covar)
|
|
67
|
+
else:
|
|
68
|
+
res = MultivariateQExponential(mean, covar, power=self.power)
|
|
69
|
+
return res
|
|
70
|
+
|
|
71
|
+
def initialize_variational_distribution(self, prior_dist):
|
|
72
|
+
prior_prec = prior_dist.covariance_matrix.inverse()
|
|
73
|
+
prior_mean = prior_dist.mean
|
|
74
|
+
noise = torch.randn_like(prior_mean).mul_(self.mean_init_std)
|
|
75
|
+
|
|
76
|
+
self.natural_vec.data.copy_((prior_prec @ prior_mean.unsqueeze(-1)).squeeze(-1).add_(noise))
|
|
77
|
+
self.natural_mat.data.copy_(prior_prec.mul(-0.5))
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _triangular_inverse(A, upper=False):
|
|
81
|
+
eye = torch.eye(A.size(-1), dtype=A.dtype, device=A.device)
|
|
82
|
+
return torch.linalg.solve_triangular(A, eye, upper=upper)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _phi_for_cholesky_(A):
|
|
86
|
+
"Modifies A to be the phi function used in differentiating through Cholesky"
|
|
87
|
+
A.tril_().diagonal(offset=0, dim1=-2, dim2=-1).mul_(0.5)
|
|
88
|
+
return A
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _cholesky_backward(dout_dL, L, L_inverse):
|
|
92
|
+
# c.f. https://github.com/pytorch/pytorch/blob/25ba802ce4cbdeaebcad4a03cec8502f0de9b7b3/
|
|
93
|
+
# tools/autograd/templates/Functions.cpp
|
|
94
|
+
A = L.transpose(-1, -2) @ dout_dL
|
|
95
|
+
phi = _phi_for_cholesky_(A)
|
|
96
|
+
grad_input = (L_inverse.transpose(-1, -2) @ phi) @ L_inverse
|
|
97
|
+
# Symmetrize gradient
|
|
98
|
+
return grad_input.add(grad_input.transpose(-1, -2)).mul_(0.5)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class _NaturalToMuVarSqrt(torch.autograd.Function):
|
|
102
|
+
@staticmethod
|
|
103
|
+
def _forward(nat_mean, nat_covar):
|
|
104
|
+
try:
|
|
105
|
+
L_inv = psd_safe_cholesky(-2.0 * nat_covar, upper=False)
|
|
106
|
+
except RuntimeError as e:
|
|
107
|
+
if str(e).startswith("cholesky"):
|
|
108
|
+
raise RuntimeError(
|
|
109
|
+
"Non-negative-definite natural covariance. You probably "
|
|
110
|
+
"updated it using an optimizer other than gpytorch.optim.NGD (such as Adam). "
|
|
111
|
+
"This is not supported."
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
raise e
|
|
115
|
+
L = _triangular_inverse(L_inv, upper=False)
|
|
116
|
+
S = L.transpose(-1, -2) @ L
|
|
117
|
+
mu = (S @ nat_mean.unsqueeze(-1)).squeeze(-1)
|
|
118
|
+
# Two choleskys are annoying, but we don't have good support for a
|
|
119
|
+
# LinearOperator of form L.T @ L
|
|
120
|
+
return mu, psd_safe_cholesky(S, upper=False)
|
|
121
|
+
|
|
122
|
+
@staticmethod
|
|
123
|
+
def forward(ctx, nat_mean, nat_covar):
|
|
124
|
+
mu, L = _NaturalToMuVarSqrt._forward(nat_mean, nat_covar)
|
|
125
|
+
ctx.save_for_backward(mu, L)
|
|
126
|
+
return mu, L
|
|
127
|
+
|
|
128
|
+
@staticmethod
|
|
129
|
+
def _backward(dout_dmu, dout_dL, mu, L, C):
|
|
130
|
+
"""Calculate dout/d(eta1, eta2), which are:
|
|
131
|
+
eta1 = mu
|
|
132
|
+
eta2 = mu*mu^T + LL^T = mu*mu^T + Sigma
|
|
133
|
+
|
|
134
|
+
Thus:
|
|
135
|
+
dout/deta1 = dout/dmu + dout/dL dL/deta1
|
|
136
|
+
dout/deta2 = dout/dL dL/deta1
|
|
137
|
+
|
|
138
|
+
For L = chol(eta2 - eta1*eta1^T).
|
|
139
|
+
dout/dSigma = _cholesky_backward(dout/dL, L)
|
|
140
|
+
dout/deta2 = dout/dSigma
|
|
141
|
+
dSigma/deta1 = -2* (dout/dSigma) mu
|
|
142
|
+
"""
|
|
143
|
+
dout_dSigma = _cholesky_backward(dout_dL, L, C)
|
|
144
|
+
dout_deta1 = dout_dmu - 2 * (dout_dSigma @ mu.unsqueeze(-1)).squeeze(-1)
|
|
145
|
+
return dout_deta1, dout_dSigma
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def backward(ctx, dout_dmu, dout_dL):
|
|
149
|
+
"Calculates the natural gradient with respect to nat_mean, nat_covar"
|
|
150
|
+
mu, L = ctx.saved_tensors
|
|
151
|
+
C = _triangular_inverse(L, upper=False)
|
|
152
|
+
return _NaturalToMuVarSqrt._backward(dout_dmu, dout_dL, mu, L, C)
|