qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import math
|
|
5
|
+
from typing import Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from linear_operator import to_dense
|
|
9
|
+
from linear_operator.operators import (
|
|
10
|
+
DiagLinearOperator,
|
|
11
|
+
LowRankRootAddedDiagLinearOperator,
|
|
12
|
+
LowRankRootLinearOperator,
|
|
13
|
+
MatmulLinearOperator,
|
|
14
|
+
)
|
|
15
|
+
from linear_operator.utils.cholesky import psd_safe_cholesky
|
|
16
|
+
from torch import Tensor
|
|
17
|
+
|
|
18
|
+
from .. import settings
|
|
19
|
+
from ..distributions import MultivariateNormal
|
|
20
|
+
from ..likelihoods import Likelihood
|
|
21
|
+
from ..mlls import InducingPointKernelAddedLossTerm
|
|
22
|
+
from ..models import exact_prediction_strategies
|
|
23
|
+
from .kernel import Kernel
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class InducingPointKernel(Kernel):
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
base_kernel: Kernel,
|
|
30
|
+
inducing_points: Tensor,
|
|
31
|
+
likelihood: Likelihood,
|
|
32
|
+
active_dims: Optional[Tuple[int, ...]] = None,
|
|
33
|
+
):
|
|
34
|
+
super(InducingPointKernel, self).__init__(active_dims=active_dims)
|
|
35
|
+
self.base_kernel = base_kernel
|
|
36
|
+
self.likelihood = likelihood
|
|
37
|
+
|
|
38
|
+
if inducing_points.ndimension() == 1:
|
|
39
|
+
inducing_points = inducing_points.unsqueeze(-1)
|
|
40
|
+
|
|
41
|
+
self.register_parameter(name="inducing_points", parameter=torch.nn.Parameter(inducing_points))
|
|
42
|
+
self.register_added_loss_term("inducing_point_loss_term")
|
|
43
|
+
|
|
44
|
+
def _clear_cache(self):
|
|
45
|
+
if hasattr(self, "_cached_kernel_mat"):
|
|
46
|
+
del self._cached_kernel_mat
|
|
47
|
+
if hasattr(self, "_cached_kernel_inv_root"):
|
|
48
|
+
del self._cached_kernel_inv_root
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def _inducing_mat(self):
|
|
52
|
+
if not self.training and hasattr(self, "_cached_kernel_mat"):
|
|
53
|
+
return self._cached_kernel_mat
|
|
54
|
+
else:
|
|
55
|
+
res = to_dense(self.base_kernel(self.inducing_points, self.inducing_points))
|
|
56
|
+
if not self.training:
|
|
57
|
+
self._cached_kernel_mat = res
|
|
58
|
+
return res
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def _inducing_inv_root(self):
|
|
62
|
+
if not self.training and hasattr(self, "_cached_kernel_inv_root"):
|
|
63
|
+
return self._cached_kernel_inv_root
|
|
64
|
+
else:
|
|
65
|
+
chol = psd_safe_cholesky(self._inducing_mat, upper=True)
|
|
66
|
+
eye = torch.eye(chol.size(-1), device=chol.device, dtype=chol.dtype)
|
|
67
|
+
inv_root = torch.linalg.solve_triangular(chol, eye, upper=True)
|
|
68
|
+
|
|
69
|
+
res = inv_root
|
|
70
|
+
if not self.training:
|
|
71
|
+
self._cached_kernel_inv_root = res
|
|
72
|
+
return res
|
|
73
|
+
|
|
74
|
+
def _get_covariance(self, x1, x2):
|
|
75
|
+
k_ux1 = to_dense(self.base_kernel(x1, self.inducing_points))
|
|
76
|
+
if torch.equal(x1, x2):
|
|
77
|
+
covar = LowRankRootLinearOperator(k_ux1.matmul(self._inducing_inv_root))
|
|
78
|
+
|
|
79
|
+
# Diagonal correction for predictive posterior
|
|
80
|
+
if not self.training and settings.sgpr_diagonal_correction.on():
|
|
81
|
+
correction = (self.base_kernel(x1, x2, diag=True) - covar.diagonal(dim1=-1, dim2=-2)).clamp(0, math.inf)
|
|
82
|
+
covar = LowRankRootAddedDiagLinearOperator(covar, DiagLinearOperator(correction))
|
|
83
|
+
else:
|
|
84
|
+
k_ux2 = to_dense(self.base_kernel(x2, self.inducing_points))
|
|
85
|
+
covar = MatmulLinearOperator(
|
|
86
|
+
k_ux1.matmul(self._inducing_inv_root), k_ux2.matmul(self._inducing_inv_root).transpose(-1, -2)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
return covar
|
|
90
|
+
|
|
91
|
+
def _covar_diag(self, inputs):
|
|
92
|
+
if inputs.ndimension() == 1:
|
|
93
|
+
inputs = inputs.unsqueeze(1)
|
|
94
|
+
|
|
95
|
+
# Get diagonal of covar
|
|
96
|
+
covar_diag = to_dense(self.base_kernel(inputs, diag=True))
|
|
97
|
+
return DiagLinearOperator(covar_diag)
|
|
98
|
+
|
|
99
|
+
def forward(self, x1, x2, diag=False, **kwargs):
|
|
100
|
+
covar = self._get_covariance(x1, x2)
|
|
101
|
+
|
|
102
|
+
if self.training:
|
|
103
|
+
if not torch.equal(x1, x2):
|
|
104
|
+
raise RuntimeError("x1 should equal x2 in training mode")
|
|
105
|
+
zero_mean = torch.zeros_like(x1.select(-1, 0))
|
|
106
|
+
new_added_loss_term = InducingPointKernelAddedLossTerm(
|
|
107
|
+
MultivariateNormal(zero_mean, self._covar_diag(x1)),
|
|
108
|
+
MultivariateNormal(zero_mean, covar),
|
|
109
|
+
self.likelihood,
|
|
110
|
+
)
|
|
111
|
+
self.update_added_loss_term("inducing_point_loss_term", new_added_loss_term)
|
|
112
|
+
|
|
113
|
+
if diag:
|
|
114
|
+
return covar.diagonal(dim1=-1, dim2=-2)
|
|
115
|
+
else:
|
|
116
|
+
return covar
|
|
117
|
+
|
|
118
|
+
def num_outputs_per_input(self, x1, x2):
|
|
119
|
+
return self.base_kernel.num_outputs_per_input(x1, x2)
|
|
120
|
+
|
|
121
|
+
def __deepcopy__(self, memo):
|
|
122
|
+
replace_inv_root = False
|
|
123
|
+
replace_kernel_mat = False
|
|
124
|
+
|
|
125
|
+
if hasattr(self, "_cached_kernel_inv_root"):
|
|
126
|
+
replace_inv_root = True
|
|
127
|
+
kernel_inv_root = self._cached_kernel_inv_root
|
|
128
|
+
if hasattr(self, "_cached_kernel_mat"):
|
|
129
|
+
replace_kernel_mat = True
|
|
130
|
+
kernel_mat = self._cached_kernel_mat
|
|
131
|
+
|
|
132
|
+
cp = self.__class__(
|
|
133
|
+
base_kernel=copy.deepcopy(self.base_kernel),
|
|
134
|
+
inducing_points=copy.deepcopy(self.inducing_points),
|
|
135
|
+
likelihood=self.likelihood,
|
|
136
|
+
active_dims=self.active_dims,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if replace_inv_root:
|
|
140
|
+
cp._cached_kernel_inv_root = kernel_inv_root
|
|
141
|
+
|
|
142
|
+
if replace_kernel_mat:
|
|
143
|
+
cp._cached_kernel_mat = kernel_mat
|
|
144
|
+
|
|
145
|
+
return cp
|
|
146
|
+
|
|
147
|
+
def prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood):
|
|
148
|
+
# Allow for fast variances
|
|
149
|
+
return exact_prediction_strategies.SGPRPredictionStrategy(
|
|
150
|
+
train_inputs, train_prior_dist, train_labels, likelihood
|
|
151
|
+
)
|