qpytorch 0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of qpytorch might be problematic. Click here for more details.
- qpytorch/__init__.py +327 -0
- qpytorch/constraints/__init__.py +3 -0
- qpytorch/distributions/__init__.py +21 -0
- qpytorch/distributions/delta.py +86 -0
- qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
- qpytorch/distributions/multivariate_qexponential.py +581 -0
- qpytorch/distributions/power.py +113 -0
- qpytorch/distributions/qexponential.py +153 -0
- qpytorch/functions/__init__.py +58 -0
- qpytorch/kernels/__init__.py +80 -0
- qpytorch/kernels/grid_interpolation_kernel.py +213 -0
- qpytorch/kernels/inducing_point_kernel.py +151 -0
- qpytorch/kernels/kernel.py +695 -0
- qpytorch/kernels/matern32_kernel_grad.py +155 -0
- qpytorch/kernels/matern52_kernel_grad.py +194 -0
- qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
- qpytorch/kernels/polynomial_kernel_grad.py +88 -0
- qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
- qpytorch/kernels/rbf_kernel_grad.py +125 -0
- qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
- qpytorch/kernels/rff_kernel.py +153 -0
- qpytorch/lazy/__init__.py +9 -0
- qpytorch/likelihoods/__init__.py +66 -0
- qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
- qpytorch/likelihoods/beta_likelihood.py +76 -0
- qpytorch/likelihoods/gaussian_likelihood.py +472 -0
- qpytorch/likelihoods/laplace_likelihood.py +59 -0
- qpytorch/likelihoods/likelihood.py +437 -0
- qpytorch/likelihoods/likelihood_list.py +60 -0
- qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
- qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
- qpytorch/likelihoods/noise_models.py +184 -0
- qpytorch/likelihoods/qexponential_likelihood.py +494 -0
- qpytorch/likelihoods/softmax_likelihood.py +97 -0
- qpytorch/likelihoods/student_t_likelihood.py +90 -0
- qpytorch/means/__init__.py +23 -0
- qpytorch/metrics/__init__.py +17 -0
- qpytorch/mlls/__init__.py +53 -0
- qpytorch/mlls/_approximate_mll.py +79 -0
- qpytorch/mlls/deep_approximate_mll.py +30 -0
- qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
- qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
- qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
- qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
- qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
- qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
- qpytorch/mlls/marginal_log_likelihood.py +48 -0
- qpytorch/mlls/predictive_log_likelihood.py +76 -0
- qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
- qpytorch/mlls/variational_elbo.py +77 -0
- qpytorch/models/__init__.py +72 -0
- qpytorch/models/approximate_qep.py +115 -0
- qpytorch/models/deep_qeps/__init__.py +22 -0
- qpytorch/models/deep_qeps/deep_qep.py +155 -0
- qpytorch/models/deep_qeps/dspp.py +114 -0
- qpytorch/models/exact_prediction_strategies.py +880 -0
- qpytorch/models/exact_qep.py +349 -0
- qpytorch/models/model_list.py +100 -0
- qpytorch/models/pyro/__init__.py +28 -0
- qpytorch/models/pyro/_pyro_mixin.py +57 -0
- qpytorch/models/pyro/distributions/__init__.py +5 -0
- qpytorch/models/pyro/pyro_qep.py +105 -0
- qpytorch/models/qep.py +7 -0
- qpytorch/models/qeplvm/__init__.py +6 -0
- qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
- qpytorch/models/qeplvm/latent_variable.py +102 -0
- qpytorch/module.py +30 -0
- qpytorch/optim/__init__.py +5 -0
- qpytorch/priors/__init__.py +42 -0
- qpytorch/priors/qep_priors.py +81 -0
- qpytorch/test/__init__.py +22 -0
- qpytorch/test/base_likelihood_test_case.py +106 -0
- qpytorch/test/model_test_case.py +150 -0
- qpytorch/test/variational_test_case.py +400 -0
- qpytorch/utils/__init__.py +38 -0
- qpytorch/utils/warnings.py +37 -0
- qpytorch/variational/__init__.py +47 -0
- qpytorch/variational/_variational_distribution.py +61 -0
- qpytorch/variational/_variational_strategy.py +391 -0
- qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
- qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
- qpytorch/variational/cholesky_variational_distribution.py +65 -0
- qpytorch/variational/ciq_variational_strategy.py +352 -0
- qpytorch/variational/delta_variational_distribution.py +41 -0
- qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
- qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
- qpytorch/variational/lmc_variational_strategy.py +248 -0
- qpytorch/variational/mean_field_variational_distribution.py +58 -0
- qpytorch/variational/multitask_variational_strategy.py +317 -0
- qpytorch/variational/natural_variational_distribution.py +152 -0
- qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
- qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
- qpytorch/variational/tril_natural_variational_distribution.py +130 -0
- qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
- qpytorch/variational/unwhitened_variational_strategy.py +225 -0
- qpytorch/variational/variational_strategy.py +280 -0
- qpytorch/version.py +4 -0
- qpytorch-0.1.dist-info/LICENSE +21 -0
- qpytorch-0.1.dist-info/METADATA +177 -0
- qpytorch-0.1.dist-info/RECORD +102 -0
- qpytorch-0.1.dist-info/WHEEL +5 -0
- qpytorch-0.1.dist-info/top_level.txt +1 -0
qpytorch/__init__.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import linear_operator
|
|
6
|
+
import torch
|
|
7
|
+
from linear_operator import LinearOperator
|
|
8
|
+
from torch import Tensor
|
|
9
|
+
|
|
10
|
+
from gpytorch import (
|
|
11
|
+
beta_features,
|
|
12
|
+
settings,
|
|
13
|
+
)
|
|
14
|
+
from . import (
|
|
15
|
+
distributions,
|
|
16
|
+
kernels,
|
|
17
|
+
lazy,
|
|
18
|
+
likelihoods,
|
|
19
|
+
means,
|
|
20
|
+
metrics,
|
|
21
|
+
mlls,
|
|
22
|
+
models,
|
|
23
|
+
optim,
|
|
24
|
+
priors,
|
|
25
|
+
utils,
|
|
26
|
+
variational,
|
|
27
|
+
)
|
|
28
|
+
from .functions import inv_matmul, log_normal_cdf, logdet, matmul # Deprecated
|
|
29
|
+
from .mlls import ExactMarginalLogLikelihood
|
|
30
|
+
from gpytorch.module import Module
|
|
31
|
+
|
|
32
|
+
Anysor = Union[LinearOperator, Tensor]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def add_diagonal(input: Anysor, diag: Tensor) -> LinearOperator:
|
|
36
|
+
r"""
|
|
37
|
+
Adds an element to the diagonal of the matrix :math:`\mathbf A`.
|
|
38
|
+
|
|
39
|
+
:param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
|
|
40
|
+
:param diag: Diagonal to add
|
|
41
|
+
:return: :math:`\mathbf A + \text{diag}(\mathbf d)`, where :math:`\mathbf A` is the linear operator
|
|
42
|
+
and :math:`\mathbf d` is the diagonal component
|
|
43
|
+
"""
|
|
44
|
+
return linear_operator.add_diagonal(input=input, diag=diag)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def add_jitter(input: Anysor, jitter_val: float = 1e-3) -> Anysor:
|
|
48
|
+
r"""
|
|
49
|
+
Adds jitter (i.e., a small diagonal component) to the matrix this
|
|
50
|
+
LinearOperator represents.
|
|
51
|
+
This is equivalent to calling :meth:`~linear_operator.operators.LinearOperator.add_diagonal`
|
|
52
|
+
with a scalar tensor.
|
|
53
|
+
|
|
54
|
+
:param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
|
|
55
|
+
:param jitter_val: The diagonal component to add
|
|
56
|
+
:return: :math:`\mathbf A + \alpha (\mathbf I)`, where :math:`\mathbf A` is the linear operator
|
|
57
|
+
and :math:`\alpha` is :attr:`jitter_val`.
|
|
58
|
+
"""
|
|
59
|
+
return linear_operator.add_jitter(input=input, jitter_val=jitter_val)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def diagonalization(input: Anysor, method: Optional[str] = None) -> Tuple[Tensor, Tensor]:
|
|
63
|
+
r"""
|
|
64
|
+
Returns a (usually partial) diagonalization of a symmetric positive definite matrix (or batch of matrices).
|
|
65
|
+
:math:`\mathbf A`.
|
|
66
|
+
Options are either "lanczos" or "symeig". "lanczos" runs Lanczos while
|
|
67
|
+
"symeig" runs LinearOperator.symeig.
|
|
68
|
+
|
|
69
|
+
:param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
|
|
70
|
+
:param method: Specify the method to use ("lanczos" or "symeig"). The method will be determined
|
|
71
|
+
based on size if not specified.
|
|
72
|
+
:return: eigenvalues and eigenvectors representing the diagonalization.
|
|
73
|
+
"""
|
|
74
|
+
return linear_operator.diagonalization(input=input, method=method)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def dsmm(
|
|
78
|
+
sparse_mat: Union[torch.sparse.HalfTensor, torch.sparse.FloatTensor, torch.sparse.DoubleTensor],
|
|
79
|
+
dense_mat: Tensor,
|
|
80
|
+
) -> Tensor:
|
|
81
|
+
r"""
|
|
82
|
+
Performs the (batch) matrix multiplication :math:`\mathbf{SD}`
|
|
83
|
+
where :math:`\mathbf S` is a sparse matrix and :math:`\mathbf D` is a dense matrix.
|
|
84
|
+
|
|
85
|
+
:param sparse_mat: Sparse matrix :math:`\mathbf S` (... x M x N)
|
|
86
|
+
:param dense_mat: Dense matrix :math:`\mathbf D` (... x N x O)
|
|
87
|
+
:return: :math:`\mathbf S \mathbf D` (... x M x N)
|
|
88
|
+
"""
|
|
89
|
+
return linear_operator.dsmm(sparse_mat=sparse_mat, dense_mat=dense_mat)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def inv_quad(input: Anysor, inv_quad_rhs: Tensor, reduce_inv_quad: bool = True) -> Tensor:
|
|
93
|
+
r"""
|
|
94
|
+
Computes an inverse quadratic form (w.r.t self) with several right hand sides, i.e:
|
|
95
|
+
|
|
96
|
+
.. math::
|
|
97
|
+
\text{tr}\left( \mathbf R^\top \mathbf A^{-1} \mathbf R \right),
|
|
98
|
+
|
|
99
|
+
where :math:`\mathbf A` is a positive definite matrix (or batch of matrices) and :math:`\mathbf R`
|
|
100
|
+
represents the right hand sides (:attr:`inv_quad_rhs`).
|
|
101
|
+
|
|
102
|
+
If :attr:`reduce_inv_quad` is set to false (and :attr:`inv_quad_rhs` is supplied),
|
|
103
|
+
the function instead computes
|
|
104
|
+
|
|
105
|
+
.. math::
|
|
106
|
+
\text{diag}\left( \mathbf R^\top \mathbf A^{-1} \mathbf R \right).
|
|
107
|
+
|
|
108
|
+
:param input: :math:`\mathbf A` - the positive definite matrix (... X N X N)
|
|
109
|
+
:param inv_quad_rhs: :math:`\mathbf R` - the right hand sides of the inverse quadratic term (... x N x M)
|
|
110
|
+
:param reduce_inv_quad: Whether to compute
|
|
111
|
+
:math:`\text{tr}\left( \mathbf R^\top \mathbf A^{-1} \mathbf R \right)`
|
|
112
|
+
or :math:`\text{diag}\left( \mathbf R^\top \mathbf A^{-1} \mathbf R \right)`.
|
|
113
|
+
:returns: The inverse quadratic term.
|
|
114
|
+
If `reduce_inv_quad=True`, the inverse quadratic term is of shape (...). Otherwise, it is (... x M).
|
|
115
|
+
"""
|
|
116
|
+
return linear_operator.inv_quad(input=input, inv_quad_rhs=inv_quad_rhs, reduce_inv_quad=reduce_inv_quad)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def inv_quad_logdet(
|
|
120
|
+
input: Anysor,
|
|
121
|
+
inv_quad_rhs: Optional[Tensor] = None,
|
|
122
|
+
logdet: bool = False,
|
|
123
|
+
reduce_inv_quad: bool = True,
|
|
124
|
+
) -> Tuple[Tensor, Tensor]:
|
|
125
|
+
r"""
|
|
126
|
+
Calls both :func:`inv_quad_logdet` and :func:`logdet` on a positive definite matrix (or batch) :math:`\mathbf A`.
|
|
127
|
+
However, calling this method is far more efficient and stable than calling each method independently.
|
|
128
|
+
|
|
129
|
+
:param input: :math:`\mathbf A` - the positive definite matrix (... X N X N)
|
|
130
|
+
:param inv_quad_rhs: :math:`\mathbf R` - the right hand sides of the inverse quadratic term (... x N x M)
|
|
131
|
+
:param logdet: Whether or not to compute the
|
|
132
|
+
logdet term :math:`\log \vert \mathbf A \vert`.
|
|
133
|
+
:param reduce_inv_quad: Whether to compute
|
|
134
|
+
:math:`\text{tr}\left( \mathbf R^\top \mathbf A^{-1} \mathbf R \right)`
|
|
135
|
+
or :math:`\text{diag}\left( \mathbf R^\top \mathbf A^{-1} \mathbf R \right)`.
|
|
136
|
+
:returns: The inverse quadratic term (or None), and the logdet term (or None).
|
|
137
|
+
If `reduce_inv_quad=True`, the inverse quadratic term is of shape (...). Otherwise, it is (... x M).
|
|
138
|
+
"""
|
|
139
|
+
return linear_operator.inv_quad_logdet(
|
|
140
|
+
input=input,
|
|
141
|
+
inv_quad_rhs=inv_quad_rhs,
|
|
142
|
+
logdet=logdet,
|
|
143
|
+
reduce_inv_quad=reduce_inv_quad,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def pivoted_cholesky(
|
|
148
|
+
input: Anysor,
|
|
149
|
+
rank: int,
|
|
150
|
+
error_tol: Optional[float] = None,
|
|
151
|
+
return_pivots: bool = False,
|
|
152
|
+
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
|
153
|
+
r"""
|
|
154
|
+
Performs a partial pivoted Cholesky factorization of a positive definite matrix (or batch of matrices).
|
|
155
|
+
:math:`\mathbf L \mathbf L^\top = \mathbf A`.
|
|
156
|
+
The partial pivoted Cholesky factor :math:`\mathbf L \in \mathbb R^{N \times \text{rank}}`
|
|
157
|
+
forms a low rank approximation to the LinearOperator.
|
|
158
|
+
|
|
159
|
+
The pivots are selected greedily, correspoading to the maximum diagonal element in the
|
|
160
|
+
residual after each Cholesky iteration. See `Harbrecht et al., 2012`_.
|
|
161
|
+
|
|
162
|
+
:param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
|
|
163
|
+
:param rank: The size of the partial pivoted Cholesky factor.
|
|
164
|
+
:param error_tol: Defines an optional stopping criterion.
|
|
165
|
+
If the residual of the factorization is less than :attr:`error_tol`, then the
|
|
166
|
+
factorization will exit early. This will result in a :math:`\leq \text{ rank}` factor.
|
|
167
|
+
:param return_pivots: Whether or not to return the pivots alongside
|
|
168
|
+
the partial pivoted Cholesky factor.
|
|
169
|
+
:return: The `... x N x rank` factor (and optionally the `... x N` pivots if :attr:`return_pivots` is True).
|
|
170
|
+
|
|
171
|
+
.. _Harbrecht et al., 2012:
|
|
172
|
+
https://www.sciencedirect.com/science/article/pii/S0168927411001814
|
|
173
|
+
"""
|
|
174
|
+
return linear_operator.pivoted_cholesky(input=input, rank=rank, return_pivots=return_pivots)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def root_decomposition(input: Anysor, method: Optional[str] = None) -> LinearOperator:
|
|
178
|
+
r"""
|
|
179
|
+
Returns a (usually low-rank) root decomposition linear operator of the
|
|
180
|
+
positive definite matrix (or batch of matrices) :math:`\mathbf A`.
|
|
181
|
+
This can be used for sampling from a Gaussian distribution, or for obtaining a
|
|
182
|
+
low-rank version of a matrix.
|
|
183
|
+
|
|
184
|
+
:param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
|
|
185
|
+
:param method: Which method to use to perform the root decomposition. Choices are:
|
|
186
|
+
"cholesky", "lanczos", "symeig", "pivoted_cholesky", or "svd".
|
|
187
|
+
:return: A tensor :math:`\mathbf R` such that :math:`\mathbf R \mathbf R^\top \approx \mathbf A`.
|
|
188
|
+
"""
|
|
189
|
+
return linear_operator.root_decomposition(input=input, method=method)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def root_inv_decomposition(
|
|
193
|
+
input: Anysor,
|
|
194
|
+
initial_vectors: Optional[Tensor] = None,
|
|
195
|
+
test_vectors: Optional[Tensor] = None,
|
|
196
|
+
method: Optional[str] = None,
|
|
197
|
+
) -> LinearOperator:
|
|
198
|
+
r"""
|
|
199
|
+
Returns a (usually low-rank) inverse root decomposition linear operator
|
|
200
|
+
of the PSD LinearOperator :math:`\mathbf A`.
|
|
201
|
+
This can be used for sampling from a Gaussian distribution, or for obtaining a
|
|
202
|
+
low-rank version of a matrix.
|
|
203
|
+
|
|
204
|
+
The root_inv_decomposition is performed using a partial Lanczos tridiagonalization.
|
|
205
|
+
|
|
206
|
+
:param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
|
|
207
|
+
:param initial_vectors: Vectors used to initialize the Lanczos decomposition.
|
|
208
|
+
The best initialization vector (determined by :attr:`test_vectors`) will be chosen.
|
|
209
|
+
:param test_vectors: Vectors used to test the accuracy of the decomposition.
|
|
210
|
+
:param method: Root decomposition method to use (symeig, diagonalization, lanczos, or cholesky).
|
|
211
|
+
:return: A tensor :math:`\mathbf R` such that :math:`\mathbf R \mathbf R^\top \approx \mathbf A^{-1}`.
|
|
212
|
+
"""
|
|
213
|
+
return linear_operator.root_inv_decomposition(
|
|
214
|
+
input=input,
|
|
215
|
+
initial_vectors=initial_vectors,
|
|
216
|
+
test_vectors=test_vectors,
|
|
217
|
+
method=method,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def solve(input: Anysor, rhs: Tensor, lhs: Optional[Tensor] = None) -> Tensor:
|
|
222
|
+
r"""
|
|
223
|
+
Given a positive definite matrix (or batch of matrices) :math:`\mathbf A`,
|
|
224
|
+
computes a linear solve with right hand side :math:`\mathbf R`:
|
|
225
|
+
|
|
226
|
+
.. math::
|
|
227
|
+
\begin{equation}
|
|
228
|
+
\mathbf A^{-1} \mathbf R,
|
|
229
|
+
\end{equation}
|
|
230
|
+
|
|
231
|
+
where :math:`\mathbf R` is :attr:`right_tensor` and :math:`\mathbf A` is the LinearOperator.
|
|
232
|
+
|
|
233
|
+
.. note::
|
|
234
|
+
Unlike :func:`torch.linalg.solve`, this function can take an optional :attr:`left_tensor` attribute.
|
|
235
|
+
If this is supplied :func:`gpytorch.solve` computes
|
|
236
|
+
|
|
237
|
+
.. math::
|
|
238
|
+
\begin{equation}
|
|
239
|
+
\mathbf L \mathbf A^{-1} \mathbf R,
|
|
240
|
+
\end{equation}
|
|
241
|
+
|
|
242
|
+
where :math:`\mathbf L` is :attr:`left_tensor`.
|
|
243
|
+
Supplying this can reduce the number of solver calls required in the backward pass.
|
|
244
|
+
|
|
245
|
+
:param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
|
|
246
|
+
:param rhs: :math:`\mathbf R` - the right hand side
|
|
247
|
+
:param lhs: :math:`\mathbf L` - the left hand side
|
|
248
|
+
:return: :math:`\mathbf A^{-1} \mathbf R` or :math:`\mathbf L \mathbf A^{-1} \mathbf R`.
|
|
249
|
+
"""
|
|
250
|
+
return linear_operator.solve(input=input, rhs=rhs, lhs=lhs)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def sqrt_inv_matmul(input: Anysor, rhs: Tensor, lhs: Optional[Tensor] = None) -> Tensor:
|
|
254
|
+
r"""
|
|
255
|
+
Given a positive definite matrix (or batch of matrices) :math:`\mathbf A`
|
|
256
|
+
and a right hand size :math:`\mathbf R`,
|
|
257
|
+
computes
|
|
258
|
+
|
|
259
|
+
.. math::
|
|
260
|
+
\begin{equation}
|
|
261
|
+
\mathbf A^{-1/2} \mathbf R,
|
|
262
|
+
\end{equation}
|
|
263
|
+
|
|
264
|
+
If :attr:`lhs` is supplied, computes
|
|
265
|
+
|
|
266
|
+
.. math::
|
|
267
|
+
\begin{equation}
|
|
268
|
+
\mathbf L \mathbf A^{-1/2} \mathbf R,
|
|
269
|
+
\end{equation}
|
|
270
|
+
|
|
271
|
+
where :math:`\mathbf L` is :attr:`lhs`.
|
|
272
|
+
(Supplying :attr:`lhs` can reduce the number of solver calls required in the backward pass.)
|
|
273
|
+
|
|
274
|
+
:param input: The matrix (or batch of matrices) :math:`\mathbf A` (... x N x N).
|
|
275
|
+
:param rhs: :math:`\mathbf R` - the right hand side
|
|
276
|
+
:param lhs: :math:`\mathbf L` - the left hand side
|
|
277
|
+
:return: :math:`\mathbf A^{-1/2} \mathbf R` or :math:`\mathbf L \mathbf A^{-1/2} \mathbf R`.
|
|
278
|
+
"""
|
|
279
|
+
return linear_operator.sqrt_inv_matmul(input=input, rhs=rhs, lhs=lhs)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
# Read version number as written by setuptools_scm
|
|
283
|
+
try:
|
|
284
|
+
from qpytorch.version import version as __version__
|
|
285
|
+
except Exception: # pragma: no cover
|
|
286
|
+
__version__ = "Unknown" # pragma: no cover
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
__all__ = [
|
|
290
|
+
# Submodules
|
|
291
|
+
"distributions",
|
|
292
|
+
"kernels",
|
|
293
|
+
"lazy",
|
|
294
|
+
"likelihoods",
|
|
295
|
+
"means",
|
|
296
|
+
"metrics",
|
|
297
|
+
"mlls",
|
|
298
|
+
"models",
|
|
299
|
+
"optim",
|
|
300
|
+
"priors",
|
|
301
|
+
"utils",
|
|
302
|
+
"variational",
|
|
303
|
+
# Classes
|
|
304
|
+
"Module",
|
|
305
|
+
"ExactMarginalLogLikelihood",
|
|
306
|
+
# Functions
|
|
307
|
+
"add_diagonal",
|
|
308
|
+
"add_jitter",
|
|
309
|
+
"dsmm",
|
|
310
|
+
"inv_quad",
|
|
311
|
+
"inv_quad_logdet",
|
|
312
|
+
"pivoted_cholesky",
|
|
313
|
+
"root_decomposition",
|
|
314
|
+
"root_inv_decomposition",
|
|
315
|
+
"solve",
|
|
316
|
+
"sqrt_inv_matmul",
|
|
317
|
+
# Context managers
|
|
318
|
+
"beta_features",
|
|
319
|
+
"settings",
|
|
320
|
+
# Other
|
|
321
|
+
"__version__",
|
|
322
|
+
# Deprecated
|
|
323
|
+
"inv_matmul",
|
|
324
|
+
"logdet",
|
|
325
|
+
"log_normal_cdf",
|
|
326
|
+
"matmul",
|
|
327
|
+
]
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
from .delta import Delta
|
|
4
|
+
from gpytorch.distributions.distribution import Distribution
|
|
5
|
+
from .qexponential import QExponential
|
|
6
|
+
from gpytorch.distributions.multivariate_normal import MultivariateNormal
|
|
7
|
+
from gpytorch.distributions.multitask_multivariate_normal import MultitaskMultivariateNormal
|
|
8
|
+
from .multivariate_qexponential import MultivariateQExponential
|
|
9
|
+
from .multitask_multivariate_qexponential import MultitaskMultivariateQExponential
|
|
10
|
+
from .power import Power
|
|
11
|
+
|
|
12
|
+
# Get the set of distributions from either PyTorch or Pyro
|
|
13
|
+
try:
|
|
14
|
+
# If pyro is installed, use that set of base distributions
|
|
15
|
+
import pyro.distributions as base_distributions
|
|
16
|
+
except ImportError:
|
|
17
|
+
# Otherwise, use PyTorch
|
|
18
|
+
import torch.distributions as base_distributions
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
__all__ = ["Delta", "QExponential", "Distribution", "MultivariateNormal", "MultitaskMultivariateNormal", "MultivariateQExponential", "MultitaskMultivariateQExponential", "Power", "base_distributions"]
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
import numbers
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch.distributions import constraints
|
|
7
|
+
from torch.distributions.kl import register_kl
|
|
8
|
+
|
|
9
|
+
from gpytorch.distributions.distribution import Distribution
|
|
10
|
+
from gpytorch.distributions.multivariate_normal import MultivariateNormal
|
|
11
|
+
from .multivariate_qexponential import MultivariateQExponential
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from pyro.distributions import Delta
|
|
15
|
+
|
|
16
|
+
except ImportError:
|
|
17
|
+
# Mostly copied from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/delta.py
|
|
18
|
+
class Delta(Distribution):
|
|
19
|
+
"""
|
|
20
|
+
Degenerate discrete distribution (a single point).
|
|
21
|
+
|
|
22
|
+
Discrete distribution that assigns probability one to the single element in
|
|
23
|
+
its support. Delta distribution parameterized by a random choice should not
|
|
24
|
+
be used with MCMC based inference, as doing so produces incorrect results.
|
|
25
|
+
|
|
26
|
+
:param torch.Tensor v: The single support element.
|
|
27
|
+
:param torch.Tensor log_density: An optional density for this Delta. This
|
|
28
|
+
is useful to keep the class of :class:`Delta` distributions closed
|
|
29
|
+
under differentiable transformation.
|
|
30
|
+
:param int event_dim: Optional event dimension, defaults to zero.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
arg_constraints = {"v": constraints.real, "log_density": constraints.real}
|
|
34
|
+
has_rsample = True
|
|
35
|
+
|
|
36
|
+
def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None):
|
|
37
|
+
if event_dim > v.dim():
|
|
38
|
+
raise ValueError("Expected event_dim <= v.dim(), actual {} vs {}".format(event_dim, v.dim()))
|
|
39
|
+
batch_dim = v.dim() - event_dim
|
|
40
|
+
batch_shape = v.shape[:batch_dim]
|
|
41
|
+
event_shape = v.shape[batch_dim:]
|
|
42
|
+
if isinstance(log_density, numbers.Number):
|
|
43
|
+
log_density = torch.full(batch_shape, log_density, dtype=v.dtype, device=v.device)
|
|
44
|
+
elif validate_args and log_density.shape != batch_shape:
|
|
45
|
+
raise ValueError("Expected log_density.shape = {}, actual {}".format(log_density.shape, batch_shape))
|
|
46
|
+
self.v = v
|
|
47
|
+
self.log_density = log_density
|
|
48
|
+
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
|
49
|
+
|
|
50
|
+
def expand(self, batch_shape, _instance=None):
|
|
51
|
+
new = self._get_checked_instance(Delta, _instance)
|
|
52
|
+
batch_shape = torch.Size(batch_shape)
|
|
53
|
+
new.v = self.v.expand(batch_shape + self.event_shape)
|
|
54
|
+
new.log_density = self.log_density.expand(batch_shape)
|
|
55
|
+
super().__init__(batch_shape, self.event_shape, validate_args=False)
|
|
56
|
+
new._validate_args = self._validate_args
|
|
57
|
+
return new
|
|
58
|
+
|
|
59
|
+
def rsample(self, sample_shape=torch.Size()):
|
|
60
|
+
shape = sample_shape + self.v.shape
|
|
61
|
+
return self.v.expand(shape)
|
|
62
|
+
|
|
63
|
+
def log_prob(self, x):
|
|
64
|
+
v = self.v.expand(self.batch_shape + self.event_shape)
|
|
65
|
+
log_prob = (x == v).type(x.dtype).log()
|
|
66
|
+
if len(self.event_shape):
|
|
67
|
+
log_prob = log_prob.sum(list(range(-1, -len(self.event_shape) - 1, -1)))
|
|
68
|
+
return log_prob + self.log_density
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def mean(self):
|
|
72
|
+
return self.v
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def variance(self):
|
|
76
|
+
return torch.zeros_like(self.v)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@register_kl(Delta, MultivariateNormal)
|
|
80
|
+
def kl_mvn_mvn(p_dist, q_dist):
|
|
81
|
+
return -q_dist.log_prob(p_dist.mean)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@register_kl(Delta, MultivariateQExponential)
|
|
85
|
+
def kl_qep_qep(p_dist, q_dist):
|
|
86
|
+
return -q_dist.log_prob(p_dist.mean)
|