qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,581 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from numbers import Number
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ from linear_operator import to_dense, to_linear_operator
12
+ from linear_operator.operators import DiagLinearOperator, LinearOperator, RootLinearOperator
13
+ from torch import Tensor
14
+ from torch.distributions import MultivariateNormal as TMultivariateNormal, Chi2
15
+ from torch.distributions.kl import register_kl
16
+ from torch.distributions.utils import _standard_normal, lazy_property
17
+
18
+ from .. import settings
19
+ from ..utils.warnings import NumericalWarning
20
+ from gpytorch.distributions.distribution import Distribution
21
+
22
+
23
+ class MultivariateQExponential(TMultivariateNormal, Distribution):
24
+ """
25
+ Constructs a multivariate q-exponential random variable, based on mean and covariance, whose density is
26
+
27
+ .. math::
28
+
29
+ p(x; \\mu, C) = \\frac{q}{2} (2\\pi)^{-\\frac{N}{2}} |C|^{-\\frac{1}{2}}
30
+ r^{\\left(\\frac{q}{2}-1\\right)\\frac{N}{2}} \\exp\\left\\{ -0.5 * r^{\\frac{q}{2}} \\right\\}, \\quad
31
+ r(x) = (x - \\mu)^T C^{-1} (x - \\mu).
32
+
33
+ The result can be multivariate, or a batch of multivariate q-exponentials.
34
+ Passing a vector mean corresponds to a multivariate q-exponential.
35
+ Passing a matrix mean corresponds to a batch of multivariate q-exponentials.
36
+
37
+ :param mean: `... x N` mean of qep distribution.
38
+ :param covariance_matrix: `... x N X N` covariance matrix of qep distribution.
39
+ :param power: (scalar) power of qep distribution. (Default: 2.)
40
+ :param validate_args: If True, validate `mean` and `covariance_matrix` arguments. (Default: False.)
41
+
42
+ :ivar torch.Size base_sample_shape: The shape of a base sample (without
43
+ batching) that is used to generate a single sample.
44
+ :ivar torch.Tensor covariance_matrix: The covariance matrix, represented as a dense :class:`torch.Tensor`
45
+ :ivar ~linear_operator.LinearOperator lazy_covariance_matrix: The covariance matrix, represented
46
+ as a :class:`~linear_operator.LinearOperator`.
47
+ :ivar torch.Tensor mean: The mean.
48
+ :ivar torch.Tensor stddev: The standard deviation.
49
+ :ivar torch.Tensor variance: The variance.
50
+ """
51
+
52
+ def __init__(self, mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator], power: Tensor = torch.tensor(2.0), validate_args: bool = False):
53
+ self._islazy = isinstance(mean, LinearOperator) or isinstance(covariance_matrix, LinearOperator)
54
+ if self._islazy:
55
+ if validate_args:
56
+ ms = mean.size(-1)
57
+ cs1 = covariance_matrix.size(-1)
58
+ cs2 = covariance_matrix.size(-2)
59
+ if not (ms == cs1 and ms == cs2):
60
+ raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}")
61
+ self.loc = mean
62
+ self._covar = covariance_matrix
63
+ self.__unbroadcasted_scale_tril = None
64
+ self._validate_args = validate_args
65
+ batch_shape = torch.broadcast_shapes(self.loc.shape[:-1], covariance_matrix.shape[:-2])
66
+
67
+ event_shape = self.loc.shape[-1:]
68
+
69
+ # TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
70
+ super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False)
71
+ else:
72
+ super().__init__(loc=mean, covariance_matrix=covariance_matrix, validate_args=validate_args)
73
+ self.power = power
74
+
75
+ def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size:
76
+ """
77
+ Returns the size of the sample returned by the distribution, given
78
+ a `sample_shape`. Note, that the batch and event shapes of a distribution
79
+ instance are fixed at the time of construction. If this is empty, the
80
+ returned shape is upcast to (1,).
81
+
82
+ :param sample_shape: the size of the sample to be drawn.
83
+ """
84
+ if not isinstance(sample_shape, torch.Size):
85
+ sample_shape = torch.Size(sample_shape)
86
+ return sample_shape + self._batch_shape + self.base_sample_shape
87
+
88
+ @staticmethod
89
+ def _repr_sizes(mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator], power: Tensor = torch.tensor(2.0)) -> str:
90
+ return f"MultivariateQExponential(loc: {mean.size()}, scale: {covariance_matrix.size()}, pow: {power.size()})"
91
+
92
+ @property
93
+ def _unbroadcasted_scale_tril(self) -> Tensor:
94
+ if self.islazy and self.__unbroadcasted_scale_tril is None:
95
+ # cache root decoposition
96
+ ust = to_dense(self.lazy_covariance_matrix.cholesky())
97
+ self.__unbroadcasted_scale_tril = ust
98
+ return self.__unbroadcasted_scale_tril
99
+
100
+ @_unbroadcasted_scale_tril.setter
101
+ def _unbroadcasted_scale_tril(self, ust: Tensor):
102
+ if self.islazy:
103
+ raise NotImplementedError("Cannot set _unbroadcasted_scale_tril for lazy QEP distributions")
104
+ else:
105
+ self.__unbroadcasted_scale_tril = ust
106
+
107
+ def add_jitter(self, noise: float = 1e-4) -> MultivariateQExponential:
108
+ r"""
109
+ Adds a small constant diagonal to the QEP covariance matrix for numerical stability.
110
+
111
+ :param noise: The size of the constant diagonal.
112
+ """
113
+ return self.__class__(self.mean, self.lazy_covariance_matrix.add_jitter(noise), self.power)
114
+
115
+ @property
116
+ def base_sample_shape(self) -> torch.Size:
117
+ base_sample_shape = self.event_shape
118
+ if isinstance(self.lazy_covariance_matrix, RootLinearOperator):
119
+ base_sample_shape = self.lazy_covariance_matrix.root.shape[-1:]
120
+
121
+ return base_sample_shape
122
+
123
+ @lazy_property
124
+ def covariance_matrix(self) -> Tensor:
125
+ if self.islazy:
126
+ return self._covar.to_dense()
127
+ else:
128
+ return super().covariance_matrix
129
+
130
+ @property
131
+ def rescalor(self) -> Tensor:
132
+ n = self.event_shape[0]
133
+ return torch.exp((2./self.power*math.log(2) - math.log(n) + torch.lgamma(n/2.+2./self.power) - math.lgamma(n/2.))/2.)
134
+
135
+ def confidence_region(self, rescale=False) -> Tuple[Tensor, Tensor]:
136
+ """
137
+ Returns 2 standard deviations above and below the mean.
138
+
139
+ :return: Pair of tensors of size `... x N`, where N is the
140
+ dimensionality of the random variable. The first (second) Tensor is the
141
+ lower (upper) end of the confidence region.
142
+ """
143
+ std2 = self.stddev.mul(2).mul(self.rescalor if rescale else 1)
144
+ mean = self.mean
145
+ return mean.sub(std2), mean.add(std2)
146
+
147
+ def expand(self, batch_size: torch.Size) -> MultivariateQExponential:
148
+ r"""
149
+ See :py:meth:`torch.distributions.Distribution.expand
150
+ <torch.distributions.distribution.Distribution.expand>`.
151
+ """
152
+ # NOTE: Pyro may call this method with list[int] instead of torch.Size.
153
+ batch_size = torch.Size(batch_size)
154
+ new_loc = self.loc.expand(batch_size + self.loc.shape[-1:])
155
+ if self.islazy:
156
+ new_covar = self._covar.expand(batch_size + self._covar.shape[-2:])
157
+ new = self.__class__(mean=new_loc, covariance_matrix=new_covar, power=self.power)
158
+ if self.__unbroadcasted_scale_tril is not None:
159
+ # Reuse the scale tril if available.
160
+ new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.expand(
161
+ batch_size + self.__unbroadcasted_scale_tril.shape[-2:]
162
+ )
163
+ else:
164
+ # Non-lazy QEP is represented using scale_tril in PyTorch.
165
+ # Constructing it from scale_tril will avoid unnecessary computation.
166
+ # Initialize using __new__, so that we can skip __init__ and use scale_tril.
167
+ new = self.__new__(type(self))
168
+ new._islazy = False
169
+ new_scale_tril = self.__unbroadcasted_scale_tril.expand(
170
+ batch_size + self.__unbroadcasted_scale_tril.shape[-2:]
171
+ )
172
+ super(MultivariateQExponential, new).__init__(loc=new_loc, scale_tril=new_scale_tril)
173
+ new.power = self.power
174
+ # Set the covar matrix, since it is always available for QPyTorch QEP.
175
+ new.covariance_matrix = self.covariance_matrix.expand(batch_size + self.covariance_matrix.shape[-2:])
176
+ return new
177
+
178
+ def unsqueeze(self, dim: int) -> MultivariateQExponential:
179
+ r"""
180
+ Constructs a new MultivariateQExponential with the batch shape unsqueezed
181
+ by the given dimension.
182
+ For example, if `self.batch_shape = torch.Size([2, 3])` and `dim = 0`, then
183
+ the returned MultivariateQExponential will have `batch_shape = torch.Size([1, 2, 3])`.
184
+ If `dim = -1`, then the returned MultivariateQExponential will have
185
+ `batch_shape = torch.Size([2, 3, 1])`.
186
+ """
187
+ if dim > len(self.batch_shape) or dim < -len(self.batch_shape) - 1:
188
+ raise IndexError(
189
+ "Dimension out of range (expected to be in range of "
190
+ f"[{-len(self.batch_shape) - 1}, {len(self.batch_shape)}], but got {dim})."
191
+ )
192
+ if dim < 0:
193
+ # If dim is negative, get the positive equivalent.
194
+ dim = len(self.batch_shape) + dim + 1
195
+
196
+ new_loc = self.loc.unsqueeze(dim)
197
+ if self.islazy:
198
+ new_covar = self._covar.unsqueeze(dim)
199
+ new = self.__class__(mean=new_loc, covariance_matrix=new_covar, power=self.power)
200
+ if self.__unbroadcasted_scale_tril is not None:
201
+ # Reuse the scale tril if available.
202
+ new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
203
+ else:
204
+ # Non-lazy QEP is represented using scale_tril in PyTorch.
205
+ # Constructing it from scale_tril will avoid unnecessary computation.
206
+ # Initialize using __new__, so that we can skip __init__ and use scale_tril.
207
+ new = self.__new__(type(self))
208
+ new._islazy = False
209
+ new_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
210
+ super(MultivariateQExponential, new).__init__(loc=new_loc, scale_tril=new_scale_tril)
211
+ new.power = self.power
212
+ # Set the covar matrix, since it is always available for QPyTorch QEP.
213
+ new.covariance_matrix = self.covariance_matrix.unsqueeze(dim)
214
+ return new
215
+
216
+ def get_base_samples(self, sample_shape: torch.Size = torch.Size(), rescale = False) -> Tensor:
217
+ r"""
218
+ Returns marginally identical but uncorrelated (m.i.u.) standard Q-Exponential samples to be used with
219
+ :py:meth:`MultivariateQExponential.rsample(base_samples=base_samples)
220
+ <qpytorch.distributions.MultivariateQExponential.rsample>`.
221
+
222
+ :param sample_shape: The number of samples to generate. (Default: `torch.Size([])`.)
223
+ :return: A `*sample_shape x *batch_shape x N` tensor of m.i.u. standard Q-Exponential samples.
224
+ """
225
+ with torch.no_grad():
226
+ shape = self._extended_shape(sample_shape)
227
+ base_samples = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
228
+ if self.power!=2: base_samples = torch.nn.functional.normalize(base_samples, dim=-1)*Chi2(shape[-1]).sample(shape[:-1]+torch.Size([1])).to(self.loc.device)**(1./self.power)
229
+ if rescale: base_samples /= self.rescalor
230
+ return base_samples
231
+
232
+ @lazy_property
233
+ def lazy_covariance_matrix(self) -> LinearOperator:
234
+ if self.islazy:
235
+ return self._covar
236
+ else:
237
+ return to_linear_operator(super().covariance_matrix)
238
+
239
+ def log_prob(self, value: Tensor) -> Tensor:
240
+ r"""
241
+ See :py:meth:`torch.distributions.Distribution.log_prob
242
+ <torch.distributions.distribution.Distribution.log_prob>`.
243
+ """
244
+ if settings.fast_computations.log_prob.off():
245
+ return super().log_prob(value)
246
+
247
+ if self._validate_args:
248
+ self._validate_sample(value)
249
+
250
+ mean, covar, power = self.loc, self.lazy_covariance_matrix, self.power
251
+ diff = value - mean
252
+
253
+ # Repeat the covar to match the batch shape of diff
254
+ if diff.shape[:-1] != covar.batch_shape:
255
+ if len(diff.shape[:-1]) < len(covar.batch_shape):
256
+ diff = diff.expand(covar.shape[:-1])
257
+ else:
258
+ padded_batch_shape = (*(1 for _ in range(diff.dim() + 1 - covar.dim())), *covar.batch_shape)
259
+ covar = covar.repeat(
260
+ *(diff_size // covar_size for diff_size, covar_size in zip(diff.shape[:-1], padded_batch_shape)),
261
+ 1,
262
+ 1,
263
+ )
264
+
265
+ # Get log determininant and first part of quadratic form
266
+ covar = covar.evaluate_kernel()
267
+ inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
268
+
269
+ res = -0.5 * sum([inv_quad**(power/2.), logdet, diff.size(-1) * math.log(2 * math.pi)])
270
+ if power!=2: res += sum([0.5 * diff.size(-1) * (power/2.-1) * torch.log(inv_quad), torch.log(power/2.)])
271
+ return res
272
+
273
+ def entropy(self, exact: bool = False) -> Tensor:
274
+ r"""
275
+ See :py:meth:`torch.distributions.Distribution.entropy
276
+ <torch.distributions.distribution.Distribution.entropy>`.
277
+ """
278
+ d = self._event_shape[0] #self.loc.shape[-1]
279
+ if self.islazy:
280
+ logdet = self.lazy_covariance_matrix.logdet()
281
+ res = 0.5 * sum([d*math.log(2*math.pi), logdet, d**(1 if exact else self.power/2.)])
282
+ else:
283
+ res = super().entropy()
284
+ if not exact: res += 0.5*(-d + d**(self.power/2.))
285
+ if self.power!=2:
286
+ res += sum([d/2.*(self.power/2.-1) *(2./self.power* Chi2(d).entropy() if exact else -math.log(d)), -torch.log(self.power/2.)])
287
+ return res
288
+
289
+ def zero_mean_qep_samples(self, op: LinearOperator, num_samples: int, **kwargs) -> Tensor:
290
+ r"""
291
+ Assumes that the LinearOpeator :math:`\mathbf A` is a covariance
292
+ matrix, or a batch of covariance matrices.
293
+ Returns samples from a zero-mean QEP, defined by :math:`\mathcal Q( \mathbf 0, \mathbf A)`.
294
+
295
+ :param num_samples: Number of samples to draw.
296
+ :return: Samples from QEP :math:`\mathcal Q( \mathbf 0, \mathbf A)`.
297
+ """
298
+ from linear_operator.utils.contour_integral_quad import contour_integral_quad
299
+
300
+ if settings.ciq_samples.on():
301
+ base_samples = self.get_base_samples(torch.Size([num_samples]), **kwargs)
302
+ if len(self.event_shape)==2: # multitask case
303
+ if not self._interleaved: base_samples = base_samples.transpose(-1,-2)
304
+ base_samples = base_samples.reshape(base_samples.shape[:-2] + op.shape[-1:])
305
+ # base_samples = base_samples.permute(-1, *range(op.dim() - 1)).contiguous()
306
+ base_samples = base_samples.unsqueeze(-1)
307
+ solves, weights, _, _ = contour_integral_quad(
308
+ op.evaluate_kernel(),
309
+ base_samples,
310
+ inverse=False,
311
+ num_contour_quadrature=settings.num_contour_quadrature.value(),
312
+ )
313
+
314
+ return (solves * weights).sum(0).squeeze(-1)
315
+
316
+ else:
317
+ if op.size()[-2:] == torch.Size([1, 1]):
318
+ covar_root = op.to_dense().sqrt()
319
+ else:
320
+ covar_root = op.root_decomposition().root
321
+
322
+ base_samples = self.get_base_samples(torch.Size([num_samples]), **kwargs)
323
+ if len(self.event_shape)==2: # multitask case
324
+ if not self._interleaved: base_samples = base_samples.transpose(-1,-2)
325
+ base_samples = base_samples.reshape(base_samples.shape[:-2] + op.shape[-1:])
326
+ base_samples = base_samples.permute(*range(1, base_samples.dim() ), 0)
327
+ if covar_root.shape < op.shape: base_samples = base_samples[...,:covar_root.size(-1),:]
328
+ samples = covar_root.matmul(base_samples).permute(-1, *range(base_samples.dim() - 1)).contiguous()
329
+
330
+ return samples
331
+
332
+ def rsample(self, sample_shape: torch.Size = torch.Size(), base_samples: Optional[Tensor] = None, **kwargs) -> Tensor:
333
+ r"""
334
+ Generates a `sample_shape` shaped reparameterized sample or `sample_shape`
335
+ shaped batch of reparameterized samples if the distribution parameters
336
+ are batched.
337
+
338
+ For the MultivariateQExponential distribution, this is accomplished through:
339
+
340
+ .. math::
341
+ \boldsymbol \mu + \mathbf L \boldsymbol \epsilon
342
+
343
+ where :math:`\boldsymbol \mu \in \mathcal R^N` is the QEP mean,
344
+ :math:`\mathbf L \in \mathcal R^{N \times N}` is a "root" of the
345
+ covariance matrix :math:`\mathbf K` (i.e. :math:`\mathbf L \mathbf
346
+ L^\top = \mathbf K`), and :math:`\boldsymbol \epsilon \in \mathcal R^N` is a
347
+ vector of (approximately) m.i.u. standard Q-Exponential random variables.
348
+
349
+ :param sample_shape: The number of samples to generate. (Default: `torch.Size([])`.)
350
+ :param base_samples: The `*sample_shape x *batch_shape x N` tensor of
351
+ m.i.u. (or approximately m.i.u.) standard Q-Exponential samples to
352
+ reparameterize. (Default: None.)
353
+ :return: A `*sample_shape x *batch_shape x N` tensor of m.i.u. reparameterized samples.
354
+ """
355
+ covar = self.lazy_covariance_matrix
356
+ if base_samples is None:
357
+ # Create some samples
358
+ num_samples = sample_shape.numel() or 1 # s
359
+
360
+ # covar_base = covar#.base_linear_op if hasattr(covar, 'base_linear_op') else covar
361
+ # if covar_base.size()[-2:] == torch.Size([1, 1]):
362
+ # covar_root = covar_base.to_dense().sqrt()
363
+ # else:
364
+ # covar_root = covar_base.root_decomposition().root # [b] x e x e
365
+ #
366
+ # base_samples = self.get_base_samples(torch.Size([num_samples]), **kwargs) # s x b x e or s x n x t
367
+ # if len(self.event_shape)==2: # multitask case
368
+ # if not self._interleaved: base_samples = base_samples.transpose(-1,-2) # s x t x n
369
+ # base_samples = base_samples.reshape(base_samples.shape[:-2] + covar_base.shape[-1:]) # s x e, e = nt
370
+ # base_samples = base_samples.permute(*range(1, covar_base.dim() ), 0) # [b] x e x s
371
+ # if covar_root.shape < covar_base.shape: base_samples = base_samples[...,:covar_root.size(-1),:]
372
+ #
373
+ # # Get samples
374
+ # res = covar_root.matmul(base_samples).permute(-1, *range(covar_base.dim()-1)).contiguous() # s x [b] x e
375
+ # # if hasattr(covar, '_remove_batch_dim'): res = covar._remove_batch_dim(res.unsqueeze(-1)).squeeze(-1)
376
+ # res = res + self.loc.unsqueeze(0)
377
+ res = self.zero_mean_qep_samples(covar, num_samples, **kwargs) + self.loc.unsqueeze(0)
378
+ res = res.view(sample_shape + self.loc.shape)
379
+
380
+ else:
381
+ covar_root = covar.root_decomposition().root
382
+
383
+ # Make sure that the base samples agree with the distribution
384
+ if (
385
+ self.loc.shape != base_samples.shape[-self.loc.dim() :]
386
+ and covar_root.shape[-1] < base_samples.shape[-1]
387
+ ):
388
+ raise RuntimeError(
389
+ "The size of base_samples (minus sample shape dimensions) should agree with the size "
390
+ "of self.loc. Expected ...{} but got {}".format(self.loc.shape, base_samples.shape)
391
+ )
392
+
393
+ # Determine what the appropriate sample_shape parameter is
394
+ sample_shape = base_samples.shape[: base_samples.dim() - self.loc.dim()]
395
+
396
+ # Reshape samples to be batch_size x num_dim x num_samples
397
+ # or num_bim x num_samples
398
+ base_samples = base_samples.view(-1, *self.loc.shape[:-1], covar_root.shape[-1])
399
+ base_samples = base_samples.permute(*range(1, self.loc.dim() + 1), 0)
400
+
401
+ # Now reparameterize those base samples
402
+ # If necessary, adjust base_samples for rank of root decomposition
403
+ if covar_root.shape[-1] < base_samples.shape[-2]:
404
+ base_samples = base_samples[..., : covar_root.shape[-1], :]
405
+ elif covar_root.shape[-1] > base_samples.shape[-2]:
406
+ # raise RuntimeError("Incompatible dimension of `base_samples`")
407
+ covar_root = covar_root.transpose(-2, -1)
408
+ res = covar_root.matmul(base_samples) + self.loc.unsqueeze(-1)
409
+
410
+ # Permute and reshape new samples to be original size
411
+ res = res.permute(-1, *range(self.loc.dim())).contiguous()
412
+ res = res.view(sample_shape + self.loc.shape)
413
+
414
+ return res
415
+
416
+ def sample(self, sample_shape: torch.Size = torch.Size(), base_samples: Optional[Tensor] = None, **kwargs) -> Tensor:
417
+ r"""
418
+ Generates a `sample_shape` shaped sample or `sample_shape`
419
+ shaped batch of samples if the distribution parameters
420
+ are batched.
421
+
422
+ Note that these samples are not reparameterized and therefore cannot be backpropagated through.
423
+
424
+ :param sample_shape: The number of samples to generate. (Default: `torch.Size([])`.)
425
+ :param base_samples: The `*sample_shape x *batch_shape x N` tensor of
426
+ m.i.u. (or approximately m.i.u.) standard Q-Exponential samples to
427
+ reparameterize. (Default: None.)
428
+ :return: A `*sample_shape x *batch_shape x N` tensor of m.i.u. samples.
429
+ """
430
+ with torch.no_grad():
431
+ return self.rsample(sample_shape=sample_shape, base_samples=base_samples, **kwargs)
432
+
433
+ @property
434
+ def stddev(self) -> Tensor:
435
+ # self.variance is guaranteed to be positive, because we do clamping.
436
+ return self.variance.sqrt()
437
+
438
+ def to_data_uncorrelated_dist(self) -> MultivariateQExponential:
439
+ """
440
+ Convert a `... x N` QEP distribution into a batch of uncorrelated Q-Exponential distributions.
441
+ Essentially, this throws away all covariance information
442
+ and treats all dimensions as batch dimensions.
443
+
444
+ :returns: A (data-uncorrelated) Q-Exponential distribution with batch shape `*batch_shape x N`.
445
+ """
446
+ # Create batch distribution where all data are uncorrelated, but the tasks are dependent
447
+ # try:
448
+ # # If pyro is installed, use that set of base distributions
449
+ # import pyro.distributions as base_distributions
450
+ # except ImportError:
451
+ # # Otherwise, use PyTorch
452
+ # import torch.distributions as base_distributions
453
+ # return base_distributions.Normal(self.mean, self.stddev)
454
+ new_cov = DiagLinearOperator(
455
+ self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
456
+ )
457
+ return self.__class__(mean=self.mean, covariance_matrix=new_cov, power=self.power)
458
+
459
+ to_data_independent_dist = to_data_uncorrelated_dist # alias to the same function with a more appropriate name
460
+
461
+ @property
462
+ def variance(self) -> Tensor:
463
+ if self.islazy:
464
+ # overwrite this since torch uses unbroadcasted_scale_tril for this
465
+ diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
466
+ diag = diag.view(diag.shape[:-1] + self._event_shape)
467
+ variance = diag.expand(self._batch_shape + self._event_shape)
468
+ else:
469
+ variance = super().variance
470
+
471
+ # Check to make sure that variance isn't lower than minimum allowed value (default 1e-6).
472
+ # This ensures that all variances are positive
473
+ min_variance = settings.min_variance.value(variance.dtype)
474
+ if variance.lt(min_variance).any():
475
+ warnings.warn(
476
+ f"Negative variance values detected. "
477
+ "This is likely due to numerical instabilities. "
478
+ f"Rounding negative variances up to {min_variance}.",
479
+ NumericalWarning,
480
+ )
481
+ variance = variance.clamp_min(min_variance)
482
+ return variance
483
+
484
+ def __add__(self, other: MultivariateQExponential) -> MultivariateQExponential:
485
+ if isinstance(other, MultivariateQExponential):
486
+ return self.__class__(
487
+ mean=self.mean + other.mean,
488
+ covariance_matrix=(self.lazy_covariance_matrix + other.lazy_covariance_matrix),
489
+ power=self.power
490
+ )
491
+ elif isinstance(other, int) or isinstance(other, float):
492
+ return self.__class__(self.mean + other, self.lazy_covariance_matrix, self.power)
493
+ else:
494
+ raise RuntimeError("Unsupported type {} for addition w/ MultivariateQExponential".format(type(other)))
495
+
496
+ def __getitem__(self, idx) -> MultivariateQExponential:
497
+ r"""
498
+ Constructs a new MultivariateQExponential that represents a random variable
499
+ modified by an indexing operation.
500
+
501
+ The mean and covariance matrix arguments are indexed accordingly.
502
+
503
+ :param idx: Index to apply to the mean. The covariance matrix is indexed accordingly.
504
+ """
505
+
506
+ if not isinstance(idx, tuple):
507
+ idx = (idx,)
508
+ if len(idx) > self.mean.dim() and Ellipsis in idx:
509
+ idx = tuple(i for i in idx if i != Ellipsis)
510
+ if len(idx) < self.mean.dim():
511
+ raise IndexError("Multiple ambiguous ellipsis in index!")
512
+
513
+ rest_idx = idx[:-1]
514
+ last_idx = idx[-1]
515
+ new_mean = self.mean[idx]
516
+
517
+ if len(idx) <= self.mean.dim() - 1 and (Ellipsis not in rest_idx):
518
+ # We are only indexing the batch dimensions in this case
519
+ new_cov = self.lazy_covariance_matrix[idx]
520
+ elif len(idx) > self.mean.dim():
521
+ raise IndexError(f"Index {idx} has too many dimensions")
522
+ else:
523
+ # In this case we know last_idx corresponds to the last dimension
524
+ # of mean and the last two dimensions of lazy_covariance_matrix
525
+ if isinstance(last_idx, int):
526
+ new_cov = DiagLinearOperator(
527
+ self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)[(*rest_idx, last_idx)]
528
+ )
529
+ elif isinstance(last_idx, slice):
530
+ new_cov = self.lazy_covariance_matrix[(*rest_idx, last_idx, last_idx)]
531
+ elif last_idx is (...):
532
+ new_cov = self.lazy_covariance_matrix[rest_idx]
533
+ else:
534
+ new_cov = self.lazy_covariance_matrix[(*rest_idx, last_idx, slice(None, None, None))][..., last_idx]
535
+ return self.__class__(mean=new_mean, covariance_matrix=new_cov, power=self.power)
536
+
537
+ def __mul__(self, other: Number) -> MultivariateQExponential:
538
+ if not (isinstance(other, int) or isinstance(other, float)):
539
+ raise RuntimeError("Can only multiply by scalars")
540
+ if other == 1:
541
+ return self
542
+ return self.__class__(mean=self.mean * other, covariance_matrix=self.lazy_covariance_matrix * (other**2), power=self.power)
543
+
544
+ def __radd__(self, other: MultivariateQExponential) -> MultivariateQExponential:
545
+ if other == 0:
546
+ return self
547
+ return self.__add__(other)
548
+
549
+ def __truediv__(self, other: Number) -> MultivariateQExponential:
550
+ return self.__mul__(1.0 / other)
551
+
552
+
553
+ @register_kl(MultivariateQExponential, MultivariateQExponential)
554
+ def kl_qep_qep(p_dist: MultivariateQExponential, q_dist: MultivariateQExponential, exact: bool = False) -> Tensor:
555
+ output_shape = torch.broadcast_shapes(p_dist.batch_shape, q_dist.batch_shape)
556
+ if output_shape != p_dist.batch_shape:
557
+ p_dist = p_dist.expand(output_shape)
558
+ if output_shape != q_dist.batch_shape:
559
+ q_dist = q_dist.expand(output_shape)
560
+
561
+ q_mean = q_dist.loc
562
+ q_covar = q_dist.lazy_covariance_matrix
563
+
564
+ p_mean = p_dist.loc
565
+ p_covar = p_dist.lazy_covariance_matrix
566
+ root_p_covar = p_covar.root_decomposition().root.to_dense()
567
+
568
+ mean_diffs = p_mean - q_mean
569
+ dim = float(mean_diffs.size(-1))
570
+ if isinstance(root_p_covar, LinearOperator):
571
+ # right now this just catches if root_p_covar is a DiagLinearOperator,
572
+ # but we may want to be smarter about this in the future
573
+ root_p_covar = root_p_covar.to_dense()
574
+ inv_quad_rhs = torch.cat([mean_diffs.unsqueeze(-1), root_p_covar], -1)
575
+ logdet_p_covar = p_covar.logdet()
576
+ trace_plus_inv_quad_form, logdet_q_covar = q_covar.inv_quad_logdet(inv_quad_rhs=inv_quad_rhs, logdet=True)
577
+
578
+ # Compute the KL Divergence.
579
+ res = 0.5 * sum([logdet_q_covar, logdet_p_covar.mul(-1), trace_plus_inv_quad_form**(q_dist.power/2.), -dim**(1 if exact else p_dist.power/2.)])
580
+ if q_dist.power!=2: res += dim/2. * sum([-(q_dist.power/2.-1)*torch.log(trace_plus_inv_quad_form), -(p_dist.power/2.-1)*(2./p_dist.power*Chi2(dim).entropy() if exact else -math.log(dim))]) # exact value is intractable; an approximation is provided instead.
581
+ return res