qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,581 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
import warnings
|
|
7
|
+
from numbers import Number
|
|
8
|
+
from typing import Optional, Tuple, Union
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from linear_operator import to_dense, to_linear_operator
|
|
12
|
+
from linear_operator.operators import DiagLinearOperator, LinearOperator, RootLinearOperator
|
|
13
|
+
from torch import Tensor
|
|
14
|
+
from torch.distributions import MultivariateNormal as TMultivariateNormal, Chi2
|
|
15
|
+
from torch.distributions.kl import register_kl
|
|
16
|
+
from torch.distributions.utils import _standard_normal, lazy_property
|
|
17
|
+
|
|
18
|
+
from .. import settings
|
|
19
|
+
from ..utils.warnings import NumericalWarning
|
|
20
|
+
from gpytorch.distributions.distribution import Distribution
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MultivariateQExponential(TMultivariateNormal, Distribution):
|
|
24
|
+
"""
|
|
25
|
+
Constructs a multivariate q-exponential random variable, based on mean and covariance, whose density is
|
|
26
|
+
|
|
27
|
+
.. math::
|
|
28
|
+
|
|
29
|
+
p(x; \\mu, C) = \\frac{q}{2} (2\\pi)^{-\\frac{N}{2}} |C|^{-\\frac{1}{2}}
|
|
30
|
+
r^{\\left(\\frac{q}{2}-1\\right)\\frac{N}{2}} \\exp\\left\\{ -0.5 * r^{\\frac{q}{2}} \\right\\}, \\quad
|
|
31
|
+
r(x) = (x - \\mu)^T C^{-1} (x - \\mu).
|
|
32
|
+
|
|
33
|
+
The result can be multivariate, or a batch of multivariate q-exponentials.
|
|
34
|
+
Passing a vector mean corresponds to a multivariate q-exponential.
|
|
35
|
+
Passing a matrix mean corresponds to a batch of multivariate q-exponentials.
|
|
36
|
+
|
|
37
|
+
:param mean: `... x N` mean of qep distribution.
|
|
38
|
+
:param covariance_matrix: `... x N X N` covariance matrix of qep distribution.
|
|
39
|
+
:param power: (scalar) power of qep distribution. (Default: 2.)
|
|
40
|
+
:param validate_args: If True, validate `mean` and `covariance_matrix` arguments. (Default: False.)
|
|
41
|
+
|
|
42
|
+
:ivar torch.Size base_sample_shape: The shape of a base sample (without
|
|
43
|
+
batching) that is used to generate a single sample.
|
|
44
|
+
:ivar torch.Tensor covariance_matrix: The covariance matrix, represented as a dense :class:`torch.Tensor`
|
|
45
|
+
:ivar ~linear_operator.LinearOperator lazy_covariance_matrix: The covariance matrix, represented
|
|
46
|
+
as a :class:`~linear_operator.LinearOperator`.
|
|
47
|
+
:ivar torch.Tensor mean: The mean.
|
|
48
|
+
:ivar torch.Tensor stddev: The standard deviation.
|
|
49
|
+
:ivar torch.Tensor variance: The variance.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator], power: Tensor = torch.tensor(2.0), validate_args: bool = False):
|
|
53
|
+
self._islazy = isinstance(mean, LinearOperator) or isinstance(covariance_matrix, LinearOperator)
|
|
54
|
+
if self._islazy:
|
|
55
|
+
if validate_args:
|
|
56
|
+
ms = mean.size(-1)
|
|
57
|
+
cs1 = covariance_matrix.size(-1)
|
|
58
|
+
cs2 = covariance_matrix.size(-2)
|
|
59
|
+
if not (ms == cs1 and ms == cs2):
|
|
60
|
+
raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}")
|
|
61
|
+
self.loc = mean
|
|
62
|
+
self._covar = covariance_matrix
|
|
63
|
+
self.__unbroadcasted_scale_tril = None
|
|
64
|
+
self._validate_args = validate_args
|
|
65
|
+
batch_shape = torch.broadcast_shapes(self.loc.shape[:-1], covariance_matrix.shape[:-2])
|
|
66
|
+
|
|
67
|
+
event_shape = self.loc.shape[-1:]
|
|
68
|
+
|
|
69
|
+
# TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
|
|
70
|
+
super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False)
|
|
71
|
+
else:
|
|
72
|
+
super().__init__(loc=mean, covariance_matrix=covariance_matrix, validate_args=validate_args)
|
|
73
|
+
self.power = power
|
|
74
|
+
|
|
75
|
+
def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size:
|
|
76
|
+
"""
|
|
77
|
+
Returns the size of the sample returned by the distribution, given
|
|
78
|
+
a `sample_shape`. Note, that the batch and event shapes of a distribution
|
|
79
|
+
instance are fixed at the time of construction. If this is empty, the
|
|
80
|
+
returned shape is upcast to (1,).
|
|
81
|
+
|
|
82
|
+
:param sample_shape: the size of the sample to be drawn.
|
|
83
|
+
"""
|
|
84
|
+
if not isinstance(sample_shape, torch.Size):
|
|
85
|
+
sample_shape = torch.Size(sample_shape)
|
|
86
|
+
return sample_shape + self._batch_shape + self.base_sample_shape
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def _repr_sizes(mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator], power: Tensor = torch.tensor(2.0)) -> str:
|
|
90
|
+
return f"MultivariateQExponential(loc: {mean.size()}, scale: {covariance_matrix.size()}, pow: {power.size()})"
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def _unbroadcasted_scale_tril(self) -> Tensor:
|
|
94
|
+
if self.islazy and self.__unbroadcasted_scale_tril is None:
|
|
95
|
+
# cache root decoposition
|
|
96
|
+
ust = to_dense(self.lazy_covariance_matrix.cholesky())
|
|
97
|
+
self.__unbroadcasted_scale_tril = ust
|
|
98
|
+
return self.__unbroadcasted_scale_tril
|
|
99
|
+
|
|
100
|
+
@_unbroadcasted_scale_tril.setter
|
|
101
|
+
def _unbroadcasted_scale_tril(self, ust: Tensor):
|
|
102
|
+
if self.islazy:
|
|
103
|
+
raise NotImplementedError("Cannot set _unbroadcasted_scale_tril for lazy QEP distributions")
|
|
104
|
+
else:
|
|
105
|
+
self.__unbroadcasted_scale_tril = ust
|
|
106
|
+
|
|
107
|
+
def add_jitter(self, noise: float = 1e-4) -> MultivariateQExponential:
|
|
108
|
+
r"""
|
|
109
|
+
Adds a small constant diagonal to the QEP covariance matrix for numerical stability.
|
|
110
|
+
|
|
111
|
+
:param noise: The size of the constant diagonal.
|
|
112
|
+
"""
|
|
113
|
+
return self.__class__(self.mean, self.lazy_covariance_matrix.add_jitter(noise), self.power)
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def base_sample_shape(self) -> torch.Size:
|
|
117
|
+
base_sample_shape = self.event_shape
|
|
118
|
+
if isinstance(self.lazy_covariance_matrix, RootLinearOperator):
|
|
119
|
+
base_sample_shape = self.lazy_covariance_matrix.root.shape[-1:]
|
|
120
|
+
|
|
121
|
+
return base_sample_shape
|
|
122
|
+
|
|
123
|
+
@lazy_property
|
|
124
|
+
def covariance_matrix(self) -> Tensor:
|
|
125
|
+
if self.islazy:
|
|
126
|
+
return self._covar.to_dense()
|
|
127
|
+
else:
|
|
128
|
+
return super().covariance_matrix
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def rescalor(self) -> Tensor:
|
|
132
|
+
n = self.event_shape[0]
|
|
133
|
+
return torch.exp((2./self.power*math.log(2) - math.log(n) + torch.lgamma(n/2.+2./self.power) - math.lgamma(n/2.))/2.)
|
|
134
|
+
|
|
135
|
+
def confidence_region(self, rescale=False) -> Tuple[Tensor, Tensor]:
|
|
136
|
+
"""
|
|
137
|
+
Returns 2 standard deviations above and below the mean.
|
|
138
|
+
|
|
139
|
+
:return: Pair of tensors of size `... x N`, where N is the
|
|
140
|
+
dimensionality of the random variable. The first (second) Tensor is the
|
|
141
|
+
lower (upper) end of the confidence region.
|
|
142
|
+
"""
|
|
143
|
+
std2 = self.stddev.mul(2).mul(self.rescalor if rescale else 1)
|
|
144
|
+
mean = self.mean
|
|
145
|
+
return mean.sub(std2), mean.add(std2)
|
|
146
|
+
|
|
147
|
+
def expand(self, batch_size: torch.Size) -> MultivariateQExponential:
|
|
148
|
+
r"""
|
|
149
|
+
See :py:meth:`torch.distributions.Distribution.expand
|
|
150
|
+
<torch.distributions.distribution.Distribution.expand>`.
|
|
151
|
+
"""
|
|
152
|
+
# NOTE: Pyro may call this method with list[int] instead of torch.Size.
|
|
153
|
+
batch_size = torch.Size(batch_size)
|
|
154
|
+
new_loc = self.loc.expand(batch_size + self.loc.shape[-1:])
|
|
155
|
+
if self.islazy:
|
|
156
|
+
new_covar = self._covar.expand(batch_size + self._covar.shape[-2:])
|
|
157
|
+
new = self.__class__(mean=new_loc, covariance_matrix=new_covar, power=self.power)
|
|
158
|
+
if self.__unbroadcasted_scale_tril is not None:
|
|
159
|
+
# Reuse the scale tril if available.
|
|
160
|
+
new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.expand(
|
|
161
|
+
batch_size + self.__unbroadcasted_scale_tril.shape[-2:]
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
# Non-lazy QEP is represented using scale_tril in PyTorch.
|
|
165
|
+
# Constructing it from scale_tril will avoid unnecessary computation.
|
|
166
|
+
# Initialize using __new__, so that we can skip __init__ and use scale_tril.
|
|
167
|
+
new = self.__new__(type(self))
|
|
168
|
+
new._islazy = False
|
|
169
|
+
new_scale_tril = self.__unbroadcasted_scale_tril.expand(
|
|
170
|
+
batch_size + self.__unbroadcasted_scale_tril.shape[-2:]
|
|
171
|
+
)
|
|
172
|
+
super(MultivariateQExponential, new).__init__(loc=new_loc, scale_tril=new_scale_tril)
|
|
173
|
+
new.power = self.power
|
|
174
|
+
# Set the covar matrix, since it is always available for QPyTorch QEP.
|
|
175
|
+
new.covariance_matrix = self.covariance_matrix.expand(batch_size + self.covariance_matrix.shape[-2:])
|
|
176
|
+
return new
|
|
177
|
+
|
|
178
|
+
def unsqueeze(self, dim: int) -> MultivariateQExponential:
|
|
179
|
+
r"""
|
|
180
|
+
Constructs a new MultivariateQExponential with the batch shape unsqueezed
|
|
181
|
+
by the given dimension.
|
|
182
|
+
For example, if `self.batch_shape = torch.Size([2, 3])` and `dim = 0`, then
|
|
183
|
+
the returned MultivariateQExponential will have `batch_shape = torch.Size([1, 2, 3])`.
|
|
184
|
+
If `dim = -1`, then the returned MultivariateQExponential will have
|
|
185
|
+
`batch_shape = torch.Size([2, 3, 1])`.
|
|
186
|
+
"""
|
|
187
|
+
if dim > len(self.batch_shape) or dim < -len(self.batch_shape) - 1:
|
|
188
|
+
raise IndexError(
|
|
189
|
+
"Dimension out of range (expected to be in range of "
|
|
190
|
+
f"[{-len(self.batch_shape) - 1}, {len(self.batch_shape)}], but got {dim})."
|
|
191
|
+
)
|
|
192
|
+
if dim < 0:
|
|
193
|
+
# If dim is negative, get the positive equivalent.
|
|
194
|
+
dim = len(self.batch_shape) + dim + 1
|
|
195
|
+
|
|
196
|
+
new_loc = self.loc.unsqueeze(dim)
|
|
197
|
+
if self.islazy:
|
|
198
|
+
new_covar = self._covar.unsqueeze(dim)
|
|
199
|
+
new = self.__class__(mean=new_loc, covariance_matrix=new_covar, power=self.power)
|
|
200
|
+
if self.__unbroadcasted_scale_tril is not None:
|
|
201
|
+
# Reuse the scale tril if available.
|
|
202
|
+
new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
|
|
203
|
+
else:
|
|
204
|
+
# Non-lazy QEP is represented using scale_tril in PyTorch.
|
|
205
|
+
# Constructing it from scale_tril will avoid unnecessary computation.
|
|
206
|
+
# Initialize using __new__, so that we can skip __init__ and use scale_tril.
|
|
207
|
+
new = self.__new__(type(self))
|
|
208
|
+
new._islazy = False
|
|
209
|
+
new_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
|
|
210
|
+
super(MultivariateQExponential, new).__init__(loc=new_loc, scale_tril=new_scale_tril)
|
|
211
|
+
new.power = self.power
|
|
212
|
+
# Set the covar matrix, since it is always available for QPyTorch QEP.
|
|
213
|
+
new.covariance_matrix = self.covariance_matrix.unsqueeze(dim)
|
|
214
|
+
return new
|
|
215
|
+
|
|
216
|
+
def get_base_samples(self, sample_shape: torch.Size = torch.Size(), rescale = False) -> Tensor:
|
|
217
|
+
r"""
|
|
218
|
+
Returns marginally identical but uncorrelated (m.i.u.) standard Q-Exponential samples to be used with
|
|
219
|
+
:py:meth:`MultivariateQExponential.rsample(base_samples=base_samples)
|
|
220
|
+
<qpytorch.distributions.MultivariateQExponential.rsample>`.
|
|
221
|
+
|
|
222
|
+
:param sample_shape: The number of samples to generate. (Default: `torch.Size([])`.)
|
|
223
|
+
:return: A `*sample_shape x *batch_shape x N` tensor of m.i.u. standard Q-Exponential samples.
|
|
224
|
+
"""
|
|
225
|
+
with torch.no_grad():
|
|
226
|
+
shape = self._extended_shape(sample_shape)
|
|
227
|
+
base_samples = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
|
|
228
|
+
if self.power!=2: base_samples = torch.nn.functional.normalize(base_samples, dim=-1)*Chi2(shape[-1]).sample(shape[:-1]+torch.Size([1])).to(self.loc.device)**(1./self.power)
|
|
229
|
+
if rescale: base_samples /= self.rescalor
|
|
230
|
+
return base_samples
|
|
231
|
+
|
|
232
|
+
@lazy_property
|
|
233
|
+
def lazy_covariance_matrix(self) -> LinearOperator:
|
|
234
|
+
if self.islazy:
|
|
235
|
+
return self._covar
|
|
236
|
+
else:
|
|
237
|
+
return to_linear_operator(super().covariance_matrix)
|
|
238
|
+
|
|
239
|
+
def log_prob(self, value: Tensor) -> Tensor:
|
|
240
|
+
r"""
|
|
241
|
+
See :py:meth:`torch.distributions.Distribution.log_prob
|
|
242
|
+
<torch.distributions.distribution.Distribution.log_prob>`.
|
|
243
|
+
"""
|
|
244
|
+
if settings.fast_computations.log_prob.off():
|
|
245
|
+
return super().log_prob(value)
|
|
246
|
+
|
|
247
|
+
if self._validate_args:
|
|
248
|
+
self._validate_sample(value)
|
|
249
|
+
|
|
250
|
+
mean, covar, power = self.loc, self.lazy_covariance_matrix, self.power
|
|
251
|
+
diff = value - mean
|
|
252
|
+
|
|
253
|
+
# Repeat the covar to match the batch shape of diff
|
|
254
|
+
if diff.shape[:-1] != covar.batch_shape:
|
|
255
|
+
if len(diff.shape[:-1]) < len(covar.batch_shape):
|
|
256
|
+
diff = diff.expand(covar.shape[:-1])
|
|
257
|
+
else:
|
|
258
|
+
padded_batch_shape = (*(1 for _ in range(diff.dim() + 1 - covar.dim())), *covar.batch_shape)
|
|
259
|
+
covar = covar.repeat(
|
|
260
|
+
*(diff_size // covar_size for diff_size, covar_size in zip(diff.shape[:-1], padded_batch_shape)),
|
|
261
|
+
1,
|
|
262
|
+
1,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# Get log determininant and first part of quadratic form
|
|
266
|
+
covar = covar.evaluate_kernel()
|
|
267
|
+
inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
|
|
268
|
+
|
|
269
|
+
res = -0.5 * sum([inv_quad**(power/2.), logdet, diff.size(-1) * math.log(2 * math.pi)])
|
|
270
|
+
if power!=2: res += sum([0.5 * diff.size(-1) * (power/2.-1) * torch.log(inv_quad), torch.log(power/2.)])
|
|
271
|
+
return res
|
|
272
|
+
|
|
273
|
+
def entropy(self, exact: bool = False) -> Tensor:
|
|
274
|
+
r"""
|
|
275
|
+
See :py:meth:`torch.distributions.Distribution.entropy
|
|
276
|
+
<torch.distributions.distribution.Distribution.entropy>`.
|
|
277
|
+
"""
|
|
278
|
+
d = self._event_shape[0] #self.loc.shape[-1]
|
|
279
|
+
if self.islazy:
|
|
280
|
+
logdet = self.lazy_covariance_matrix.logdet()
|
|
281
|
+
res = 0.5 * sum([d*math.log(2*math.pi), logdet, d**(1 if exact else self.power/2.)])
|
|
282
|
+
else:
|
|
283
|
+
res = super().entropy()
|
|
284
|
+
if not exact: res += 0.5*(-d + d**(self.power/2.))
|
|
285
|
+
if self.power!=2:
|
|
286
|
+
res += sum([d/2.*(self.power/2.-1) *(2./self.power* Chi2(d).entropy() if exact else -math.log(d)), -torch.log(self.power/2.)])
|
|
287
|
+
return res
|
|
288
|
+
|
|
289
|
+
def zero_mean_qep_samples(self, op: LinearOperator, num_samples: int, **kwargs) -> Tensor:
|
|
290
|
+
r"""
|
|
291
|
+
Assumes that the LinearOpeator :math:`\mathbf A` is a covariance
|
|
292
|
+
matrix, or a batch of covariance matrices.
|
|
293
|
+
Returns samples from a zero-mean QEP, defined by :math:`\mathcal Q( \mathbf 0, \mathbf A)`.
|
|
294
|
+
|
|
295
|
+
:param num_samples: Number of samples to draw.
|
|
296
|
+
:return: Samples from QEP :math:`\mathcal Q( \mathbf 0, \mathbf A)`.
|
|
297
|
+
"""
|
|
298
|
+
from linear_operator.utils.contour_integral_quad import contour_integral_quad
|
|
299
|
+
|
|
300
|
+
if settings.ciq_samples.on():
|
|
301
|
+
base_samples = self.get_base_samples(torch.Size([num_samples]), **kwargs)
|
|
302
|
+
if len(self.event_shape)==2: # multitask case
|
|
303
|
+
if not self._interleaved: base_samples = base_samples.transpose(-1,-2)
|
|
304
|
+
base_samples = base_samples.reshape(base_samples.shape[:-2] + op.shape[-1:])
|
|
305
|
+
# base_samples = base_samples.permute(-1, *range(op.dim() - 1)).contiguous()
|
|
306
|
+
base_samples = base_samples.unsqueeze(-1)
|
|
307
|
+
solves, weights, _, _ = contour_integral_quad(
|
|
308
|
+
op.evaluate_kernel(),
|
|
309
|
+
base_samples,
|
|
310
|
+
inverse=False,
|
|
311
|
+
num_contour_quadrature=settings.num_contour_quadrature.value(),
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
return (solves * weights).sum(0).squeeze(-1)
|
|
315
|
+
|
|
316
|
+
else:
|
|
317
|
+
if op.size()[-2:] == torch.Size([1, 1]):
|
|
318
|
+
covar_root = op.to_dense().sqrt()
|
|
319
|
+
else:
|
|
320
|
+
covar_root = op.root_decomposition().root
|
|
321
|
+
|
|
322
|
+
base_samples = self.get_base_samples(torch.Size([num_samples]), **kwargs)
|
|
323
|
+
if len(self.event_shape)==2: # multitask case
|
|
324
|
+
if not self._interleaved: base_samples = base_samples.transpose(-1,-2)
|
|
325
|
+
base_samples = base_samples.reshape(base_samples.shape[:-2] + op.shape[-1:])
|
|
326
|
+
base_samples = base_samples.permute(*range(1, base_samples.dim() ), 0)
|
|
327
|
+
if covar_root.shape < op.shape: base_samples = base_samples[...,:covar_root.size(-1),:]
|
|
328
|
+
samples = covar_root.matmul(base_samples).permute(-1, *range(base_samples.dim() - 1)).contiguous()
|
|
329
|
+
|
|
330
|
+
return samples
|
|
331
|
+
|
|
332
|
+
def rsample(self, sample_shape: torch.Size = torch.Size(), base_samples: Optional[Tensor] = None, **kwargs) -> Tensor:
|
|
333
|
+
r"""
|
|
334
|
+
Generates a `sample_shape` shaped reparameterized sample or `sample_shape`
|
|
335
|
+
shaped batch of reparameterized samples if the distribution parameters
|
|
336
|
+
are batched.
|
|
337
|
+
|
|
338
|
+
For the MultivariateQExponential distribution, this is accomplished through:
|
|
339
|
+
|
|
340
|
+
.. math::
|
|
341
|
+
\boldsymbol \mu + \mathbf L \boldsymbol \epsilon
|
|
342
|
+
|
|
343
|
+
where :math:`\boldsymbol \mu \in \mathcal R^N` is the QEP mean,
|
|
344
|
+
:math:`\mathbf L \in \mathcal R^{N \times N}` is a "root" of the
|
|
345
|
+
covariance matrix :math:`\mathbf K` (i.e. :math:`\mathbf L \mathbf
|
|
346
|
+
L^\top = \mathbf K`), and :math:`\boldsymbol \epsilon \in \mathcal R^N` is a
|
|
347
|
+
vector of (approximately) m.i.u. standard Q-Exponential random variables.
|
|
348
|
+
|
|
349
|
+
:param sample_shape: The number of samples to generate. (Default: `torch.Size([])`.)
|
|
350
|
+
:param base_samples: The `*sample_shape x *batch_shape x N` tensor of
|
|
351
|
+
m.i.u. (or approximately m.i.u.) standard Q-Exponential samples to
|
|
352
|
+
reparameterize. (Default: None.)
|
|
353
|
+
:return: A `*sample_shape x *batch_shape x N` tensor of m.i.u. reparameterized samples.
|
|
354
|
+
"""
|
|
355
|
+
covar = self.lazy_covariance_matrix
|
|
356
|
+
if base_samples is None:
|
|
357
|
+
# Create some samples
|
|
358
|
+
num_samples = sample_shape.numel() or 1 # s
|
|
359
|
+
|
|
360
|
+
# covar_base = covar#.base_linear_op if hasattr(covar, 'base_linear_op') else covar
|
|
361
|
+
# if covar_base.size()[-2:] == torch.Size([1, 1]):
|
|
362
|
+
# covar_root = covar_base.to_dense().sqrt()
|
|
363
|
+
# else:
|
|
364
|
+
# covar_root = covar_base.root_decomposition().root # [b] x e x e
|
|
365
|
+
#
|
|
366
|
+
# base_samples = self.get_base_samples(torch.Size([num_samples]), **kwargs) # s x b x e or s x n x t
|
|
367
|
+
# if len(self.event_shape)==2: # multitask case
|
|
368
|
+
# if not self._interleaved: base_samples = base_samples.transpose(-1,-2) # s x t x n
|
|
369
|
+
# base_samples = base_samples.reshape(base_samples.shape[:-2] + covar_base.shape[-1:]) # s x e, e = nt
|
|
370
|
+
# base_samples = base_samples.permute(*range(1, covar_base.dim() ), 0) # [b] x e x s
|
|
371
|
+
# if covar_root.shape < covar_base.shape: base_samples = base_samples[...,:covar_root.size(-1),:]
|
|
372
|
+
#
|
|
373
|
+
# # Get samples
|
|
374
|
+
# res = covar_root.matmul(base_samples).permute(-1, *range(covar_base.dim()-1)).contiguous() # s x [b] x e
|
|
375
|
+
# # if hasattr(covar, '_remove_batch_dim'): res = covar._remove_batch_dim(res.unsqueeze(-1)).squeeze(-1)
|
|
376
|
+
# res = res + self.loc.unsqueeze(0)
|
|
377
|
+
res = self.zero_mean_qep_samples(covar, num_samples, **kwargs) + self.loc.unsqueeze(0)
|
|
378
|
+
res = res.view(sample_shape + self.loc.shape)
|
|
379
|
+
|
|
380
|
+
else:
|
|
381
|
+
covar_root = covar.root_decomposition().root
|
|
382
|
+
|
|
383
|
+
# Make sure that the base samples agree with the distribution
|
|
384
|
+
if (
|
|
385
|
+
self.loc.shape != base_samples.shape[-self.loc.dim() :]
|
|
386
|
+
and covar_root.shape[-1] < base_samples.shape[-1]
|
|
387
|
+
):
|
|
388
|
+
raise RuntimeError(
|
|
389
|
+
"The size of base_samples (minus sample shape dimensions) should agree with the size "
|
|
390
|
+
"of self.loc. Expected ...{} but got {}".format(self.loc.shape, base_samples.shape)
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# Determine what the appropriate sample_shape parameter is
|
|
394
|
+
sample_shape = base_samples.shape[: base_samples.dim() - self.loc.dim()]
|
|
395
|
+
|
|
396
|
+
# Reshape samples to be batch_size x num_dim x num_samples
|
|
397
|
+
# or num_bim x num_samples
|
|
398
|
+
base_samples = base_samples.view(-1, *self.loc.shape[:-1], covar_root.shape[-1])
|
|
399
|
+
base_samples = base_samples.permute(*range(1, self.loc.dim() + 1), 0)
|
|
400
|
+
|
|
401
|
+
# Now reparameterize those base samples
|
|
402
|
+
# If necessary, adjust base_samples for rank of root decomposition
|
|
403
|
+
if covar_root.shape[-1] < base_samples.shape[-2]:
|
|
404
|
+
base_samples = base_samples[..., : covar_root.shape[-1], :]
|
|
405
|
+
elif covar_root.shape[-1] > base_samples.shape[-2]:
|
|
406
|
+
# raise RuntimeError("Incompatible dimension of `base_samples`")
|
|
407
|
+
covar_root = covar_root.transpose(-2, -1)
|
|
408
|
+
res = covar_root.matmul(base_samples) + self.loc.unsqueeze(-1)
|
|
409
|
+
|
|
410
|
+
# Permute and reshape new samples to be original size
|
|
411
|
+
res = res.permute(-1, *range(self.loc.dim())).contiguous()
|
|
412
|
+
res = res.view(sample_shape + self.loc.shape)
|
|
413
|
+
|
|
414
|
+
return res
|
|
415
|
+
|
|
416
|
+
def sample(self, sample_shape: torch.Size = torch.Size(), base_samples: Optional[Tensor] = None, **kwargs) -> Tensor:
|
|
417
|
+
r"""
|
|
418
|
+
Generates a `sample_shape` shaped sample or `sample_shape`
|
|
419
|
+
shaped batch of samples if the distribution parameters
|
|
420
|
+
are batched.
|
|
421
|
+
|
|
422
|
+
Note that these samples are not reparameterized and therefore cannot be backpropagated through.
|
|
423
|
+
|
|
424
|
+
:param sample_shape: The number of samples to generate. (Default: `torch.Size([])`.)
|
|
425
|
+
:param base_samples: The `*sample_shape x *batch_shape x N` tensor of
|
|
426
|
+
m.i.u. (or approximately m.i.u.) standard Q-Exponential samples to
|
|
427
|
+
reparameterize. (Default: None.)
|
|
428
|
+
:return: A `*sample_shape x *batch_shape x N` tensor of m.i.u. samples.
|
|
429
|
+
"""
|
|
430
|
+
with torch.no_grad():
|
|
431
|
+
return self.rsample(sample_shape=sample_shape, base_samples=base_samples, **kwargs)
|
|
432
|
+
|
|
433
|
+
@property
|
|
434
|
+
def stddev(self) -> Tensor:
|
|
435
|
+
# self.variance is guaranteed to be positive, because we do clamping.
|
|
436
|
+
return self.variance.sqrt()
|
|
437
|
+
|
|
438
|
+
def to_data_uncorrelated_dist(self) -> MultivariateQExponential:
|
|
439
|
+
"""
|
|
440
|
+
Convert a `... x N` QEP distribution into a batch of uncorrelated Q-Exponential distributions.
|
|
441
|
+
Essentially, this throws away all covariance information
|
|
442
|
+
and treats all dimensions as batch dimensions.
|
|
443
|
+
|
|
444
|
+
:returns: A (data-uncorrelated) Q-Exponential distribution with batch shape `*batch_shape x N`.
|
|
445
|
+
"""
|
|
446
|
+
# Create batch distribution where all data are uncorrelated, but the tasks are dependent
|
|
447
|
+
# try:
|
|
448
|
+
# # If pyro is installed, use that set of base distributions
|
|
449
|
+
# import pyro.distributions as base_distributions
|
|
450
|
+
# except ImportError:
|
|
451
|
+
# # Otherwise, use PyTorch
|
|
452
|
+
# import torch.distributions as base_distributions
|
|
453
|
+
# return base_distributions.Normal(self.mean, self.stddev)
|
|
454
|
+
new_cov = DiagLinearOperator(
|
|
455
|
+
self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
|
|
456
|
+
)
|
|
457
|
+
return self.__class__(mean=self.mean, covariance_matrix=new_cov, power=self.power)
|
|
458
|
+
|
|
459
|
+
to_data_independent_dist = to_data_uncorrelated_dist # alias to the same function with a more appropriate name
|
|
460
|
+
|
|
461
|
+
@property
|
|
462
|
+
def variance(self) -> Tensor:
|
|
463
|
+
if self.islazy:
|
|
464
|
+
# overwrite this since torch uses unbroadcasted_scale_tril for this
|
|
465
|
+
diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
|
|
466
|
+
diag = diag.view(diag.shape[:-1] + self._event_shape)
|
|
467
|
+
variance = diag.expand(self._batch_shape + self._event_shape)
|
|
468
|
+
else:
|
|
469
|
+
variance = super().variance
|
|
470
|
+
|
|
471
|
+
# Check to make sure that variance isn't lower than minimum allowed value (default 1e-6).
|
|
472
|
+
# This ensures that all variances are positive
|
|
473
|
+
min_variance = settings.min_variance.value(variance.dtype)
|
|
474
|
+
if variance.lt(min_variance).any():
|
|
475
|
+
warnings.warn(
|
|
476
|
+
f"Negative variance values detected. "
|
|
477
|
+
"This is likely due to numerical instabilities. "
|
|
478
|
+
f"Rounding negative variances up to {min_variance}.",
|
|
479
|
+
NumericalWarning,
|
|
480
|
+
)
|
|
481
|
+
variance = variance.clamp_min(min_variance)
|
|
482
|
+
return variance
|
|
483
|
+
|
|
484
|
+
def __add__(self, other: MultivariateQExponential) -> MultivariateQExponential:
|
|
485
|
+
if isinstance(other, MultivariateQExponential):
|
|
486
|
+
return self.__class__(
|
|
487
|
+
mean=self.mean + other.mean,
|
|
488
|
+
covariance_matrix=(self.lazy_covariance_matrix + other.lazy_covariance_matrix),
|
|
489
|
+
power=self.power
|
|
490
|
+
)
|
|
491
|
+
elif isinstance(other, int) or isinstance(other, float):
|
|
492
|
+
return self.__class__(self.mean + other, self.lazy_covariance_matrix, self.power)
|
|
493
|
+
else:
|
|
494
|
+
raise RuntimeError("Unsupported type {} for addition w/ MultivariateQExponential".format(type(other)))
|
|
495
|
+
|
|
496
|
+
def __getitem__(self, idx) -> MultivariateQExponential:
|
|
497
|
+
r"""
|
|
498
|
+
Constructs a new MultivariateQExponential that represents a random variable
|
|
499
|
+
modified by an indexing operation.
|
|
500
|
+
|
|
501
|
+
The mean and covariance matrix arguments are indexed accordingly.
|
|
502
|
+
|
|
503
|
+
:param idx: Index to apply to the mean. The covariance matrix is indexed accordingly.
|
|
504
|
+
"""
|
|
505
|
+
|
|
506
|
+
if not isinstance(idx, tuple):
|
|
507
|
+
idx = (idx,)
|
|
508
|
+
if len(idx) > self.mean.dim() and Ellipsis in idx:
|
|
509
|
+
idx = tuple(i for i in idx if i != Ellipsis)
|
|
510
|
+
if len(idx) < self.mean.dim():
|
|
511
|
+
raise IndexError("Multiple ambiguous ellipsis in index!")
|
|
512
|
+
|
|
513
|
+
rest_idx = idx[:-1]
|
|
514
|
+
last_idx = idx[-1]
|
|
515
|
+
new_mean = self.mean[idx]
|
|
516
|
+
|
|
517
|
+
if len(idx) <= self.mean.dim() - 1 and (Ellipsis not in rest_idx):
|
|
518
|
+
# We are only indexing the batch dimensions in this case
|
|
519
|
+
new_cov = self.lazy_covariance_matrix[idx]
|
|
520
|
+
elif len(idx) > self.mean.dim():
|
|
521
|
+
raise IndexError(f"Index {idx} has too many dimensions")
|
|
522
|
+
else:
|
|
523
|
+
# In this case we know last_idx corresponds to the last dimension
|
|
524
|
+
# of mean and the last two dimensions of lazy_covariance_matrix
|
|
525
|
+
if isinstance(last_idx, int):
|
|
526
|
+
new_cov = DiagLinearOperator(
|
|
527
|
+
self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)[(*rest_idx, last_idx)]
|
|
528
|
+
)
|
|
529
|
+
elif isinstance(last_idx, slice):
|
|
530
|
+
new_cov = self.lazy_covariance_matrix[(*rest_idx, last_idx, last_idx)]
|
|
531
|
+
elif last_idx is (...):
|
|
532
|
+
new_cov = self.lazy_covariance_matrix[rest_idx]
|
|
533
|
+
else:
|
|
534
|
+
new_cov = self.lazy_covariance_matrix[(*rest_idx, last_idx, slice(None, None, None))][..., last_idx]
|
|
535
|
+
return self.__class__(mean=new_mean, covariance_matrix=new_cov, power=self.power)
|
|
536
|
+
|
|
537
|
+
def __mul__(self, other: Number) -> MultivariateQExponential:
|
|
538
|
+
if not (isinstance(other, int) or isinstance(other, float)):
|
|
539
|
+
raise RuntimeError("Can only multiply by scalars")
|
|
540
|
+
if other == 1:
|
|
541
|
+
return self
|
|
542
|
+
return self.__class__(mean=self.mean * other, covariance_matrix=self.lazy_covariance_matrix * (other**2), power=self.power)
|
|
543
|
+
|
|
544
|
+
def __radd__(self, other: MultivariateQExponential) -> MultivariateQExponential:
|
|
545
|
+
if other == 0:
|
|
546
|
+
return self
|
|
547
|
+
return self.__add__(other)
|
|
548
|
+
|
|
549
|
+
def __truediv__(self, other: Number) -> MultivariateQExponential:
|
|
550
|
+
return self.__mul__(1.0 / other)
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
@register_kl(MultivariateQExponential, MultivariateQExponential)
|
|
554
|
+
def kl_qep_qep(p_dist: MultivariateQExponential, q_dist: MultivariateQExponential, exact: bool = False) -> Tensor:
|
|
555
|
+
output_shape = torch.broadcast_shapes(p_dist.batch_shape, q_dist.batch_shape)
|
|
556
|
+
if output_shape != p_dist.batch_shape:
|
|
557
|
+
p_dist = p_dist.expand(output_shape)
|
|
558
|
+
if output_shape != q_dist.batch_shape:
|
|
559
|
+
q_dist = q_dist.expand(output_shape)
|
|
560
|
+
|
|
561
|
+
q_mean = q_dist.loc
|
|
562
|
+
q_covar = q_dist.lazy_covariance_matrix
|
|
563
|
+
|
|
564
|
+
p_mean = p_dist.loc
|
|
565
|
+
p_covar = p_dist.lazy_covariance_matrix
|
|
566
|
+
root_p_covar = p_covar.root_decomposition().root.to_dense()
|
|
567
|
+
|
|
568
|
+
mean_diffs = p_mean - q_mean
|
|
569
|
+
dim = float(mean_diffs.size(-1))
|
|
570
|
+
if isinstance(root_p_covar, LinearOperator):
|
|
571
|
+
# right now this just catches if root_p_covar is a DiagLinearOperator,
|
|
572
|
+
# but we may want to be smarter about this in the future
|
|
573
|
+
root_p_covar = root_p_covar.to_dense()
|
|
574
|
+
inv_quad_rhs = torch.cat([mean_diffs.unsqueeze(-1), root_p_covar], -1)
|
|
575
|
+
logdet_p_covar = p_covar.logdet()
|
|
576
|
+
trace_plus_inv_quad_form, logdet_q_covar = q_covar.inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=True)
|
|
577
|
+
|
|
578
|
+
# Compute the KL Divergence.
|
|
579
|
+
res = 0.5 * sum([logdet_q_covar, logdet_p_covar.mul(-1), trace_plus_inv_quad_form**(q_dist.power/2.), -dim**(1 if exact else p_dist.power/2.)])
|
|
580
|
+
if q_dist.power!=2: res += dim/2. * sum([-(q_dist.power/2.-1)*torch.log(trace_plus_inv_quad_form), -(p_dist.power/2.-1)*(2./p_dist.power*Chi2(dim).entropy() if exact else -math.log(dim))]) # exact value is intractable; an approximation is provided instead.
|
|
581
|
+
return res
|