qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from linear_operator import LinearOperator, to_linear_operator
|
|
5
|
+
from linear_operator.operators import (
|
|
6
|
+
BlockDiagLinearOperator,
|
|
7
|
+
BlockInterleavedLinearOperator,
|
|
8
|
+
CatLinearOperator,
|
|
9
|
+
DiagLinearOperator,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from .multivariate_qexponential import MultivariateQExponential
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MultitaskMultivariateQExponential(MultivariateQExponential):
|
|
16
|
+
"""
|
|
17
|
+
Constructs a multi-output multivariate Q-Exponential random variable, based on mean and covariance
|
|
18
|
+
Can be multi-output multivariate, or a batch of multi-output multivariate Q-Exponential
|
|
19
|
+
|
|
20
|
+
Passing a matrix mean corresponds to a multi-output multivariate Q-Exponential
|
|
21
|
+
Passing a matrix mean corresponds to a batch of multivariate Q-Exponentials
|
|
22
|
+
|
|
23
|
+
:param torch.Tensor mean: An `n x t` or batch `b x n x t` matrix of means for the QEP distribution.
|
|
24
|
+
:param ~linear_operator.operators.LinearOperator covar: An `... x NT x NT` (batch) matrix.
|
|
25
|
+
covariance matrix of QEP distribution.
|
|
26
|
+
:param power: (default=2.0) (scalar) power of QEP distribution.
|
|
27
|
+
:param bool validate_args: (default=False) If True, validate `mean` and `covariance_matrix` arguments.
|
|
28
|
+
:param bool interleaved: (default=True) If True, covariance matrix is interpreted as block-diagonal w.r.t.
|
|
29
|
+
inter-task covariances for each observation. If False, it is interpreted as block-diagonal
|
|
30
|
+
w.r.t. inter-observation covariance for each task.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, mean, covariance_matrix, power=torch.tensor(2.0), validate_args=False, interleaved=True):
|
|
34
|
+
if not torch.is_tensor(mean) and not isinstance(mean, LinearOperator):
|
|
35
|
+
raise RuntimeError("The mean of a MultitaskMultivariateQExponential must be a Tensor or LinearOperator")
|
|
36
|
+
|
|
37
|
+
if not torch.is_tensor(covariance_matrix) and not isinstance(covariance_matrix, LinearOperator):
|
|
38
|
+
raise RuntimeError("The covariance of a MultitaskMultivariateQExponential must be a Tensor or LinearOperator")
|
|
39
|
+
|
|
40
|
+
if mean.dim() < 2:
|
|
41
|
+
raise RuntimeError("mean should be a matrix or a batch matrix (batch mode)")
|
|
42
|
+
|
|
43
|
+
# Ensure that shapes are broadcasted appropriately across the mean and covariance
|
|
44
|
+
# Means can have singleton dimensions for either the `n` or `t` dimensions
|
|
45
|
+
batch_shape = torch.broadcast_shapes(mean.shape[:-2], covariance_matrix.shape[:-2])
|
|
46
|
+
if mean.shape[-2:].numel() != covariance_matrix.size(-1):
|
|
47
|
+
if covariance_matrix.size(-1) % mean.shape[-2:].numel():
|
|
48
|
+
raise RuntimeError(
|
|
49
|
+
f"mean shape {mean.shape} is incompatible with covariance shape {covariance_matrix.shape}"
|
|
50
|
+
)
|
|
51
|
+
elif mean.size(-2) == 1:
|
|
52
|
+
mean = mean.expand(*batch_shape, covariance_matrix.size(-1) // mean.size(-1), mean.size(-1))
|
|
53
|
+
elif mean.size(-1) == 1:
|
|
54
|
+
mean = mean.expand(*batch_shape, mean.size(-2), covariance_matrix.size(-2) // mean.size(-2))
|
|
55
|
+
else:
|
|
56
|
+
raise RuntimeError(
|
|
57
|
+
f"mean shape {mean.shape} is incompatible with covariance shape {covariance_matrix.shape}"
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
mean = mean.expand(*batch_shape, *mean.shape[-2:])
|
|
61
|
+
|
|
62
|
+
self._output_shape = mean.shape
|
|
63
|
+
# TODO: Instead of transpose / view operations, use a PermutationLinearOperator (see #539)
|
|
64
|
+
# to handle interleaving
|
|
65
|
+
self._interleaved = interleaved
|
|
66
|
+
if self._interleaved:
|
|
67
|
+
mean_qep = mean.reshape(*mean.shape[:-2], -1)
|
|
68
|
+
else:
|
|
69
|
+
mean_qep = mean.transpose(-1, -2).reshape(*mean.shape[:-2], -1)
|
|
70
|
+
super().__init__(mean=mean_qep, covariance_matrix=covariance_matrix, power=power, validate_args=validate_args)
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def base_sample_shape(self):
|
|
74
|
+
"""
|
|
75
|
+
Returns the shape of a base sample (without batching) that is used to
|
|
76
|
+
generate a single sample.
|
|
77
|
+
"""
|
|
78
|
+
base_sample_shape = self.event_shape
|
|
79
|
+
return base_sample_shape
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def event_shape(self):
|
|
83
|
+
return self._output_shape[-2:]
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
def from_batch_qep(cls, batch_qep, task_dim=-1):
|
|
87
|
+
"""
|
|
88
|
+
Reinterprate a batch of multivariate q-exponential distributions as an (uncorrelated) multitask multivariate q-exponential
|
|
89
|
+
distribution.
|
|
90
|
+
|
|
91
|
+
:param ~qpytorch.distributions.MultivariateQExponential batch_qep: The base QEP distribution.
|
|
92
|
+
(This distribution should have at least one batch dimension).
|
|
93
|
+
:param int task_dim: Which batch dimension should be interpreted as the dimension for the independent tasks.
|
|
94
|
+
:returns: the uncorrelated multitask distribution
|
|
95
|
+
:rtype: qpytorch.distributions.MultitaskMultivariateQExponential
|
|
96
|
+
|
|
97
|
+
Example:
|
|
98
|
+
>>> # model is a qpytorch.models.VariationalQEP
|
|
99
|
+
>>> # likelihood is a qpytorch.likelihoods.Likelihood
|
|
100
|
+
>>> mean = torch.randn(4, 2, 3)
|
|
101
|
+
>>> covar_factor = torch.randn(4, 2, 3, 3)
|
|
102
|
+
>>> covar = covar_factor @ covar_factor.transpose(-1, -2)
|
|
103
|
+
>>> power = torch.tensor(1.0)
|
|
104
|
+
>>> qep = qpytorch.distributions.MultivariateQExponential(mean, covar, power)
|
|
105
|
+
>>> print(qep.event_shape, qep.batch_shape)
|
|
106
|
+
>>> # torch.Size([3]), torch.Size([4, 2])
|
|
107
|
+
>>>
|
|
108
|
+
>>> mqep = MultitaskMultivariateQExponential.from_batch_qep(qep, task_dim=-1)
|
|
109
|
+
>>> print(mqep.event_shape, mqep.batch_shape)
|
|
110
|
+
>>> # torch.Size([3, 2]), torch.Size([4])
|
|
111
|
+
"""
|
|
112
|
+
orig_task_dim = task_dim
|
|
113
|
+
task_dim = task_dim if task_dim >= 0 else (len(batch_qep.batch_shape) + task_dim)
|
|
114
|
+
if task_dim < 0 or task_dim > len(batch_qep.batch_shape):
|
|
115
|
+
raise ValueError(
|
|
116
|
+
f"task_dim of {orig_task_dim} is incompatible with QEP batch shape of {batch_qep.batch_shape}"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
num_dim = batch_qep.mean.dim()
|
|
120
|
+
res = cls(
|
|
121
|
+
mean=batch_qep.mean.permute(*range(0, task_dim), *range(task_dim + 1, num_dim), task_dim),
|
|
122
|
+
covariance_matrix=BlockInterleavedLinearOperator(batch_qep.lazy_covariance_matrix, block_dim=task_dim),
|
|
123
|
+
power=batch_qep.power
|
|
124
|
+
)
|
|
125
|
+
return res
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def from_uncorrelated_qeps(cls, qeps):
|
|
129
|
+
"""
|
|
130
|
+
Convert an iterable of QEPs into a :obj:`~qpytorch.distributions.MultitaskMultivariateQExponential`.
|
|
131
|
+
The resulting distribution will have ``len(qeps)`` tasks, and the tasks will be uncorrelated.
|
|
132
|
+
|
|
133
|
+
:param ~qpytorch.distributions.MultivariateQExponential qep: The base QEP distributions.
|
|
134
|
+
:returns: the uncorrelated multitask distribution
|
|
135
|
+
:rtype: qpytorch.distributions.MultitaskMultivariateQExponential
|
|
136
|
+
|
|
137
|
+
Example:
|
|
138
|
+
>>> # model is a qpytorch.models.VariationalQEP
|
|
139
|
+
>>> # likelihood is a qpytorch.likelihoods.Likelihood
|
|
140
|
+
>>> mean = torch.randn(4, 3)
|
|
141
|
+
>>> covar_factor = torch.randn(4, 3, 3)
|
|
142
|
+
>>> covar = covar_factor @ covar_factor.transpose(-1, -2)
|
|
143
|
+
>>> power = torch.tensor(1.0)
|
|
144
|
+
>>> qep1 = qpytorch.distributions.MultivariateQExponential(mean, covar, power)
|
|
145
|
+
>>>
|
|
146
|
+
>>> mean = torch.randn(4, 3)
|
|
147
|
+
>>> covar_factor = torch.randn(4, 3, 3)
|
|
148
|
+
>>> covar = covar_factor @ covar_factor.transpose(-1, -2)
|
|
149
|
+
>>> qep2 = qpytorch.distributions.MultivariateQExponential(mean, covar, power)
|
|
150
|
+
>>>
|
|
151
|
+
>>> mqep = MultitaskMultivariateQExponential.from_uncorrelated_qeps([qep1, qep2])
|
|
152
|
+
>>> print(mqep.event_shape, mqep.batch_shape)
|
|
153
|
+
>>> # torch.Size([3, 2]), torch.Size([4])
|
|
154
|
+
"""
|
|
155
|
+
if len(qeps) < 2:
|
|
156
|
+
raise ValueError("Must provide at least 2 QEPs to form a MultitaskMultivariateQExponential")
|
|
157
|
+
if any(isinstance(qep, MultitaskMultivariateQExponential) for qep in qeps):
|
|
158
|
+
raise ValueError("Cannot accept MultitaskMultivariateQExponentials")
|
|
159
|
+
if not all(m.batch_shape == qeps[0].batch_shape for m in qeps[1:]):
|
|
160
|
+
batch_shape = torch.broadcast_shapes(*(m.batch_shape for m in qeps))
|
|
161
|
+
qeps = [qep.expand(batch_shape) for qep in qeps]
|
|
162
|
+
if not all(m.event_shape == qeps[0].event_shape for m in qeps[1:]):
|
|
163
|
+
raise ValueError("All MultivariateQExponentials must have the same event shape")
|
|
164
|
+
mean = torch.stack([qep.mean for qep in qeps], -1)
|
|
165
|
+
# TODO: To do the following efficiently, we don't want to evaluate the
|
|
166
|
+
# covariance matrices. Instead, we want to use the lazies directly in the
|
|
167
|
+
# BlockDiagLinearOperator. This will require implementing a new BatchLinearOperator:
|
|
168
|
+
|
|
169
|
+
# https://github.com/cornellius-gp/gpytorch/issues/468
|
|
170
|
+
covar_blocks_lazy = CatLinearOperator(
|
|
171
|
+
*[qep.lazy_covariance_matrix.unsqueeze(0) for qep in qeps], dim=0, output_device=mean.device
|
|
172
|
+
)
|
|
173
|
+
covar_lazy = BlockDiagLinearOperator(covar_blocks_lazy, block_dim=0)
|
|
174
|
+
return cls(mean=mean, covariance_matrix=covar_lazy, power=qeps[0].power, interleaved=False)
|
|
175
|
+
|
|
176
|
+
@classmethod
|
|
177
|
+
def from_repeated_qep(cls, qep, num_tasks):
|
|
178
|
+
"""
|
|
179
|
+
Convert a single QEP into a :obj:`~qpytorch.distributions.MultitaskMultivariateQExponential`,
|
|
180
|
+
where each task shares the same mean and covariance.
|
|
181
|
+
|
|
182
|
+
:param ~qpytorch.distributions.MultivariateQExponential qep: The base QEP distribution.
|
|
183
|
+
:param int num_tasks: How many tasks to create.
|
|
184
|
+
:returns: the uncorrelated multitask distribution
|
|
185
|
+
:rtype: qpytorch.distributions.MultitaskMultivariateQExponential
|
|
186
|
+
|
|
187
|
+
Example:
|
|
188
|
+
>>> # model is a qpytorch.models.VariationalQEP
|
|
189
|
+
>>> # likelihood is a qpytorch.likelihoods.Likelihood
|
|
190
|
+
>>> mean = torch.randn(4, 3)
|
|
191
|
+
>>> covar_factor = torch.randn(4, 3, 3)
|
|
192
|
+
>>> covar = covar_factor @ covar_factor.transpose(-1, -2)
|
|
193
|
+
>>> qep = qpytorch.distributions.MultivariateQExponential(mean, covar)
|
|
194
|
+
>>> print(qep.event_shape, qep.batch_shape)
|
|
195
|
+
>>> # torch.Size([3]), torch.Size([4])
|
|
196
|
+
>>>
|
|
197
|
+
>>> mqep = MultitaskMultivariateQExponential.from_repeated_qep(qep, num_tasks=2)
|
|
198
|
+
>>> print(mqep.event_shape, mqep.batch_shape)
|
|
199
|
+
>>> # torch.Size([3, 2]), torch.Size([4])
|
|
200
|
+
"""
|
|
201
|
+
return cls.from_batch_qep(qep.expand(torch.Size([num_tasks]) + qep.batch_shape), task_dim=0)
|
|
202
|
+
|
|
203
|
+
def expand(self, batch_size):
|
|
204
|
+
new_mean = self.mean.expand(torch.Size(batch_size) + self.mean.shape[-2:])
|
|
205
|
+
new_covar = self._covar.expand(torch.Size(batch_size) + self._covar.shape[-2:])
|
|
206
|
+
res = self.__class__(new_mean, new_covar, power=self.power, interleaved=self._interleaved)
|
|
207
|
+
return res
|
|
208
|
+
|
|
209
|
+
def get_base_samples(self, sample_shape=torch.Size(), **kwargs):
|
|
210
|
+
base_samples = super().get_base_samples(sample_shape, **kwargs)
|
|
211
|
+
if not self._interleaved:
|
|
212
|
+
# flip shape of last two dimensions
|
|
213
|
+
new_shape = sample_shape + self._output_shape[:-2] + self._output_shape[:-3:-1]
|
|
214
|
+
return base_samples.view(new_shape).transpose(-1, -2).contiguous()
|
|
215
|
+
return base_samples.view(*sample_shape, *self._output_shape)
|
|
216
|
+
|
|
217
|
+
def log_prob(self, value):
|
|
218
|
+
if not self._interleaved:
|
|
219
|
+
# flip shape of last two dimensions
|
|
220
|
+
new_shape = value.shape[:-2] + value.shape[:-3:-1]
|
|
221
|
+
value = value.view(new_shape).transpose(-1, -2).contiguous()
|
|
222
|
+
return super().log_prob(value.reshape(*value.shape[:-2], -1))
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def mean(self):
|
|
226
|
+
mean = super().mean
|
|
227
|
+
if not self._interleaved:
|
|
228
|
+
# flip shape of last two dimensions
|
|
229
|
+
new_shape = self._output_shape[:-2] + self._output_shape[:-3:-1]
|
|
230
|
+
return mean.view(new_shape).transpose(-1, -2).contiguous()
|
|
231
|
+
return mean.view(self._output_shape)
|
|
232
|
+
|
|
233
|
+
@property
|
|
234
|
+
def num_tasks(self):
|
|
235
|
+
return self._output_shape[-1]
|
|
236
|
+
|
|
237
|
+
def rsample(self, sample_shape=torch.Size(), base_samples=None, **kwargs):
|
|
238
|
+
if base_samples is not None:
|
|
239
|
+
# Make sure that the base samples agree with the distribution
|
|
240
|
+
mean_shape = self.mean.shape
|
|
241
|
+
base_sample_shape = base_samples.shape[-self.mean.ndimension() :]
|
|
242
|
+
if mean_shape != base_sample_shape:
|
|
243
|
+
raise RuntimeError(
|
|
244
|
+
"The shape of base_samples (minus sample shape dimensions) should agree with the shape "
|
|
245
|
+
"of self.mean. Expected ...{} but got {}".format(mean_shape, base_sample_shape)
|
|
246
|
+
)
|
|
247
|
+
sample_shape = base_samples.shape[: -self.mean.ndimension()]
|
|
248
|
+
base_samples = base_samples.view(*sample_shape, *self.loc.shape)
|
|
249
|
+
|
|
250
|
+
samples = super().rsample(sample_shape=sample_shape, base_samples=base_samples, **kwargs)
|
|
251
|
+
if not self._interleaved:
|
|
252
|
+
# flip shape of last two dimensions
|
|
253
|
+
new_shape = sample_shape + self._output_shape[:-2] + self._output_shape[:-3:-1]
|
|
254
|
+
return samples.view(new_shape).transpose(-1, -2).contiguous()
|
|
255
|
+
return samples.view(sample_shape + self._output_shape)
|
|
256
|
+
|
|
257
|
+
def to_data_uncorrelated_dist(self, jitter_val=1e-4):
|
|
258
|
+
"""
|
|
259
|
+
Convert a multitask QEP into a batched (non-multitask) QEPs
|
|
260
|
+
The result retains the intertask covariances, but gets rid of the inter-data covariances.
|
|
261
|
+
The resulting distribution will have ``len(qeps)`` tasks, and the tasks will be uncorrelated.
|
|
262
|
+
|
|
263
|
+
:returns: the bached data-uncorrelated QEP
|
|
264
|
+
:rtype: qpytorch.distributions.MultivariateQExponential
|
|
265
|
+
"""
|
|
266
|
+
# Create batch distribution where all data are independent, but the tasks are dependent
|
|
267
|
+
full_covar = self.lazy_covariance_matrix
|
|
268
|
+
num_data, num_tasks = self.mean.shape[-2:]
|
|
269
|
+
if self._interleaved:
|
|
270
|
+
data_indices = torch.arange(0, num_data * num_tasks, num_tasks, device=full_covar.device).view(-1, 1, 1)
|
|
271
|
+
task_indices = torch.arange(num_tasks, device=full_covar.device)
|
|
272
|
+
else:
|
|
273
|
+
data_indices = torch.arange(num_data, device=full_covar.device).view(-1, 1, 1)
|
|
274
|
+
task_indices = torch.arange(0, num_data * num_tasks, num_data, device=full_covar.device)
|
|
275
|
+
task_covars = full_covar[
|
|
276
|
+
..., data_indices + task_indices.unsqueeze(-2), data_indices + task_indices.unsqueeze(-1)
|
|
277
|
+
]
|
|
278
|
+
return MultivariateQExponential(self.mean, to_linear_operator(task_covars).add_jitter(jitter_val=jitter_val), self.power)
|
|
279
|
+
|
|
280
|
+
# to_data_independent_dist = to_data_uncorrelated_dist # alias to the same function with a more appropriate name
|
|
281
|
+
|
|
282
|
+
@property
|
|
283
|
+
def variance(self):
|
|
284
|
+
var = super().variance
|
|
285
|
+
if not self._interleaved:
|
|
286
|
+
# flip shape of last two dimensions
|
|
287
|
+
new_shape = self._output_shape[:-2] + self._output_shape[:-3:-1]
|
|
288
|
+
return var.view(new_shape).transpose(-1, -2).contiguous()
|
|
289
|
+
return var.view(self._output_shape)
|
|
290
|
+
|
|
291
|
+
def __getitem__(self, idx) -> MultivariateQExponential:
|
|
292
|
+
"""
|
|
293
|
+
Constructs a new MultivariateQExponential that represents a random variable
|
|
294
|
+
modified by an indexing operation.
|
|
295
|
+
|
|
296
|
+
The mean and covariance matrix arguments are indexed accordingly.
|
|
297
|
+
|
|
298
|
+
:param Any idx: Index to apply to the mean. The covariance matrix is indexed accordingly.
|
|
299
|
+
:returns: If indices specify a slice for samples and tasks, returns a
|
|
300
|
+
MultitaskMultivariateQExponential, else returns a MultivariateQExponential.
|
|
301
|
+
"""
|
|
302
|
+
|
|
303
|
+
# Normalize index to a tuple
|
|
304
|
+
if not isinstance(idx, tuple):
|
|
305
|
+
idx = (idx,)
|
|
306
|
+
|
|
307
|
+
if ... in idx:
|
|
308
|
+
# Replace ellipsis '...' with explicit indices
|
|
309
|
+
ellipsis_location = idx.index(...)
|
|
310
|
+
if ... in idx[ellipsis_location + 1 :]:
|
|
311
|
+
raise IndexError("Only one ellipsis '...' is supported!")
|
|
312
|
+
prefix = idx[:ellipsis_location]
|
|
313
|
+
suffix = idx[ellipsis_location + 1 :]
|
|
314
|
+
infix_length = self.mean.dim() - len(prefix) - len(suffix)
|
|
315
|
+
if infix_length < 0:
|
|
316
|
+
raise IndexError(f"Index {idx} has too many dimensions")
|
|
317
|
+
idx = prefix + (slice(None),) * infix_length + suffix
|
|
318
|
+
elif len(idx) == self.mean.dim() - 1:
|
|
319
|
+
# Normalize indices ignoring the task-index to include it
|
|
320
|
+
idx = idx + (slice(None),)
|
|
321
|
+
|
|
322
|
+
new_mean = self.mean[idx]
|
|
323
|
+
|
|
324
|
+
# We now create a covariance matrix appropriate for new_mean
|
|
325
|
+
if len(idx) <= self.mean.dim() - 2:
|
|
326
|
+
# We are only indexing the batch dimensions in this case
|
|
327
|
+
return MultitaskMultivariateQExponential(
|
|
328
|
+
mean=new_mean,
|
|
329
|
+
covariance_matrix=self.lazy_covariance_matrix[idx],
|
|
330
|
+
power=self.power,
|
|
331
|
+
interleaved=self._interleaved,
|
|
332
|
+
)
|
|
333
|
+
elif len(idx) > self.mean.dim():
|
|
334
|
+
raise IndexError(f"Index {idx} has too many dimensions")
|
|
335
|
+
else:
|
|
336
|
+
# We have an index that extends over all dimensions
|
|
337
|
+
batch_idx = idx[:-2]
|
|
338
|
+
if self._interleaved:
|
|
339
|
+
row_idx = idx[-2]
|
|
340
|
+
col_idx = idx[-1]
|
|
341
|
+
num_rows = self._output_shape[-2]
|
|
342
|
+
num_cols = self._output_shape[-1]
|
|
343
|
+
else:
|
|
344
|
+
row_idx = idx[-1]
|
|
345
|
+
col_idx = idx[-2]
|
|
346
|
+
num_rows = self._output_shape[-1]
|
|
347
|
+
num_cols = self._output_shape[-2]
|
|
348
|
+
|
|
349
|
+
if isinstance(row_idx, int) and isinstance(col_idx, int):
|
|
350
|
+
# Single sample with single task
|
|
351
|
+
row_idx = _normalize_index(row_idx, num_rows)
|
|
352
|
+
col_idx = _normalize_index(col_idx, num_cols)
|
|
353
|
+
new_cov = DiagLinearOperator(
|
|
354
|
+
self.lazy_covariance_matrix.diagonal()[batch_idx + (row_idx * num_cols + col_idx,)]
|
|
355
|
+
)
|
|
356
|
+
return MultivariateQExponential(mean=new_mean, covariance_matrix=new_cov, power=self.power)
|
|
357
|
+
elif isinstance(row_idx, int) and isinstance(col_idx, slice):
|
|
358
|
+
# A block of the covariance matrix
|
|
359
|
+
row_idx = _normalize_index(row_idx, num_rows)
|
|
360
|
+
col_idx = _normalize_slice(col_idx, num_cols)
|
|
361
|
+
new_slice = slice(
|
|
362
|
+
col_idx.start + row_idx * num_cols,
|
|
363
|
+
col_idx.stop + row_idx * num_cols,
|
|
364
|
+
col_idx.step,
|
|
365
|
+
)
|
|
366
|
+
new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)]
|
|
367
|
+
return MultivariateQExponential(mean=new_mean, covariance_matrix=new_cov, power=self.power)
|
|
368
|
+
elif isinstance(row_idx, slice) and isinstance(col_idx, int):
|
|
369
|
+
# A block of the reversely interleaved covariance matrix
|
|
370
|
+
row_idx = _normalize_slice(row_idx, num_rows)
|
|
371
|
+
col_idx = _normalize_index(col_idx, num_cols)
|
|
372
|
+
new_slice = slice(row_idx.start + col_idx, row_idx.stop * num_cols + col_idx, row_idx.step * num_cols)
|
|
373
|
+
new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)]
|
|
374
|
+
return MultivariateQExponential(mean=new_mean, covariance_matrix=new_cov, power=self.power)
|
|
375
|
+
elif (
|
|
376
|
+
isinstance(row_idx, slice)
|
|
377
|
+
and isinstance(col_idx, slice)
|
|
378
|
+
and row_idx == col_idx == slice(None, None, None)
|
|
379
|
+
):
|
|
380
|
+
new_cov = self.lazy_covariance_matrix[batch_idx]
|
|
381
|
+
return MultitaskMultivariateQExponential(
|
|
382
|
+
mean=new_mean,
|
|
383
|
+
covariance_matrix=new_cov,
|
|
384
|
+
power=self.power,
|
|
385
|
+
interleaved=self._interleaved,
|
|
386
|
+
validate_args=False,
|
|
387
|
+
)
|
|
388
|
+
elif isinstance(row_idx, slice) or isinstance(col_idx, slice):
|
|
389
|
+
# slice x slice or indices x slice or slice x indices
|
|
390
|
+
if isinstance(row_idx, slice):
|
|
391
|
+
row_idx = torch.arange(num_rows)[row_idx]
|
|
392
|
+
if isinstance(col_idx, slice):
|
|
393
|
+
col_idx = torch.arange(num_cols)[col_idx]
|
|
394
|
+
row_grid, col_grid = torch.meshgrid(row_idx, col_idx, indexing="ij")
|
|
395
|
+
indices = (row_grid * num_cols + col_grid).reshape(-1)
|
|
396
|
+
new_cov = self.lazy_covariance_matrix[batch_idx + (indices,)][..., indices]
|
|
397
|
+
return MultitaskMultivariateQExponential(
|
|
398
|
+
mean=new_mean, covariance_matrix=new_cov, power=self.power, interleaved=self._interleaved, validate_args=False
|
|
399
|
+
)
|
|
400
|
+
else:
|
|
401
|
+
# row_idx and col_idx have pairs of indices
|
|
402
|
+
indices = row_idx * num_cols + col_idx
|
|
403
|
+
new_cov = self.lazy_covariance_matrix[batch_idx + (indices,)][..., indices]
|
|
404
|
+
return MultivariateQExponential(
|
|
405
|
+
mean=new_mean,
|
|
406
|
+
covariance_matrix=new_cov,
|
|
407
|
+
power=self.power
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
def __repr__(self) -> str:
|
|
411
|
+
return f"MultitaskMultivariateQExponential(mean shape: {self._output_shape})"
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def _normalize_index(i: int, dim_size: int) -> int:
|
|
415
|
+
if i < 0:
|
|
416
|
+
return dim_size + i
|
|
417
|
+
else:
|
|
418
|
+
return i
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def _normalize_slice(s: slice, dim_size: int) -> slice:
|
|
422
|
+
start = s.start
|
|
423
|
+
if start is None:
|
|
424
|
+
start = 0
|
|
425
|
+
elif start < 0:
|
|
426
|
+
start = dim_size + start
|
|
427
|
+
stop = s.stop
|
|
428
|
+
if stop is None:
|
|
429
|
+
stop = dim_size
|
|
430
|
+
elif stop < 0:
|
|
431
|
+
stop = dim_size + stop
|
|
432
|
+
step = s.step
|
|
433
|
+
if step is None:
|
|
434
|
+
step = 1
|
|
435
|
+
return slice(start, stop, step)
|