qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from ..distributions import Delta, Distribution, MultivariateNormal, MultivariateQExponential
|
|
8
|
+
from ._variational_distribution import _VariationalDistribution
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DeltaVariationalDistribution(_VariationalDistribution):
|
|
12
|
+
"""
|
|
13
|
+
This :obj:`~qpytorch.variational._VariationalDistribution` object replaces a variational distribution
|
|
14
|
+
with a single particle. It is equivalent to doing MAP inference.
|
|
15
|
+
|
|
16
|
+
:param int num_inducing_points: Size of the variational distribution. This implies that the variational mean
|
|
17
|
+
should be this size, and the variational covariance matrix should have this many rows and columns.
|
|
18
|
+
:param batch_shape: Specifies an optional batch size
|
|
19
|
+
for the variational parameters. This is useful for example when doing additive variational inference.
|
|
20
|
+
:type batch_shape: :obj:`torch.Size`, optional
|
|
21
|
+
:param float mean_init_std: (Default: 1e-3) Standard deviation of gaussian (q-exponential) noise to add to the mean initialization.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
num_inducing_points: int,
|
|
27
|
+
batch_shape: torch.Size = torch.Size([]),
|
|
28
|
+
mean_init_std: float = 1e-3,
|
|
29
|
+
**kwargs,
|
|
30
|
+
):
|
|
31
|
+
super().__init__(num_inducing_points=num_inducing_points, batch_shape=batch_shape, mean_init_std=mean_init_std)
|
|
32
|
+
mean_init = torch.zeros(num_inducing_points)
|
|
33
|
+
mean_init = mean_init.repeat(*batch_shape, 1)
|
|
34
|
+
self.register_parameter(name="variational_mean", parameter=torch.nn.Parameter(mean_init))
|
|
35
|
+
|
|
36
|
+
def forward(self) -> Distribution:
|
|
37
|
+
return Delta(self.variational_mean)
|
|
38
|
+
|
|
39
|
+
def initialize_variational_distribution(self, prior_dist: Union[MultivariateNormal, MultivariateQExponential]) -> None:
|
|
40
|
+
self.variational_mean.data.copy_(prior_dist.mean)
|
|
41
|
+
self.variational_mean.data.add_(torch.randn_like(prior_dist.mean), alpha=self.mean_init_std)
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from linear_operator.operators import InterpolatedLinearOperator
|
|
5
|
+
from linear_operator.utils.interpolation import left_interp
|
|
6
|
+
|
|
7
|
+
from ..distributions import MultivariateNormal, MultivariateQExponential
|
|
8
|
+
from gpytorch.utils.interpolation import Interpolation
|
|
9
|
+
from gpytorch.utils.memoize import cached
|
|
10
|
+
from ._variational_strategy import _VariationalStrategy
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GridInterpolationVariationalStrategy(_VariationalStrategy):
|
|
14
|
+
r"""
|
|
15
|
+
This strategy constrains the inducing points to a grid and applies a deterministic
|
|
16
|
+
relationship between :math:`\mathbf f` and :math:`\mathbf u`.
|
|
17
|
+
It was introduced by `Wilson et al. (2016)`_.
|
|
18
|
+
|
|
19
|
+
Here, the inducing points are not learned. Instead, the strategy
|
|
20
|
+
automatically creates inducing points based on a set of grid sizes and grid
|
|
21
|
+
bounds.
|
|
22
|
+
|
|
23
|
+
.. _Wilson et al. (2016):
|
|
24
|
+
https://arxiv.org/abs/1611.00336
|
|
25
|
+
|
|
26
|
+
:param ~gpytorch.models.ApproximateGP (~qpytorch.models.ApproximateQEP) model: Model this strategy is applied to.
|
|
27
|
+
Typically passed in when the VariationalStrategy is created in the
|
|
28
|
+
__init__ method of the user defined model.
|
|
29
|
+
It should contain power if Q-Exponential distribution is involved in.
|
|
30
|
+
:param int grid_size: Size of the grid
|
|
31
|
+
:param list grid_bounds: Bounds of each dimension of the grid (should be a list of (float, float) tuples)
|
|
32
|
+
:param ~qpytorch.variational.VariationalDistribution variational_distribution: A
|
|
33
|
+
VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, model, grid_size, grid_bounds, variational_distribution):
|
|
37
|
+
grid = torch.zeros(grid_size, len(grid_bounds))
|
|
38
|
+
for i in range(len(grid_bounds)):
|
|
39
|
+
grid_diff = float(grid_bounds[i][1] - grid_bounds[i][0]) / (grid_size - 2)
|
|
40
|
+
grid[:, i] = torch.linspace(grid_bounds[i][0] - grid_diff, grid_bounds[i][1] + grid_diff, grid_size)
|
|
41
|
+
|
|
42
|
+
inducing_points = torch.zeros(int(pow(grid_size, len(grid_bounds))), len(grid_bounds))
|
|
43
|
+
prev_points = None
|
|
44
|
+
for i in range(len(grid_bounds)):
|
|
45
|
+
for j in range(grid_size):
|
|
46
|
+
inducing_points[j * grid_size**i : (j + 1) * grid_size**i, i].fill_(grid[j, i])
|
|
47
|
+
if prev_points is not None:
|
|
48
|
+
inducing_points[j * grid_size**i : (j + 1) * grid_size**i, :i].copy_(prev_points)
|
|
49
|
+
prev_points = inducing_points[: grid_size ** (i + 1), : (i + 1)]
|
|
50
|
+
|
|
51
|
+
super(GridInterpolationVariationalStrategy, self).__init__(
|
|
52
|
+
model, inducing_points, variational_distribution, learn_inducing_locations=False
|
|
53
|
+
)
|
|
54
|
+
object.__setattr__(self, "model", model)
|
|
55
|
+
|
|
56
|
+
self.register_buffer("grid", grid)
|
|
57
|
+
|
|
58
|
+
def _compute_grid(self, inputs):
|
|
59
|
+
n_data, n_dimensions = inputs.size(-2), inputs.size(-1)
|
|
60
|
+
batch_shape = inputs.shape[:-2]
|
|
61
|
+
|
|
62
|
+
inputs = inputs.reshape(-1, n_dimensions)
|
|
63
|
+
interp_indices, interp_values = Interpolation().interpolate(self.grid, inputs)
|
|
64
|
+
interp_indices = interp_indices.view(*batch_shape, n_data, -1)
|
|
65
|
+
interp_values = interp_values.view(*batch_shape, n_data, -1)
|
|
66
|
+
|
|
67
|
+
if (interp_indices.dim() - 2) != len(self._variational_distribution.batch_shape):
|
|
68
|
+
batch_shape = torch.broadcast_shapes(interp_indices.shape[:-2], self._variational_distribution.batch_shape)
|
|
69
|
+
interp_indices = interp_indices.expand(*batch_shape, *interp_indices.shape[-2:])
|
|
70
|
+
interp_values = interp_values.expand(*batch_shape, *interp_values.shape[-2:])
|
|
71
|
+
return interp_indices, interp_values
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
@cached(name="prior_distribution_memo")
|
|
75
|
+
def prior_distribution(self):
|
|
76
|
+
out = self.model.forward(self.inducing_points)
|
|
77
|
+
# TODO: investigate why smaller than 1e-3 breaks some tests
|
|
78
|
+
if hasattr(self.model, 'power'):
|
|
79
|
+
res = MultivariateQExponential(out.mean, out.lazy_covariance_matrix.add_jitter(1e-3), power=self.model.power)
|
|
80
|
+
else:
|
|
81
|
+
res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter(1e-3))
|
|
82
|
+
return res
|
|
83
|
+
|
|
84
|
+
def forward(self, x, inducing_points, inducing_values, variational_inducing_covar=None):
|
|
85
|
+
if variational_inducing_covar is None:
|
|
86
|
+
raise RuntimeError(
|
|
87
|
+
"GridInterpolationVariationalStrategy is only compatible with Gaussian (Q-Exponential) variational "
|
|
88
|
+
f"distributions. Got ({self.variational_distribution.__class__.__name__}."
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
variational_distribution = self.variational_distribution
|
|
92
|
+
|
|
93
|
+
# Get interpolations
|
|
94
|
+
interp_indices, interp_values = self._compute_grid(x)
|
|
95
|
+
|
|
96
|
+
# Compute test mean
|
|
97
|
+
# Left multiply samples by interpolation matrix
|
|
98
|
+
predictive_mean = left_interp(interp_indices, interp_values, inducing_values.unsqueeze(-1))
|
|
99
|
+
predictive_mean = predictive_mean.squeeze(-1)
|
|
100
|
+
|
|
101
|
+
# Compute test covar
|
|
102
|
+
predictive_covar = InterpolatedLinearOperator(
|
|
103
|
+
variational_distribution.lazy_covariance_matrix,
|
|
104
|
+
interp_indices,
|
|
105
|
+
interp_values,
|
|
106
|
+
interp_indices,
|
|
107
|
+
interp_values,
|
|
108
|
+
)
|
|
109
|
+
if hasattr(self.model, 'power'):
|
|
110
|
+
output = MultivariateQExponential(predictive_mean, predictive_covar, power=self.model.power)
|
|
111
|
+
else:
|
|
112
|
+
output = MultivariateNormal(predictive_mean, predictive_covar)
|
|
113
|
+
return output
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from linear_operator.operators import RootLinearOperator
|
|
7
|
+
|
|
8
|
+
from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
|
|
9
|
+
from ..module import Module
|
|
10
|
+
from ._variational_strategy import _VariationalStrategy
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class IndependentMultitaskVariationalStrategy(_VariationalStrategy):
|
|
14
|
+
"""
|
|
15
|
+
IndependentMultitaskVariationalStrategy wraps an existing
|
|
16
|
+
:obj:`~qpytorch.variational.VariationalStrategy` to produce vector-valued (multi-task)
|
|
17
|
+
output distributions. Each task will be independent of one another.
|
|
18
|
+
|
|
19
|
+
The output will either be a :obj:`~gpytorch.distributions.MultitaskMultivariateNormal` distribution
|
|
20
|
+
(if we wish to evaluate all tasks for each input) or a :obj:`~gpytorch.distributions.MultivariateNormal`
|
|
21
|
+
(if we wish to evaluate a single task for each input).
|
|
22
|
+
|
|
23
|
+
The base variational strategy is assumed to operate on a batch of GPs. One of the batch
|
|
24
|
+
dimensions corresponds to the multiple tasks.
|
|
25
|
+
|
|
26
|
+
:param ~qpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy
|
|
27
|
+
:param int num_tasks: Number of tasks. Should correspond to the batch size of task_dim.
|
|
28
|
+
:param int task_dim: (Default: -1) Which batch dimension is the task dimension
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, base_variational_strategy, num_tasks, task_dim=-1):
|
|
32
|
+
Module.__init__(self)
|
|
33
|
+
self.base_variational_strategy = base_variational_strategy
|
|
34
|
+
self.task_dim = task_dim
|
|
35
|
+
self.num_tasks = num_tasks
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def prior_distribution(self):
|
|
39
|
+
return self.base_variational_strategy.prior_distribution
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def variational_distribution(self):
|
|
43
|
+
return self.base_variational_strategy.variational_distribution
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def variational_params_initialized(self):
|
|
47
|
+
return self.base_variational_strategy.variational_params_initialized
|
|
48
|
+
|
|
49
|
+
def kl_divergence(self):
|
|
50
|
+
return super().kl_divergence().sum(dim=-1)
|
|
51
|
+
|
|
52
|
+
def __call__(self, x, task_indices=None, prior=False, **kwargs):
|
|
53
|
+
r"""
|
|
54
|
+
See :class:`LMCVariationalStrategy`.
|
|
55
|
+
"""
|
|
56
|
+
function_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
|
|
57
|
+
|
|
58
|
+
if task_indices is None:
|
|
59
|
+
# Every data point will get an output for each task
|
|
60
|
+
if (
|
|
61
|
+
self.task_dim > 0
|
|
62
|
+
and self.task_dim > len(function_dist.batch_shape)
|
|
63
|
+
or self.task_dim < 0
|
|
64
|
+
and self.task_dim + len(function_dist.batch_shape) < 0
|
|
65
|
+
):
|
|
66
|
+
return MultitaskMultivariateNormal.from_repeated_mvn(function_dist, num_tasks=self.num_tasks)
|
|
67
|
+
else:
|
|
68
|
+
function_dist = MultitaskMultivariateNormal.from_batch_mvn(function_dist, task_dim=self.task_dim)
|
|
69
|
+
assert function_dist.event_shape[-1] == self.num_tasks
|
|
70
|
+
return function_dist
|
|
71
|
+
|
|
72
|
+
else:
|
|
73
|
+
# Each data point will get a single output corresponding to a single task
|
|
74
|
+
|
|
75
|
+
if self.task_dim > 0:
|
|
76
|
+
raise RuntimeError(f"task_dim must be a negative indexed batch dimension: got {self.task_dim}.")
|
|
77
|
+
num_batch = len(function_dist.batch_shape)
|
|
78
|
+
task_dim = num_batch + self.task_dim
|
|
79
|
+
|
|
80
|
+
# Create a mask to choose specific task assignment
|
|
81
|
+
shape = list(function_dist.batch_shape + function_dist.event_shape)
|
|
82
|
+
shape[task_dim] = 1
|
|
83
|
+
task_indices = task_indices.expand(shape).squeeze(task_dim)
|
|
84
|
+
|
|
85
|
+
# Create a mask to choose specific task assignment
|
|
86
|
+
task_mask = torch.nn.functional.one_hot(task_indices, num_classes=self.num_tasks)
|
|
87
|
+
task_mask = task_mask.permute(*range(0, task_dim), *range(task_dim + 1, num_batch + 1), task_dim)
|
|
88
|
+
|
|
89
|
+
mean = (function_dist.mean * task_mask).sum(task_dim)
|
|
90
|
+
covar = (function_dist.lazy_covariance_matrix * RootLinearOperator(task_mask[..., None])).sum(task_dim)
|
|
91
|
+
return MultivariateNormal(mean, covar)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class MultitaskVariationalStrategy(IndependentMultitaskVariationalStrategy):
|
|
95
|
+
"""
|
|
96
|
+
IndependentMultitaskVariationalStrategy wraps an existing
|
|
97
|
+
:obj:`~qpytorch.variational.VariationalStrategy`
|
|
98
|
+
to produce a :obj:`~gpytorch.variational.MultitaskMultivariateNormal` distribution.
|
|
99
|
+
All outputs will be independent of one another.
|
|
100
|
+
|
|
101
|
+
The base variational strategy is assumed to operate on a batch of GPs. One of the batch
|
|
102
|
+
dimensions corresponds to the multiple tasks.
|
|
103
|
+
|
|
104
|
+
:param ~qpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy
|
|
105
|
+
:param int num_tasks: Number of tasks. Should correspond to the batch size of task_dim.
|
|
106
|
+
:param int task_dim: (Default: -1) Which batch dimension is the task dimension
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(self, base_variational_strategy, num_tasks, task_dim=-1):
|
|
110
|
+
warnings.warn(
|
|
111
|
+
"MultitaskVariationalStrategy has been renamed to IndependentMultitaskVariationalStrategy",
|
|
112
|
+
DeprecationWarning,
|
|
113
|
+
)
|
|
114
|
+
super().__init__(base_variational_strategy, num_tasks, task_dim=-1)
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from linear_operator.operators import KroneckerProductLinearOperator, RootLinearOperator
|
|
7
|
+
from linear_operator.utils.interpolation import left_interp
|
|
8
|
+
from torch import LongTensor, Tensor
|
|
9
|
+
|
|
10
|
+
from .. import settings
|
|
11
|
+
from ..distributions import MultitaskMultivariateNormal, MultitaskMultivariateQExponential, MultivariateNormal, MultivariateQExponential
|
|
12
|
+
from ..module import Module
|
|
13
|
+
from ._variational_strategy import _VariationalStrategy
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _select_lmc_coefficients(lmc_coefficients: torch.Tensor, indices: torch.LongTensor) -> torch.Tensor:
|
|
17
|
+
"""
|
|
18
|
+
Given a list of indices for ... x N datapoints,
|
|
19
|
+
select the row from lmc_coefficient that corresponds to each datapoint
|
|
20
|
+
|
|
21
|
+
lmc_coefficients: torch.Tensor ... x num_latents x ... x num_tasks
|
|
22
|
+
indices: torch.Tesnor ... x N
|
|
23
|
+
"""
|
|
24
|
+
batch_shape = torch.broadcast_shapes(lmc_coefficients.shape[:-1], indices.shape[:-1])
|
|
25
|
+
|
|
26
|
+
# We will use the left_interp helper to do the indexing
|
|
27
|
+
lmc_coefficients = lmc_coefficients.expand(*batch_shape, lmc_coefficients.shape[-1])[..., None]
|
|
28
|
+
indices = indices.expand(*batch_shape, indices.shape[-1])[..., None]
|
|
29
|
+
res = left_interp(
|
|
30
|
+
indices,
|
|
31
|
+
torch.ones(indices.shape, dtype=torch.long, device=indices.device),
|
|
32
|
+
lmc_coefficients,
|
|
33
|
+
).squeeze(-1)
|
|
34
|
+
return res
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class LMCVariationalStrategy(_VariationalStrategy):
|
|
38
|
+
r"""
|
|
39
|
+
LMCVariationalStrategy is an implementation of the "Linear Model of Coregionalization"
|
|
40
|
+
for multitask GPs (QEPs). This model assumes that there are :math:`Q` latent functions
|
|
41
|
+
:math:`\mathbf g(\cdot) = [g^{(1)}(\cdot), \ldots, g^{(q)}(\cdot)]`,
|
|
42
|
+
each of which is modelled by a GP (QEP).
|
|
43
|
+
The output functions (tasks) are linear combination of the latent functions:
|
|
44
|
+
|
|
45
|
+
.. math::
|
|
46
|
+
|
|
47
|
+
f_{\text{task } i}( \mathbf x) = \sum_{q=1}^Q a_i^{(q)} g^{(q)} ( \mathbf x )
|
|
48
|
+
|
|
49
|
+
LMCVariationalStrategy wraps an existing :obj:`~qpytorch.variational.VariationalStrategy`.
|
|
50
|
+
The output will either be a :obj:`~gpytorch.distributions.MultitaskMultivariateNormal` (:obj:`~qpytorch.distributions.MultitaskMultivariateQExponential`) distribution
|
|
51
|
+
(if we wish to evaluate all tasks for each input) or a :obj:`~gpytorch.distributions.MultivariateNormal` (:obj:`~qpytorch.distributions.MultivariateQExponential`)
|
|
52
|
+
(if we wish to evaluate a single task for each input).
|
|
53
|
+
|
|
54
|
+
The base variational strategy is assumed to operate on a multi-batch of GPs (QEPs), where one
|
|
55
|
+
of the batch dimensions corresponds to the latent function dimension.
|
|
56
|
+
|
|
57
|
+
.. note::
|
|
58
|
+
|
|
59
|
+
The batch shape of the base :obj:`~qpytorch.variational.VariationalStrategy` does not
|
|
60
|
+
necessarily have to correspond to the batch shape of the underlying GP (QEP) objects.
|
|
61
|
+
|
|
62
|
+
For example, if the base variational strategy has a batch shape of `[3]` (corresponding
|
|
63
|
+
to 3 latent functions), the GP (QEP) kernel object could have a batch shape of `[3]` or no
|
|
64
|
+
batch shape. This would correspond to each of the latent functions having different kernels
|
|
65
|
+
or the same kernel, respectivly.
|
|
66
|
+
|
|
67
|
+
Example:
|
|
68
|
+
>>> class LMCMultitaskGP(qpytorch.models.ApproximateGP):
|
|
69
|
+
>>> '''
|
|
70
|
+
>>> 3 latent functions
|
|
71
|
+
>>> 5 output dimensions (tasks)
|
|
72
|
+
>>> '''
|
|
73
|
+
>>> def __init__(self):
|
|
74
|
+
>>> # Each latent function shares the same inducing points
|
|
75
|
+
>>> # We'll have 32 inducing points, and let's assume the input dimensionality is 2
|
|
76
|
+
>>> inducing_points = torch.randn(32, 2)
|
|
77
|
+
>>>
|
|
78
|
+
>>> # The variational parameters have a batch_shape of [3] - for 3 latent functions
|
|
79
|
+
>>> variational_distribution = qpytorch.variational.MeanFieldVariationalDistribution(
|
|
80
|
+
>>> inducing_points.size(-1), batch_shape=torch.Size([3]),
|
|
81
|
+
>>> )
|
|
82
|
+
>>> variational_strategy = qpytorch.variational.LMCVariationalStrategy(
|
|
83
|
+
>>> qpytorch.variational.VariationalStrategy(
|
|
84
|
+
>>> self, inducing_points, variational_distribution, learn_inducing_locations=True,
|
|
85
|
+
>>> ),
|
|
86
|
+
>>> num_tasks=5,
|
|
87
|
+
>>> num_latents=3,
|
|
88
|
+
>>> latent_dim=-1,
|
|
89
|
+
>>> )
|
|
90
|
+
>>>
|
|
91
|
+
>>> # Each latent function has its own mean/kernel function
|
|
92
|
+
>>> super().__init__(variational_strategy)
|
|
93
|
+
>>> self.mean_module = qpytorch.means.ConstantMean(batch_shape=torch.Size([3]))
|
|
94
|
+
>>> self.covar_module = qpytorch.kernels.ScaleKernel(
|
|
95
|
+
>>> qpytorch.kernels.RBFKernel(batch_shape=torch.Size([3])),
|
|
96
|
+
>>> batch_shape=torch.Size([3]),
|
|
97
|
+
>>> )
|
|
98
|
+
>>>
|
|
99
|
+
|
|
100
|
+
:param base_variational_strategy: Base variational strategy
|
|
101
|
+
:param num_tasks: The total number of tasks (output functions)
|
|
102
|
+
:param num_latents: The total number of latent functions in each group
|
|
103
|
+
:param latent_dim: (Default: -1) Which batch dimension corresponds to the latent function batch.
|
|
104
|
+
**Must be negative indexed**
|
|
105
|
+
:param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
base_variational_strategy: _VariationalStrategy,
|
|
111
|
+
num_tasks: int,
|
|
112
|
+
num_latents: int = 1,
|
|
113
|
+
latent_dim: int = -1,
|
|
114
|
+
jitter_val: Optional[float] = None,
|
|
115
|
+
):
|
|
116
|
+
Module.__init__(self)
|
|
117
|
+
self.base_variational_strategy = base_variational_strategy
|
|
118
|
+
self.num_tasks = num_tasks
|
|
119
|
+
batch_shape = self.base_variational_strategy._variational_distribution.batch_shape
|
|
120
|
+
|
|
121
|
+
# Check if no functions
|
|
122
|
+
if latent_dim >= 0:
|
|
123
|
+
raise RuntimeError(f"latent_dim must be a negative indexed batch dimension: got {latent_dim}.")
|
|
124
|
+
if not (batch_shape[latent_dim] == num_latents or batch_shape[latent_dim] == 1):
|
|
125
|
+
raise RuntimeError(
|
|
126
|
+
f"Mismatch in num_latents: got a variational distribution of batch shape {batch_shape}, "
|
|
127
|
+
f"expected the function dim {latent_dim} to be {num_latents}."
|
|
128
|
+
)
|
|
129
|
+
self.num_latents = num_latents
|
|
130
|
+
self.latent_dim = latent_dim
|
|
131
|
+
|
|
132
|
+
# Make the batch_shape
|
|
133
|
+
self.batch_shape = list(batch_shape)
|
|
134
|
+
del self.batch_shape[self.latent_dim]
|
|
135
|
+
self.batch_shape = torch.Size(self.batch_shape)
|
|
136
|
+
|
|
137
|
+
# LCM coefficients
|
|
138
|
+
lmc_coefficients = torch.randn(*batch_shape, self.num_tasks)
|
|
139
|
+
self.register_parameter("lmc_coefficients", torch.nn.Parameter(lmc_coefficients))
|
|
140
|
+
|
|
141
|
+
if jitter_val is None:
|
|
142
|
+
self.jitter_val = settings.variational_cholesky_jitter.value(
|
|
143
|
+
self.base_variational_strategy.inducing_points.dtype
|
|
144
|
+
)
|
|
145
|
+
else:
|
|
146
|
+
self.jitter_val = jitter_val
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def prior_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
150
|
+
return self.base_variational_strategy.prior_distribution
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def variational_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
|
|
154
|
+
return self.base_variational_strategy.variational_distribution
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def variational_params_initialized(self) -> bool:
|
|
158
|
+
return self.base_variational_strategy.variational_params_initialized
|
|
159
|
+
|
|
160
|
+
def kl_divergence(self) -> Tensor:
|
|
161
|
+
return super().kl_divergence().sum(dim=self.latent_dim)
|
|
162
|
+
|
|
163
|
+
def __call__(
|
|
164
|
+
self, x: Tensor, prior: bool = False, task_indices: Optional[LongTensor] = None, **kwargs
|
|
165
|
+
) -> Union[MultitaskMultivariateNormal, MultitaskMultivariateQExponential, MultivariateNormal, MultivariateQExponential]:
|
|
166
|
+
r"""
|
|
167
|
+
Computes the variational (or prior) distribution
|
|
168
|
+
:math:`q( \mathbf f \mid \mathbf X)` (or :math:`p( \mathbf f \mid \mathbf X)`).
|
|
169
|
+
There are two modes:
|
|
170
|
+
|
|
171
|
+
1. Compute **all tasks** for all inputs.
|
|
172
|
+
If this is the case, the task_indices attribute should be None.
|
|
173
|
+
The return type will be a (... x N x num_tasks)
|
|
174
|
+
:class:`~gpytorch.distributions.MultitaskMultivariateNormal` (:class:`~qpytorch.distributions.MultitaskMultivariateQExponential`).
|
|
175
|
+
2. Compute **one task** per inputs.
|
|
176
|
+
If this is the case, the (... x N) task_indices tensor should contain
|
|
177
|
+
the indices of each input's assigned task.
|
|
178
|
+
The return type will be a (... x N)
|
|
179
|
+
:class:`~gpytorch.distributions.MultivariateNormal` (:class:`~qpytorch.distributions.MultivariateQExponential`).
|
|
180
|
+
|
|
181
|
+
:param x: (... x N x D) Input locations to evaluate variational strategy
|
|
182
|
+
:param task_indices: (Default: None) Task index associated with each input.
|
|
183
|
+
If this **is not** provided, then the returned distribution evaluates every input on every task
|
|
184
|
+
(returns :class:`~gpytorch.distributions.MultitaskMultivariateNormal` or :class:`~qpytorch.distributions.MultitaskMultivariateQExponential`).
|
|
185
|
+
If this **is** provided, then the returned distribution evaluates each input only on its assigned task.
|
|
186
|
+
(returns :class:`~gpytorch.distributions.MultivariateNormal` or :class:`~qpytorch.distributions.MultivariateQExponential`).
|
|
187
|
+
:param prior: (Default: False) If False, returns the variational distribution
|
|
188
|
+
:math:`q( \mathbf f \mid \mathbf X)`.
|
|
189
|
+
If True, returns the prior distribution
|
|
190
|
+
:math:`p( \mathbf f \mid \mathbf X)`.
|
|
191
|
+
:return: :math:`q( \mathbf f \mid \mathbf X)` (or the prior),
|
|
192
|
+
either for all tasks (if `task_indices == None`)
|
|
193
|
+
or for a specific task (if `task_indices != None`).
|
|
194
|
+
:rtype: ~gpytorch.distributions.MultitaskMultivariateNormal (~qpytorch.distributions.MultitaskMultivariateQExponential) (... x N x num_tasks)
|
|
195
|
+
or ~gpytorch.distributions.MultivariateNormal (~qpytorch.distributions.MultivariateQExponential) (... x N)
|
|
196
|
+
"""
|
|
197
|
+
latent_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
|
|
198
|
+
num_batch = len(latent_dist.batch_shape)
|
|
199
|
+
latent_dim = num_batch + self.latent_dim
|
|
200
|
+
|
|
201
|
+
if task_indices is None:
|
|
202
|
+
num_dim = num_batch + len(latent_dist.event_shape)
|
|
203
|
+
|
|
204
|
+
# Every data point will get an output for each task
|
|
205
|
+
# Therefore, we will set up the lmc_coefficients shape for a matmul
|
|
206
|
+
lmc_coefficients = self.lmc_coefficients.expand(*latent_dist.batch_shape, self.lmc_coefficients.size(-1))
|
|
207
|
+
|
|
208
|
+
# Mean: ... x N x num_tasks
|
|
209
|
+
latent_mean = latent_dist.mean.permute(*range(0, latent_dim), *range(latent_dim + 1, num_dim), latent_dim)
|
|
210
|
+
mean = latent_mean @ lmc_coefficients.permute(
|
|
211
|
+
*range(0, latent_dim), *range(latent_dim + 1, num_dim - 1), latent_dim, -1
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Covar: ... x (N x num_tasks) x (N x num_tasks)
|
|
215
|
+
latent_covar = latent_dist.lazy_covariance_matrix
|
|
216
|
+
lmc_factor = RootLinearOperator(lmc_coefficients.unsqueeze(-1))
|
|
217
|
+
covar = KroneckerProductLinearOperator(latent_covar, lmc_factor).sum(latent_dim)
|
|
218
|
+
# Add a bit of jitter to make the covar PD
|
|
219
|
+
covar = covar.add_jitter(self.jitter_val)
|
|
220
|
+
|
|
221
|
+
# Done!
|
|
222
|
+
if isinstance(latent_dist, MultivariateNormal):
|
|
223
|
+
function_dist = MultitaskMultivariateNormal(mean, covar)
|
|
224
|
+
elif isinstance(latent_dist, MultivariateQExponential):
|
|
225
|
+
function_dist = MultitaskMultivariateQExponential(mean, covar, power=latent_dist.power)
|
|
226
|
+
|
|
227
|
+
else:
|
|
228
|
+
# Each data point will get a single output corresponding to a single task
|
|
229
|
+
# Therefore, we will select the appropriate lmc coefficients for each task
|
|
230
|
+
lmc_coefficients = _select_lmc_coefficients(self.lmc_coefficients, task_indices)
|
|
231
|
+
|
|
232
|
+
# Mean: ... x N
|
|
233
|
+
mean = (latent_dist.mean * lmc_coefficients).sum(latent_dim)
|
|
234
|
+
|
|
235
|
+
# Covar: ... x N x N
|
|
236
|
+
latent_covar = latent_dist.lazy_covariance_matrix
|
|
237
|
+
lmc_factor = RootLinearOperator(lmc_coefficients.unsqueeze(-1))
|
|
238
|
+
covar = (latent_covar * lmc_factor).sum(latent_dim)
|
|
239
|
+
# Add a bit of jitter to make the covar PD
|
|
240
|
+
covar = covar.add_jitter(self.jitter_val)
|
|
241
|
+
|
|
242
|
+
# Done!
|
|
243
|
+
if isinstance(latent_dist, MultivariateNormal):
|
|
244
|
+
function_dist = MultivariateNormal(mean, covar)
|
|
245
|
+
elif isinstance(latent_dist, MultivariateQExponential):
|
|
246
|
+
function_dist = MultivariateQExponential(mean, covar, power=latent_dist.power)
|
|
247
|
+
|
|
248
|
+
return function_dist
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from linear_operator.operators import DiagLinearOperator
|
|
5
|
+
|
|
6
|
+
from ..distributions import MultivariateNormal, MultivariateQExponential
|
|
7
|
+
from ._variational_distribution import _VariationalDistribution
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MeanFieldVariationalDistribution(_VariationalDistribution):
|
|
11
|
+
"""
|
|
12
|
+
A :obj:`~qpytorch.variational._VariationalDistribution` that is defined to be a multivariate normal (q-exponential) distribution
|
|
13
|
+
with a diagonal covariance matrix. This will not be as flexible/expressive as a
|
|
14
|
+
:obj:`~qpytorch.variational.CholeskyVariationalDistribution`.
|
|
15
|
+
|
|
16
|
+
:param int num_inducing_points: Size of the variational distribution. This implies that the variational mean
|
|
17
|
+
should be this size, and the variational covariance matrix should have this many rows and columns.
|
|
18
|
+
:param batch_shape: Specifies an optional batch size
|
|
19
|
+
for the variational parameters. This is useful for example when doing additive variational inference.
|
|
20
|
+
:type batch_shape: :obj:`torch.Size`, optional
|
|
21
|
+
:param float mean_init_std: (Default: 1e-3) Standard deviation of gaussian (q-exponential) noise to add to the mean initialization.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, num_inducing_points, batch_shape=torch.Size([]), mean_init_std=1e-3, **kwargs):
|
|
25
|
+
super().__init__(num_inducing_points=num_inducing_points, batch_shape=batch_shape, mean_init_std=mean_init_std)
|
|
26
|
+
mean_init = torch.zeros(num_inducing_points)
|
|
27
|
+
covar_init = torch.ones(num_inducing_points)
|
|
28
|
+
mean_init = mean_init.repeat(*batch_shape, 1)
|
|
29
|
+
covar_init = covar_init.repeat(*batch_shape, 1)
|
|
30
|
+
|
|
31
|
+
self.register_parameter(name="variational_mean", parameter=torch.nn.Parameter(mean_init))
|
|
32
|
+
self.register_parameter(name="_variational_stddev", parameter=torch.nn.Parameter(covar_init))
|
|
33
|
+
|
|
34
|
+
if 'power' in kwargs: self.power = kwargs.pop('power')
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def variational_stddev(self):
|
|
38
|
+
# TODO: if we don't multiply self._variational_stddev by a mask of one, Pyro models fail
|
|
39
|
+
# not sure where this bug is occuring (in Pyro or PyTorch)
|
|
40
|
+
# throwing this in as a hotfix for now - we should investigate later
|
|
41
|
+
mask = torch.ones_like(self._variational_stddev)
|
|
42
|
+
return self._variational_stddev.mul(mask).abs().clamp_min(1e-8)
|
|
43
|
+
|
|
44
|
+
def forward(self):
|
|
45
|
+
# TODO: if we don't multiply self._variational_stddev by a mask of one, Pyro models fail
|
|
46
|
+
# not sure where this bug is occuring (in Pyro or PyTorch)
|
|
47
|
+
# throwing this in as a hotfix for now - we should investigate later
|
|
48
|
+
mask = torch.ones_like(self._variational_stddev)
|
|
49
|
+
variational_covar = DiagLinearOperator(self._variational_stddev.mul(mask).pow(2))
|
|
50
|
+
if not hasattr(self, 'power'):
|
|
51
|
+
return MultivariateNormal(self.variational_mean, variational_covar)
|
|
52
|
+
else:
|
|
53
|
+
return MultivariateQExponential(self.variational_mean, variational_covar, power=self.power)
|
|
54
|
+
|
|
55
|
+
def initialize_variational_distribution(self, prior_dist):
|
|
56
|
+
self.variational_mean.data.copy_(prior_dist.mean)
|
|
57
|
+
self.variational_mean.data.add_(torch.randn_like(prior_dist.mean), alpha=self.mean_init_std)
|
|
58
|
+
self._variational_stddev.data.copy_(prior_dist.stddev)
|