qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
qpytorch/__init__.py ADDED
@@ -0,0 +1,327 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import linear_operator
6
+ import torch
7
+ from linear_operator import LinearOperator
8
+ from torch import Tensor
9
+
10
+ from gpytorch import (
11
+ beta_features,
12
+ settings,
13
+ )
14
+ from . import (
15
+ distributions,
16
+ kernels,
17
+ lazy,
18
+ likelihoods,
19
+ means,
20
+ metrics,
21
+ mlls,
22
+ models,
23
+ optim,
24
+ priors,
25
+ utils,
26
+ variational,
27
+ )
28
+ from .functions import inv_matmul, log_normal_cdf, logdet, matmul # Deprecated
29
+ from .mlls import ExactMarginalLogLikelihood
30
+ from gpytorch.module import Module
31
+
32
+ Anysor = Union[LinearOperator, Tensor]
33
+
34
+
35
+ def add_diagonal(input: Anysor, diag: Tensor) -> LinearOperator:
36
+ r"""
37
+ Adds an element to the diagonal of the matrix :math:`\mathbf A`.
38
+
39
+ :param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
40
+ :param diag: Diagonal to add
41
+ :return: :math:`\mathbf A + \text{diag}(\mathbf d)`, where :math:`\mathbf A` is the linear operator
42
+ and :math:`\mathbf d` is the diagonal component
43
+ """
44
+ return linear_operator.add_diagonal(input=input, diag=diag)
45
+
46
+
47
+ def add_jitter(input: Anysor, jitter_val: float = 1e-3) -> Anysor:
48
+ r"""
49
+ Adds jitter (i.e., a small diagonal component) to the matrix this
50
+ LinearOperator represents.
51
+ This is equivalent to calling :meth:`~linear_operator.operators.LinearOperator.add_diagonal`
52
+ with a scalar tensor.
53
+
54
+ :param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
55
+ :param jitter_val: The diagonal component to add
56
+ :return: :math:`\mathbf A + \alpha (\mathbf I)`, where :math:`\mathbf A` is the linear operator
57
+ and :math:`\alpha` is :attr:`jitter_val`.
58
+ """
59
+ return linear_operator.add_jitter(input=input, jitter_val=jitter_val)
60
+
61
+
62
+ def diagonalization(input: Anysor, method: Optional[str] = None) -> Tuple[Tensor, Tensor]:
63
+ r"""
64
+ Returns a (usually partial) diagonalization of a symmetric positive definite matrix (or batch of matrices).
65
+ :math:`\mathbf A`.
66
+ Options are either "lanczos" or "symeig". "lanczos" runs Lanczos while
67
+ "symeig" runs LinearOperator.symeig.
68
+
69
+ :param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
70
+ :param method: Specify the method to use ("lanczos" or "symeig"). The method will be determined
71
+ based on size if not specified.
72
+ :return: eigenvalues and eigenvectors representing the diagonalization.
73
+ """
74
+ return linear_operator.diagonalization(input=input, method=method)
75
+
76
+
77
+ def dsmm(
78
+ sparse_mat: Union[torch.sparse.HalfTensor, torch.sparse.FloatTensor, torch.sparse.DoubleTensor],
79
+ dense_mat: Tensor,
80
+ ) -> Tensor:
81
+ r"""
82
+ Performs the (batch) matrix multiplication :math:`\mathbf{SD}`
83
+ where :math:`\mathbf S` is a sparse matrix and :math:`\mathbf D` is a dense matrix.
84
+
85
+ :param sparse_mat: Sparse matrix :math:`\mathbf S` (... x M x N)
86
+ :param dense_mat: Dense matrix :math:`\mathbf D` (... x N x O)
87
+ :return: :math:`\mathbf S \mathbf D` (... x M x N)
88
+ """
89
+ return linear_operator.dsmm(sparse_mat=sparse_mat, dense_mat=dense_mat)
90
+
91
+
92
+ def inv_quad(input: Anysor, inv_quad_rhs: Tensor, reduce_inv_quad: bool = True) -> Tensor:
93
+ r"""
94
+ Computes an inverse quadratic form (w.r.t self) with several right hand sides, i.e:
95
+
96
+ .. math::
97
+ \text{tr}\left( \mathbf R^\top \mathbf A^{-1} \mathbf R \right),
98
+
99
+ where :math:`\mathbf A` is a positive definite matrix (or batch of matrices) and :math:`\mathbf R`
100
+ represents the right hand sides (:attr:`inv_quad_rhs`).
101
+
102
+ If :attr:`reduce_inv_quad` is set to false (and :attr:`inv_quad_rhs` is supplied),
103
+ the function instead computes
104
+
105
+ .. math::
106
+ \text{diag}\left( \mathbf R^\top \mathbf A^{-1} \mathbf R \right).
107
+
108
+ :param input: :math:`\mathbf A` - the positive definite matrix (... X N X N)
109
+ :param inv_quad_rhs: :math:`\mathbf R` - the right hand sides of the inverse quadratic term (... x N x M)
110
+ :param reduce_inv_quad: Whether to compute
111
+ :math:`\text{tr}\left( \mathbf R^\top \mathbf A^{-1} \mathbf R \right)`
112
+ or :math:`\text{diag}\left( \mathbf R^\top \mathbf A^{-1} \mathbf R \right)`.
113
+ :returns: The inverse quadratic term.
114
+ If `reduce_inv_quad=True`, the inverse quadratic term is of shape (...). Otherwise, it is (... x M).
115
+ """
116
+ return linear_operator.inv_quad(input=input, inv_quad_rhs=inv_quad_rhs, reduce_inv_quad=reduce_inv_quad)
117
+
118
+
119
+ def inv_quad_logdet(
120
+ input: Anysor,
121
+ inv_quad_rhs: Optional[Tensor] = None,
122
+ logdet: bool = False,
123
+ reduce_inv_quad: bool = True,
124
+ ) -> Tuple[Tensor, Tensor]:
125
+ r"""
126
+ Calls both :func:`inv_quad_logdet` and :func:`logdet` on a positive definite matrix (or batch) :math:`\mathbf A`.
127
+ However, calling this method is far more efficient and stable than calling each method independently.
128
+
129
+ :param input: :math:`\mathbf A` - the positive definite matrix (... X N X N)
130
+ :param inv_quad_rhs: :math:`\mathbf R` - the right hand sides of the inverse quadratic term (... x N x M)
131
+ :param logdet: Whether or not to compute the
132
+ logdet term :math:`\log \vert \mathbf A \vert`.
133
+ :param reduce_inv_quad: Whether to compute
134
+ :math:`\text{tr}\left( \mathbf R^\top \mathbf A^{-1} \mathbf R \right)`
135
+ or :math:`\text{diag}\left( \mathbf R^\top \mathbf A^{-1} \mathbf R \right)`.
136
+ :returns: The inverse quadratic term (or None), and the logdet term (or None).
137
+ If `reduce_inv_quad=True`, the inverse quadratic term is of shape (...). Otherwise, it is (... x M).
138
+ """
139
+ return linear_operator.inv_quad_logdet(
140
+ input=input,
141
+ inv_quad_rhs=inv_quad_rhs,
142
+ logdet=logdet,
143
+ reduce_inv_quad=reduce_inv_quad,
144
+ )
145
+
146
+
147
+ def pivoted_cholesky(
148
+ input: Anysor,
149
+ rank: int,
150
+ error_tol: Optional[float] = None,
151
+ return_pivots: bool = False,
152
+ ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
153
+ r"""
154
+ Performs a partial pivoted Cholesky factorization of a positive definite matrix (or batch of matrices).
155
+ :math:`\mathbf L \mathbf L^\top = \mathbf A`.
156
+ The partial pivoted Cholesky factor :math:`\mathbf L \in \mathbb R^{N \times \text{rank}}`
157
+ forms a low rank approximation to the LinearOperator.
158
+
159
+ The pivots are selected greedily, correspoading to the maximum diagonal element in the
160
+ residual after each Cholesky iteration. See `Harbrecht et al., 2012`_.
161
+
162
+ :param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
163
+ :param rank: The size of the partial pivoted Cholesky factor.
164
+ :param error_tol: Defines an optional stopping criterion.
165
+ If the residual of the factorization is less than :attr:`error_tol`, then the
166
+ factorization will exit early. This will result in a :math:`\leq \text{ rank}` factor.
167
+ :param return_pivots: Whether or not to return the pivots alongside
168
+ the partial pivoted Cholesky factor.
169
+ :return: The `... x N x rank` factor (and optionally the `... x N` pivots if :attr:`return_pivots` is True).
170
+
171
+ .. _Harbrecht et al., 2012:
172
+ https://www.sciencedirect.com/science/article/pii/S0168927411001814
173
+ """
174
+ return linear_operator.pivoted_cholesky(input=input, rank=rank, return_pivots=return_pivots)
175
+
176
+
177
+ def root_decomposition(input: Anysor, method: Optional[str] = None) -> LinearOperator:
178
+ r"""
179
+ Returns a (usually low-rank) root decomposition linear operator of the
180
+ positive definite matrix (or batch of matrices) :math:`\mathbf A`.
181
+ This can be used for sampling from a Gaussian distribution, or for obtaining a
182
+ low-rank version of a matrix.
183
+
184
+ :param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
185
+ :param method: Which method to use to perform the root decomposition. Choices are:
186
+ "cholesky", "lanczos", "symeig", "pivoted_cholesky", or "svd".
187
+ :return: A tensor :math:`\mathbf R` such that :math:`\mathbf R \mathbf R^\top \approx \mathbf A`.
188
+ """
189
+ return linear_operator.root_decomposition(input=input, method=method)
190
+
191
+
192
+ def root_inv_decomposition(
193
+ input: Anysor,
194
+ initial_vectors: Optional[Tensor] = None,
195
+ test_vectors: Optional[Tensor] = None,
196
+ method: Optional[str] = None,
197
+ ) -> LinearOperator:
198
+ r"""
199
+ Returns a (usually low-rank) inverse root decomposition linear operator
200
+ of the PSD LinearOperator :math:`\mathbf A`.
201
+ This can be used for sampling from a Gaussian distribution, or for obtaining a
202
+ low-rank version of a matrix.
203
+
204
+ The root_inv_decomposition is performed using a partial Lanczos tridiagonalization.
205
+
206
+ :param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
207
+ :param initial_vectors: Vectors used to initialize the Lanczos decomposition.
208
+ The best initialization vector (determined by :attr:`test_vectors`) will be chosen.
209
+ :param test_vectors: Vectors used to test the accuracy of the decomposition.
210
+ :param method: Root decomposition method to use (symeig, diagonalization, lanczos, or cholesky).
211
+ :return: A tensor :math:`\mathbf R` such that :math:`\mathbf R \mathbf R^\top \approx \mathbf A^{-1}`.
212
+ """
213
+ return linear_operator.root_inv_decomposition(
214
+ input=input,
215
+ initial_vectors=initial_vectors,
216
+ test_vectors=test_vectors,
217
+ method=method,
218
+ )
219
+
220
+
221
+ def solve(input: Anysor, rhs: Tensor, lhs: Optional[Tensor] = None) -> Tensor:
222
+ r"""
223
+ Given a positive definite matrix (or batch of matrices) :math:`\mathbf A`,
224
+ computes a linear solve with right hand side :math:`\mathbf R`:
225
+
226
+ .. math::
227
+ \begin{equation}
228
+ \mathbf A^{-1} \mathbf R,
229
+ \end{equation}
230
+
231
+ where :math:`\mathbf R` is :attr:`right_tensor` and :math:`\mathbf A` is the LinearOperator.
232
+
233
+ .. note::
234
+ Unlike :func:`torch.linalg.solve`, this function can take an optional :attr:`left_tensor` attribute.
235
+ If this is supplied :func:`gpytorch.solve` computes
236
+
237
+ .. math::
238
+ \begin{equation}
239
+ \mathbf L \mathbf A^{-1} \mathbf R,
240
+ \end{equation}
241
+
242
+ where :math:`\mathbf L` is :attr:`left_tensor`.
243
+ Supplying this can reduce the number of solver calls required in the backward pass.
244
+
245
+ :param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
246
+ :param rhs: :math:`\mathbf R` - the right hand side
247
+ :param lhs: :math:`\mathbf L` - the left hand side
248
+ :return: :math:`\mathbf A^{-1} \mathbf R` or :math:`\mathbf L \mathbf A^{-1} \mathbf R`.
249
+ """
250
+ return linear_operator.solve(input=input, rhs=rhs, lhs=lhs)
251
+
252
+
253
+ def sqrt_inv_matmul(input: Anysor, rhs: Tensor, lhs: Optional[Tensor] = None) -> Tensor:
254
+ r"""
255
+ Given a positive definite matrix (or batch of matrices) :math:`\mathbf A`
256
+ and a right hand size :math:`\mathbf R`,
257
+ computes
258
+
259
+ .. math::
260
+ \begin{equation}
261
+ \mathbf A^{-1/2} \mathbf R,
262
+ \end{equation}
263
+
264
+ If :attr:`lhs` is supplied, computes
265
+
266
+ .. math::
267
+ \begin{equation}
268
+ \mathbf L \mathbf A^{-1/2} \mathbf R,
269
+ \end{equation}
270
+
271
+ where :math:`\mathbf L` is :attr:`lhs`.
272
+ (Supplying :attr:`lhs` can reduce the number of solver calls required in the backward pass.)
273
+
274
+ :param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
275
+ :param rhs: :math:`\mathbf R` - the right hand side
276
+ :param lhs: :math:`\mathbf L` - the left hand side
277
+ :return: :math:`\mathbf A^{-1/2} \mathbf R` or :math:`\mathbf L \mathbf A^{-1/2} \mathbf R`.
278
+ """
279
+ return linear_operator.sqrt_inv_matmul(input=input, rhs=rhs, lhs=lhs)
280
+
281
+
282
+ # Read version number as written by setuptools_scm
283
+ try:
284
+ from qpytorch.version import version as __version__
285
+ except Exception: # pragma: no cover
286
+ __version__ = "Unknown" # pragma: no cover
287
+
288
+
289
+ __all__ = [
290
+ # Submodules
291
+ "distributions",
292
+ "kernels",
293
+ "lazy",
294
+ "likelihoods",
295
+ "means",
296
+ "metrics",
297
+ "mlls",
298
+ "models",
299
+ "optim",
300
+ "priors",
301
+ "utils",
302
+ "variational",
303
+ # Classes
304
+ "Module",
305
+ "ExactMarginalLogLikelihood",
306
+ # Functions
307
+ "add_diagonal",
308
+ "add_jitter",
309
+ "dsmm",
310
+ "inv_quad",
311
+ "inv_quad_logdet",
312
+ "pivoted_cholesky",
313
+ "root_decomposition",
314
+ "root_inv_decomposition",
315
+ "solve",
316
+ "sqrt_inv_matmul",
317
+ # Context managers
318
+ "beta_features",
319
+ "settings",
320
+ # Other
321
+ "__version__",
322
+ # Deprecated
323
+ "inv_matmul",
324
+ "logdet",
325
+ "log_normal_cdf",
326
+ "matmul",
327
+ ]
@@ -0,0 +1,3 @@
1
+ from gpytorch.constraints import GreaterThan, Interval, LessThan, Positive
2
+
3
+ __all__ = ["GreaterThan", "Interval", "LessThan", "Positive"]
@@ -0,0 +1,21 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from .delta import Delta
4
+ from gpytorch.distributions.distribution import Distribution
5
+ from .qexponential import QExponential
6
+ from gpytorch.distributions.multivariate_normal import MultivariateNormal
7
+ from gpytorch.distributions.multitask_multivariate_normal import MultitaskMultivariateNormal
8
+ from .multivariate_qexponential import MultivariateQExponential
9
+ from .multitask_multivariate_qexponential import MultitaskMultivariateQExponential
10
+ from .power import Power
11
+
12
+ # Get the set of distributions from either PyTorch or Pyro
13
+ try:
14
+ # If pyro is installed, use that set of base distributions
15
+ import pyro.distributions as base_distributions
16
+ except ImportError:
17
+ # Otherwise, use PyTorch
18
+ import torch.distributions as base_distributions
19
+
20
+
21
+ __all__ = ["Delta", "QExponential", "Distribution", "MultivariateNormal", "MultitaskMultivariateNormal", "MultivariateQExponential", "MultitaskMultivariateQExponential", "Power", "base_distributions"]
@@ -0,0 +1,86 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import numbers
4
+
5
+ import torch
6
+ from torch.distributions import constraints
7
+ from torch.distributions.kl import register_kl
8
+
9
+ from gpytorch.distributions.distribution import Distribution
10
+ from gpytorch.distributions.multivariate_normal import MultivariateNormal
11
+ from .multivariate_qexponential import MultivariateQExponential
12
+
13
+ try:
14
+ from pyro.distributions import Delta
15
+
16
+ except ImportError:
17
+ # Mostly copied from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/delta.py
18
+ class Delta(Distribution):
19
+ """
20
+ Degenerate discrete distribution (a single point).
21
+
22
+ Discrete distribution that assigns probability one to the single element in
23
+ its support. Delta distribution parameterized by a random choice should not
24
+ be used with MCMC based inference, as doing so produces incorrect results.
25
+
26
+ :param torch.Tensor v: The single support element.
27
+ :param torch.Tensor log_density: An optional density for this Delta. This
28
+ is useful to keep the class of :class:`Delta` distributions closed
29
+ under differentiable transformation.
30
+ :param int event_dim: Optional event dimension, defaults to zero.
31
+ """
32
+
33
+ arg_constraints = {"v": constraints.real, "log_density": constraints.real}
34
+ has_rsample = True
35
+
36
+ def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None):
37
+ if event_dim > v.dim():
38
+ raise ValueError("Expected event_dim <= v.dim(), actual {} vs {}".format(event_dim, v.dim()))
39
+ batch_dim = v.dim() - event_dim
40
+ batch_shape = v.shape[:batch_dim]
41
+ event_shape = v.shape[batch_dim:]
42
+ if isinstance(log_density, numbers.Number):
43
+ log_density = torch.full(batch_shape, log_density, dtype=v.dtype, device=v.device)
44
+ elif validate_args and log_density.shape != batch_shape:
45
+ raise ValueError("Expected log_density.shape = {}, actual {}".format(log_density.shape, batch_shape))
46
+ self.v = v
47
+ self.log_density = log_density
48
+ super().__init__(batch_shape, event_shape, validate_args=validate_args)
49
+
50
+ def expand(self, batch_shape, _instance=None):
51
+ new = self._get_checked_instance(Delta, _instance)
52
+ batch_shape = torch.Size(batch_shape)
53
+ new.v = self.v.expand(batch_shape + self.event_shape)
54
+ new.log_density = self.log_density.expand(batch_shape)
55
+ super().__init__(batch_shape, self.event_shape, validate_args=False)
56
+ new._validate_args = self._validate_args
57
+ return new
58
+
59
+ def rsample(self, sample_shape=torch.Size()):
60
+ shape = sample_shape + self.v.shape
61
+ return self.v.expand(shape)
62
+
63
+ def log_prob(self, x):
64
+ v = self.v.expand(self.batch_shape + self.event_shape)
65
+ log_prob = (x == v).type(x.dtype).log()
66
+ if len(self.event_shape):
67
+ log_prob = log_prob.sum(list(range(-1, -len(self.event_shape) - 1, -1)))
68
+ return log_prob + self.log_density
69
+
70
+ @property
71
+ def mean(self):
72
+ return self.v
73
+
74
+ @property
75
+ def variance(self):
76
+ return torch.zeros_like(self.v)
77
+
78
+
79
+ @register_kl(Delta, MultivariateNormal)
80
+ def kl_mvn_mvn(p_dist, q_dist):
81
+ return -q_dist.log_prob(p_dist.mean)
82
+
83
+
84
+ @register_kl(Delta, MultivariateQExponential)
85
+ def kl_qep_qep(p_dist, q_dist):
86
+ return -q_dist.log_prob(p_dist.mean)