qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,130 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Tuple, Union
4
+
5
+ import torch
6
+ from linear_operator.operators import CholLinearOperator, TriangularLinearOperator
7
+ from torch import Tensor
8
+ from torch.autograd.function import FunctionCtx
9
+
10
+ from ..distributions import Distribution, MultivariateNormal, MultivariateQExponential
11
+ from .natural_variational_distribution import (
12
+ _NaturalToMuVarSqrt,
13
+ _NaturalVariationalDistribution,
14
+ _phi_for_cholesky_,
15
+ _triangular_inverse,
16
+ )
17
+
18
+
19
+ class TrilNaturalVariationalDistribution(_NaturalVariationalDistribution):
20
+ r"""A multivariate normal :obj:`~qpytorch.variational._VariationalDistribution`,
21
+ parameterized by the natural vector, and a triangular decomposition of the
22
+ natural matrix (which is not the Cholesky).
23
+
24
+ .. note::
25
+ The :obj:`~qpytorch.variational.TrilNaturalVariationalDistribution` should only
26
+ be used with :obj:`gpytorch.optim.NGD`, or other optimizers
27
+ that follow exactly the gradient direction.
28
+
29
+ .. seealso::
30
+ The `natural gradient descent tutorial
31
+ <examples/04_Variational_and_Approximate_GPs/Natural_Gradient_Descent.ipynb>`_
32
+ for use instructions.
33
+
34
+ The :obj:`~qpytorch.variational.NaturalVariationalDistribution`, which
35
+ needs less iterations to make variational regression converge, at the
36
+ cost of introducing numerical instability.
37
+
38
+ .. note::
39
+ The relationship of the parameter :math:`\mathbf \Theta_\text{tril_mat}`
40
+ to the natural parameter :math:`\mathbf \Theta_\text{mat}` from
41
+ :obj:`~qpytorch.variational.NaturalVariationalDistribution` is
42
+ :math:`\mathbf \Theta_\text{mat} = -1/2 {\mathbf \Theta_\text{tril_mat}}^T {\mathbf \Theta_\text{tril_mat}}`.
43
+ Note that this is not the form of the Cholesky decomposition of :math:`\boldsymbol \Theta_\text{mat}`.
44
+
45
+ :param int num_inducing_points: Size of the variational distribution. This implies that the variational mean
46
+ should be this size, and the variational covariance matrix should have this many rows and columns.
47
+ :param batch_shape: Specifies an optional batch size
48
+ for the variational parameters. This is useful for example when doing additive variational inference.
49
+ :type batch_shape: :obj:`torch.Size`, optional
50
+ :param float mean_init_std: (Default: 1e-3) Standard deviation of gaussian (q-exponential) noise to add to the mean initialization.
51
+ """
52
+
53
+ def __init__(self, num_inducing_points: int, batch_shape: torch.Size = torch.Size([]), mean_init_std: float = 1e-3, **kwargs):
54
+ super().__init__(num_inducing_points=num_inducing_points, batch_shape=batch_shape, mean_init_std=mean_init_std)
55
+ scaled_mean_init = torch.zeros(num_inducing_points)
56
+ neg_prec_init = torch.eye(num_inducing_points, num_inducing_points)
57
+ scaled_mean_init = scaled_mean_init.repeat(*batch_shape, 1)
58
+ neg_prec_init = neg_prec_init.repeat(*batch_shape, 1, 1)
59
+
60
+ # eta1 and tril_dec(eta2) parameterization of the variational distribution
61
+ self.register_parameter(name="natural_vec", parameter=torch.nn.Parameter(scaled_mean_init))
62
+ self.register_parameter(name="natural_tril_mat", parameter=torch.nn.Parameter(neg_prec_init))
63
+
64
+ if 'power' in kwargs: self.power = kwargs.pop('power')
65
+
66
+ def forward(self) -> Distribution:
67
+ mean, chol_covar = _TrilNaturalToMuVarSqrt.apply(self.natural_vec, self.natural_tril_mat)
68
+ covar = CholLinearOperator(TriangularLinearOperator(chol_covar))
69
+ if not hasattr(self, 'power'):
70
+ return MultivariateNormal(mean, covar)
71
+ else:
72
+ return MultivariateQExponential(mean, covar, power=self.power)
73
+
74
+ def initialize_variational_distribution(self, prior_dist: Union[MultivariateNormal, MultivariateQExponential]) -> None:
75
+ prior_cov = prior_dist.lazy_covariance_matrix
76
+ chol = prior_cov.cholesky().to_dense()
77
+ tril_mat = _triangular_inverse(chol, upper=False)
78
+
79
+ natural_vec = prior_cov.solve(prior_dist.mean.unsqueeze(-1)).squeeze(-1)
80
+ noise = torch.randn_like(natural_vec).mul_(self.mean_init_std)
81
+
82
+ self.natural_vec.data.copy_(natural_vec.add_(noise))
83
+ self.natural_tril_mat.data.copy_(tril_mat)
84
+
85
+
86
+ class _TrilNaturalToMuVarSqrt(torch.autograd.Function):
87
+ @staticmethod
88
+ def _forward(nat_mean: Tensor, tril_nat_covar: Tensor) -> Tuple[Tensor, Tensor]:
89
+ L = _triangular_inverse(tril_nat_covar, upper=False)
90
+ mu = L @ (L.transpose(-1, -2) @ nat_mean.unsqueeze(-1))
91
+ return mu.squeeze(-1), L
92
+ # return nat_mean, L
93
+
94
+ @staticmethod
95
+ def forward(ctx: FunctionCtx, nat_mean: Tensor, tril_nat_covar: Tensor) -> Tuple[Tensor, Tensor]:
96
+ mu, L = _TrilNaturalToMuVarSqrt._forward(nat_mean, tril_nat_covar)
97
+ ctx.save_for_backward(mu, L, tril_nat_covar)
98
+ return mu, L
99
+
100
+ @staticmethod
101
+ def backward(ctx: FunctionCtx, dout_dmu: Tensor, dout_dL: Tensor) -> Tuple[Tensor, Tensor]:
102
+ mu, L, C = ctx.saved_tensors
103
+ dout_dnat1, dout_dnat2 = _NaturalToMuVarSqrt._backward(dout_dmu, dout_dL, mu, L, C)
104
+ """
105
+ Now we need to do the Jacobian-Vector Product for the transformation:
106
+ L = inv(chol(inv(-2 theta_cov)))
107
+
108
+ C^T C = -2 theta_cov
109
+
110
+ so we need to do forward differentiation, starting with sensitivity (sensitivities marked with .dots.)
111
+ .theta_cov. = dout_dnat2
112
+
113
+ and ending with sensitivity .C.
114
+
115
+ if B = inv(-2 theta_cov) then:
116
+
117
+ .B. = d inv(-2 theta_cov)/dtheta_cov * .theta_cov. = -B (-2 .theta_cov.) B
118
+
119
+ if L = chol(B), B = LL^T then (https://homepages.inf.ed.ac.uk/imurray2/pub/16choldiff/choldiff.pdf):
120
+
121
+ .L. = L phi(L^{-1} .B. (L^{-1})^T) = L phi(2 L^T .theta_cov. L)
122
+
123
+ Then C = inv(L), so
124
+
125
+ .C. = -C .L. C = phi(-2 L^T .theta_cov. L)C
126
+ """
127
+ A = L.transpose(-2, -1) @ dout_dnat2 @ L
128
+ phi = _phi_for_cholesky_(A.mul_(-2))
129
+ dout_dtril = phi @ C
130
+ return dout_dnat1, dout_dtril
@@ -0,0 +1,114 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import warnings
4
+
5
+ import torch
6
+ from linear_operator.operators import RootLinearOperator
7
+
8
+ from ..distributions import MultitaskMultivariateQExponential, MultivariateQExponential
9
+ from ..module import Module
10
+ from ._variational_strategy import _VariationalStrategy
11
+
12
+
13
+ class UncorrelatedMultitaskVariationalStrategy(_VariationalStrategy):
14
+ """
15
+ UncorrelatedMultitaskVariationalStrategy wraps an existing
16
+ :obj:`~qpytorch.variational.VariationalStrategy` to produce vector-valued (multi-task)
17
+ output distributions. Each task will be uncorrelated to one another.
18
+
19
+ The output will either be a :obj:`~qpytorch.distributions.MultitaskMultivariateQExponential` distribution
20
+ (if we wish to evaluate all tasks for each input) or a :obj:`~qpytorch.distributions.MultivariateQExponential`
21
+ (if we wish to evaluate a single task for each input).
22
+
23
+ The base variational strategy is assumed to operate on a batch of QEPs. One of the batch
24
+ dimensions corresponds to the multiple tasks.
25
+
26
+ :param ~qpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy
27
+ :param int num_tasks: Number of tasks. Should correspond to the batch size of task_dim.
28
+ :param int task_dim: (Default: -1) Which batch dimension is the task dimension
29
+ """
30
+
31
+ def __init__(self, base_variational_strategy, num_tasks, task_dim=-1):
32
+ Module.__init__(self)
33
+ self.base_variational_strategy = base_variational_strategy
34
+ self.task_dim = task_dim
35
+ self.num_tasks = num_tasks
36
+
37
+ @property
38
+ def prior_distribution(self):
39
+ return self.base_variational_strategy.prior_distribution
40
+
41
+ @property
42
+ def variational_distribution(self):
43
+ return self.base_variational_strategy.variational_distribution
44
+
45
+ @property
46
+ def variational_params_initialized(self):
47
+ return self.base_variational_strategy.variational_params_initialized
48
+
49
+ def kl_divergence(self):
50
+ return super().kl_divergence().sum(dim=-1)
51
+
52
+ def __call__(self, x, task_indices=None, prior=False, **kwargs):
53
+ r"""
54
+ See :class:`LMCVariationalStrategy`.
55
+ """
56
+ function_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
57
+
58
+ if task_indices is None:
59
+ # Every data point will get an output for each task
60
+ if (
61
+ self.task_dim > 0
62
+ and self.task_dim > len(function_dist.batch_shape)
63
+ or self.task_dim < 0
64
+ and self.task_dim + len(function_dist.batch_shape) < 0
65
+ ):
66
+ return MultitaskMultivariateQExponential.from_repeated_qep(function_dist, num_tasks=self.num_tasks)
67
+ else:
68
+ function_dist = MultitaskMultivariateQExponential.from_batch_qep(function_dist, task_dim=self.task_dim)
69
+ assert function_dist.event_shape[-1] == self.num_tasks
70
+ return function_dist
71
+
72
+ else:
73
+ # Each data point will get a single output corresponding to a single task
74
+
75
+ if self.task_dim > 0:
76
+ raise RuntimeError(f"task_dim must be a negative indexed batch dimension: got {self.task_dim}.")
77
+ num_batch = len(function_dist.batch_shape)
78
+ task_dim = num_batch + self.task_dim
79
+
80
+ # Create a mask to choose specific task assignment
81
+ shape = list(function_dist.batch_shape + function_dist.event_shape)
82
+ shape[task_dim] = 1
83
+ task_indices = task_indices.expand(shape).squeeze(task_dim)
84
+
85
+ # Create a mask to choose specific task assignment
86
+ task_mask = torch.nn.functional.one_hot(task_indices, num_classes=self.num_tasks)
87
+ task_mask = task_mask.permute(*range(0, task_dim), *range(task_dim + 1, num_batch + 1), task_dim)
88
+
89
+ mean = (function_dist.mean * task_mask).sum(task_dim)
90
+ covar = (function_dist.lazy_covariance_matrix * RootLinearOperator(task_mask[..., None])).sum(task_dim)
91
+ return MultivariateQExponential(mean, covar, power=function_dist.power)
92
+
93
+
94
+ class MultitaskVariationalStrategy(UncorrelatedMultitaskVariationalStrategy):
95
+ """
96
+ UncorrelatedMultitaskVariationalStrategy wraps an existing
97
+ :obj:`~qpytorch.variational.VariationalStrategy`
98
+ to produce a :obj:`~qpytorch.variational.MultitaskMultivariateQExponential` distribution.
99
+ All outputs will be uncorrelated to one another.
100
+
101
+ The base variational strategy is assumed to operate on a batch of QEPs. One of the batch
102
+ dimensions corresponds to the multiple tasks.
103
+
104
+ :param ~qpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy
105
+ :param int num_tasks: Number of tasks. Should correspond to the batch size of task_dim.
106
+ :param int task_dim: (Default: -1) Which batch dimension is the task dimension
107
+ """
108
+
109
+ def __init__(self, base_variational_strategy, num_tasks, task_dim=-1):
110
+ warnings.warn(
111
+ "MultitaskVariationalStrategy has been renamed to UncorrelatedMultitaskVariationalStrategy",
112
+ DeprecationWarning,
113
+ )
114
+ super().__init__(base_variational_strategy, num_tasks, task_dim=-1)
@@ -0,0 +1,225 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import math
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import torch
7
+ from linear_operator import to_dense
8
+ from linear_operator.operators import (
9
+ CholLinearOperator,
10
+ DiagLinearOperator,
11
+ LinearOperator,
12
+ PsdSumLinearOperator,
13
+ RootLinearOperator,
14
+ TriangularLinearOperator,
15
+ ZeroLinearOperator,
16
+ )
17
+ from linear_operator.utils.cholesky import psd_safe_cholesky
18
+ from linear_operator.utils.errors import NotPSDError
19
+ from torch import Tensor
20
+
21
+ from .. import settings
22
+ from ..distributions import MultivariateNormal, MultivariateQExponential
23
+ from gpytorch.utils.memoize import add_to_cache, cached
24
+ from ._variational_strategy import _VariationalStrategy
25
+ from .cholesky_variational_distribution import CholeskyVariationalDistribution
26
+
27
+
28
+ class UnwhitenedVariationalStrategy(_VariationalStrategy):
29
+ r"""
30
+ Similar to :obj:`~qpytorch.variational.VariationalStrategy`, but does not perform the
31
+ whitening operation. In almost all cases :obj:`~qpytorch.variational.VariationalStrategy`
32
+ is preferable, with a few exceptions:
33
+
34
+ - When the inducing points are exactly equal to the training points (i.e. :math:`\mathbf Z = \mathbf X`).
35
+ Unwhitened models are faster in this case.
36
+
37
+ - When the number of inducing points is very large (e.g. >2000). Unwhitened models can use CG for faster
38
+ computation.
39
+
40
+ :param ~model: Model this strategy is applied to.
41
+ Typically passed in when the VariationalStrategy is created in the
42
+ __init__ method of the user defined model.
43
+ It should contain power if Q-Exponential distribution is involved in.
44
+ :param inducing_points: Tensor containing a set of inducing
45
+ points to use for variational inference.
46
+ :param variational_distribution: A
47
+ VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
48
+ :param learn_inducing_locations: (default True): Whether or not
49
+ the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they
50
+ parameters of the model).
51
+ :param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability
52
+ """
53
+ has_fantasy_strategy = True
54
+
55
+ @cached(name="cholesky_factor", ignore_args=True)
56
+ def _cholesky_factor(self, induc_induc_covar: LinearOperator) -> TriangularLinearOperator:
57
+ # Maybe used - if we're not using CG
58
+ L = psd_safe_cholesky(to_dense(induc_induc_covar))
59
+ return TriangularLinearOperator(L)
60
+
61
+ @property
62
+ @cached(name="prior_distribution_memo")
63
+ def prior_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
64
+ out = self.model.forward(self.inducing_points)
65
+ if hasattr(self.model, 'power'):
66
+ res = MultivariateQExponential(out.mean, out.lazy_covariance_matrix.add_jitter(), power=self.model.power)
67
+ else:
68
+ res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter())
69
+ return res
70
+
71
+ @property
72
+ @cached(name="pseudo_points_memo")
73
+ def pseudo_points(self) -> Tuple[Tensor, Tensor]:
74
+ # TODO: implement for other distributions
75
+ # retrieve the variational mean, m and covariance matrix, S.
76
+ if not isinstance(self._variational_distribution, CholeskyVariationalDistribution):
77
+ raise NotImplementedError(
78
+ "Only CholeskyVariationalDistribution has pseudo-point support currently, ",
79
+ "but your _variational_distribution is a ",
80
+ self._variational_distribution.__name__,
81
+ )
82
+
83
+ # retrieve the variational mean, m and covariance matrix, S.
84
+ var_cov_root = TriangularLinearOperator(self._variational_distribution.chol_variational_covar)
85
+ var_cov = CholLinearOperator(var_cov_root)
86
+ var_mean = self.variational_distribution.mean # .unsqueeze(-1)
87
+ if var_mean.shape[-1] != 1:
88
+ var_mean = var_mean.unsqueeze(-1)
89
+
90
+ # R = K - S
91
+ Kmm = self.model.covar_module(self.inducing_points)
92
+ res = Kmm - var_cov
93
+
94
+ cov_diff = res
95
+
96
+ # D_a = (S^{-1} - K^{-1})^{-1} = S + S R^{-1} S
97
+ # note that in the whitened case R = I - S, unwhitened R = K - S
98
+ # we compute (R R^{T})^{-1} R^T S for stability reasons as R is probably not PSD.
99
+ eval_lhs = var_cov.to_dense()
100
+ eval_rhs = cov_diff.transpose(-1, -2).matmul(eval_lhs)
101
+ inner_term = cov_diff.matmul(cov_diff.transpose(-1, -2))
102
+ # TODO: flag the jitter here
103
+ inner_solve = inner_term.add_jitter(self.jitter_val).solve(eval_rhs, eval_lhs.transpose(-1, -2))
104
+ inducing_covar = var_cov + inner_solve
105
+
106
+ # mean term: D_a S^{-1} m
107
+ # unwhitened: (S - S R^{-1} S) S^{-1} m = (I - S R^{-1}) m
108
+ rhs = cov_diff.transpose(-1, -2).matmul(var_mean)
109
+ inner_rhs_mean_solve = inner_term.add_jitter(self.jitter_val).solve(rhs)
110
+ pseudo_target_mean = var_mean + var_cov.matmul(inner_rhs_mean_solve)
111
+
112
+ # ensure inducing covar is psd
113
+ try:
114
+ pseudo_target_covar = CholLinearOperator(inducing_covar.add_jitter(self.jitter_val).cholesky()).to_dense()
115
+ except NotPSDError:
116
+ from linear_operator.operators import DiagLinearOperator
117
+
118
+ evals, evecs = torch.linalg.eigh(inducing_covar)
119
+ pseudo_target_covar = (
120
+ evecs.matmul(DiagLinearOperator(evals + self.jitter_val)).matmul(evecs.transpose(-1, -2)).to_dense()
121
+ )
122
+
123
+ return pseudo_target_covar, pseudo_target_mean
124
+
125
+ def forward(
126
+ self,
127
+ x: Tensor,
128
+ inducing_points: Tensor,
129
+ inducing_values: Tensor,
130
+ variational_inducing_covar: Optional[LinearOperator] = None,
131
+ **kwargs,
132
+ ) -> Union[MultivariateNormal, MultivariateQExponential]:
133
+ # If our points equal the inducing points, we're done
134
+ if torch.equal(x, inducing_points):
135
+ if variational_inducing_covar is None:
136
+ raise RuntimeError
137
+ else:
138
+ if hasattr(self.model, 'power'):
139
+ return MultivariateQExponential(inducing_values, variational_inducing_covar, power=self.model.power)
140
+ else:
141
+ return MultivariateNormal(inducing_values, variational_inducing_covar)
142
+
143
+ # Otherwise, we have to marginalize
144
+ num_induc = inducing_points.size(-2)
145
+ full_inputs = torch.cat([inducing_points, x], dim=-2)
146
+ full_output = self.model.forward(full_inputs)
147
+ full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix
148
+
149
+ # Mean terms
150
+ test_mean = full_mean[..., num_induc:]
151
+ induc_mean = full_mean[..., :num_induc]
152
+ mean_diff = (inducing_values - induc_mean).unsqueeze(-1)
153
+
154
+ # Covariance terms
155
+ induc_induc_covar = full_covar[..., :num_induc, :num_induc].add_jitter(self.jitter_val)
156
+ induc_data_covar = full_covar[..., :num_induc, num_induc:].to_dense()
157
+ data_data_covar = full_covar[..., num_induc:, num_induc:]
158
+
159
+ # Compute Cholesky factorization of inducing covariance matrix
160
+ if settings.fast_computations.log_prob.off() or (num_induc <= settings.max_cholesky_size.value()):
161
+ induc_induc_covar = CholLinearOperator(self._cholesky_factor(induc_induc_covar))
162
+
163
+ # If we are making predictions and don't need variances, we can do things very quickly.
164
+ if not self.training and settings.skip_posterior_variances.on():
165
+ self._mean_cache = induc_induc_covar.solve(mean_diff).detach()
166
+ predictive_mean = torch.add(
167
+ test_mean, induc_data_covar.transpose(-2, -1).matmul(self._mean_cache).squeeze(-1)
168
+ )
169
+ predictive_covar = ZeroLinearOperator(test_mean.size(-1), test_mean.size(-1))
170
+ if hasattr(self.model, 'power'):
171
+ return MultivariateQExponential(predictive_mean, predictive_covar, power=self.model.power)
172
+ else:
173
+ return MultivariateNormal(predictive_mean, predictive_covar)
174
+
175
+ # Expand everything to the right size
176
+ shapes = [mean_diff.shape[:-1], induc_data_covar.shape[:-1], induc_induc_covar.shape[:-1]]
177
+ root_variational_covar = None
178
+ if variational_inducing_covar is not None:
179
+ root_variational_covar = variational_inducing_covar.root_decomposition().root.to_dense()
180
+ shapes.append(root_variational_covar.shape[:-1])
181
+ shape = torch.broadcast_shapes(*shapes)
182
+ mean_diff = mean_diff.expand(*shape, mean_diff.size(-1))
183
+ induc_data_covar = induc_data_covar.expand(*shape, induc_data_covar.size(-1))
184
+ induc_induc_covar = induc_induc_covar.expand(*shape, induc_induc_covar.size(-1))
185
+ if variational_inducing_covar is not None:
186
+ root_variational_covar = root_variational_covar.expand(*shape, root_variational_covar.size(-1))
187
+
188
+ # Cache the kernel matrix with the cached CG calls
189
+ if self.training:
190
+ if hasattr(self.model, 'power'):
191
+ prior_dist = MultivariateQExponential(induc_mean, induc_induc_covar, power=self.model.power)
192
+ else:
193
+ prior_dist = MultivariateNormal(induc_mean, induc_induc_covar)
194
+ add_to_cache(self, "prior_distribution_memo", prior_dist)
195
+
196
+ # Compute predictive mean
197
+ if variational_inducing_covar is None:
198
+ left_tensors = mean_diff
199
+ else:
200
+ left_tensors = torch.cat([mean_diff, root_variational_covar], -1)
201
+ inv_products = induc_induc_covar.solve(induc_data_covar, left_tensors.transpose(-1, -2))
202
+ predictive_mean = torch.add(test_mean, inv_products[..., 0, :])
203
+
204
+ # Compute covariance
205
+ if self.training:
206
+ interp_data_data_var, _ = induc_induc_covar.inv_quad_logdet(
207
+ induc_data_covar, logdet=False, reduce_inv_quad=False
208
+ )
209
+ data_covariance = DiagLinearOperator(
210
+ (data_data_covar.diagonal(dim1=-1, dim2=-2) - interp_data_data_var).clamp(0, math.inf)
211
+ )
212
+ else:
213
+ neg_induc_data_data_covar = torch.matmul(
214
+ induc_data_covar.transpose(-1, -2).mul(-1), induc_induc_covar.solve(induc_data_covar)
215
+ )
216
+ data_covariance = data_data_covar + neg_induc_data_data_covar
217
+ predictive_covar = PsdSumLinearOperator(
218
+ RootLinearOperator(inv_products[..., 1:, :].transpose(-1, -2)), data_covariance
219
+ )
220
+
221
+ # Done!
222
+ if hasattr(self.model, 'power'):
223
+ return MultivariateQExponential(predictive_mean, predictive_covar, power=self.model.power)
224
+ else:
225
+ return MultivariateNormal(predictive_mean, predictive_covar)