qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,542 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Any, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from linear_operator import to_linear_operator
7
+ from linear_operator.operators import (
8
+ ConstantDiagLinearOperator,
9
+ DiagLinearOperator,
10
+ KroneckerProductDiagLinearOperator,
11
+ KroneckerProductLinearOperator,
12
+ LinearOperator,
13
+ RootLinearOperator,
14
+ BlockDiagLinearOperator,
15
+ ZeroLinearOperator
16
+ )
17
+ from torch import Tensor
18
+ from torch.distributions import Normal
19
+
20
+ from ..constraints import GreaterThan, Interval
21
+ from ..distributions import base_distributions, MultitaskMultivariateNormal, Distribution
22
+ from ..lazy import LazyEvaluatedKernelTensor
23
+ from ..likelihoods import _GaussianLikelihoodBase, Likelihood
24
+ from ..priors import Prior
25
+ from .noise_models import FixedGaussianNoise, MultitaskHomoskedasticNoise, Noise
26
+
27
+
28
+ class _MultitaskGaussianLikelihoodBase(_GaussianLikelihoodBase):
29
+ r"""
30
+ Base class for multi-task Gaussian Likelihoods, supporting general heteroskedastic noise models.
31
+
32
+ :param num_tasks: Number of tasks.
33
+ :param noise_covar: A model for the noise covariance. This can be a simple homoskedastic noise model, or a GP
34
+ that is to be fitted on the observed measurement errors.
35
+ :param rank: The rank of the task noise covariance matrix to fit. If `rank`
36
+ is set to 0, then a diagonal covariance matrix is fit.
37
+ :param task_correlation_prior: Prior to use over the task noise correlation
38
+ matrix. Only used when :math:`\text{rank} > 0`.
39
+ :param batch_shape: Number of batches.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ num_tasks: int,
45
+ noise_covar: Union[Noise, FixedGaussianNoise],
46
+ rank: int = 0,
47
+ task_correlation_prior: Optional[Prior] = None,
48
+ batch_shape: torch.Size = torch.Size(),
49
+ ) -> None:
50
+ super().__init__(noise_covar=noise_covar)
51
+ if rank != 0:
52
+ if rank > num_tasks:
53
+ raise ValueError(f"Cannot have rank ({rank}) greater than num_tasks ({num_tasks})")
54
+ tidcs = torch.tril_indices(num_tasks, rank, dtype=torch.long)
55
+ self.tidcs: Tensor = tidcs[:, 1:] # (1, 1) must be 1.0, no need to parameterize this
56
+ task_noise_corr = torch.randn(*batch_shape, self.tidcs.size(-1))
57
+ self.register_parameter("task_noise_corr", torch.nn.Parameter(task_noise_corr))
58
+ if task_correlation_prior is not None:
59
+ self.register_prior(
60
+ "MultitaskErrorCorrelationPrior", task_correlation_prior, lambda m: m._eval_corr_matrix
61
+ )
62
+ elif task_correlation_prior is not None:
63
+ raise ValueError("Can only specify task_correlation_prior if rank>0")
64
+ self.num_tasks = num_tasks
65
+ self.rank = rank
66
+
67
+ def _eval_corr_matrix(self) -> Tensor:
68
+ tnc = self.task_noise_corr
69
+ fac_diag = torch.ones(*tnc.shape[:-1], self.num_tasks, device=tnc.device, dtype=tnc.dtype)
70
+ Cfac = torch.diag_embed(fac_diag)
71
+ Cfac[..., self.tidcs[0], self.tidcs[1]] = self.task_noise_corr
72
+ # squared rows must sum to one for this to be a correlation matrix
73
+ C = Cfac / Cfac.pow(2).sum(dim=-1, keepdim=True).sqrt()
74
+ return C @ C.transpose(-1, -2)
75
+
76
+ def marginal(
77
+ self, function_dist: MultitaskMultivariateNormal, *params: Any, **kwargs: Any
78
+ ) -> MultitaskMultivariateNormal: # pyre-ignore[14]
79
+ r"""
80
+ If :math:`\text{rank} = 0`, adds the task noises to the diagonal of the
81
+ covariance matrix of the supplied
82
+ :obj:`~gpytorch.distributions.MultivariateNormal` or
83
+ :obj:`~gpytorch.distributions.MultitaskMultivariateNormal`. Otherwise,
84
+ adds a rank `rank` covariance matrix to it.
85
+
86
+ To accomplish this, we form a new
87
+ :obj:`~linear_operator.operators.KroneckerProductLinearOperator`
88
+ between :math:`I_{n}`, an identity matrix with size equal to the data
89
+ and a (not necessarily diagonal) matrix containing the task noises
90
+ :math:`D_{t}`.
91
+
92
+ We also incorporate a shared `noise` parameter from the base
93
+ :class:`qpytorch.likelihoods.GaussianLikelihood` that we extend.
94
+
95
+ The final covariance matrix after this method is then
96
+ :math:`\mathbf K + \mathbf D_{t} \otimes \mathbf I_{n} + \sigma^{2} \mathbf I_{nt}`.
97
+
98
+ :param function_dist: Random variable whose covariance
99
+ matrix is a :obj:`~linear_operator.operators.LinearOperator` we intend to augment.
100
+ :rtype: `gpytorch.distributions.MultitaskMultivariateNormal`:
101
+ :return: A new random variable whose covariance matrix is a
102
+ :obj:`~linear_operator.operators.LinearOperator` with
103
+ :math:`\mathbf D_{t} \otimes \mathbf I_{n}` and :math:`\sigma^{2} \mathbf I_{nt}` added.
104
+ """
105
+ mean, covar = function_dist.mean, function_dist.lazy_covariance_matrix
106
+
107
+ # ensure that sumKroneckerLT is actually called
108
+ if isinstance(covar, LazyEvaluatedKernelTensor):
109
+ covar = covar.evaluate_kernel()
110
+
111
+ covar_kron_lt = self._shaped_noise_covar(
112
+ mean.shape, add_noise=self.has_global_noise, interleaved=function_dist._interleaved
113
+ )
114
+ covar = covar + covar_kron_lt
115
+
116
+ return function_dist.__class__(mean, covar, interleaved=function_dist._interleaved)
117
+
118
+ def _shaped_noise_covar(
119
+ self, shape: torch.Size, add_noise: Optional[bool] = True, interleaved: bool = True, *params: Any, **kwargs: Any
120
+ ) -> LinearOperator:
121
+ if not self.has_task_noise:
122
+ noise = ConstantDiagLinearOperator(self.noise, diag_shape=shape[-2] * self.num_tasks)
123
+ return noise
124
+
125
+ if self.rank == 0:
126
+ task_noises = self.raw_task_noises_constraint.transform(self.raw_task_noises)
127
+ task_var_lt = DiagLinearOperator(task_noises)
128
+ dtype, device = task_noises.dtype, task_noises.device
129
+ ckl_init = KroneckerProductDiagLinearOperator
130
+ else:
131
+ task_noise_covar_factor = self.task_noise_covar_factor
132
+ task_var_lt = RootLinearOperator(task_noise_covar_factor)
133
+ dtype, device = task_noise_covar_factor.dtype, task_noise_covar_factor.device
134
+ ckl_init = KroneckerProductLinearOperator
135
+
136
+ eye_lt = ConstantDiagLinearOperator(
137
+ torch.ones(*shape[:-2], 1, dtype=dtype, device=device), diag_shape=shape[-2]
138
+ )
139
+ task_var_lt = task_var_lt.expand(*shape[:-2], *task_var_lt.matrix_shape) # pyre-ignore[6]
140
+
141
+ # to add the latent noise we exploit the fact that
142
+ # I \kron D_T + \sigma^2 I_{NT} = I \kron (D_T + \sigma^2 I)
143
+ # which allows us to move the latent noise inside the task dependent noise
144
+ # thereby allowing exploitation of Kronecker structure in this likelihood.
145
+ if add_noise and self.has_global_noise:
146
+ noise = ConstantDiagLinearOperator(self.noise, diag_shape=task_var_lt.shape[-1])
147
+ task_var_lt = task_var_lt + noise
148
+
149
+ if interleaved:
150
+ covar_kron_lt = ckl_init(eye_lt, task_var_lt)
151
+ else:
152
+ covar_kron_lt = ckl_init(task_var_lt, eye_lt)
153
+
154
+ return covar_kron_lt
155
+
156
+ def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> Normal:
157
+ noise = self._shaped_noise_covar(function_samples.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2)
158
+ noise = noise.reshape(*noise.shape[:-1], *function_samples.shape[-2:])
159
+ return base_distributions.Independent(base_distributions.Normal(function_samples, noise.sqrt()), 1)
160
+
161
+
162
+ class MultitaskGaussianLikelihood(_MultitaskGaussianLikelihoodBase):
163
+ r"""
164
+ A convenient extension of the :class:`~qpytorch.likelihoods.GaussianLikelihood` to the multitask setting that allows
165
+ for a full cross-task covariance structure for the noise. The fitted covariance matrix has rank `rank`.
166
+ If a strictly diagonal task noise covariance matrix is desired, then rank=0 should be set. (This option still
167
+ allows for a different `noise` parameter for each task.)
168
+
169
+ Like the Gaussian likelihood, this object can be used with exact inference.
170
+
171
+ .. note::
172
+ At least one of :attr:`has_global_noise` or :attr:`has_task_noise` should be specified.
173
+
174
+ .. note::
175
+ MultitaskGaussianLikelihood has an analytic marginal distribution.
176
+
177
+ :param num_tasks: Number of tasks.
178
+ :param noise_covar: A model for the noise covariance. This can be a simple homoskedastic noise model, or a GP
179
+ that is to be fitted on the observed measurement errors.
180
+ :param rank: The rank of the task noise covariance matrix to fit. If `rank`
181
+ is set to 0, then a diagonal covariance matrix is fit.
182
+ :param task_prior: Prior to use over the task noise correlation
183
+ matrix. Only used when :math:`\text{rank} > 0`.
184
+ :param batch_shape: Number of batches.
185
+ :param has_global_noise: Whether to include a :math:`\sigma^2 \mathbf I_{nt}` term in the noise model.
186
+ :param has_task_noise: Whether to include task-specific noise terms, which add
187
+ :math:`\mathbf I_n \otimes \mathbf D_T` into the noise model.
188
+
189
+ :ivar torch.Tensor task_noise_covar: The inter-task noise covariance matrix
190
+ :ivar torch.Tensor task_noises: (Optional) task specific noise variances (added onto the `task_noise_covar`)
191
+ :ivar torch.Tensor noise: (Optional) global noise variance (added onto the `task_noise_covar`)
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ num_tasks: int,
197
+ rank: int = 0,
198
+ batch_shape: torch.Size = torch.Size(),
199
+ task_prior: Optional[Prior] = None,
200
+ noise_prior: Optional[Prior] = None,
201
+ noise_constraint: Optional[Interval] = None,
202
+ has_global_noise: bool = True,
203
+ has_task_noise: bool = True,
204
+ ) -> None:
205
+ super(Likelihood, self).__init__() # pyre-ignore[20]
206
+ if noise_constraint is None:
207
+ noise_constraint = GreaterThan(1e-4)
208
+
209
+ if not has_task_noise and not has_global_noise:
210
+ raise ValueError(
211
+ "At least one of has_task_noise or has_global_noise must be specified. "
212
+ "Attempting to specify a likelihood that has no noise terms."
213
+ )
214
+
215
+ if has_task_noise:
216
+ if rank == 0:
217
+ self.register_parameter(
218
+ name="raw_task_noises", parameter=torch.nn.Parameter(torch.zeros(*batch_shape, num_tasks))
219
+ )
220
+ self.register_constraint("raw_task_noises", noise_constraint)
221
+ if noise_prior is not None:
222
+ self.register_prior("raw_task_noises_prior", noise_prior, lambda m: m.task_noises)
223
+ if task_prior is not None:
224
+ raise RuntimeError("Cannot set a `task_prior` if rank=0")
225
+ else:
226
+ self.register_parameter(
227
+ name="task_noise_covar_factor",
228
+ parameter=torch.nn.Parameter(torch.randn(*batch_shape, num_tasks, rank)),
229
+ )
230
+ if task_prior is not None:
231
+ self.register_prior("MultitaskErrorCovariancePrior", task_prior, lambda m: m._eval_covar_matrix)
232
+ self.num_tasks = num_tasks
233
+ self.rank = rank
234
+
235
+ if has_global_noise:
236
+ self.register_parameter(name="raw_noise", parameter=torch.nn.Parameter(torch.zeros(*batch_shape, 1)))
237
+ self.register_constraint("raw_noise", noise_constraint)
238
+ if noise_prior is not None:
239
+ self.register_prior("raw_noise_prior", noise_prior, lambda m: m.noise)
240
+
241
+ self.has_global_noise = has_global_noise
242
+ self.has_task_noise = has_task_noise
243
+
244
+ @property
245
+ def noise(self) -> Optional[Tensor]:
246
+ return self.raw_noise_constraint.transform(self.raw_noise)
247
+
248
+ @noise.setter
249
+ def noise(self, value: Union[float, Tensor]) -> None:
250
+ self._set_noise(value)
251
+
252
+ @property
253
+ def task_noises(self) -> Optional[Tensor]:
254
+ if self.rank == 0:
255
+ return self.raw_task_noises_constraint.transform(self.raw_task_noises)
256
+ else:
257
+ raise AttributeError("Cannot set diagonal task noises when covariance has ", self.rank, ">0")
258
+
259
+ @task_noises.setter
260
+ def task_noises(self, value: Union[float, Tensor]) -> None:
261
+ if self.rank == 0:
262
+ self._set_task_noises(value)
263
+ else:
264
+ raise AttributeError("Cannot set diagonal task noises when covariance has ", self.rank, ">0")
265
+
266
+ def _set_noise(self, value: Union[float, Tensor]) -> None:
267
+ self.initialize(raw_noise=self.raw_noise_constraint.inverse_transform(value))
268
+
269
+ def _set_task_noises(self, value: Union[float, Tensor]) -> None:
270
+ self.initialize(raw_task_noises=self.raw_task_noises_constraint.inverse_transform(value))
271
+
272
+ @property
273
+ def task_noise_covar(self) -> Tensor:
274
+ if self.rank > 0:
275
+ return self.task_noise_covar_factor.matmul(self.task_noise_covar_factor.transpose(-1, -2))
276
+ else:
277
+ raise AttributeError("Cannot retrieve task noises when covariance is diagonal.")
278
+
279
+ @task_noise_covar.setter
280
+ def task_noise_covar(self, value: Tensor) -> None:
281
+ # internally uses a pivoted cholesky decomposition to construct a low rank
282
+ # approximation of the covariance
283
+ if self.rank > 0:
284
+ with torch.no_grad():
285
+ self.task_noise_covar_factor.data = to_linear_operator(value).pivoted_cholesky(rank=self.rank)
286
+ else:
287
+ raise AttributeError("Cannot set non-diagonal task noises when covariance is diagonal.")
288
+
289
+ def _eval_covar_matrix(self) -> Tensor:
290
+ covar_factor = self.task_noise_covar_factor
291
+ noise = self.noise
292
+ D = noise * torch.eye(self.num_tasks, dtype=noise.dtype, device=noise.device) # pyre-fixme[16]
293
+ return covar_factor.matmul(covar_factor.transpose(-1, -2)) + D
294
+
295
+ def marginal(
296
+ self, function_dist: MultitaskMultivariateNormal, *args: Any, **kwargs: Any
297
+ ) -> MultitaskMultivariateNormal:
298
+ r"""
299
+ :return: Analytic marginal :math:`p(\mathbf y)`.
300
+ """
301
+ return super().marginal(function_dist, *args, **kwargs)
302
+
303
+
304
+ class MultitaskFixedNoiseGaussianLikelihood(_MultitaskGaussianLikelihoodBase):
305
+ r"""
306
+ A convenient extension of the :class:`~qpytorch.likelihoods.FixedNoiseGaussianLikelihood` to the multitask setting
307
+ that assumes fixed heteroscedastic noise. This is useful when you have fixed, known observation
308
+ noise for each training example.
309
+
310
+ Note that this likelihood takes an additional argument when you call it, `noise`, that adds a specified amount
311
+ of noise to the passed MultivariateNormal. This allows for adding known observational noise to test data.
312
+
313
+ .. note::
314
+ This likelihood can be used for exact or approximate inference.
315
+
316
+ :param num_tasks: Number of tasks.
317
+ :param noise: Known observation noise (variance) for each training example.
318
+ :type noise: torch.Tensor (... x N)
319
+ :param rank: The rank of the task noise covariance matrix to fit. If `rank`
320
+ is set to 0, then a diagonal covariance matrix is fit.
321
+ :param learn_additional_noise: Set to true if you additionally want to
322
+ learn added diagonal noise, similar to GaussianLikelihood.
323
+ :type learn_additional_noise: bool, optional
324
+ :param batch_shape: The batch shape of the learned noise parameter (default
325
+ []) if :obj:`learn_additional_noise=True`.
326
+ :type batch_shape: torch.Size, optional
327
+
328
+ :var torch.Tensor noise: :math:`\sigma^2` parameter (noise)
329
+
330
+ .. note::
331
+ MultitaskFixedNoiseGaussianLikelihood has an analytic marginal distribution.
332
+
333
+ Example:
334
+ >>> num_tasks = 2
335
+ >>> train_x = torch.randn(55, 2)
336
+ >>> noises = torch.ones(55) * 0.01
337
+ >>> likelihood = MultitaskFixedNoiseGaussianLikelihood(num_tasks=num_tasks, noise=noises, learn_additional_noise=True)
338
+ >>> pred_y = likelihood(gp_model(train_x))
339
+ >>>
340
+ >>> test_x = torch.randn(21, 2)
341
+ >>> test_noises = torch.ones(21) * 0.02
342
+ >>> pred_y = likelihood(gp_model(test_x), noise=test_noises)
343
+ """
344
+
345
+ def __init__(
346
+ self,
347
+ num_tasks: int,
348
+ noise: Tensor,
349
+ rank: int = 0,
350
+ learn_additional_noise: Optional[bool] = False,
351
+ batch_shape: Optional[torch.Size] = torch.Size(),
352
+ **kwargs: Any,
353
+ ) -> None:
354
+ super().__init__(num_tasks=num_tasks, noise_covar=FixedGaussianNoise(noise=noise), rank=rank, batch_shape=batch_shape)
355
+
356
+ self.second_noise_covar: Optional[MultitaskHomoskedasticNoise] = None
357
+ if learn_additional_noise:
358
+ noise_prior = kwargs.get("noise_prior", None)
359
+ noise_constraint = kwargs.get("noise_constraint", None)
360
+ self.second_noise_covar = MultitaskHomoskedasticNoise(
361
+ num_tasks=1, noise_prior=noise_prior, noise_constraint=noise_constraint, batch_shape=batch_shape
362
+ )
363
+
364
+ @property
365
+ def noise(self) -> Tensor:
366
+ return self.noise_covar.noise + self.second_noise
367
+
368
+ @noise.setter
369
+ def noise(self, value: Tensor) -> None:
370
+ self.noise_covar.initialize(noise=value)
371
+
372
+ @property
373
+ def second_noise(self) -> Union[float, Tensor]:
374
+ if self.second_noise_covar is None:
375
+ return 0.0
376
+ else:
377
+ return self.second_noise_covar.noise
378
+
379
+ @second_noise.setter
380
+ def second_noise(self, value: Tensor) -> None:
381
+ if self.second_noise_covar is None:
382
+ raise RuntimeError(
383
+ "Attempting to set secondary learned noise for MultitaskFixedNoiseGaussianLikelihood, "
384
+ "but learn_additional_noise must have been False!"
385
+ )
386
+ self.second_noise_covar.initialize(noise=value)
387
+
388
+ def get_fantasy_likelihood(self, **kwargs: Any) -> "MultitaskFixedNoiseGaussianLikelihood":
389
+ if "noise" not in kwargs:
390
+ raise RuntimeError("MultitaskFixedNoiseGaussianLikelihood.fantasize requires a `noise` kwarg")
391
+ old_noise_covar = self.noise_covar
392
+ self.noise_covar = None # pyre-fixme[8]
393
+ fantasy_liklihood = deepcopy(self)
394
+ self.noise_covar = old_noise_covar
395
+
396
+ old_noise = old_noise_covar.noise
397
+ new_noise = kwargs.get("noise")
398
+ if old_noise.dim() != new_noise.dim():
399
+ old_noise = old_noise.expand(*new_noise.shape[:-1], old_noise.shape[-1])
400
+ fantasy_liklihood.noise_covar = FixedGaussianNoise(noise=torch.cat([old_noise, new_noise], -1))
401
+ return fantasy_liklihood
402
+
403
+ def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: Any) -> Union[Tensor, LinearOperator]:
404
+ if len(params) > 0:
405
+ # we can infer the shape from the params
406
+ shape = None
407
+ else:
408
+ # here shape[:-1] is the batch shape requested, and shape[-1] is `n`, the number of points
409
+ shape = base_shape[:-2]+base_shape[-2:][::-1]
410
+
411
+ res = self.noise_covar(*params, shape=shape, **kwargs)
412
+
413
+ if self.second_noise_covar is not None:
414
+ res = res + self.second_noise_covar(*params, shape=shape, **kwargs)
415
+ elif isinstance(res, ZeroLinearOperator):
416
+ warnings.warn(
417
+ "You have passed data through a FixedNoiseGaussianLikelihood that did not match the size "
418
+ "of the fixed noise, *and* you did not specify noise. This is treated as a no-op.",
419
+ GPInputWarning,
420
+ )
421
+
422
+ return BlockDiagLinearOperator(res)
423
+
424
+ def marginal(self, function_dist: MultitaskMultivariateNormal, *args: Any, **kwargs: Any) -> MultitaskMultivariateNormal:
425
+ r"""
426
+ :return: Analytic marginal :math:`p(\mathbf y)`.
427
+ """
428
+ return super().marginal(function_dist, *args, **kwargs)
429
+
430
+
431
+ class MultitaskDirichletClassificationLikelihood(MultitaskFixedNoiseGaussianLikelihood):
432
+ r"""
433
+ A multi-classification likelihood that treats the labels as regression targets with fixed heteroscedastic noise.
434
+ From Milios et al, NeurIPS, 2018 [https://arxiv.org/abs/1805.10915].
435
+
436
+ .. note::
437
+ This multitask likelihood can be used for exact or approximate inference and in deep models.
438
+
439
+ :param targets: (... x N) Classification labels.
440
+ :param alpha_epsilon: Tuning parameter for the scaling of the likeihood targets. We'd suggest 0.01 or setting
441
+ via cross-validation.
442
+ :param learn_additional_noise: Set to true if you additionally want to
443
+ learn added diagonal noise, similar to GaussianLikelihood.
444
+ :param batch_shape: The batch shape of the learned noise parameter (default
445
+ []) if :obj:`learn_additional_noise=True`.
446
+
447
+ :ivar torch.Tensor noise: :math:`\sigma^2` parameter (noise)
448
+
449
+ .. note::
450
+ MultitaskDirichletClassificationLikelihood has an analytic marginal distribution.
451
+
452
+ Example:
453
+ >>> train_x = torch.randn(55, 1)
454
+ >>> labels = torch.round(train_x).long()
455
+ >>> likelihood = MultitaskDirichletClassificationLikelihood(targets=labels, learn_additional_noise=True)
456
+ >>> pred_y = likelihood(gp_model(train_x))
457
+ >>>
458
+ >>> test_x = torch.randn(21, 1)
459
+ >>> test_labels = torch.round(test_x).long()
460
+ >>> pred_y = likelihood(gp_model(test_x), targets=labels)
461
+ """
462
+
463
+ def _prepare_targets(
464
+ self, targets: Tensor, num_classes: Optional = None, alpha_epsilon: float = 0.01, dtype: torch.dtype = torch.float
465
+ ) -> Tuple[Tensor, Tensor, int]:
466
+ if num_classes is None: num_classes = int(targets.max() + 1)
467
+ # set alpha = \alpha_\epsilon
468
+ alpha = alpha_epsilon * torch.ones(targets.shape[-1], num_classes, device=targets.device, dtype=dtype)
469
+
470
+ # alpha[class_labels] = 1 + \alpha_\epsilon
471
+ alpha[torch.arange(len(targets)), targets] = alpha[torch.arange(len(targets)), targets] + 1.0
472
+
473
+ # sigma^2 = log(1 / alpha + 1)
474
+ sigma2_i = torch.log(alpha.reciprocal() + 1.0)
475
+
476
+ # y = log(alpha) - 0.5 * sigma^2
477
+ transformed_targets = alpha.log() - 0.5 * sigma2_i
478
+
479
+ return sigma2_i.transpose(-2, -1).type(dtype), transformed_targets.type(dtype), num_classes
480
+
481
+ def __init__(
482
+ self,
483
+ targets: Tensor,
484
+ alpha_epsilon: float = 0.01,
485
+ learn_additional_noise: Optional[bool] = False,
486
+ batch_shape: torch.Size = torch.Size(),
487
+ dtype: torch.dtype = torch.float,
488
+ **kwargs: Any,
489
+ ) -> None:
490
+ sigma2_labels, transformed_targets, num_classes = self._prepare_targets(
491
+ targets, alpha_epsilon=alpha_epsilon, dtype=dtype
492
+ )
493
+ super().__init__(
494
+ num_tasks=num_classes,
495
+ noise=sigma2_labels,
496
+ learn_additional_noise=learn_additional_noise,
497
+ batch_shape=torch.Size((num_classes,)),
498
+ **kwargs,
499
+ )
500
+ self.transformed_targets: Tensor = transformed_targets.transpose(-2, -1)
501
+ self.num_classes: int = num_classes
502
+ self.targets: Tensor = targets
503
+ self.alpha_epsilon: float = alpha_epsilon
504
+
505
+ def get_fantasy_likelihood(self, **kwargs: Any) -> "MultitaskDirichletClassificationLikelihood":
506
+ # we assume that the number of classes does not change.
507
+
508
+ if "targets" not in kwargs:
509
+ raise RuntimeError("FixedNoiseGaussianLikelihood.fantasize requires a `targets` kwarg")
510
+
511
+ old_noise_covar = self.noise_covar
512
+ self.noise_covar = None # pyre-fixme[8]
513
+ fantasy_liklihood = deepcopy(self)
514
+ self.noise_covar = old_noise_covar
515
+
516
+ old_noise = old_noise_covar.noise
517
+ new_targets = kwargs.get("noise")
518
+ new_noise, new_targets, _ = fantasy_liklihood._prepare_targets(new_targets, self.alpha_epsilon)
519
+ fantasy_liklihood.targets = torch.cat([fantasy_liklihood.targets, new_targets], -1)
520
+
521
+ if old_noise.dim() != new_noise.dim():
522
+ old_noise = old_noise.expand(*new_noise.shape[:-1], old_noise.shape[-1])
523
+
524
+ fantasy_liklihood.noise_covar = FixedGaussianNoise(noise=torch.cat([old_noise, new_noise], -1))
525
+ return fantasy_liklihood
526
+
527
+ def marginal(self, function_dist: MultitaskMultivariateNormal, *args: Any, **kwargs: Any) -> MultitaskMultivariateNormal:
528
+ r"""
529
+ :return: Analytic marginal :math:`p(\mathbf y)`.
530
+ """
531
+ return super().marginal(function_dist, *args, **kwargs)
532
+
533
+ def __call__(self, input: Union[Tensor, MultitaskMultivariateNormal], *args: Any, **kwargs: Any) -> Distribution:
534
+ if "targets" in kwargs:
535
+ targets = kwargs.pop("targets")
536
+ dtype = self.transformed_targets.dtype
537
+ new_noise, _, _ = self._prepare_targets(targets, dtype=dtype)
538
+ kwargs["noise"] = new_noise
539
+ return super().__call__(input, *args, **kwargs)
540
+
541
+ class MultitaskGaussianDirichletClassificationLikelihood(MultitaskDirichletClassificationLikelihood):
542
+ pass