qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,97 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import warnings
4
+ from typing import Any, Optional, Union
5
+
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.distributions import Categorical, Distribution
9
+
10
+ from ..distributions import base_distributions, MultitaskMultivariateNormal, MultitaskMultivariateQExponential
11
+ from ..priors import Prior
12
+ from .likelihood import Likelihood
13
+
14
+
15
+ class SoftmaxLikelihood(Likelihood):
16
+ r"""
17
+ Implements the Softmax (multiclass) likelihood used for GP (QEP) classification.
18
+
19
+ .. math::
20
+ p(\mathbf y \mid \mathbf f) = \text{Softmax} \left( \mathbf W \mathbf f \right)
21
+
22
+ :math:`\mathbf W` is a set of linear mixing weights applied to the latent functions :math:`\mathbf f`.
23
+
24
+ :param num_features: Dimensionality of latent function :math:`\mathbf f`.
25
+ :param num_classes: Number of classes.
26
+ :param mixing_weights: (Default: `True`) Whether to learn a linear mixing weight :math:`\mathbf W` applied to
27
+ the latent function :math:`\mathbf f`. If `False`, then :math:`\mathbf W = \mathbf I`.
28
+ :param mixing_weights_prior: Prior to use over the mixing weights :math:`\mathbf W`.
29
+
30
+ :ivar torch.Tensor mixing_weights: (Optional) mixing weights.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ num_features: Optional[int] = None,
36
+ num_classes: int = None, # pyre-fixme[9]
37
+ mixing_weights: bool = True,
38
+ mixing_weights_prior: Optional[Prior] = None,
39
+ ) -> None:
40
+ super().__init__()
41
+ if num_classes is None:
42
+ raise ValueError("num_classes is required")
43
+ self.num_classes = num_classes
44
+ if mixing_weights:
45
+ if num_features is None:
46
+ raise ValueError("num_features is required with mixing weights")
47
+ self.num_features: int = num_features
48
+ self.register_parameter(
49
+ name="mixing_weights",
50
+ parameter=torch.nn.Parameter(torch.randn(num_classes, num_features).div_(num_features)),
51
+ )
52
+ if mixing_weights_prior is not None:
53
+ self.register_prior("mixing_weights_prior", mixing_weights_prior, "mixing_weights")
54
+ else:
55
+ self.num_features = num_classes
56
+ self.mixing_weights: Optional[torch.nn.Parameter] = None
57
+
58
+ def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> Categorical:
59
+ num_data, num_features = function_samples.shape[-2:]
60
+
61
+ # Catch legacy mode
62
+ if num_data == self.num_features:
63
+ warnings.warn(
64
+ "The input to SoftmaxLikelihood should be a MultitaskMultivariateNormal or MultitaskMultivariateQExponential (num_data x num_tasks). "
65
+ "Batch MultivariateNormal inputs (num_tasks x num_data) will be deprectated.",
66
+ DeprecationWarning,
67
+ )
68
+ function_samples = function_samples.transpose(-1, -2)
69
+ num_data, num_features = function_samples.shape[-2:]
70
+
71
+ if num_features != self.num_features:
72
+ raise RuntimeError("There should be %d features" % self.num_features)
73
+
74
+ if self.mixing_weights is not None:
75
+ mixed_fs = function_samples @ self.mixing_weights.t() # num_classes x num_data
76
+ else:
77
+ mixed_fs = function_samples
78
+ res = base_distributions.Categorical(logits=mixed_fs)
79
+ return res
80
+
81
+ def __call__(self, input: Union[Tensor, MultitaskMultivariateNormal, MultitaskMultivariateQExponential], *args: Any, **kwargs: Any) -> Distribution:
82
+ if isinstance(input, Distribution):
83
+ if not isinstance(input, MultitaskMultivariateNormal) and not hasattr(input, 'power'):
84
+ warnings.warn(
85
+ "The input to SoftmaxLikelihood should be a MultitaskMultivariateNormal (num_data x num_tasks). "
86
+ "Batch MultivariateNormal inputs (num_tasks x num_data) will be deprectated.",
87
+ DeprecationWarning,
88
+ )
89
+ input = MultitaskMultivariateNormal.from_batch_mvn(input)
90
+ elif not isinstance(input, MultitaskMultivariateQExponential) and hasattr(input, 'power'):
91
+ warnings.warn(
92
+ "The input to SoftmaxLikelihood should be a MultitaskMultivariateQExponential (num_data x num_tasks). "
93
+ "Batch MultivariateQExponential inputs (num_tasks x num_data) will be deprectated.",
94
+ DeprecationWarning,
95
+ )
96
+ input = MultitaskMultivariateQExponential.from_batch_qep(input)
97
+ return super().__call__(input, *args, **kwargs)
@@ -0,0 +1,90 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torch.distributions import StudentT
8
+
9
+ from ..constraints import GreaterThan, Interval, Positive
10
+ from ..distributions import base_distributions
11
+ from ..priors import Prior
12
+ from .likelihood import _OneDimensionalLikelihood
13
+
14
+
15
+ class StudentTLikelihood(_OneDimensionalLikelihood):
16
+ r"""
17
+ A Student T likelihood/noise model for GP (QEP) regression.
18
+ It has two learnable parameters: :math:`\nu` - the degrees of freedom, and
19
+ :math:`\sigma^2` - the noise
20
+
21
+ :param batch_shape: The batch shape of the learned noise parameter (default: []).
22
+ :param noise_prior: Prior for noise parameter :math:`\sigma^2`.
23
+ :param noise_constraint: Constraint for noise parameter :math:`\sigma^2`.
24
+ :param deg_free_prior: Prior for deg_free parameter :math:`\nu`.
25
+ :param deg_free_constraint: Constraint for deg_free parameter :math:`\nu`.
26
+
27
+ :var torch.Tensor deg_free: :math:`\nu` parameter (degrees of freedom)
28
+ :var torch.Tensor noise: :math:`\sigma^2` parameter (noise)
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ batch_shape: torch.Size = torch.Size([]),
34
+ deg_free_prior: Optional[Prior] = None,
35
+ deg_free_constraint: Optional[Interval] = None,
36
+ noise_prior: Optional[Prior] = None,
37
+ noise_constraint: Optional[Interval] = None,
38
+ ) -> None:
39
+ super().__init__()
40
+
41
+ if deg_free_constraint is None:
42
+ deg_free_constraint = GreaterThan(2)
43
+
44
+ if noise_constraint is None:
45
+ noise_constraint = Positive()
46
+
47
+ self.raw_deg_free = torch.nn.Parameter(torch.zeros(*batch_shape, 1))
48
+ self.raw_noise = torch.nn.Parameter(torch.zeros(*batch_shape, 1))
49
+
50
+ if noise_prior is not None:
51
+ self.register_prior("noise_prior", noise_prior, lambda m: m.noise, lambda m, v: m._set_noise(v))
52
+
53
+ self.register_constraint("raw_noise", noise_constraint)
54
+
55
+ if deg_free_prior is not None:
56
+ self.register_prior("deg_free_prior", deg_free_prior, lambda m: m.deg_free, lambda m, v: m._set_deg_free(v))
57
+
58
+ self.register_constraint("raw_deg_free", deg_free_constraint)
59
+
60
+ # Rough initialization
61
+ self.initialize(deg_free=7)
62
+
63
+ @property
64
+ def deg_free(self) -> Tensor:
65
+ return self.raw_deg_free_constraint.transform(self.raw_deg_free)
66
+
67
+ @deg_free.setter
68
+ def deg_free(self, value: Tensor) -> None:
69
+ self._set_deg_free(value)
70
+
71
+ def _set_deg_free(self, value: Tensor) -> None:
72
+ if not torch.is_tensor(value):
73
+ value = torch.as_tensor(value).to(self.raw_deg_free)
74
+ self.initialize(raw_deg_free=self.raw_deg_free_constraint.inverse_transform(value))
75
+
76
+ @property
77
+ def noise(self) -> Tensor:
78
+ return self.raw_noise_constraint.transform(self.raw_noise)
79
+
80
+ @noise.setter
81
+ def noise(self, value: Tensor) -> None:
82
+ self._set_noise(value)
83
+
84
+ def _set_noise(self, value: Tensor) -> None:
85
+ if not torch.is_tensor(value):
86
+ value = torch.as_tensor(value).to(self.raw_noise)
87
+ self.initialize(raw_noise=self.raw_noise_constraint.inverse_transform(value))
88
+
89
+ def forward(self, function_samples: Tensor, *args: Any, **kwargs: Any) -> StudentT:
90
+ return base_distributions.StudentT(df=self.deg_free, loc=function_samples, scale=self.noise.sqrt())
@@ -0,0 +1,23 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from gpytorch.means.constant_mean import ConstantMean
4
+ from gpytorch.means.constant_mean_grad import ConstantMeanGrad
5
+ from gpytorch.means.constant_mean_gradgrad import ConstantMeanGradGrad
6
+ from gpytorch.means.linear_mean import LinearMean
7
+ from gpytorch.means.linear_mean_grad import LinearMeanGrad
8
+ from gpytorch.means.linear_mean_gradgrad import LinearMeanGradGrad
9
+ from gpytorch.means.mean import Mean
10
+ from gpytorch.means.multitask_mean import MultitaskMean
11
+ from gpytorch.means.zero_mean import ZeroMean
12
+
13
+ __all__ = [
14
+ "Mean",
15
+ "ConstantMean",
16
+ "ConstantMeanGrad",
17
+ "ConstantMeanGradGrad",
18
+ "LinearMean",
19
+ "LinearMeanGrad",
20
+ "LinearMeanGradGrad",
21
+ "MultitaskMean",
22
+ "ZeroMean",
23
+ ]
@@ -0,0 +1,17 @@
1
+ from gpytorch.metrics import (
2
+ mean_absolute_error,
3
+ mean_squared_error,
4
+ mean_standardized_log_loss,
5
+ negative_log_predictive_density,
6
+ quantile_coverage_error,
7
+ standardized_mean_squared_error,
8
+ )
9
+
10
+ __all__ = [
11
+ "mean_absolute_error",
12
+ "mean_squared_error",
13
+ "standardized_mean_squared_error",
14
+ "mean_standardized_log_loss",
15
+ "negative_log_predictive_density",
16
+ "quantile_coverage_error",
17
+ ]
@@ -0,0 +1,53 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import warnings
4
+
5
+ from gpytorch.mlls.added_loss_term import AddedLossTerm
6
+ from .deep_approximate_mll import DeepApproximateMLL
7
+ from .deep_predictive_log_likelihood import DeepPredictiveLogLikelihood
8
+ from .exact_marginal_log_likelihood import ExactMarginalLogLikelihood
9
+ from .gamma_robust_variational_elbo import GammaRobustVariationalELBO
10
+ from .inducing_point_kernel_added_loss_term import InducingPointKernelAddedLossTerm
11
+ from gpytorch.mlls.kl_gaussian_added_loss_term import KLGaussianAddedLossTerm
12
+ from .kl_qexponential_added_loss_term import KLQExponentialAddedLossTerm
13
+ from .leave_one_out_pseudo_likelihood import LeaveOneOutPseudoLikelihood
14
+ from .marginal_log_likelihood import MarginalLogLikelihood
15
+ from gpytorch.mlls.noise_model_added_loss_term import NoiseModelAddedLossTerm
16
+ from .predictive_log_likelihood import PredictiveLogLikelihood
17
+ from .sum_marginal_log_likelihood import SumMarginalLogLikelihood
18
+ from .variational_elbo import VariationalELBO
19
+
20
+
21
+ # Deprecated for 0.4 release
22
+ class VariationalMarginalLogLikelihood(VariationalELBO):
23
+ def __init__(self, *args, **kwargs):
24
+ # Remove after 1.0
25
+ warnings.warn(
26
+ "VariationalMarginalLogLikelihood is deprecated. Please use VariationalELBO instead.", DeprecationWarning
27
+ )
28
+ super().__init__(*args, **kwargs)
29
+
30
+
31
+ class VariationalELBOEmpirical(VariationalELBO):
32
+ def __init__(self, *args, **kwargs):
33
+ # Remove after 1.0
34
+ warnings.warn("VariationalELBOEmpirical is deprecated. Please use VariationalELBO instead.", DeprecationWarning)
35
+ super().__init__(*args, **kwargs)
36
+
37
+
38
+ __all__ = [
39
+ "AddedLossTerm",
40
+ "DeepApproximateMLL",
41
+ "DeepPredictiveLogLikelihood",
42
+ "ExactMarginalLogLikelihood",
43
+ "InducingPointKernelAddedLossTerm",
44
+ "LeaveOneOutPseudoLikelihood",
45
+ "KLGaussianAddedLossTerm",
46
+ "KLQExponentialAddedLossTerm",
47
+ "MarginalLogLikelihood",
48
+ "NoiseModelAddedLossTerm",
49
+ "PredictiveLogLikelihood",
50
+ "GammaRobustVariationalELBO",
51
+ "SumMarginalLogLikelihood",
52
+ "VariationalELBO",
53
+ ]
@@ -0,0 +1,79 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ import torch
6
+
7
+ from .marginal_log_likelihood import MarginalLogLikelihood
8
+
9
+
10
+ class _ApproximateMarginalLogLikelihood(MarginalLogLikelihood, ABC):
11
+ r"""
12
+ An approximate marginal log likelihood (typically a bound) for approximate GP (QEP) models.
13
+ We expect that model is a :obj:`gpytorch.models.ApproximateGP` or :obj:`qpytorch.models.ApproximateQEP`.
14
+
15
+ Args:
16
+ likelihood (:obj:`qpytorch.likelihoods.Likelihood`):
17
+ The likelihood for the model
18
+ model (:obj:`gpytorch.models.ApproximateGP` or :obj:`qpytorch.models.ApproximateQEP`):
19
+ The approximate GP (QEP) model
20
+ num_data (int):
21
+ The total number of training data points (necessary for SGD)
22
+ beta (float - default 1.):
23
+ A multiplicative factor for the KL divergence term.
24
+ Setting it to 1 (default) recovers true variational inference
25
+ (as derived in `Scalable Variational Gaussian (Q-Exponential) Process Classification`_).
26
+ Setting it to anything less than 1 reduces the regularization effect of the model
27
+ (similarly to what was proposed in `the beta-VAE paper`_).
28
+ combine_terms (bool):
29
+ Whether or not to sum the expected NLL with the KL terms (default True)
30
+ """
31
+
32
+ def __init__(self, likelihood, model, num_data, beta=1.0, combine_terms=True):
33
+ super().__init__(likelihood, model)
34
+ self.combine_terms = combine_terms
35
+ self.num_data = num_data
36
+ self.beta = beta
37
+
38
+ @abstractmethod
39
+ def _log_likelihood_term(self, approximate_dist_f, target, **kwargs):
40
+ raise NotImplementedError
41
+
42
+ def forward(self, approximate_dist_f, target, **kwargs):
43
+ r"""
44
+ Computes the Variational ELBO given :math:`q(\mathbf f)` and `\mathbf y`.
45
+ Calling this function will call the likelihood's `expected_log_prob` function.
46
+
47
+ Args:
48
+ approximate_dist_f (:obj:`gpytorch.distributions.MultivariateNormal` or :obj:`qpytorch.distributions.MultivariateQExponential`):
49
+ :math:`q(\mathbf f)` the outputs of the latent function (the :obj:`gpytorch.models.ApproximateGP` or :obj:`qpytorch.models.ApproximateQEP`)
50
+ target (`torch.Tensor`):
51
+ :math:`\mathbf y` The target values
52
+
53
+ Keyword Args:
54
+ Additional arguments passed to the likelihood's `expected_log_prob` function.
55
+ """
56
+ # Get likelihood term and KL term
57
+ num_batch = approximate_dist_f.event_shape[0]
58
+ log_likelihood = self._log_likelihood_term(approximate_dist_f, target, **kwargs).div(num_batch)
59
+ kl_divergence = self.model.variational_strategy.kl_divergence().div(self.num_data / self.beta)
60
+
61
+ # Add any additional registered loss terms
62
+ added_loss = torch.zeros_like(log_likelihood)
63
+ had_added_losses = False
64
+ for added_loss_term in self.model.added_loss_terms():
65
+ added_loss.add_(added_loss_term.loss())
66
+ had_added_losses = True
67
+
68
+ # Log prior term
69
+ log_prior = torch.zeros_like(log_likelihood)
70
+ for name, module, prior, closure, _ in self.named_priors():
71
+ log_prior.add_(prior.log_prob(closure(module)).sum().div(self.num_data))
72
+
73
+ if self.combine_terms:
74
+ return log_likelihood - kl_divergence + log_prior - added_loss
75
+ else:
76
+ if had_added_losses:
77
+ return log_likelihood, kl_divergence, log_prior, added_loss
78
+ else:
79
+ return log_likelihood, kl_divergence, log_prior
@@ -0,0 +1,30 @@
1
+ from ._approximate_mll import _ApproximateMarginalLogLikelihood
2
+
3
+
4
+ class DeepApproximateMLL(_ApproximateMarginalLogLikelihood):
5
+ """
6
+ A wrapper to make a QPyTorch approximate marginal log likelihoods compatible with Deep QEPs.
7
+
8
+ Example:
9
+ >>> deep_mll = qpytorch.mlls.DeepApproximateMLL(
10
+ >>> qpytorch.mlls.VariationalELBO(likelihood, model, num_data=1000)
11
+ >>> )
12
+
13
+ :param ~qpytorch.mlls._ApproximateMarginalLogLikelihood base_mll: The base
14
+ approximate MLL
15
+ """
16
+
17
+ def __init__(self, base_mll):
18
+ if not base_mll.combine_terms:
19
+ raise ValueError(
20
+ "The base marginal log likelihood object should combine terms "
21
+ "when used in conjunction with a DeepApproximateMLL."
22
+ )
23
+ super().__init__(base_mll.likelihood, base_mll.model, num_data=base_mll.num_data, beta=base_mll.beta)
24
+ self.base_mll = base_mll
25
+
26
+ def _log_likelihood_term(self, approximate_dist_f, target, **kwargs):
27
+ return self.base_mll._log_likelihood_term(approximate_dist_f, target, **kwargs).mean(0)
28
+
29
+ def forward(self, approximate_dist_f, target, **kwargs):
30
+ return self.base_mll.forward(approximate_dist_f, target, **kwargs).mean(0)
@@ -0,0 +1,32 @@
1
+ from ..models.deep_qeps.dspp import DSPP
2
+ from ._approximate_mll import _ApproximateMarginalLogLikelihood
3
+
4
+
5
+ class DeepPredictiveLogLikelihood(_ApproximateMarginalLogLikelihood):
6
+ """
7
+ An implementation of the predictive log likelihood extended to DSPPs as discussed in Jankowiak et al., 2020.
8
+
9
+ If you are using a DSPP model, this is the loss object you want to create and optimize over.
10
+
11
+ This loss object is compatible only with models of type :obj:~qpytorch.models.deep_qeps.DSPP
12
+ """
13
+
14
+ def __init__(self, likelihood, model, num_data, beta=1.0, combine_terms=True):
15
+ if not combine_terms:
16
+ raise ValueError(
17
+ "The base marginal log likelihood object should combine terms "
18
+ "when used in conjunction with a DeepApproximateMLL."
19
+ )
20
+
21
+ if not isinstance(model, DSPP):
22
+ raise ValueError("DeepPredictiveLogLikelihood can only be used with a DSPP model.")
23
+
24
+ super().__init__(likelihood, model, num_data, beta, combine_terms)
25
+
26
+ def _log_likelihood_term(self, approximate_dist_f, target, **kwargs):
27
+ base_log_marginal = self.likelihood.log_marginal(target, approximate_dist_f, **kwargs)
28
+ deep_log_marginal = self.model.quad_weights.unsqueeze(-1) + base_log_marginal
29
+
30
+ deep_log_prob = deep_log_marginal.logsumexp(dim=0)
31
+
32
+ return deep_log_prob.sum(-1)
@@ -0,0 +1,96 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from linear_operator.operators import MaskedLinearOperator
4
+
5
+ from .. import settings
6
+ from ..distributions import MultivariateNormal, MultivariateQExponential
7
+ from ..likelihoods import _GaussianLikelihoodBase, _QExponentialLikelihoodBase
8
+ from .marginal_log_likelihood import MarginalLogLikelihood
9
+
10
+
11
+ class ExactMarginalLogLikelihood(MarginalLogLikelihood):
12
+ """
13
+ The exact marginal log likelihood (MLL) for an exact Gaussian (Q-Exponential) process with a
14
+ Gaussian (Q-Exponential) likelihood.
15
+
16
+ .. note::
17
+ This module will not work with anything other than a :obj:`~qpytorch.likelihoods.GaussianLikelihood`
18
+ (:obj:`~qpytorch.likelihoods.QExponentialLikelihood`) and a :obj:`~gpytorch.models.ExactGP` (:obj:`~qpytorch.models.ExactQEP`).
19
+ It also cannot be used in conjunction with stochastic optimization.
20
+
21
+ :param ~qpytorch.likelihoods.GaussianLikelihood (~qpytorch.likelihoods.QExponentialLikelihood) likelihood: The Gaussian (Q-Exponential) likelihood for the model
22
+ :param ~gpytorch.models.ExactGP (~qpytorch.models.ExactQEP) model: The exact GP (QEP) model
23
+
24
+ Example:
25
+ >>> # model is a qpytorch.models.ExactGP or qpytorch.models.ExactQEP
26
+ >>> # likelihood is a qpytorch.likelihoods.Likelihood
27
+ >>> mll = qpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
28
+ >>>
29
+ >>> output = model(train_x)
30
+ >>> loss = -mll(output, train_y)
31
+ >>> loss.backward()
32
+ """
33
+
34
+ def __init__(self, likelihood, model):
35
+ if not isinstance(likelihood, (_GaussianLikelihoodBase, _QExponentialLikelihoodBase)):
36
+ raise RuntimeError("Likelihood must be Gaussian or Q-Exponential for exact inference")
37
+ super(ExactMarginalLogLikelihood, self).__init__(likelihood, model)
38
+
39
+ def _add_other_terms(self, res, params):
40
+ # Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models)
41
+ for added_loss_term in self.model.added_loss_terms():
42
+ res = res.add(added_loss_term.loss(*params))
43
+
44
+ # Add log probs of priors on the (functions of) parameters
45
+ res_ndim = res.ndim
46
+ for name, module, prior, closure, _ in self.model.named_priors():
47
+ prior_term = prior.log_prob(closure(module))
48
+ res.add_(prior_term.view(*prior_term.shape[:res_ndim], -1).sum(dim=-1))
49
+
50
+ return res
51
+
52
+ def forward(self, function_dist, target, *params, **kwargs):
53
+ r"""
54
+ Computes the MLL given :math:`p(\mathbf f)` and :math:`\mathbf y`.
55
+
56
+ :param ~gpytorch.distributions.MultivariateNormal or ~qpytorch.distributions.MultivariateQExponential function_dist: :math:`p(\mathbf f)`
57
+ the outputs of the latent function (the :obj:`gpytorch.models.ExactGP` or :obj:`qpytorch.models.ExactQEP`)
58
+ :param torch.Tensor target: :math:`\mathbf y` The target values
59
+ :rtype: torch.Tensor
60
+ :return: Exact MLL. Output shape corresponds to batch shape of the model/input data.
61
+ """
62
+ if not isinstance(function_dist, (MultivariateNormal, MultivariateQExponential)):
63
+ raise RuntimeError("ExactMarginalLogLikelihood can only operate on Gaussian or Q-Exponential random variables")
64
+
65
+ # Determine output likelihood
66
+ output = self.likelihood(function_dist, *params, **kwargs)
67
+
68
+ # Remove NaN values if enabled
69
+ if settings.observation_nan_policy.value() == "mask":
70
+ observed = settings.observation_nan_policy._get_observed(target, output.event_shape)
71
+ if isinstance(function_dist, MultivariateNormal):
72
+ output = MultivariateNormal(
73
+ mean=output.mean[..., observed],
74
+ covariance_matrix=MaskedLinearOperator(
75
+ output.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1)
76
+ ),
77
+ )
78
+ elif isinstance(function_dist, MultivariateQExponential):
79
+ output = MultivariateQExponential(
80
+ mean=output.mean[..., observed],
81
+ covariance_matrix=MaskedLinearOperator(
82
+ output.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1)
83
+ ),
84
+ power=output.power
85
+ )
86
+ target = target[..., observed]
87
+ elif settings.observation_nan_policy.value() == "fill":
88
+ raise ValueError("NaN observation policy 'fill' is not supported by ExactMarginalLogLikelihood!")
89
+
90
+ # Get the log prob of the marginal distribution
91
+ res = output.log_prob(target)
92
+ res = self._add_other_terms(res, params)
93
+
94
+ # Scale by the amount of data we have
95
+ num_data = function_dist.event_shape.numel()
96
+ return res.div_(num_data)
@@ -0,0 +1,106 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import math
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from ..likelihoods import _GaussianLikelihoodBase, _QExponentialLikelihoodBase
9
+ from ._approximate_mll import _ApproximateMarginalLogLikelihood
10
+
11
+
12
+ class GammaRobustVariationalELBO(_ApproximateMarginalLogLikelihood):
13
+ r"""
14
+ An alternative to the variational evidence lower bound (ELBO), proposed by `Knoblauch, 2019`_.
15
+ It is derived by replacing the log-likelihood term in the ELBO with a `\gamma` divergence:
16
+
17
+ .. math::
18
+
19
+ \begin{align*}
20
+ \mathcal{L}_{\gamma} &=
21
+ \sum_{i=1}^N \mathbb{E}_{q( \mathbf u)} \left[
22
+ -\frac{\gamma}{\gamma - 1}
23
+ \frac{
24
+ p( y_i \! \mid \! \mathbf u, x_i)^{\gamma - 1}
25
+ }{
26
+ \int p(y \mid \mathbf u, x_i)^{\gamma} \: dy
27
+ }
28
+ \right] - \beta \: \text{KL} \left[ q( \mathbf u) \Vert p( \mathbf u) \right]
29
+ \end{align*}
30
+
31
+ where :math:`N` is the number of datapoints, :math:`\gamma` is a hyperparameter,
32
+ :math:`q(\mathbf u)` is the variational distribution for
33
+ the inducing function values, and :math:`p(\mathbf u)` is the prior distribution for the inducing function
34
+ values.
35
+
36
+ :math:`\beta` is a scaling constant for the KL divergence.
37
+
38
+ .. note::
39
+ This module will only work with :obj:`~qpytorch.likelihoods.GaussianLikelihood` or :obj:`~qpytorch.likelihoods.QExponentialLikelihood`.
40
+
41
+ :param ~qpytorch.likelihoods.GaussianLikelihood (~qpytorch.likelihoods.QExponentialLikelihood) likelihood: The likelihood for the model
42
+ :param ~gpytorch.models.ApproximateGP (~qpytorch.models.ApproximateQEP) model: The approximate GP (QEP) model
43
+ :param int num_data: The total number of training data points (necessary for SGD)
44
+ :param float beta: (optional, default=1.) A multiplicative factor for the KL divergence term.
45
+ Setting it to anything less than 1 reduces the regularization effect of the model
46
+ (similarly to what was proposed in `the beta-VAE paper`_).
47
+ :param float gamma: (optional, default=1.03) The :math:`\gamma`-divergence hyperparameter.
48
+ :param bool combine_terms: (default=True): Whether or not to sum the
49
+ expected NLL with the KL terms (default True)
50
+
51
+ Example:
52
+ >>> # model is a qpytorch.models.ApproximateGP or qpytorch.models.ApproximateQEP
53
+ >>> # likelihood is a qpytorch.likelihoods.Likelihood
54
+ >>> mll = qpytorch.mlls.GammaRobustVariationalELBO(likelihood, model, num_data=100, beta=0.5, gamma=1.03)
55
+ >>>
56
+ >>> output = model(train_x)
57
+ >>> loss = -mll(output, train_y)
58
+ >>> loss.backward()
59
+
60
+ .. _Knoblauch, 2019:
61
+ https://arxiv.org/pdf/1904.02303.pdf
62
+ .. _Knoblauch, Jewson, Damoulas 2019:
63
+ https://arxiv.org/pdf/1904.02063.pdf
64
+ """
65
+
66
+ def __init__(self, likelihood, model, gamma=1.03, *args, **kwargs):
67
+ if not isinstance(likelihood, (_GaussianLikelihoodBase, _QExponentialLikelihoodBase)):
68
+ raise RuntimeError("Likelihood must be Gaussian or Q-Exponential for exact inference")
69
+ super().__init__(likelihood, model, *args, **kwargs)
70
+ if gamma <= 1.0:
71
+ raise ValueError("gamma should be > 1.0")
72
+ self.gamma = gamma
73
+
74
+ def _log_likelihood_term(self, variational_dist_f, target, *args, **kwargs):
75
+ shifted_gamma = self.gamma - 1
76
+
77
+ muf, varf = variational_dist_f.mean, variational_dist_f.variance
78
+
79
+ # Get noise from likelihood
80
+ noise = self.likelihood._shaped_noise_covar(muf.shape, *args, **kwargs).diagonal(dim1=-1, dim2=-2)
81
+ # Potentially reshape the noise to deal with the multitask case
82
+ noise = noise.view(*noise.shape[:-1], *variational_dist_f.event_shape)
83
+
84
+ # adapted from https://github.com/JeremiasKnoblauch/GVIPublic/
85
+ mut = shifted_gamma * target / noise + muf / varf
86
+ sigmat = 1.0 / (shifted_gamma / noise + 1.0 / varf)
87
+ log_integral = -0.5 * shifted_gamma * torch.log(2.0 * math.pi * noise) - 0.5 * np.log1p(shifted_gamma)
88
+ log_tempered = (
89
+ -math.log(shifted_gamma)
90
+ - 0.5 * shifted_gamma * torch.log(2.0 * math.pi * noise)
91
+ - 0.5 * torch.log1p(shifted_gamma * varf / noise)
92
+ - 0.5 * (shifted_gamma * target.pow(2.0) / noise)
93
+ - 0.5 * muf.pow(2.0) / varf
94
+ + 0.5 * mut.pow(2.0) * sigmat
95
+ )
96
+ # TODO: verify for Q-Exponential
97
+
98
+ factor = log_tempered + shifted_gamma / self.gamma * log_integral
99
+ factor = self.gamma * factor.exp()
100
+
101
+ # Do appropriate summation for multitask Gaussian (Q-Exponential) likelihoods
102
+ num_event_dim = len(variational_dist_f.event_shape)
103
+ if num_event_dim > 1:
104
+ factor = factor.sum(list(range(-1, -num_event_dim, -1)))
105
+
106
+ return factor.sum(-1)