qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,76 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torch.distributions import Beta
8
+
9
+ from ..constraints import Interval, Positive
10
+ from ..distributions import base_distributions
11
+ from ..priors import Prior
12
+ from .likelihood import _OneDimensionalLikelihood
13
+
14
+
15
+ class BetaLikelihood(_OneDimensionalLikelihood):
16
+ r"""
17
+ A Beta likelihood for regressing over percentages.
18
+
19
+ The Beta distribution is parameterized by :math:`\alpha > 0` and :math:`\beta > 0` parameters
20
+ which roughly correspond to the number of prior positive and negative observations.
21
+ We instead parameterize it through a mixture :math:`m \in [0, 1]` and scale :math:`s > 0` parameter.
22
+
23
+ .. math::
24
+ \begin{equation*}
25
+ \alpha = ms, \quad \beta = (1-m)s
26
+ \end{equation*}
27
+
28
+ The mixture parameter is the output of the GP passed through a logit function :math:`\sigma(\cdot)`.
29
+ The scale parameter is learned.
30
+
31
+ .. math::
32
+ p(y \mid f) = \text{Beta} \left( \sigma(f) s , (1 - \sigma(f)) s\right)
33
+
34
+ :param batch_shape: The batch shape of the learned noise parameter (default: []).
35
+ :param scale_prior: Prior for scale parameter :math:`s`.
36
+ :param scale_constraint: Constraint for scale parameter :math:`s`.
37
+
38
+ :ivar torch.Tensor scale: :math:`s` parameter (scale)
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ batch_shape: torch.Size = torch.Size([]),
44
+ scale_prior: Optional[Prior] = None,
45
+ scale_constraint: Optional[Interval] = None,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ if scale_constraint is None:
50
+ scale_constraint = Positive()
51
+
52
+ self.raw_scale = torch.nn.Parameter(torch.ones(*batch_shape, 1))
53
+ if scale_prior is not None:
54
+ self.register_prior("scale_prior", scale_prior, lambda m: m.scale, lambda m, v: m._set_scale(v))
55
+
56
+ self.register_constraint("raw_scale", scale_constraint)
57
+
58
+ @property
59
+ def scale(self) -> Tensor:
60
+ return self.raw_scale_constraint.transform(self.raw_scale)
61
+
62
+ @scale.setter
63
+ def scale(self, value: Tensor) -> None:
64
+ self._set_scale(value)
65
+
66
+ def _set_scale(self, value: Tensor) -> None:
67
+ if not torch.is_tensor(value):
68
+ value = torch.as_tensor(value).to(self.raw_scale)
69
+ self.initialize(raw_scale=self.raw_scale_constraint.inverse_transform(value))
70
+
71
+ def forward(self, function_samples: Tensor, *args: Any, **kwargs: Any) -> Beta:
72
+ mixture = torch.sigmoid(function_samples)
73
+ scale = self.scale
74
+ alpha = mixture * scale + 1
75
+ beta = scale - alpha + 2
76
+ return base_distributions.Beta(concentration1=alpha, concentration0=beta)
@@ -0,0 +1,472 @@
1
+ #!/usr/bin/env python3
2
+ import math
3
+ import warnings
4
+ from copy import deepcopy
5
+ from typing import Any, Optional, Tuple, Union
6
+
7
+ import torch
8
+ from linear_operator.operators import LinearOperator, MaskedLinearOperator, ZeroLinearOperator
9
+ from torch import Tensor
10
+ from torch.distributions import Distribution, Normal
11
+
12
+ from .. import settings
13
+ from ..constraints import Interval
14
+ from ..distributions import base_distributions, MultivariateNormal
15
+ from ..priors import Prior
16
+ from ..utils.warnings import GPInputWarning
17
+ from .likelihood import Likelihood
18
+ from .noise_models import FixedGaussianNoise, HomoskedasticNoise, Noise
19
+
20
+
21
+ class _GaussianLikelihoodBase(Likelihood):
22
+ """Base class for Gaussian Likelihoods, supporting general heteroskedastic noise models."""
23
+
24
+ has_analytic_marginal = True
25
+
26
+ def __init__(self, noise_covar: Union[Noise, FixedGaussianNoise], **kwargs: Any) -> None:
27
+ super().__init__()
28
+ param_transform = kwargs.get("param_transform")
29
+ if param_transform is not None:
30
+ warnings.warn(
31
+ "The 'param_transform' argument is now deprecated. If you want to use a different "
32
+ "transformaton, specify a different 'noise_constraint' instead.",
33
+ DeprecationWarning,
34
+ )
35
+
36
+ self.noise_covar = noise_covar
37
+
38
+ def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: Any) -> Union[Tensor, LinearOperator]:
39
+ return self.noise_covar(*params, shape=base_shape, **kwargs)
40
+
41
+ def expected_log_prob(self, target: Tensor, input: MultivariateNormal, *params: Any, **kwargs: Any) -> Tensor:
42
+
43
+ noise = self._shaped_noise_covar(input.mean.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2)
44
+ # Potentially reshape the noise to deal with the multitask case
45
+ noise = noise.view(*noise.shape[:-1], *input.event_shape)
46
+
47
+ # Handle NaN values if enabled
48
+ nan_policy = settings.observation_nan_policy.value()
49
+ if nan_policy == "mask":
50
+ observed = settings.observation_nan_policy._get_observed(target, input.event_shape)
51
+ input = MultivariateNormal(
52
+ mean=input.mean[..., observed],
53
+ covariance_matrix=MaskedLinearOperator(
54
+ input.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1)
55
+ ),
56
+ )
57
+ noise = noise[..., observed]
58
+ target = target[..., observed]
59
+ elif nan_policy == "fill":
60
+ missing = torch.isnan(target)
61
+ target = settings.observation_nan_policy._fill_tensor(target)
62
+
63
+ mean, variance = input.mean, input.variance
64
+ res = ((target - mean).square() + variance) / noise + noise.log() + math.log(2 * math.pi)
65
+ res = res.mul(-0.5)
66
+
67
+ if nan_policy == "fill":
68
+ res = res * ~missing
69
+
70
+ # Do appropriate summation for multitask Gaussian likelihoods
71
+ num_event_dim = len(input.event_shape)
72
+ if num_event_dim > 1:
73
+ res = res.sum(list(range(-1, -num_event_dim, -1)))
74
+
75
+ return res
76
+
77
+ def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> Normal:
78
+ noise = self._shaped_noise_covar(function_samples.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2)
79
+ return base_distributions.Normal(function_samples, noise.sqrt())
80
+
81
+ def log_marginal(
82
+ self, observations: Tensor, function_dist: MultivariateNormal, *params: Any, **kwargs: Any
83
+ ) -> Tensor:
84
+ marginal = self.marginal(function_dist, *params, **kwargs)
85
+
86
+ # Handle NaN values if enabled
87
+ nan_policy = settings.observation_nan_policy.value()
88
+ if nan_policy == "mask":
89
+ observed = settings.observation_nan_policy._get_observed(observations, marginal.event_shape)
90
+ marginal = MultivariateNormal(
91
+ mean=marginal.mean[..., observed],
92
+ covariance_matrix=MaskedLinearOperator(
93
+ marginal.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1)
94
+ ),
95
+ )
96
+ observations = observations[..., observed]
97
+ elif nan_policy == "fill":
98
+ missing = torch.isnan(observations)
99
+ observations = settings.observation_nan_policy._fill_tensor(observations)
100
+
101
+ # We're making everything conditionally independent
102
+ indep_dist = base_distributions.Normal(marginal.mean, marginal.variance.clamp_min(1e-8).sqrt())
103
+ res = indep_dist.log_prob(observations)
104
+
105
+ if nan_policy == "fill":
106
+ res = res * ~missing
107
+
108
+ # Do appropriate summation for multitask Gaussian likelihoods
109
+ num_event_dim = len(marginal.event_shape)
110
+ if num_event_dim > 1:
111
+ res = res.sum(list(range(-1, -num_event_dim, -1)))
112
+ return res
113
+
114
+ def marginal(self, function_dist: MultivariateNormal, *params: Any, **kwargs: Any) -> MultivariateNormal:
115
+ mean, covar = function_dist.mean, function_dist.lazy_covariance_matrix
116
+ noise_covar = self._shaped_noise_covar(mean.shape, *params, **kwargs)
117
+ full_covar = covar + noise_covar
118
+ return function_dist.__class__(mean, full_covar)
119
+
120
+
121
+ class GaussianLikelihood(_GaussianLikelihoodBase):
122
+ r"""
123
+ The standard likelihood for regression.
124
+ Assumes a standard homoskedastic noise model:
125
+
126
+ .. math::
127
+ p(y \mid f) = f + \epsilon, \quad \epsilon \sim \mathcal N (0, \sigma^2)
128
+
129
+ where :math:`\sigma^2` is a noise parameter.
130
+
131
+ .. note::
132
+ This likelihood can be used for exact or approximate inference.
133
+
134
+ .. note::
135
+ GaussianLikelihood has an analytic marginal distribution.
136
+
137
+ :param noise_prior: Prior for noise parameter :math:`\sigma^2`.
138
+ :param noise_constraint: Constraint for noise parameter :math:`\sigma^2`.
139
+ :param batch_shape: The batch shape of the learned noise parameter (default: []).
140
+ :param kwargs:
141
+
142
+ :ivar torch.Tensor noise: :math:`\sigma^2` parameter (noise)
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ noise_prior: Optional[Prior] = None,
148
+ noise_constraint: Optional[Interval] = None,
149
+ batch_shape: torch.Size = torch.Size(),
150
+ **kwargs: Any,
151
+ ) -> None:
152
+ noise_covar = HomoskedasticNoise(
153
+ noise_prior=noise_prior, noise_constraint=noise_constraint, batch_shape=batch_shape
154
+ )
155
+ super().__init__(noise_covar=noise_covar)
156
+
157
+ @property
158
+ def noise(self) -> Tensor:
159
+ return self.noise_covar.noise
160
+
161
+ @noise.setter
162
+ def noise(self, value: Tensor) -> None:
163
+ self.noise_covar.initialize(noise=value)
164
+
165
+ @property
166
+ def raw_noise(self) -> Tensor:
167
+ return self.noise_covar.raw_noise
168
+
169
+ @raw_noise.setter
170
+ def raw_noise(self, value: Tensor) -> None:
171
+ self.noise_covar.initialize(raw_noise=value)
172
+
173
+ def marginal(self, function_dist: MultivariateNormal, *args: Any, **kwargs: Any) -> MultivariateNormal:
174
+ r"""
175
+ :return: Analytic marginal :math:`p(\mathbf y)`.
176
+ """
177
+ return super().marginal(function_dist, *args, **kwargs)
178
+
179
+
180
+ class GaussianLikelihoodWithMissingObs(GaussianLikelihood):
181
+ r"""
182
+ The standard likelihood for regression with support for missing values.
183
+ Assumes a standard homoskedastic noise model:
184
+
185
+ .. math::
186
+ p(y \mid f) = f + \epsilon, \quad \epsilon \sim \mathcal N (0, \sigma^2)
187
+
188
+ where :math:`\sigma^2` is a noise parameter. Values of y that are nan do
189
+ not impact the likelihood calculation.
190
+
191
+ .. note::
192
+ This likelihood can be used for exact or approximate inference.
193
+
194
+ .. warning::
195
+ This likelihood is deprecated in favor of :class:`gpytorch.settings.observation_nan_policy`.
196
+
197
+ :param noise_prior: Prior for noise parameter :math:`\sigma^2`.
198
+ :type noise_prior: ~gpytorch.priors.Prior, optional
199
+ :param noise_constraint: Constraint for noise parameter :math:`\sigma^2`.
200
+ :type noise_constraint: ~gpytorch.constraints.Interval, optional
201
+ :param batch_shape: The batch shape of the learned noise parameter (default: []).
202
+ :type batch_shape: torch.Size, optional
203
+ :var torch.Tensor noise: :math:`\sigma^2` parameter (noise)
204
+
205
+ .. note::
206
+ GaussianLikelihoodWithMissingObs has an analytic marginal distribution.
207
+ """
208
+
209
+ MISSING_VALUE_FILL: float = -999.0
210
+
211
+ def __init__(self, **kwargs: Any) -> None:
212
+ warnings.warn(
213
+ "GaussianLikelihoodWithMissingObs is replaced by qpytorch.settings.observation_nan_policy('fill').",
214
+ DeprecationWarning,
215
+ )
216
+ super().__init__(**kwargs)
217
+
218
+ def _get_masked_obs(self, x: Tensor) -> Tuple[Tensor, Tensor]:
219
+ missing_idx = x.isnan()
220
+ x_masked = x.masked_fill(missing_idx, self.MISSING_VALUE_FILL)
221
+ return missing_idx, x_masked
222
+
223
+ def expected_log_prob(self, target: Tensor, input: MultivariateNormal, *params: Any, **kwargs: Any) -> Tensor:
224
+ missing_idx, target = self._get_masked_obs(target)
225
+ res = super().expected_log_prob(target, input, *params, **kwargs)
226
+ return res * ~missing_idx
227
+
228
+ def log_marginal(
229
+ self, observations: Tensor, function_dist: MultivariateNormal, *params: Any, **kwargs: Any
230
+ ) -> Tensor:
231
+ missing_idx, observations = self._get_masked_obs(observations)
232
+ res = super().log_marginal(observations, function_dist, *params, **kwargs)
233
+ return res * ~missing_idx
234
+
235
+ def marginal(self, function_dist: MultivariateNormal, *args: Any, **kwargs: Any) -> MultivariateNormal:
236
+ r"""
237
+ :return: Analytic marginal :math:`p(\mathbf y)`.
238
+ """
239
+ return super().marginal(function_dist, *args, **kwargs)
240
+
241
+
242
+ class FixedNoiseGaussianLikelihood(_GaussianLikelihoodBase):
243
+ r"""
244
+ A Likelihood that assumes fixed heteroscedastic noise. This is useful when you have fixed, known observation
245
+ noise for each training example.
246
+
247
+ Note that this likelihood takes an additional argument when you call it, `noise`, that adds a specified amount
248
+ of noise to the passed MultivariateNormal. This allows for adding known observational noise to test data.
249
+
250
+ .. note::
251
+ This likelihood can be used for exact or approximate inference.
252
+
253
+ :param noise: Known observation noise (variance) for each training example.
254
+ :type noise: torch.Tensor (... x N)
255
+ :param learn_additional_noise: Set to true if you additionally want to
256
+ learn added diagonal noise, similar to GaussianLikelihood.
257
+ :type learn_additional_noise: bool, optional
258
+ :param batch_shape: The batch shape of the learned noise parameter (default
259
+ []) if :obj:`learn_additional_noise=True`.
260
+ :type batch_shape: torch.Size, optional
261
+
262
+ :var torch.Tensor noise: :math:`\sigma^2` parameter (noise)
263
+
264
+ .. note::
265
+ FixedNoiseGaussianLikelihood has an analytic marginal distribution.
266
+
267
+ Example:
268
+ >>> train_x = torch.randn(55, 2)
269
+ >>> noises = torch.ones(55) * 0.01
270
+ >>> likelihood = FixedNoiseGaussianLikelihood(noise=noises, learn_additional_noise=True)
271
+ >>> pred_y = likelihood(gp_model(train_x))
272
+ >>>
273
+ >>> test_x = torch.randn(21, 2)
274
+ >>> test_noises = torch.ones(21) * 0.02
275
+ >>> pred_y = likelihood(gp_model(test_x), noise=test_noises)
276
+ """
277
+
278
+ def __init__(
279
+ self,
280
+ noise: Tensor,
281
+ learn_additional_noise: Optional[bool] = False,
282
+ batch_shape: Optional[torch.Size] = torch.Size(),
283
+ **kwargs: Any,
284
+ ) -> None:
285
+ super().__init__(noise_covar=FixedGaussianNoise(noise=noise))
286
+
287
+ self.second_noise_covar: Optional[HomoskedasticNoise] = None
288
+ if learn_additional_noise:
289
+ noise_prior = kwargs.get("noise_prior", None)
290
+ noise_constraint = kwargs.get("noise_constraint", None)
291
+ self.second_noise_covar = HomoskedasticNoise(
292
+ noise_prior=noise_prior, noise_constraint=noise_constraint, batch_shape=batch_shape
293
+ )
294
+
295
+ @property
296
+ def noise(self) -> Tensor:
297
+ return self.noise_covar.noise + self.second_noise
298
+
299
+ @noise.setter
300
+ def noise(self, value: Tensor) -> None:
301
+ self.noise_covar.initialize(noise=value)
302
+
303
+ @property
304
+ def second_noise(self) -> Union[float, Tensor]:
305
+ if self.second_noise_covar is None:
306
+ return 0.0
307
+ else:
308
+ return self.second_noise_covar.noise
309
+
310
+ @second_noise.setter
311
+ def second_noise(self, value: Tensor) -> None:
312
+ if self.second_noise_covar is None:
313
+ raise RuntimeError(
314
+ "Attempting to set secondary learned noise for FixedNoiseGaussianLikelihood, "
315
+ "but learn_additional_noise must have been False!"
316
+ )
317
+ self.second_noise_covar.initialize(noise=value)
318
+
319
+ def get_fantasy_likelihood(self, **kwargs: Any) -> "FixedNoiseGaussianLikelihood":
320
+ if "noise" not in kwargs:
321
+ raise RuntimeError("FixedNoiseGaussianLikelihood.fantasize requires a `noise` kwarg")
322
+ old_noise_covar = self.noise_covar
323
+ self.noise_covar = None # pyre-fixme[8]
324
+ fantasy_liklihood = deepcopy(self)
325
+ self.noise_covar = old_noise_covar
326
+
327
+ old_noise = old_noise_covar.noise
328
+ new_noise = kwargs.get("noise")
329
+ if old_noise.dim() != new_noise.dim():
330
+ old_noise = old_noise.expand(*new_noise.shape[:-1], old_noise.shape[-1])
331
+ fantasy_liklihood.noise_covar = FixedGaussianNoise(noise=torch.cat([old_noise, new_noise], -1))
332
+ return fantasy_liklihood
333
+
334
+ def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: Any) -> Union[Tensor, LinearOperator]:
335
+ if len(params) > 0:
336
+ # we can infer the shape from the params
337
+ shape = None
338
+ else:
339
+ # here shape[:-1] is the batch shape requested, and shape[-1] is `n`, the number of points
340
+ shape = base_shape
341
+
342
+ res = self.noise_covar(*params, shape=shape, **kwargs)
343
+
344
+ if self.second_noise_covar is not None:
345
+ res = res + self.second_noise_covar(*params, shape=shape, **kwargs)
346
+ elif isinstance(res, ZeroLinearOperator):
347
+ warnings.warn(
348
+ "You have passed data through a FixedNoiseGaussianLikelihood that did not match the size "
349
+ "of the fixed noise, *and* you did not specify noise. This is treated as a no-op.",
350
+ GPInputWarning,
351
+ )
352
+
353
+ return res
354
+
355
+ def marginal(self, function_dist: MultivariateNormal, *args: Any, **kwargs: Any) -> MultivariateNormal:
356
+ r"""
357
+ :return: Analytic marginal :math:`p(\mathbf y)`.
358
+ """
359
+ return super().marginal(function_dist, *args, **kwargs)
360
+
361
+
362
+ class DirichletClassificationLikelihood(FixedNoiseGaussianLikelihood):
363
+ r"""
364
+ A classification likelihood that treats the labels as regression targets with fixed heteroscedastic noise.
365
+ From Milios et al, NeurIPS, 2018 [https://arxiv.org/abs/1805.10915].
366
+
367
+ .. note::
368
+ This likelihood can be used for exact or approximate inference.
369
+
370
+ :param targets: (... x N) Classification labels.
371
+ :param alpha_epsilon: Tuning parameter for the scaling of the likeihood targets. We'd suggest 0.01 or setting
372
+ via cross-validation.
373
+ :param learn_additional_noise: Set to true if you additionally want to
374
+ learn added diagonal noise, similar to GaussianLikelihood.
375
+ :param batch_shape: The batch shape of the learned noise parameter (default
376
+ []) if :obj:`learn_additional_noise=True`.
377
+
378
+ :ivar torch.Tensor noise: :math:`\sigma^2` parameter (noise)
379
+
380
+ .. note::
381
+ DirichletClassificationLikelihood has an analytic marginal distribution.
382
+
383
+ Example:
384
+ >>> train_x = torch.randn(55, 1)
385
+ >>> labels = torch.round(train_x).long()
386
+ >>> likelihood = DirichletClassificationLikelihood(targets=labels, learn_additional_noise=True)
387
+ >>> pred_y = likelihood(gp_model(train_x))
388
+ >>>
389
+ >>> test_x = torch.randn(21, 1)
390
+ >>> test_labels = torch.round(test_x).long()
391
+ >>> pred_y = likelihood(gp_model(test_x), targets=labels)
392
+ """
393
+
394
+ def _prepare_targets(
395
+ self, targets: Tensor, num_classes: Optional = None, alpha_epsilon: float = 0.01, dtype: torch.dtype = torch.float
396
+ ) -> Tuple[Tensor, Tensor, int]:
397
+ if num_classes is None: num_classes = int(targets.max() + 1)
398
+ # set alpha = \alpha_\epsilon
399
+ alpha = alpha_epsilon * torch.ones(targets.shape[-1], num_classes, device=targets.device, dtype=dtype)
400
+
401
+ # alpha[class_labels] = 1 + \alpha_\epsilon
402
+ alpha[torch.arange(len(targets)), targets] = alpha[torch.arange(len(targets)), targets] + 1.0
403
+
404
+ # sigma^2 = log(1 / alpha + 1)
405
+ sigma2_i = torch.log(alpha.reciprocal() + 1.0)
406
+
407
+ # y = log(alpha) - 0.5 * sigma^2
408
+ transformed_targets = alpha.log() - 0.5 * sigma2_i
409
+
410
+ return sigma2_i.transpose(-2, -1).type(dtype), transformed_targets.type(dtype), num_classes
411
+
412
+ def __init__(
413
+ self,
414
+ targets: Tensor,
415
+ alpha_epsilon: float = 0.01,
416
+ learn_additional_noise: Optional[bool] = False,
417
+ batch_shape: torch.Size = torch.Size(),
418
+ dtype: torch.dtype = torch.float,
419
+ **kwargs: Any,
420
+ ) -> None:
421
+ sigma2_labels, transformed_targets, num_classes = self._prepare_targets(
422
+ targets, alpha_epsilon=alpha_epsilon, dtype=dtype
423
+ )
424
+ super().__init__(
425
+ noise=sigma2_labels,
426
+ learn_additional_noise=learn_additional_noise,
427
+ batch_shape=torch.Size((num_classes,)),
428
+ **kwargs,
429
+ )
430
+ self.transformed_targets: Tensor = transformed_targets.transpose(-2, -1)
431
+ self.num_classes: int = num_classes
432
+ self.targets: Tensor = targets
433
+ self.alpha_epsilon: float = alpha_epsilon
434
+
435
+ def get_fantasy_likelihood(self, **kwargs: Any) -> "DirichletClassificationLikelihood":
436
+ # we assume that the number of classes does not change.
437
+
438
+ if "targets" not in kwargs:
439
+ raise RuntimeError("FixedNoiseGaussianLikelihood.fantasize requires a `targets` kwarg")
440
+
441
+ old_noise_covar = self.noise_covar
442
+ self.noise_covar = None # pyre-fixme[8]
443
+ fantasy_liklihood = deepcopy(self)
444
+ self.noise_covar = old_noise_covar
445
+
446
+ old_noise = old_noise_covar.noise
447
+ new_targets = kwargs.get("noise")
448
+ new_noise, new_targets, _ = fantasy_liklihood._prepare_targets(new_targets, self.alpha_epsilon)
449
+ fantasy_liklihood.targets = torch.cat([fantasy_liklihood.targets, new_targets], -1)
450
+
451
+ if old_noise.dim() != new_noise.dim():
452
+ old_noise = old_noise.expand(*new_noise.shape[:-1], old_noise.shape[-1])
453
+
454
+ fantasy_liklihood.noise_covar = FixedGaussianNoise(noise=torch.cat([old_noise, new_noise], -1))
455
+ return fantasy_liklihood
456
+
457
+ def marginal(self, function_dist: MultivariateNormal, *args: Any, **kwargs: Any) -> MultivariateNormal:
458
+ r"""
459
+ :return: Analytic marginal :math:`p(\mathbf y)`.
460
+ """
461
+ return super().marginal(function_dist, *args, **kwargs)
462
+
463
+ def __call__(self, input: Union[Tensor, MultivariateNormal], *args: Any, **kwargs: Any) -> Distribution:
464
+ if "targets" in kwargs:
465
+ targets = kwargs.pop("targets")
466
+ dtype = self.transformed_targets.dtype
467
+ new_noise, _, _ = self._prepare_targets(targets, dtype=dtype)
468
+ kwargs["noise"] = new_noise
469
+ return super().__call__(input, *args, **kwargs)
470
+
471
+ class GaussianDirichletClassificationLikelihood(DirichletClassificationLikelihood):
472
+ pass
@@ -0,0 +1,59 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torch.distributions import Laplace
8
+
9
+ from ..constraints import Interval, Positive
10
+ from ..distributions import base_distributions
11
+ from ..priors import Prior
12
+ from .likelihood import _OneDimensionalLikelihood
13
+
14
+
15
+ class LaplaceLikelihood(_OneDimensionalLikelihood):
16
+ r"""
17
+ A Laplace likelihood/noise model for GP/QEP regression.
18
+ It has one learnable parameter: :math:`\sigma` - the noise
19
+
20
+ :param batch_shape: The batch shape of the learned noise parameter (default: []).
21
+ :param noise_prior: Prior for noise parameter :math:`\sigma`.
22
+ :param noise_constraint: Constraint for noise parameter :math:`\sigma`.
23
+
24
+ :var torch.Tensor noise: :math:`\sigma` parameter (noise)
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ batch_shape: torch.Size = torch.Size([]),
30
+ noise_prior: Optional[Prior] = None,
31
+ noise_constraint: Optional[Interval] = None,
32
+ ) -> None:
33
+ super().__init__()
34
+
35
+ if noise_constraint is None:
36
+ noise_constraint = Positive()
37
+
38
+ self.raw_noise = torch.nn.Parameter(torch.zeros(*batch_shape, 1))
39
+
40
+ if noise_prior is not None:
41
+ self.register_prior("noise_prior", noise_prior, lambda m: m.noise, lambda m, v: m._set_noise(v))
42
+
43
+ self.register_constraint("raw_noise", noise_constraint)
44
+
45
+ @property
46
+ def noise(self) -> Tensor:
47
+ return self.raw_noise_constraint.transform(self.raw_noise)
48
+
49
+ @noise.setter
50
+ def noise(self, value: Tensor) -> None:
51
+ self._set_noise(value)
52
+
53
+ def _set_noise(self, value: Tensor) -> None:
54
+ if not torch.is_tensor(value):
55
+ value = torch.as_tensor(value).to(self.raw_noise)
56
+ self.initialize(raw_noise=self.raw_noise_constraint.inverse_transform(value))
57
+
58
+ def forward(self, function_samples: Tensor, *args: Any, **kwargs: Any) -> Laplace:
59
+ return base_distributions.Laplace(loc=function_samples, scale=self.noise.sqrt())