qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from typing import Tuple, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from linear_operator.operators import CholLinearOperator, TriangularLinearOperator
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from torch.autograd.function import FunctionCtx
|
|
9
|
+
|
|
10
|
+
from ..distributions import Distribution, MultivariateNormal, MultivariateQExponential
|
|
11
|
+
from .natural_variational_distribution import (
|
|
12
|
+
_NaturalToMuVarSqrt,
|
|
13
|
+
_NaturalVariationalDistribution,
|
|
14
|
+
_phi_for_cholesky_,
|
|
15
|
+
_triangular_inverse,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TrilNaturalVariationalDistribution(_NaturalVariationalDistribution):
|
|
20
|
+
r"""A multivariate normal :obj:`~qpytorch.variational._VariationalDistribution`,
|
|
21
|
+
parameterized by the natural vector, and a triangular decomposition of the
|
|
22
|
+
natural matrix (which is not the Cholesky).
|
|
23
|
+
|
|
24
|
+
.. note::
|
|
25
|
+
The :obj:`~qpytorch.variational.TrilNaturalVariationalDistribution` should only
|
|
26
|
+
be used with :obj:`gpytorch.optim.NGD`, or other optimizers
|
|
27
|
+
that follow exactly the gradient direction.
|
|
28
|
+
|
|
29
|
+
.. seealso::
|
|
30
|
+
The `natural gradient descent tutorial
|
|
31
|
+
<examples/04_Variational_and_Approximate_GPs/Natural_Gradient_Descent.ipynb>`_
|
|
32
|
+
for use instructions.
|
|
33
|
+
|
|
34
|
+
The :obj:`~qpytorch.variational.NaturalVariationalDistribution`, which
|
|
35
|
+
needs less iterations to make variational regression converge, at the
|
|
36
|
+
cost of introducing numerical instability.
|
|
37
|
+
|
|
38
|
+
.. note::
|
|
39
|
+
The relationship of the parameter :math:`\mathbf \Theta_\text{tril_mat}`
|
|
40
|
+
to the natural parameter :math:`\mathbf \Theta_\text{mat}` from
|
|
41
|
+
:obj:`~qpytorch.variational.NaturalVariationalDistribution` is
|
|
42
|
+
:math:`\mathbf \Theta_\text{mat} = -1/2 {\mathbf \Theta_\text{tril_mat}}^T {\mathbf \Theta_\text{tril_mat}}`.
|
|
43
|
+
Note that this is not the form of the Cholesky decomposition of :math:`\boldsymbol \Theta_\text{mat}`.
|
|
44
|
+
|
|
45
|
+
:param int num_inducing_points: Size of the variational distribution. This implies that the variational mean
|
|
46
|
+
should be this size, and the variational covariance matrix should have this many rows and columns.
|
|
47
|
+
:param batch_shape: Specifies an optional batch size
|
|
48
|
+
for the variational parameters. This is useful for example when doing additive variational inference.
|
|
49
|
+
:type batch_shape: :obj:`torch.Size`, optional
|
|
50
|
+
:param float mean_init_std: (Default: 1e-3) Standard deviation of gaussian (q-exponential) noise to add to the mean initialization.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self, num_inducing_points: int, batch_shape: torch.Size = torch.Size([]), mean_init_std: float = 1e-3, **kwargs):
|
|
54
|
+
super().__init__(num_inducing_points=num_inducing_points, batch_shape=batch_shape, mean_init_std=mean_init_std)
|
|
55
|
+
scaled_mean_init = torch.zeros(num_inducing_points)
|
|
56
|
+
neg_prec_init = torch.eye(num_inducing_points, num_inducing_points)
|
|
57
|
+
scaled_mean_init = scaled_mean_init.repeat(*batch_shape, 1)
|
|
58
|
+
neg_prec_init = neg_prec_init.repeat(*batch_shape, 1, 1)
|
|
59
|
+
|
|
60
|
+
# eta1 and tril_dec(eta2) parameterization of the variational distribution
|
|
61
|
+
self.register_parameter(name="natural_vec", parameter=torch.nn.Parameter(scaled_mean_init))
|
|
62
|
+
self.register_parameter(name="natural_tril_mat", parameter=torch.nn.Parameter(neg_prec_init))
|
|
63
|
+
|
|
64
|
+
if 'power' in kwargs: self.power = kwargs.pop('power')
|
|
65
|
+
|
|
66
|
+
def forward(self) -> Distribution:
|
|
67
|
+
mean, chol_covar = _TrilNaturalToMuVarSqrt.apply(self.natural_vec, self.natural_tril_mat)
|
|
68
|
+
covar = CholLinearOperator(TriangularLinearOperator(chol_covar))
|
|
69
|
+
if not hasattr(self, 'power'):
|
|
70
|
+
return MultivariateNormal(mean, covar)
|
|
71
|
+
else:
|
|
72
|
+
return MultivariateQExponential(mean, covar, power=self.power)
|
|
73
|
+
|
|
74
|
+
def initialize_variational_distribution(self, prior_dist: Union[MultivariateNormal, MultivariateQExponential]) -> None:
|
|
75
|
+
prior_cov = prior_dist.lazy_covariance_matrix
|
|
76
|
+
chol = prior_cov.cholesky().to_dense()
|
|
77
|
+
tril_mat = _triangular_inverse(chol, upper=False)
|
|
78
|
+
|
|
79
|
+
natural_vec = prior_cov.solve(prior_dist.mean.unsqueeze(-1)).squeeze(-1)
|
|
80
|
+
noise = torch.randn_like(natural_vec).mul_(self.mean_init_std)
|
|
81
|
+
|
|
82
|
+
self.natural_vec.data.copy_(natural_vec.add_(noise))
|
|
83
|
+
self.natural_tril_mat.data.copy_(tril_mat)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class _TrilNaturalToMuVarSqrt(torch.autograd.Function):
|
|
87
|
+
@staticmethod
|
|
88
|
+
def _forward(nat_mean: Tensor, tril_nat_covar: Tensor) -> Tuple[Tensor, Tensor]:
|
|
89
|
+
L = _triangular_inverse(tril_nat_covar, upper=False)
|
|
90
|
+
mu = L @ (L.transpose(-1, -2) @ nat_mean.unsqueeze(-1))
|
|
91
|
+
return mu.squeeze(-1), L
|
|
92
|
+
# return nat_mean, L
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def forward(ctx: FunctionCtx, nat_mean: Tensor, tril_nat_covar: Tensor) -> Tuple[Tensor, Tensor]:
|
|
96
|
+
mu, L = _TrilNaturalToMuVarSqrt._forward(nat_mean, tril_nat_covar)
|
|
97
|
+
ctx.save_for_backward(mu, L, tril_nat_covar)
|
|
98
|
+
return mu, L
|
|
99
|
+
|
|
100
|
+
@staticmethod
|
|
101
|
+
def backward(ctx: FunctionCtx, dout_dmu: Tensor, dout_dL: Tensor) -> Tuple[Tensor, Tensor]:
|
|
102
|
+
mu, L, C = ctx.saved_tensors
|
|
103
|
+
dout_dnat1, dout_dnat2 = _NaturalToMuVarSqrt._backward(dout_dmu, dout_dL, mu, L, C)
|
|
104
|
+
"""
|
|
105
|
+
Now we need to do the Jacobian-Vector Product for the transformation:
|
|
106
|
+
L = inv(chol(inv(-2 theta_cov)))
|
|
107
|
+
|
|
108
|
+
C^T C = -2 theta_cov
|
|
109
|
+
|
|
110
|
+
so we need to do forward differentiation, starting with sensitivity (sensitivities marked with .dots.)
|
|
111
|
+
.theta_cov. = dout_dnat2
|
|
112
|
+
|
|
113
|
+
and ending with sensitivity .C.
|
|
114
|
+
|
|
115
|
+
if B = inv(-2 theta_cov) then:
|
|
116
|
+
|
|
117
|
+
.B. = d inv(-2 theta_cov)/dtheta_cov * .theta_cov. = -B (-2 .theta_cov.) B
|
|
118
|
+
|
|
119
|
+
if L = chol(B), B = LL^T then (https://homepages.inf.ed.ac.uk/imurray2/pub/16choldiff/choldiff.pdf):
|
|
120
|
+
|
|
121
|
+
.L. = L phi(L^{-1} .B. (L^{-1})^T) = L phi(2 L^T .theta_cov. L)
|
|
122
|
+
|
|
123
|
+
Then C = inv(L), so
|
|
124
|
+
|
|
125
|
+
.C. = -C .L. C = phi(-2 L^T .theta_cov. L)C
|
|
126
|
+
"""
|
|
127
|
+
A = L.transpose(-2, -1) @ dout_dnat2 @ L
|
|
128
|
+
phi = _phi_for_cholesky_(A.mul_(-2))
|
|
129
|
+
dout_dtril = phi @ C
|
|
130
|
+
return dout_dnat1, dout_dtril
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from linear_operator.operators import RootLinearOperator
|
|
7
|
+
|
|
8
|
+
from ..distributions import MultitaskMultivariateQExponential, MultivariateQExponential
|
|
9
|
+
from ..module import Module
|
|
10
|
+
from ._variational_strategy import _VariationalStrategy
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class UncorrelatedMultitaskVariationalStrategy(_VariationalStrategy):
|
|
14
|
+
"""
|
|
15
|
+
UncorrelatedMultitaskVariationalStrategy wraps an existing
|
|
16
|
+
:obj:`~qpytorch.variational.VariationalStrategy` to produce vector-valued (multi-task)
|
|
17
|
+
output distributions. Each task will be uncorrelated to one another.
|
|
18
|
+
|
|
19
|
+
The output will either be a :obj:`~qpytorch.distributions.MultitaskMultivariateQExponential` distribution
|
|
20
|
+
(if we wish to evaluate all tasks for each input) or a :obj:`~qpytorch.distributions.MultivariateQExponential`
|
|
21
|
+
(if we wish to evaluate a single task for each input).
|
|
22
|
+
|
|
23
|
+
The base variational strategy is assumed to operate on a batch of QEPs. One of the batch
|
|
24
|
+
dimensions corresponds to the multiple tasks.
|
|
25
|
+
|
|
26
|
+
:param ~qpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy
|
|
27
|
+
:param int num_tasks: Number of tasks. Should correspond to the batch size of task_dim.
|
|
28
|
+
:param int task_dim: (Default: -1) Which batch dimension is the task dimension
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, base_variational_strategy, num_tasks, task_dim=-1):
|
|
32
|
+
Module.__init__(self)
|
|
33
|
+
self.base_variational_strategy = base_variational_strategy
|
|
34
|
+
self.task_dim = task_dim
|
|
35
|
+
self.num_tasks = num_tasks
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def prior_distribution(self):
|
|
39
|
+
return self.base_variational_strategy.prior_distribution
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def variational_distribution(self):
|
|
43
|
+
return self.base_variational_strategy.variational_distribution
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def variational_params_initialized(self):
|
|
47
|
+
return self.base_variational_strategy.variational_params_initialized
|
|
48
|
+
|
|
49
|
+
def kl_divergence(self):
|
|
50
|
+
return super().kl_divergence().sum(dim=-1)
|
|
51
|
+
|
|
52
|
+
def __call__(self, x, task_indices=None, prior=False, **kwargs):
|
|
53
|
+
r"""
|
|
54
|
+
See :class:`LMCVariationalStrategy`.
|
|
55
|
+
"""
|
|
56
|
+
function_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
|
|
57
|
+
|
|
58
|
+
if task_indices is None:
|
|
59
|
+
# Every data point will get an output for each task
|
|
60
|
+
if (
|
|
61
|
+
self.task_dim > 0
|
|
62
|
+
and self.task_dim > len(function_dist.batch_shape)
|
|
63
|
+
or self.task_dim < 0
|
|
64
|
+
and self.task_dim + len(function_dist.batch_shape) < 0
|
|
65
|
+
):
|
|
66
|
+
return MultitaskMultivariateQExponential.from_repeated_qep(function_dist, num_tasks=self.num_tasks)
|
|
67
|
+
else:
|
|
68
|
+
function_dist = MultitaskMultivariateQExponential.from_batch_qep(function_dist, task_dim=self.task_dim)
|
|
69
|
+
assert function_dist.event_shape[-1] == self.num_tasks
|
|
70
|
+
return function_dist
|
|
71
|
+
|
|
72
|
+
else:
|
|
73
|
+
# Each data point will get a single output corresponding to a single task
|
|
74
|
+
|
|
75
|
+
if self.task_dim > 0:
|
|
76
|
+
raise RuntimeError(f"task_dim must be a negative indexed batch dimension: got {self.task_dim}.")
|
|
77
|
+
num_batch = len(function_dist.batch_shape)
|
|
78
|
+
task_dim = num_batch + self.task_dim
|
|
79
|
+
|
|
80
|
+
# Create a mask to choose specific task assignment
|
|
81
|
+
shape = list(function_dist.batch_shape + function_dist.event_shape)
|
|
82
|
+
shape[task_dim] = 1
|
|
83
|
+
task_indices = task_indices.expand(shape).squeeze(task_dim)
|
|
84
|
+
|
|
85
|
+
# Create a mask to choose specific task assignment
|
|
86
|
+
task_mask = torch.nn.functional.one_hot(task_indices, num_classes=self.num_tasks)
|
|
87
|
+
task_mask = task_mask.permute(*range(0, task_dim), *range(task_dim + 1, num_batch + 1), task_dim)
|
|
88
|
+
|
|
89
|
+
mean = (function_dist.mean * task_mask).sum(task_dim)
|
|
90
|
+
covar = (function_dist.lazy_covariance_matrix * RootLinearOperator(task_mask[..., None])).sum(task_dim)
|
|
91
|
+
return MultivariateQExponential(mean, covar, power=function_dist.power)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class MultitaskVariationalStrategy(UncorrelatedMultitaskVariationalStrategy):
|
|
95
|
+
"""
|
|
96
|
+
UncorrelatedMultitaskVariationalStrategy wraps an existing
|
|
97
|
+
:obj:`~qpytorch.variational.VariationalStrategy`
|
|
98
|
+
to produce a :obj:`~qpytorch.variational.MultitaskMultivariateQExponential` distribution.
|
|
99
|
+
All outputs will be uncorrelated to one another.
|
|
100
|
+
|
|
101
|
+
The base variational strategy is assumed to operate on a batch of QEPs. One of the batch
|
|
102
|
+
dimensions corresponds to the multiple tasks.
|
|
103
|
+
|
|
104
|
+
:param ~qpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy
|
|
105
|
+
:param int num_tasks: Number of tasks. Should correspond to the batch size of task_dim.
|
|
106
|
+
:param int task_dim: (Default: -1) Which batch dimension is the task dimension
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(self, base_variational_strategy, num_tasks, task_dim=-1):
|
|
110
|
+
warnings.warn(
|
|
111
|
+
"MultitaskVariationalStrategy has been renamed to UncorrelatedMultitaskVariationalStrategy",
|
|
112
|
+
DeprecationWarning,
|
|
113
|
+
)
|
|
114
|
+
super().__init__(base_variational_strategy, num_tasks, task_dim=-1)
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from linear_operator import to_dense
|
|
8
|
+
from linear_operator.operators import (
|
|
9
|
+
CholLinearOperator,
|
|
10
|
+
DiagLinearOperator,
|
|
11
|
+
LinearOperator,
|
|
12
|
+
PsdSumLinearOperator,
|
|
13
|
+
RootLinearOperator,
|
|
14
|
+
TriangularLinearOperator,
|
|
15
|
+
ZeroLinearOperator,
|
|
16
|
+
)
|
|
17
|
+
from linear_operator.utils.cholesky import psd_safe_cholesky
|
|
18
|
+
from linear_operator.utils.errors import NotPSDError
|
|
19
|
+
from torch import Tensor
|
|
20
|
+
|
|
21
|
+
from .. import settings
|
|
22
|
+
from ..distributions import MultivariateNormal, MultivariateQExponential
|
|
23
|
+
from gpytorch.utils.memoize import add_to_cache, cached
|
|
24
|
+
from ._variational_strategy import _VariationalStrategy
|
|
25
|
+
from .cholesky_variational_distribution import CholeskyVariationalDistribution
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class UnwhitenedVariationalStrategy(_VariationalStrategy):
|
|
29
|
+
r"""
|
|
30
|
+
Similar to :obj:`~qpytorch.variational.VariationalStrategy`, but does not perform the
|
|
31
|
+
whitening operation. In almost all cases :obj:`~qpytorch.variational.VariationalStrategy`
|
|
32
|
+
is preferable, with a few exceptions:
|
|
33
|
+
|
|
34
|
+
- When the inducing points are exactly equal to the training points (i.e. :math:`\mathbf Z = \mathbf X`).
|
|
35
|
+
Unwhitened models are faster in this case.
|
|
36
|
+
|
|
37
|
+
- When the number of inducing points is very large (e.g. >2000). Unwhitened models can use CG for faster
|
|
38
|
+
computation.
|
|
39
|
+
|
|
40
|
+
:param ~model: Model this strategy is applied to.
|
|
41
|
+
Typically passed in when the VariationalStrategy is created in the
|
|
42
|
+
__init__ method of the user defined model.
|
|
43
|
+
It should contain power if Q-Exponential distribution is involved in.
|
|
44
|
+
:param inducing_points: Tensor containing a set of inducing
|
|
45
|
+
points to use for variational inference.
|
|
46
|
+
:param variational_distribution: A
|
|
47
|
+
VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
|
|
48
|
+
:param learn_inducing_locations: (default True): Whether or not
|
|
49
|
+
the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they
|
|
50
|
+
parameters of the model).
|
|
51
|
+
:param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability
|
|
52
|
+
"""
|
|
53
|
+
has_fantasy_strategy = True
|
|
54
|
+
|
|
55
|
+
@cached(name="cholesky_factor", ignore_args=True)
|
|
56
|
+
def _cholesky_factor(self, induc_induc_covar: LinearOperator) -> TriangularLinearOperator:
|
|
57
|
+
# Maybe used - if we're not using CG
|
|
58
|
+
L = psd_safe_cholesky(to_dense(induc_induc_covar))
|
|
59
|
+
return TriangularLinearOperator(L)
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
@cached(name="prior_distribution_memo")
|
|
63
|
+
def prior_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
64
|
+
out = self.model.forward(self.inducing_points)
|
|
65
|
+
if hasattr(self.model, 'power'):
|
|
66
|
+
res = MultivariateQExponential(out.mean, out.lazy_covariance_matrix.add_jitter(), power=self.model.power)
|
|
67
|
+
else:
|
|
68
|
+
res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter())
|
|
69
|
+
return res
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
@cached(name="pseudo_points_memo")
|
|
73
|
+
def pseudo_points(self) -> Tuple[Tensor, Tensor]:
|
|
74
|
+
# TODO: implement for other distributions
|
|
75
|
+
# retrieve the variational mean, m and covariance matrix, S.
|
|
76
|
+
if not isinstance(self._variational_distribution, CholeskyVariationalDistribution):
|
|
77
|
+
raise NotImplementedError(
|
|
78
|
+
"Only CholeskyVariationalDistribution has pseudo-point support currently, ",
|
|
79
|
+
"but your _variational_distribution is a ",
|
|
80
|
+
self._variational_distribution.__name__,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# retrieve the variational mean, m and covariance matrix, S.
|
|
84
|
+
var_cov_root = TriangularLinearOperator(self._variational_distribution.chol_variational_covar)
|
|
85
|
+
var_cov = CholLinearOperator(var_cov_root)
|
|
86
|
+
var_mean = self.variational_distribution.mean # .unsqueeze(-1)
|
|
87
|
+
if var_mean.shape[-1] != 1:
|
|
88
|
+
var_mean = var_mean.unsqueeze(-1)
|
|
89
|
+
|
|
90
|
+
# R = K - S
|
|
91
|
+
Kmm = self.model.covar_module(self.inducing_points)
|
|
92
|
+
res = Kmm - var_cov
|
|
93
|
+
|
|
94
|
+
cov_diff = res
|
|
95
|
+
|
|
96
|
+
# D_a = (S^{-1} - K^{-1})^{-1} = S + S R^{-1} S
|
|
97
|
+
# note that in the whitened case R = I - S, unwhitened R = K - S
|
|
98
|
+
# we compute (R R^{T})^{-1} R^T S for stability reasons as R is probably not PSD.
|
|
99
|
+
eval_lhs = var_cov.to_dense()
|
|
100
|
+
eval_rhs = cov_diff.transpose(-1, -2).matmul(eval_lhs)
|
|
101
|
+
inner_term = cov_diff.matmul(cov_diff.transpose(-1, -2))
|
|
102
|
+
# TODO: flag the jitter here
|
|
103
|
+
inner_solve = inner_term.add_jitter(self.jitter_val).solve(eval_rhs, eval_lhs.transpose(-1, -2))
|
|
104
|
+
inducing_covar = var_cov + inner_solve
|
|
105
|
+
|
|
106
|
+
# mean term: D_a S^{-1} m
|
|
107
|
+
# unwhitened: (S - S R^{-1} S) S^{-1} m = (I - S R^{-1}) m
|
|
108
|
+
rhs = cov_diff.transpose(-1, -2).matmul(var_mean)
|
|
109
|
+
inner_rhs_mean_solve = inner_term.add_jitter(self.jitter_val).solve(rhs)
|
|
110
|
+
pseudo_target_mean = var_mean + var_cov.matmul(inner_rhs_mean_solve)
|
|
111
|
+
|
|
112
|
+
# ensure inducing covar is psd
|
|
113
|
+
try:
|
|
114
|
+
pseudo_target_covar = CholLinearOperator(inducing_covar.add_jitter(self.jitter_val).cholesky()).to_dense()
|
|
115
|
+
except NotPSDError:
|
|
116
|
+
from linear_operator.operators import DiagLinearOperator
|
|
117
|
+
|
|
118
|
+
evals, evecs = torch.linalg.eigh(inducing_covar)
|
|
119
|
+
pseudo_target_covar = (
|
|
120
|
+
evecs.matmul(DiagLinearOperator(evals + self.jitter_val)).matmul(evecs.transpose(-1, -2)).to_dense()
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
return pseudo_target_covar, pseudo_target_mean
|
|
124
|
+
|
|
125
|
+
def forward(
|
|
126
|
+
self,
|
|
127
|
+
x: Tensor,
|
|
128
|
+
inducing_points: Tensor,
|
|
129
|
+
inducing_values: Tensor,
|
|
130
|
+
variational_inducing_covar: Optional[LinearOperator] = None,
|
|
131
|
+
**kwargs,
|
|
132
|
+
) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
133
|
+
# If our points equal the inducing points, we're done
|
|
134
|
+
if torch.equal(x, inducing_points):
|
|
135
|
+
if variational_inducing_covar is None:
|
|
136
|
+
raise RuntimeError
|
|
137
|
+
else:
|
|
138
|
+
if hasattr(self.model, 'power'):
|
|
139
|
+
return MultivariateQExponential(inducing_values, variational_inducing_covar, power=self.model.power)
|
|
140
|
+
else:
|
|
141
|
+
return MultivariateNormal(inducing_values, variational_inducing_covar)
|
|
142
|
+
|
|
143
|
+
# Otherwise, we have to marginalize
|
|
144
|
+
num_induc = inducing_points.size(-2)
|
|
145
|
+
full_inputs = torch.cat([inducing_points, x], dim=-2)
|
|
146
|
+
full_output = self.model.forward(full_inputs)
|
|
147
|
+
full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix
|
|
148
|
+
|
|
149
|
+
# Mean terms
|
|
150
|
+
test_mean = full_mean[..., num_induc:]
|
|
151
|
+
induc_mean = full_mean[..., :num_induc]
|
|
152
|
+
mean_diff = (inducing_values - induc_mean).unsqueeze(-1)
|
|
153
|
+
|
|
154
|
+
# Covariance terms
|
|
155
|
+
induc_induc_covar = full_covar[..., :num_induc, :num_induc].add_jitter(self.jitter_val)
|
|
156
|
+
induc_data_covar = full_covar[..., :num_induc, num_induc:].to_dense()
|
|
157
|
+
data_data_covar = full_covar[..., num_induc:, num_induc:]
|
|
158
|
+
|
|
159
|
+
# Compute Cholesky factorization of inducing covariance matrix
|
|
160
|
+
if settings.fast_computations.log_prob.off() or (num_induc <= settings.max_cholesky_size.value()):
|
|
161
|
+
induc_induc_covar = CholLinearOperator(self._cholesky_factor(induc_induc_covar))
|
|
162
|
+
|
|
163
|
+
# If we are making predictions and don't need variances, we can do things very quickly.
|
|
164
|
+
if not self.training and settings.skip_posterior_variances.on():
|
|
165
|
+
self._mean_cache = induc_induc_covar.solve(mean_diff).detach()
|
|
166
|
+
predictive_mean = torch.add(
|
|
167
|
+
test_mean, induc_data_covar.transpose(-2, -1).matmul(self._mean_cache).squeeze(-1)
|
|
168
|
+
)
|
|
169
|
+
predictive_covar = ZeroLinearOperator(test_mean.size(-1), test_mean.size(-1))
|
|
170
|
+
if hasattr(self.model, 'power'):
|
|
171
|
+
return MultivariateQExponential(predictive_mean, predictive_covar, power=self.model.power)
|
|
172
|
+
else:
|
|
173
|
+
return MultivariateNormal(predictive_mean, predictive_covar)
|
|
174
|
+
|
|
175
|
+
# Expand everything to the right size
|
|
176
|
+
shapes = [mean_diff.shape[:-1], induc_data_covar.shape[:-1], induc_induc_covar.shape[:-1]]
|
|
177
|
+
root_variational_covar = None
|
|
178
|
+
if variational_inducing_covar is not None:
|
|
179
|
+
root_variational_covar = variational_inducing_covar.root_decomposition().root.to_dense()
|
|
180
|
+
shapes.append(root_variational_covar.shape[:-1])
|
|
181
|
+
shape = torch.broadcast_shapes(*shapes)
|
|
182
|
+
mean_diff = mean_diff.expand(*shape, mean_diff.size(-1))
|
|
183
|
+
induc_data_covar = induc_data_covar.expand(*shape, induc_data_covar.size(-1))
|
|
184
|
+
induc_induc_covar = induc_induc_covar.expand(*shape, induc_induc_covar.size(-1))
|
|
185
|
+
if variational_inducing_covar is not None:
|
|
186
|
+
root_variational_covar = root_variational_covar.expand(*shape, root_variational_covar.size(-1))
|
|
187
|
+
|
|
188
|
+
# Cache the kernel matrix with the cached CG calls
|
|
189
|
+
if self.training:
|
|
190
|
+
if hasattr(self.model, 'power'):
|
|
191
|
+
prior_dist = MultivariateQExponential(induc_mean, induc_induc_covar, power=self.model.power)
|
|
192
|
+
else:
|
|
193
|
+
prior_dist = MultivariateNormal(induc_mean, induc_induc_covar)
|
|
194
|
+
add_to_cache(self, "prior_distribution_memo", prior_dist)
|
|
195
|
+
|
|
196
|
+
# Compute predictive mean
|
|
197
|
+
if variational_inducing_covar is None:
|
|
198
|
+
left_tensors = mean_diff
|
|
199
|
+
else:
|
|
200
|
+
left_tensors = torch.cat([mean_diff, root_variational_covar], -1)
|
|
201
|
+
inv_products = induc_induc_covar.solve(induc_data_covar, left_tensors.transpose(-1, -2))
|
|
202
|
+
predictive_mean = torch.add(test_mean, inv_products[..., 0, :])
|
|
203
|
+
|
|
204
|
+
# Compute covariance
|
|
205
|
+
if self.training:
|
|
206
|
+
interp_data_data_var, _ = induc_induc_covar.inv_quad_logdet(
|
|
207
|
+
induc_data_covar, logdet=False, reduce_inv_quad=False
|
|
208
|
+
)
|
|
209
|
+
data_covariance = DiagLinearOperator(
|
|
210
|
+
(data_data_covar.diagonal(dim1=-1, dim2=-2) - interp_data_data_var).clamp(0, math.inf)
|
|
211
|
+
)
|
|
212
|
+
else:
|
|
213
|
+
neg_induc_data_data_covar = torch.matmul(
|
|
214
|
+
induc_data_covar.transpose(-1, -2).mul(-1), induc_induc_covar.solve(induc_data_covar)
|
|
215
|
+
)
|
|
216
|
+
data_covariance = data_data_covar + neg_induc_data_data_covar
|
|
217
|
+
predictive_covar = PsdSumLinearOperator(
|
|
218
|
+
RootLinearOperator(inv_products[..., 1:, :].transpose(-1, -2)), data_covariance
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Done!
|
|
222
|
+
if hasattr(self.model, 'power'):
|
|
223
|
+
return MultivariateQExponential(predictive_mean, predictive_covar, power=self.model.power)
|
|
224
|
+
else:
|
|
225
|
+
return MultivariateNormal(predictive_mean, predictive_covar)
|