qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,41 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ from ..distributions import Delta, Distribution, MultivariateNormal, MultivariateQExponential
8
+ from ._variational_distribution import _VariationalDistribution
9
+
10
+
11
+ class DeltaVariationalDistribution(_VariationalDistribution):
12
+ """
13
+ This :obj:`~qpytorch.variational._VariationalDistribution` object replaces a variational distribution
14
+ with a single particle. It is equivalent to doing MAP inference.
15
+
16
+ :param int num_inducing_points: Size of the variational distribution. This implies that the variational mean
17
+ should be this size, and the variational covariance matrix should have this many rows and columns.
18
+ :param batch_shape: Specifies an optional batch size
19
+ for the variational parameters. This is useful for example when doing additive variational inference.
20
+ :type batch_shape: :obj:`torch.Size`, optional
21
+ :param float mean_init_std: (Default: 1e-3) Standard deviation of gaussian (q-exponential) noise to add to the mean initialization.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ num_inducing_points: int,
27
+ batch_shape: torch.Size = torch.Size([]),
28
+ mean_init_std: float = 1e-3,
29
+ **kwargs,
30
+ ):
31
+ super().__init__(num_inducing_points=num_inducing_points, batch_shape=batch_shape, mean_init_std=mean_init_std)
32
+ mean_init = torch.zeros(num_inducing_points)
33
+ mean_init = mean_init.repeat(*batch_shape, 1)
34
+ self.register_parameter(name="variational_mean", parameter=torch.nn.Parameter(mean_init))
35
+
36
+ def forward(self) -> Distribution:
37
+ return Delta(self.variational_mean)
38
+
39
+ def initialize_variational_distribution(self, prior_dist: Union[MultivariateNormal, MultivariateQExponential]) -> None:
40
+ self.variational_mean.data.copy_(prior_dist.mean)
41
+ self.variational_mean.data.add_(torch.randn_like(prior_dist.mean), alpha=self.mean_init_std)
@@ -0,0 +1,113 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import torch
4
+ from linear_operator.operators import InterpolatedLinearOperator
5
+ from linear_operator.utils.interpolation import left_interp
6
+
7
+ from ..distributions import MultivariateNormal, MultivariateQExponential
8
+ from gpytorch.utils.interpolation import Interpolation
9
+ from gpytorch.utils.memoize import cached
10
+ from ._variational_strategy import _VariationalStrategy
11
+
12
+
13
+ class GridInterpolationVariationalStrategy(_VariationalStrategy):
14
+ r"""
15
+ This strategy constrains the inducing points to a grid and applies a deterministic
16
+ relationship between :math:`\mathbf f` and :math:`\mathbf u`.
17
+ It was introduced by `Wilson et al. (2016)`_.
18
+
19
+ Here, the inducing points are not learned. Instead, the strategy
20
+ automatically creates inducing points based on a set of grid sizes and grid
21
+ bounds.
22
+
23
+ .. _Wilson et al. (2016):
24
+ https://arxiv.org/abs/1611.00336
25
+
26
+ :param ~gpytorch.models.ApproximateGP (~qpytorch.models.ApproximateQEP) model: Model this strategy is applied to.
27
+ Typically passed in when the VariationalStrategy is created in the
28
+ __init__ method of the user defined model.
29
+ It should contain power if Q-Exponential distribution is involved in.
30
+ :param int grid_size: Size of the grid
31
+ :param list grid_bounds: Bounds of each dimension of the grid (should be a list of (float, float) tuples)
32
+ :param ~qpytorch.variational.VariationalDistribution variational_distribution: A
33
+ VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
34
+ """
35
+
36
+ def __init__(self, model, grid_size, grid_bounds, variational_distribution):
37
+ grid = torch.zeros(grid_size, len(grid_bounds))
38
+ for i in range(len(grid_bounds)):
39
+ grid_diff = float(grid_bounds[i][1] - grid_bounds[i][0]) / (grid_size - 2)
40
+ grid[:, i] = torch.linspace(grid_bounds[i][0] - grid_diff, grid_bounds[i][1] + grid_diff, grid_size)
41
+
42
+ inducing_points = torch.zeros(int(pow(grid_size, len(grid_bounds))), len(grid_bounds))
43
+ prev_points = None
44
+ for i in range(len(grid_bounds)):
45
+ for j in range(grid_size):
46
+ inducing_points[j * grid_size**i : (j + 1) * grid_size**i, i].fill_(grid[j, i])
47
+ if prev_points is not None:
48
+ inducing_points[j * grid_size**i : (j + 1) * grid_size**i, :i].copy_(prev_points)
49
+ prev_points = inducing_points[: grid_size ** (i + 1), : (i + 1)]
50
+
51
+ super(GridInterpolationVariationalStrategy, self).__init__(
52
+ model, inducing_points, variational_distribution, learn_inducing_locations=False
53
+ )
54
+ object.__setattr__(self, "model", model)
55
+
56
+ self.register_buffer("grid", grid)
57
+
58
+ def _compute_grid(self, inputs):
59
+ n_data, n_dimensions = inputs.size(-2), inputs.size(-1)
60
+ batch_shape = inputs.shape[:-2]
61
+
62
+ inputs = inputs.reshape(-1, n_dimensions)
63
+ interp_indices, interp_values = Interpolation().interpolate(self.grid, inputs)
64
+ interp_indices = interp_indices.view(*batch_shape, n_data, -1)
65
+ interp_values = interp_values.view(*batch_shape, n_data, -1)
66
+
67
+ if (interp_indices.dim() - 2) != len(self._variational_distribution.batch_shape):
68
+ batch_shape = torch.broadcast_shapes(interp_indices.shape[:-2], self._variational_distribution.batch_shape)
69
+ interp_indices = interp_indices.expand(*batch_shape, *interp_indices.shape[-2:])
70
+ interp_values = interp_values.expand(*batch_shape, *interp_values.shape[-2:])
71
+ return interp_indices, interp_values
72
+
73
+ @property
74
+ @cached(name="prior_distribution_memo")
75
+ def prior_distribution(self):
76
+ out = self.model.forward(self.inducing_points)
77
+ # TODO: investigate why smaller than 1e-3 breaks some tests
78
+ if hasattr(self.model, 'power'):
79
+ res = MultivariateQExponential(out.mean, out.lazy_covariance_matrix.add_jitter(1e-3), power=self.model.power)
80
+ else:
81
+ res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter(1e-3))
82
+ return res
83
+
84
+ def forward(self, x, inducing_points, inducing_values, variational_inducing_covar=None):
85
+ if variational_inducing_covar is None:
86
+ raise RuntimeError(
87
+ "GridInterpolationVariationalStrategy is only compatible with Gaussian (Q-Exponential) variational "
88
+ f"distributions. Got ({self.variational_distribution.__class__.__name__}."
89
+ )
90
+
91
+ variational_distribution = self.variational_distribution
92
+
93
+ # Get interpolations
94
+ interp_indices, interp_values = self._compute_grid(x)
95
+
96
+ # Compute test mean
97
+ # Left multiply samples by interpolation matrix
98
+ predictive_mean = left_interp(interp_indices, interp_values, inducing_values.unsqueeze(-1))
99
+ predictive_mean = predictive_mean.squeeze(-1)
100
+
101
+ # Compute test covar
102
+ predictive_covar = InterpolatedLinearOperator(
103
+ variational_distribution.lazy_covariance_matrix,
104
+ interp_indices,
105
+ interp_values,
106
+ interp_indices,
107
+ interp_values,
108
+ )
109
+ if hasattr(self.model, 'power'):
110
+ output = MultivariateQExponential(predictive_mean, predictive_covar, power=self.model.power)
111
+ else:
112
+ output = MultivariateNormal(predictive_mean, predictive_covar)
113
+ return output
@@ -0,0 +1,114 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import warnings
4
+
5
+ import torch
6
+ from linear_operator.operators import RootLinearOperator
7
+
8
+ from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
9
+ from ..module import Module
10
+ from ._variational_strategy import _VariationalStrategy
11
+
12
+
13
+ class IndependentMultitaskVariationalStrategy(_VariationalStrategy):
14
+ """
15
+ IndependentMultitaskVariationalStrategy wraps an existing
16
+ :obj:`~qpytorch.variational.VariationalStrategy` to produce vector-valued (multi-task)
17
+ output distributions. Each task will be independent of one another.
18
+
19
+ The output will either be a :obj:`~gpytorch.distributions.MultitaskMultivariateNormal` distribution
20
+ (if we wish to evaluate all tasks for each input) or a :obj:`~gpytorch.distributions.MultivariateNormal`
21
+ (if we wish to evaluate a single task for each input).
22
+
23
+ The base variational strategy is assumed to operate on a batch of GPs. One of the batch
24
+ dimensions corresponds to the multiple tasks.
25
+
26
+ :param ~qpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy
27
+ :param int num_tasks: Number of tasks. Should correspond to the batch size of task_dim.
28
+ :param int task_dim: (Default: -1) Which batch dimension is the task dimension
29
+ """
30
+
31
+ def __init__(self, base_variational_strategy, num_tasks, task_dim=-1):
32
+ Module.__init__(self)
33
+ self.base_variational_strategy = base_variational_strategy
34
+ self.task_dim = task_dim
35
+ self.num_tasks = num_tasks
36
+
37
+ @property
38
+ def prior_distribution(self):
39
+ return self.base_variational_strategy.prior_distribution
40
+
41
+ @property
42
+ def variational_distribution(self):
43
+ return self.base_variational_strategy.variational_distribution
44
+
45
+ @property
46
+ def variational_params_initialized(self):
47
+ return self.base_variational_strategy.variational_params_initialized
48
+
49
+ def kl_divergence(self):
50
+ return super().kl_divergence().sum(dim=-1)
51
+
52
+ def __call__(self, x, task_indices=None, prior=False, **kwargs):
53
+ r"""
54
+ See :class:`LMCVariationalStrategy`.
55
+ """
56
+ function_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
57
+
58
+ if task_indices is None:
59
+ # Every data point will get an output for each task
60
+ if (
61
+ self.task_dim > 0
62
+ and self.task_dim > len(function_dist.batch_shape)
63
+ or self.task_dim < 0
64
+ and self.task_dim + len(function_dist.batch_shape) < 0
65
+ ):
66
+ return MultitaskMultivariateNormal.from_repeated_mvn(function_dist, num_tasks=self.num_tasks)
67
+ else:
68
+ function_dist = MultitaskMultivariateNormal.from_batch_mvn(function_dist, task_dim=self.task_dim)
69
+ assert function_dist.event_shape[-1] == self.num_tasks
70
+ return function_dist
71
+
72
+ else:
73
+ # Each data point will get a single output corresponding to a single task
74
+
75
+ if self.task_dim > 0:
76
+ raise RuntimeError(f"task_dim must be a negative indexed batch dimension: got {self.task_dim}.")
77
+ num_batch = len(function_dist.batch_shape)
78
+ task_dim = num_batch + self.task_dim
79
+
80
+ # Create a mask to choose specific task assignment
81
+ shape = list(function_dist.batch_shape + function_dist.event_shape)
82
+ shape[task_dim] = 1
83
+ task_indices = task_indices.expand(shape).squeeze(task_dim)
84
+
85
+ # Create a mask to choose specific task assignment
86
+ task_mask = torch.nn.functional.one_hot(task_indices, num_classes=self.num_tasks)
87
+ task_mask = task_mask.permute(*range(0, task_dim), *range(task_dim + 1, num_batch + 1), task_dim)
88
+
89
+ mean = (function_dist.mean * task_mask).sum(task_dim)
90
+ covar = (function_dist.lazy_covariance_matrix * RootLinearOperator(task_mask[..., None])).sum(task_dim)
91
+ return MultivariateNormal(mean, covar)
92
+
93
+
94
+ class MultitaskVariationalStrategy(IndependentMultitaskVariationalStrategy):
95
+ """
96
+ IndependentMultitaskVariationalStrategy wraps an existing
97
+ :obj:`~qpytorch.variational.VariationalStrategy`
98
+ to produce a :obj:`~gpytorch.variational.MultitaskMultivariateNormal` distribution.
99
+ All outputs will be independent of one another.
100
+
101
+ The base variational strategy is assumed to operate on a batch of GPs. One of the batch
102
+ dimensions corresponds to the multiple tasks.
103
+
104
+ :param ~qpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy
105
+ :param int num_tasks: Number of tasks. Should correspond to the batch size of task_dim.
106
+ :param int task_dim: (Default: -1) Which batch dimension is the task dimension
107
+ """
108
+
109
+ def __init__(self, base_variational_strategy, num_tasks, task_dim=-1):
110
+ warnings.warn(
111
+ "MultitaskVariationalStrategy has been renamed to IndependentMultitaskVariationalStrategy",
112
+ DeprecationWarning,
113
+ )
114
+ super().__init__(base_variational_strategy, num_tasks, task_dim=-1)
@@ -0,0 +1,248 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ from linear_operator.operators import KroneckerProductLinearOperator, RootLinearOperator
7
+ from linear_operator.utils.interpolation import left_interp
8
+ from torch import LongTensor, Tensor
9
+
10
+ from .. import settings
11
+ from ..distributions import MultitaskMultivariateNormal, MultitaskMultivariateQExponential, MultivariateNormal, MultivariateQExponential
12
+ from ..module import Module
13
+ from ._variational_strategy import _VariationalStrategy
14
+
15
+
16
+ def _select_lmc_coefficients(lmc_coefficients: torch.Tensor, indices: torch.LongTensor) -> torch.Tensor:
17
+ """
18
+ Given a list of indices for ... x N datapoints,
19
+ select the row from lmc_coefficient that corresponds to each datapoint
20
+
21
+ lmc_coefficients: torch.Tensor ... x num_latents x ... x num_tasks
22
+ indices: torch.Tesnor ... x N
23
+ """
24
+ batch_shape = torch.broadcast_shapes(lmc_coefficients.shape[:-1], indices.shape[:-1])
25
+
26
+ # We will use the left_interp helper to do the indexing
27
+ lmc_coefficients = lmc_coefficients.expand(*batch_shape, lmc_coefficients.shape[-1])[..., None]
28
+ indices = indices.expand(*batch_shape, indices.shape[-1])[..., None]
29
+ res = left_interp(
30
+ indices,
31
+ torch.ones(indices.shape, dtype=torch.long, device=indices.device),
32
+ lmc_coefficients,
33
+ ).squeeze(-1)
34
+ return res
35
+
36
+
37
+ class LMCVariationalStrategy(_VariationalStrategy):
38
+ r"""
39
+ LMCVariationalStrategy is an implementation of the "Linear Model of Coregionalization"
40
+ for multitask GPs (QEPs). This model assumes that there are :math:`Q` latent functions
41
+ :math:`\mathbf g(\cdot) = [g^{(1)}(\cdot), \ldots, g^{(q)}(\cdot)]`,
42
+ each of which is modelled by a GP (QEP).
43
+ The output functions (tasks) are linear combination of the latent functions:
44
+
45
+ .. math::
46
+
47
+ f_{\text{task } i}( \mathbf x) = \sum_{q=1}^Q a_i^{(q)} g^{(q)} ( \mathbf x )
48
+
49
+ LMCVariationalStrategy wraps an existing :obj:`~qpytorch.variational.VariationalStrategy`.
50
+ The output will either be a :obj:`~gpytorch.distributions.MultitaskMultivariateNormal` (:obj:`~qpytorch.distributions.MultitaskMultivariateQExponential`) distribution
51
+ (if we wish to evaluate all tasks for each input) or a :obj:`~gpytorch.distributions.MultivariateNormal` (:obj:`~qpytorch.distributions.MultivariateQExponential`)
52
+ (if we wish to evaluate a single task for each input).
53
+
54
+ The base variational strategy is assumed to operate on a multi-batch of GPs (QEPs), where one
55
+ of the batch dimensions corresponds to the latent function dimension.
56
+
57
+ .. note::
58
+
59
+ The batch shape of the base :obj:`~qpytorch.variational.VariationalStrategy` does not
60
+ necessarily have to correspond to the batch shape of the underlying GP (QEP) objects.
61
+
62
+ For example, if the base variational strategy has a batch shape of `[3]` (corresponding
63
+ to 3 latent functions), the GP (QEP) kernel object could have a batch shape of `[3]` or no
64
+ batch shape. This would correspond to each of the latent functions having different kernels
65
+ or the same kernel, respectivly.
66
+
67
+ Example:
68
+ >>> class LMCMultitaskGP(qpytorch.models.ApproximateGP):
69
+ >>> '''
70
+ >>> 3 latent functions
71
+ >>> 5 output dimensions (tasks)
72
+ >>> '''
73
+ >>> def __init__(self):
74
+ >>> # Each latent function shares the same inducing points
75
+ >>> # We'll have 32 inducing points, and let's assume the input dimensionality is 2
76
+ >>> inducing_points = torch.randn(32, 2)
77
+ >>>
78
+ >>> # The variational parameters have a batch_shape of [3] - for 3 latent functions
79
+ >>> variational_distribution = qpytorch.variational.MeanFieldVariationalDistribution(
80
+ >>> inducing_points.size(-1), batch_shape=torch.Size([3]),
81
+ >>> )
82
+ >>> variational_strategy = qpytorch.variational.LMCVariationalStrategy(
83
+ >>> qpytorch.variational.VariationalStrategy(
84
+ >>> self, inducing_points, variational_distribution, learn_inducing_locations=True,
85
+ >>> ),
86
+ >>> num_tasks=5,
87
+ >>> num_latents=3,
88
+ >>> latent_dim=-1,
89
+ >>> )
90
+ >>>
91
+ >>> # Each latent function has its own mean/kernel function
92
+ >>> super().__init__(variational_strategy)
93
+ >>> self.mean_module = qpytorch.means.ConstantMean(batch_shape=torch.Size([3]))
94
+ >>> self.covar_module = qpytorch.kernels.ScaleKernel(
95
+ >>> qpytorch.kernels.RBFKernel(batch_shape=torch.Size([3])),
96
+ >>> batch_shape=torch.Size([3]),
97
+ >>> )
98
+ >>>
99
+
100
+ :param base_variational_strategy: Base variational strategy
101
+ :param num_tasks: The total number of tasks (output functions)
102
+ :param num_latents: The total number of latent functions in each group
103
+ :param latent_dim: (Default: -1) Which batch dimension corresponds to the latent function batch.
104
+ **Must be negative indexed**
105
+ :param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ base_variational_strategy: _VariationalStrategy,
111
+ num_tasks: int,
112
+ num_latents: int = 1,
113
+ latent_dim: int = -1,
114
+ jitter_val: Optional[float] = None,
115
+ ):
116
+ Module.__init__(self)
117
+ self.base_variational_strategy = base_variational_strategy
118
+ self.num_tasks = num_tasks
119
+ batch_shape = self.base_variational_strategy._variational_distribution.batch_shape
120
+
121
+ # Check if no functions
122
+ if latent_dim >= 0:
123
+ raise RuntimeError(f"latent_dim must be a negative indexed batch dimension: got {latent_dim}.")
124
+ if not (batch_shape[latent_dim] == num_latents or batch_shape[latent_dim] == 1):
125
+ raise RuntimeError(
126
+ f"Mismatch in num_latents: got a variational distribution of batch shape {batch_shape}, "
127
+ f"expected the function dim {latent_dim} to be {num_latents}."
128
+ )
129
+ self.num_latents = num_latents
130
+ self.latent_dim = latent_dim
131
+
132
+ # Make the batch_shape
133
+ self.batch_shape = list(batch_shape)
134
+ del self.batch_shape[self.latent_dim]
135
+ self.batch_shape = torch.Size(self.batch_shape)
136
+
137
+ # LCM coefficients
138
+ lmc_coefficients = torch.randn(*batch_shape, self.num_tasks)
139
+ self.register_parameter("lmc_coefficients", torch.nn.Parameter(lmc_coefficients))
140
+
141
+ if jitter_val is None:
142
+ self.jitter_val = settings.variational_cholesky_jitter.value(
143
+ self.base_variational_strategy.inducing_points.dtype
144
+ )
145
+ else:
146
+ self.jitter_val = jitter_val
147
+
148
+ @property
149
+ def prior_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
150
+ return self.base_variational_strategy.prior_distribution
151
+
152
+ @property
153
+ def variational_distribution(self) -> Union[MultivariateNormal, MultivariateQExponential]:
154
+ return self.base_variational_strategy.variational_distribution
155
+
156
+ @property
157
+ def variational_params_initialized(self) -> bool:
158
+ return self.base_variational_strategy.variational_params_initialized
159
+
160
+ def kl_divergence(self) -> Tensor:
161
+ return super().kl_divergence().sum(dim=self.latent_dim)
162
+
163
+ def __call__(
164
+ self, x: Tensor, prior: bool = False, task_indices: Optional[LongTensor] = None, **kwargs
165
+ ) -> Union[MultitaskMultivariateNormal, MultitaskMultivariateQExponential, MultivariateNormal, MultivariateQExponential]:
166
+ r"""
167
+ Computes the variational (or prior) distribution
168
+ :math:`q( \mathbf f \mid \mathbf X)` (or :math:`p( \mathbf f \mid \mathbf X)`).
169
+ There are two modes:
170
+
171
+ 1. Compute **all tasks** for all inputs.
172
+ If this is the case, the task_indices attribute should be None.
173
+ The return type will be a (... x N x num_tasks)
174
+ :class:`~gpytorch.distributions.MultitaskMultivariateNormal` (:class:`~qpytorch.distributions.MultitaskMultivariateQExponential`).
175
+ 2. Compute **one task** per inputs.
176
+ If this is the case, the (... x N) task_indices tensor should contain
177
+ the indices of each input's assigned task.
178
+ The return type will be a (... x N)
179
+ :class:`~gpytorch.distributions.MultivariateNormal` (:class:`~qpytorch.distributions.MultivariateQExponential`).
180
+
181
+ :param x: (... x N x D) Input locations to evaluate variational strategy
182
+ :param task_indices: (Default: None) Task index associated with each input.
183
+ If this **is not** provided, then the returned distribution evaluates every input on every task
184
+ (returns :class:`~gpytorch.distributions.MultitaskMultivariateNormal` or :class:`~qpytorch.distributions.MultitaskMultivariateQExponential`).
185
+ If this **is** provided, then the returned distribution evaluates each input only on its assigned task.
186
+ (returns :class:`~gpytorch.distributions.MultivariateNormal` or :class:`~qpytorch.distributions.MultivariateQExponential`).
187
+ :param prior: (Default: False) If False, returns the variational distribution
188
+ :math:`q( \mathbf f \mid \mathbf X)`.
189
+ If True, returns the prior distribution
190
+ :math:`p( \mathbf f \mid \mathbf X)`.
191
+ :return: :math:`q( \mathbf f \mid \mathbf X)` (or the prior),
192
+ either for all tasks (if `task_indices == None`)
193
+ or for a specific task (if `task_indices != None`).
194
+ :rtype: ~gpytorch.distributions.MultitaskMultivariateNormal (~qpytorch.distributions.MultitaskMultivariateQExponential) (... x N x num_tasks)
195
+ or ~gpytorch.distributions.MultivariateNormal (~qpytorch.distributions.MultivariateQExponential) (... x N)
196
+ """
197
+ latent_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
198
+ num_batch = len(latent_dist.batch_shape)
199
+ latent_dim = num_batch + self.latent_dim
200
+
201
+ if task_indices is None:
202
+ num_dim = num_batch + len(latent_dist.event_shape)
203
+
204
+ # Every data point will get an output for each task
205
+ # Therefore, we will set up the lmc_coefficients shape for a matmul
206
+ lmc_coefficients = self.lmc_coefficients.expand(*latent_dist.batch_shape, self.lmc_coefficients.size(-1))
207
+
208
+ # Mean: ... x N x num_tasks
209
+ latent_mean = latent_dist.mean.permute(*range(0, latent_dim), *range(latent_dim + 1, num_dim), latent_dim)
210
+ mean = latent_mean @ lmc_coefficients.permute(
211
+ *range(0, latent_dim), *range(latent_dim + 1, num_dim - 1), latent_dim, -1
212
+ )
213
+
214
+ # Covar: ... x (N x num_tasks) x (N x num_tasks)
215
+ latent_covar = latent_dist.lazy_covariance_matrix
216
+ lmc_factor = RootLinearOperator(lmc_coefficients.unsqueeze(-1))
217
+ covar = KroneckerProductLinearOperator(latent_covar, lmc_factor).sum(latent_dim)
218
+ # Add a bit of jitter to make the covar PD
219
+ covar = covar.add_jitter(self.jitter_val)
220
+
221
+ # Done!
222
+ if isinstance(latent_dist, MultivariateNormal):
223
+ function_dist = MultitaskMultivariateNormal(mean, covar)
224
+ elif isinstance(latent_dist, MultivariateQExponential):
225
+ function_dist = MultitaskMultivariateQExponential(mean, covar, power=latent_dist.power)
226
+
227
+ else:
228
+ # Each data point will get a single output corresponding to a single task
229
+ # Therefore, we will select the appropriate lmc coefficients for each task
230
+ lmc_coefficients = _select_lmc_coefficients(self.lmc_coefficients, task_indices)
231
+
232
+ # Mean: ... x N
233
+ mean = (latent_dist.mean * lmc_coefficients).sum(latent_dim)
234
+
235
+ # Covar: ... x N x N
236
+ latent_covar = latent_dist.lazy_covariance_matrix
237
+ lmc_factor = RootLinearOperator(lmc_coefficients.unsqueeze(-1))
238
+ covar = (latent_covar * lmc_factor).sum(latent_dim)
239
+ # Add a bit of jitter to make the covar PD
240
+ covar = covar.add_jitter(self.jitter_val)
241
+
242
+ # Done!
243
+ if isinstance(latent_dist, MultivariateNormal):
244
+ function_dist = MultivariateNormal(mean, covar)
245
+ elif isinstance(latent_dist, MultivariateQExponential):
246
+ function_dist = MultivariateQExponential(mean, covar, power=latent_dist.power)
247
+
248
+ return function_dist
@@ -0,0 +1,58 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import torch
4
+ from linear_operator.operators import DiagLinearOperator
5
+
6
+ from ..distributions import MultivariateNormal, MultivariateQExponential
7
+ from ._variational_distribution import _VariationalDistribution
8
+
9
+
10
+ class MeanFieldVariationalDistribution(_VariationalDistribution):
11
+ """
12
+ A :obj:`~qpytorch.variational._VariationalDistribution` that is defined to be a multivariate normal (q-exponential) distribution
13
+ with a diagonal covariance matrix. This will not be as flexible/expressive as a
14
+ :obj:`~qpytorch.variational.CholeskyVariationalDistribution`.
15
+
16
+ :param int num_inducing_points: Size of the variational distribution. This implies that the variational mean
17
+ should be this size, and the variational covariance matrix should have this many rows and columns.
18
+ :param batch_shape: Specifies an optional batch size
19
+ for the variational parameters. This is useful for example when doing additive variational inference.
20
+ :type batch_shape: :obj:`torch.Size`, optional
21
+ :param float mean_init_std: (Default: 1e-3) Standard deviation of gaussian (q-exponential) noise to add to the mean initialization.
22
+ """
23
+
24
+ def __init__(self, num_inducing_points, batch_shape=torch.Size([]), mean_init_std=1e-3, **kwargs):
25
+ super().__init__(num_inducing_points=num_inducing_points, batch_shape=batch_shape, mean_init_std=mean_init_std)
26
+ mean_init = torch.zeros(num_inducing_points)
27
+ covar_init = torch.ones(num_inducing_points)
28
+ mean_init = mean_init.repeat(*batch_shape, 1)
29
+ covar_init = covar_init.repeat(*batch_shape, 1)
30
+
31
+ self.register_parameter(name="variational_mean", parameter=torch.nn.Parameter(mean_init))
32
+ self.register_parameter(name="_variational_stddev", parameter=torch.nn.Parameter(covar_init))
33
+
34
+ if 'power' in kwargs: self.power = kwargs.pop('power')
35
+
36
+ @property
37
+ def variational_stddev(self):
38
+ # TODO: if we don't multiply self._variational_stddev by a mask of one, Pyro models fail
39
+ # not sure where this bug is occuring (in Pyro or PyTorch)
40
+ # throwing this in as a hotfix for now - we should investigate later
41
+ mask = torch.ones_like(self._variational_stddev)
42
+ return self._variational_stddev.mul(mask).abs().clamp_min(1e-8)
43
+
44
+ def forward(self):
45
+ # TODO: if we don't multiply self._variational_stddev by a mask of one, Pyro models fail
46
+ # not sure where this bug is occuring (in Pyro or PyTorch)
47
+ # throwing this in as a hotfix for now - we should investigate later
48
+ mask = torch.ones_like(self._variational_stddev)
49
+ variational_covar = DiagLinearOperator(self._variational_stddev.mul(mask).pow(2))
50
+ if not hasattr(self, 'power'):
51
+ return MultivariateNormal(self.variational_mean, variational_covar)
52
+ else:
53
+ return MultivariateQExponential(self.variational_mean, variational_covar, power=self.power)
54
+
55
+ def initialize_variational_distribution(self, prior_dist):
56
+ self.variational_mean.data.copy_(prior_dist.mean)
57
+ self.variational_mean.data.add_(torch.randn_like(prior_dist.mean), alpha=self.mean_init_std)
58
+ self._variational_stddev.data.copy_(prior_dist.stddev)