qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,151 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import copy
4
+ import math
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ from linear_operator import to_dense
9
+ from linear_operator.operators import (
10
+ DiagLinearOperator,
11
+ LowRankRootAddedDiagLinearOperator,
12
+ LowRankRootLinearOperator,
13
+ MatmulLinearOperator,
14
+ )
15
+ from linear_operator.utils.cholesky import psd_safe_cholesky
16
+ from torch import Tensor
17
+
18
+ from .. import settings
19
+ from ..distributions import MultivariateNormal
20
+ from ..likelihoods import Likelihood
21
+ from ..mlls import InducingPointKernelAddedLossTerm
22
+ from ..models import exact_prediction_strategies
23
+ from .kernel import Kernel
24
+
25
+
26
+ class InducingPointKernel(Kernel):
27
+ def __init__(
28
+ self,
29
+ base_kernel: Kernel,
30
+ inducing_points: Tensor,
31
+ likelihood: Likelihood,
32
+ active_dims: Optional[Tuple[int, ...]] = None,
33
+ ):
34
+ super(InducingPointKernel, self).__init__(active_dims=active_dims)
35
+ self.base_kernel = base_kernel
36
+ self.likelihood = likelihood
37
+
38
+ if inducing_points.ndimension() == 1:
39
+ inducing_points = inducing_points.unsqueeze(-1)
40
+
41
+ self.register_parameter(name="inducing_points", parameter=torch.nn.Parameter(inducing_points))
42
+ self.register_added_loss_term("inducing_point_loss_term")
43
+
44
+ def _clear_cache(self):
45
+ if hasattr(self, "_cached_kernel_mat"):
46
+ del self._cached_kernel_mat
47
+ if hasattr(self, "_cached_kernel_inv_root"):
48
+ del self._cached_kernel_inv_root
49
+
50
+ @property
51
+ def _inducing_mat(self):
52
+ if not self.training and hasattr(self, "_cached_kernel_mat"):
53
+ return self._cached_kernel_mat
54
+ else:
55
+ res = to_dense(self.base_kernel(self.inducing_points, self.inducing_points))
56
+ if not self.training:
57
+ self._cached_kernel_mat = res
58
+ return res
59
+
60
+ @property
61
+ def _inducing_inv_root(self):
62
+ if not self.training and hasattr(self, "_cached_kernel_inv_root"):
63
+ return self._cached_kernel_inv_root
64
+ else:
65
+ chol = psd_safe_cholesky(self._inducing_mat, upper=True)
66
+ eye = torch.eye(chol.size(-1), device=chol.device, dtype=chol.dtype)
67
+ inv_root = torch.linalg.solve_triangular(chol, eye, upper=True)
68
+
69
+ res = inv_root
70
+ if not self.training:
71
+ self._cached_kernel_inv_root = res
72
+ return res
73
+
74
+ def _get_covariance(self, x1, x2):
75
+ k_ux1 = to_dense(self.base_kernel(x1, self.inducing_points))
76
+ if torch.equal(x1, x2):
77
+ covar = LowRankRootLinearOperator(k_ux1.matmul(self._inducing_inv_root))
78
+
79
+ # Diagonal correction for predictive posterior
80
+ if not self.training and settings.sgpr_diagonal_correction.on():
81
+ correction = (self.base_kernel(x1, x2, diag=True) - covar.diagonal(dim1=-1, dim2=-2)).clamp(0, math.inf)
82
+ covar = LowRankRootAddedDiagLinearOperator(covar, DiagLinearOperator(correction))
83
+ else:
84
+ k_ux2 = to_dense(self.base_kernel(x2, self.inducing_points))
85
+ covar = MatmulLinearOperator(
86
+ k_ux1.matmul(self._inducing_inv_root), k_ux2.matmul(self._inducing_inv_root).transpose(-1, -2)
87
+ )
88
+
89
+ return covar
90
+
91
+ def _covar_diag(self, inputs):
92
+ if inputs.ndimension() == 1:
93
+ inputs = inputs.unsqueeze(1)
94
+
95
+ # Get diagonal of covar
96
+ covar_diag = to_dense(self.base_kernel(inputs, diag=True))
97
+ return DiagLinearOperator(covar_diag)
98
+
99
+ def forward(self, x1, x2, diag=False, **kwargs):
100
+ covar = self._get_covariance(x1, x2)
101
+
102
+ if self.training:
103
+ if not torch.equal(x1, x2):
104
+ raise RuntimeError("x1 should equal x2 in training mode")
105
+ zero_mean = torch.zeros_like(x1.select(-1, 0))
106
+ new_added_loss_term = InducingPointKernelAddedLossTerm(
107
+ MultivariateNormal(zero_mean, self._covar_diag(x1)),
108
+ MultivariateNormal(zero_mean, covar),
109
+ self.likelihood,
110
+ )
111
+ self.update_added_loss_term("inducing_point_loss_term", new_added_loss_term)
112
+
113
+ if diag:
114
+ return covar.diagonal(dim1=-1, dim2=-2)
115
+ else:
116
+ return covar
117
+
118
+ def num_outputs_per_input(self, x1, x2):
119
+ return self.base_kernel.num_outputs_per_input(x1, x2)
120
+
121
+ def __deepcopy__(self, memo):
122
+ replace_inv_root = False
123
+ replace_kernel_mat = False
124
+
125
+ if hasattr(self, "_cached_kernel_inv_root"):
126
+ replace_inv_root = True
127
+ kernel_inv_root = self._cached_kernel_inv_root
128
+ if hasattr(self, "_cached_kernel_mat"):
129
+ replace_kernel_mat = True
130
+ kernel_mat = self._cached_kernel_mat
131
+
132
+ cp = self.__class__(
133
+ base_kernel=copy.deepcopy(self.base_kernel),
134
+ inducing_points=copy.deepcopy(self.inducing_points),
135
+ likelihood=self.likelihood,
136
+ active_dims=self.active_dims,
137
+ )
138
+
139
+ if replace_inv_root:
140
+ cp._cached_kernel_inv_root = kernel_inv_root
141
+
142
+ if replace_kernel_mat:
143
+ cp._cached_kernel_mat = kernel_mat
144
+
145
+ return cp
146
+
147
+ def prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood):
148
+ # Allow for fast variances
149
+ return exact_prediction_strategies.SGPRPredictionStrategy(
150
+ train_inputs, train_prior_dist, train_labels, likelihood
151
+ )