qpytorch 0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of qpytorch might be problematic. Click here for more details.

Files changed (102) hide show
  1. qpytorch/__init__.py +327 -0
  2. qpytorch/constraints/__init__.py +3 -0
  3. qpytorch/distributions/__init__.py +21 -0
  4. qpytorch/distributions/delta.py +86 -0
  5. qpytorch/distributions/multitask_multivariate_qexponential.py +435 -0
  6. qpytorch/distributions/multivariate_qexponential.py +581 -0
  7. qpytorch/distributions/power.py +113 -0
  8. qpytorch/distributions/qexponential.py +153 -0
  9. qpytorch/functions/__init__.py +58 -0
  10. qpytorch/kernels/__init__.py +80 -0
  11. qpytorch/kernels/grid_interpolation_kernel.py +213 -0
  12. qpytorch/kernels/inducing_point_kernel.py +151 -0
  13. qpytorch/kernels/kernel.py +695 -0
  14. qpytorch/kernels/matern32_kernel_grad.py +155 -0
  15. qpytorch/kernels/matern52_kernel_grad.py +194 -0
  16. qpytorch/kernels/matern52_kernel_gradgrad.py +248 -0
  17. qpytorch/kernels/polynomial_kernel_grad.py +88 -0
  18. qpytorch/kernels/qexponential_symmetrized_kl_kernel.py +61 -0
  19. qpytorch/kernels/rbf_kernel_grad.py +125 -0
  20. qpytorch/kernels/rbf_kernel_gradgrad.py +186 -0
  21. qpytorch/kernels/rff_kernel.py +153 -0
  22. qpytorch/lazy/__init__.py +9 -0
  23. qpytorch/likelihoods/__init__.py +66 -0
  24. qpytorch/likelihoods/bernoulli_likelihood.py +75 -0
  25. qpytorch/likelihoods/beta_likelihood.py +76 -0
  26. qpytorch/likelihoods/gaussian_likelihood.py +472 -0
  27. qpytorch/likelihoods/laplace_likelihood.py +59 -0
  28. qpytorch/likelihoods/likelihood.py +437 -0
  29. qpytorch/likelihoods/likelihood_list.py +60 -0
  30. qpytorch/likelihoods/multitask_gaussian_likelihood.py +542 -0
  31. qpytorch/likelihoods/multitask_qexponential_likelihood.py +545 -0
  32. qpytorch/likelihoods/noise_models.py +184 -0
  33. qpytorch/likelihoods/qexponential_likelihood.py +494 -0
  34. qpytorch/likelihoods/softmax_likelihood.py +97 -0
  35. qpytorch/likelihoods/student_t_likelihood.py +90 -0
  36. qpytorch/means/__init__.py +23 -0
  37. qpytorch/metrics/__init__.py +17 -0
  38. qpytorch/mlls/__init__.py +53 -0
  39. qpytorch/mlls/_approximate_mll.py +79 -0
  40. qpytorch/mlls/deep_approximate_mll.py +30 -0
  41. qpytorch/mlls/deep_predictive_log_likelihood.py +32 -0
  42. qpytorch/mlls/exact_marginal_log_likelihood.py +96 -0
  43. qpytorch/mlls/gamma_robust_variational_elbo.py +106 -0
  44. qpytorch/mlls/inducing_point_kernel_added_loss_term.py +69 -0
  45. qpytorch/mlls/kl_qexponential_added_loss_term.py +41 -0
  46. qpytorch/mlls/leave_one_out_pseudo_likelihood.py +73 -0
  47. qpytorch/mlls/marginal_log_likelihood.py +48 -0
  48. qpytorch/mlls/predictive_log_likelihood.py +76 -0
  49. qpytorch/mlls/sum_marginal_log_likelihood.py +40 -0
  50. qpytorch/mlls/variational_elbo.py +77 -0
  51. qpytorch/models/__init__.py +72 -0
  52. qpytorch/models/approximate_qep.py +115 -0
  53. qpytorch/models/deep_qeps/__init__.py +22 -0
  54. qpytorch/models/deep_qeps/deep_qep.py +155 -0
  55. qpytorch/models/deep_qeps/dspp.py +114 -0
  56. qpytorch/models/exact_prediction_strategies.py +880 -0
  57. qpytorch/models/exact_qep.py +349 -0
  58. qpytorch/models/model_list.py +100 -0
  59. qpytorch/models/pyro/__init__.py +28 -0
  60. qpytorch/models/pyro/_pyro_mixin.py +57 -0
  61. qpytorch/models/pyro/distributions/__init__.py +5 -0
  62. qpytorch/models/pyro/pyro_qep.py +105 -0
  63. qpytorch/models/qep.py +7 -0
  64. qpytorch/models/qeplvm/__init__.py +6 -0
  65. qpytorch/models/qeplvm/bayesian_qeplvm.py +40 -0
  66. qpytorch/models/qeplvm/latent_variable.py +102 -0
  67. qpytorch/module.py +30 -0
  68. qpytorch/optim/__init__.py +5 -0
  69. qpytorch/priors/__init__.py +42 -0
  70. qpytorch/priors/qep_priors.py +81 -0
  71. qpytorch/test/__init__.py +22 -0
  72. qpytorch/test/base_likelihood_test_case.py +106 -0
  73. qpytorch/test/model_test_case.py +150 -0
  74. qpytorch/test/variational_test_case.py +400 -0
  75. qpytorch/utils/__init__.py +38 -0
  76. qpytorch/utils/warnings.py +37 -0
  77. qpytorch/variational/__init__.py +47 -0
  78. qpytorch/variational/_variational_distribution.py +61 -0
  79. qpytorch/variational/_variational_strategy.py +391 -0
  80. qpytorch/variational/additive_grid_interpolation_variational_strategy.py +90 -0
  81. qpytorch/variational/batch_decoupled_variational_strategy.py +256 -0
  82. qpytorch/variational/cholesky_variational_distribution.py +65 -0
  83. qpytorch/variational/ciq_variational_strategy.py +352 -0
  84. qpytorch/variational/delta_variational_distribution.py +41 -0
  85. qpytorch/variational/grid_interpolation_variational_strategy.py +113 -0
  86. qpytorch/variational/independent_multitask_variational_strategy.py +114 -0
  87. qpytorch/variational/lmc_variational_strategy.py +248 -0
  88. qpytorch/variational/mean_field_variational_distribution.py +58 -0
  89. qpytorch/variational/multitask_variational_strategy.py +317 -0
  90. qpytorch/variational/natural_variational_distribution.py +152 -0
  91. qpytorch/variational/nearest_neighbor_variational_strategy.py +487 -0
  92. qpytorch/variational/orthogonally_decoupled_variational_strategy.py +128 -0
  93. qpytorch/variational/tril_natural_variational_distribution.py +130 -0
  94. qpytorch/variational/uncorrelated_multitask_variational_strategy.py +114 -0
  95. qpytorch/variational/unwhitened_variational_strategy.py +225 -0
  96. qpytorch/variational/variational_strategy.py +280 -0
  97. qpytorch/version.py +4 -0
  98. qpytorch-0.1.dist-info/LICENSE +21 -0
  99. qpytorch-0.1.dist-info/METADATA +177 -0
  100. qpytorch-0.1.dist-info/RECORD +102 -0
  101. qpytorch-0.1.dist-info/WHEEL +5 -0
  102. qpytorch-0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,435 @@
1
+ #!/usr/bin/env python3
2
+
3
+ import torch
4
+ from linear_operator import LinearOperator, to_linear_operator
5
+ from linear_operator.operators import (
6
+ BlockDiagLinearOperator,
7
+ BlockInterleavedLinearOperator,
8
+ CatLinearOperator,
9
+ DiagLinearOperator,
10
+ )
11
+
12
+ from .multivariate_qexponential import MultivariateQExponential
13
+
14
+
15
+ class MultitaskMultivariateQExponential(MultivariateQExponential):
16
+ """
17
+ Constructs a multi-output multivariate Q-Exponential random variable, based on mean and covariance
18
+ Can be multi-output multivariate, or a batch of multi-output multivariate Q-Exponential
19
+
20
+ Passing a matrix mean corresponds to a multi-output multivariate Q-Exponential
21
+ Passing a matrix mean corresponds to a batch of multivariate Q-Exponentials
22
+
23
+ :param torch.Tensor mean: An `n x t` or batch `b x n x t` matrix of means for the QEP distribution.
24
+ :param ~linear_operator.operators.LinearOperator covar: An `... x NT x NT` (batch) matrix.
25
+ covariance matrix of QEP distribution.
26
+ :param power: (default=2.0) (scalar) power of QEP distribution.
27
+ :param bool validate_args: (default=False) If True, validate `mean` and `covariance_matrix` arguments.
28
+ :param bool interleaved: (default=True) If True, covariance matrix is interpreted as block-diagonal w.r.t.
29
+ inter-task covariances for each observation. If False, it is interpreted as block-diagonal
30
+ w.r.t. inter-observation covariance for each task.
31
+ """
32
+
33
+ def __init__(self, mean, covariance_matrix, power=torch.tensor(2.0), validate_args=False, interleaved=True):
34
+ if not torch.is_tensor(mean) and not isinstance(mean, LinearOperator):
35
+ raise RuntimeError("The mean of a MultitaskMultivariateQExponential must be a Tensor or LinearOperator")
36
+
37
+ if not torch.is_tensor(covariance_matrix) and not isinstance(covariance_matrix, LinearOperator):
38
+ raise RuntimeError("The covariance of a MultitaskMultivariateQExponential must be a Tensor or LinearOperator")
39
+
40
+ if mean.dim() < 2:
41
+ raise RuntimeError("mean should be a matrix or a batch matrix (batch mode)")
42
+
43
+ # Ensure that shapes are broadcasted appropriately across the mean and covariance
44
+ # Means can have singleton dimensions for either the `n` or `t` dimensions
45
+ batch_shape = torch.broadcast_shapes(mean.shape[:-2], covariance_matrix.shape[:-2])
46
+ if mean.shape[-2:].numel() != covariance_matrix.size(-1):
47
+ if covariance_matrix.size(-1) % mean.shape[-2:].numel():
48
+ raise RuntimeError(
49
+ f"mean shape {mean.shape} is incompatible with covariance shape {covariance_matrix.shape}"
50
+ )
51
+ elif mean.size(-2) == 1:
52
+ mean = mean.expand(*batch_shape, covariance_matrix.size(-1) // mean.size(-1), mean.size(-1))
53
+ elif mean.size(-1) == 1:
54
+ mean = mean.expand(*batch_shape, mean.size(-2), covariance_matrix.size(-2) // mean.size(-2))
55
+ else:
56
+ raise RuntimeError(
57
+ f"mean shape {mean.shape} is incompatible with covariance shape {covariance_matrix.shape}"
58
+ )
59
+ else:
60
+ mean = mean.expand(*batch_shape, *mean.shape[-2:])
61
+
62
+ self._output_shape = mean.shape
63
+ # TODO: Instead of transpose / view operations, use a PermutationLinearOperator (see #539)
64
+ # to handle interleaving
65
+ self._interleaved = interleaved
66
+ if self._interleaved:
67
+ mean_qep = mean.reshape(*mean.shape[:-2], -1)
68
+ else:
69
+ mean_qep = mean.transpose(-1, -2).reshape(*mean.shape[:-2], -1)
70
+ super().__init__(mean=mean_qep, covariance_matrix=covariance_matrix, power=power, validate_args=validate_args)
71
+
72
+ @property
73
+ def base_sample_shape(self):
74
+ """
75
+ Returns the shape of a base sample (without batching) that is used to
76
+ generate a single sample.
77
+ """
78
+ base_sample_shape = self.event_shape
79
+ return base_sample_shape
80
+
81
+ @property
82
+ def event_shape(self):
83
+ return self._output_shape[-2:]
84
+
85
+ @classmethod
86
+ def from_batch_qep(cls, batch_qep, task_dim=-1):
87
+ """
88
+ Reinterprate a batch of multivariate q-exponential distributions as an (uncorrelated) multitask multivariate q-exponential
89
+ distribution.
90
+
91
+ :param ~qpytorch.distributions.MultivariateQExponential batch_qep: The base QEP distribution.
92
+ (This distribution should have at least one batch dimension).
93
+ :param int task_dim: Which batch dimension should be interpreted as the dimension for the independent tasks.
94
+ :returns: the uncorrelated multitask distribution
95
+ :rtype: qpytorch.distributions.MultitaskMultivariateQExponential
96
+
97
+ Example:
98
+ >>> # model is a qpytorch.models.VariationalQEP
99
+ >>> # likelihood is a qpytorch.likelihoods.Likelihood
100
+ >>> mean = torch.randn(4, 2, 3)
101
+ >>> covar_factor = torch.randn(4, 2, 3, 3)
102
+ >>> covar = covar_factor @ covar_factor.transpose(-1, -2)
103
+ >>> power = torch.tensor(1.0)
104
+ >>> qep = qpytorch.distributions.MultivariateQExponential(mean, covar, power)
105
+ >>> print(qep.event_shape, qep.batch_shape)
106
+ >>> # torch.Size([3]), torch.Size([4, 2])
107
+ >>>
108
+ >>> mqep = MultitaskMultivariateQExponential.from_batch_qep(qep, task_dim=-1)
109
+ >>> print(mqep.event_shape, mqep.batch_shape)
110
+ >>> # torch.Size([3, 2]), torch.Size([4])
111
+ """
112
+ orig_task_dim = task_dim
113
+ task_dim = task_dim if task_dim >= 0 else (len(batch_qep.batch_shape) + task_dim)
114
+ if task_dim < 0 or task_dim > len(batch_qep.batch_shape):
115
+ raise ValueError(
116
+ f"task_dim of {orig_task_dim} is incompatible with QEP batch shape of {batch_qep.batch_shape}"
117
+ )
118
+
119
+ num_dim = batch_qep.mean.dim()
120
+ res = cls(
121
+ mean=batch_qep.mean.permute(*range(0, task_dim), *range(task_dim + 1, num_dim), task_dim),
122
+ covariance_matrix=BlockInterleavedLinearOperator(batch_qep.lazy_covariance_matrix, block_dim=task_dim),
123
+ power=batch_qep.power
124
+ )
125
+ return res
126
+
127
+ @classmethod
128
+ def from_uncorrelated_qeps(cls, qeps):
129
+ """
130
+ Convert an iterable of QEPs into a :obj:`~qpytorch.distributions.MultitaskMultivariateQExponential`.
131
+ The resulting distribution will have ``len(qeps)`` tasks, and the tasks will be uncorrelated.
132
+
133
+ :param ~qpytorch.distributions.MultivariateQExponential qep: The base QEP distributions.
134
+ :returns: the uncorrelated multitask distribution
135
+ :rtype: qpytorch.distributions.MultitaskMultivariateQExponential
136
+
137
+ Example:
138
+ >>> # model is a qpytorch.models.VariationalQEP
139
+ >>> # likelihood is a qpytorch.likelihoods.Likelihood
140
+ >>> mean = torch.randn(4, 3)
141
+ >>> covar_factor = torch.randn(4, 3, 3)
142
+ >>> covar = covar_factor @ covar_factor.transpose(-1, -2)
143
+ >>> power = torch.tensor(1.0)
144
+ >>> qep1 = qpytorch.distributions.MultivariateQExponential(mean, covar, power)
145
+ >>>
146
+ >>> mean = torch.randn(4, 3)
147
+ >>> covar_factor = torch.randn(4, 3, 3)
148
+ >>> covar = covar_factor @ covar_factor.transpose(-1, -2)
149
+ >>> qep2 = qpytorch.distributions.MultivariateQExponential(mean, covar, power)
150
+ >>>
151
+ >>> mqep = MultitaskMultivariateQExponential.from_uncorrelated_qeps([qep1, qep2])
152
+ >>> print(mqep.event_shape, mqep.batch_shape)
153
+ >>> # torch.Size([3, 2]), torch.Size([4])
154
+ """
155
+ if len(qeps) < 2:
156
+ raise ValueError("Must provide at least 2 QEPs to form a MultitaskMultivariateQExponential")
157
+ if any(isinstance(qep, MultitaskMultivariateQExponential) for qep in qeps):
158
+ raise ValueError("Cannot accept MultitaskMultivariateQExponentials")
159
+ if not all(m.batch_shape == qeps[0].batch_shape for m in qeps[1:]):
160
+ batch_shape = torch.broadcast_shapes(*(m.batch_shape for m in qeps))
161
+ qeps = [qep.expand(batch_shape) for qep in qeps]
162
+ if not all(m.event_shape == qeps[0].event_shape for m in qeps[1:]):
163
+ raise ValueError("All MultivariateQExponentials must have the same event shape")
164
+ mean = torch.stack([qep.mean for qep in qeps], -1)
165
+ # TODO: To do the following efficiently, we don't want to evaluate the
166
+ # covariance matrices. Instead, we want to use the lazies directly in the
167
+ # BlockDiagLinearOperator. This will require implementing a new BatchLinearOperator:
168
+
169
+ # https://github.com/cornellius-gp/gpytorch/issues/468
170
+ covar_blocks_lazy = CatLinearOperator(
171
+ *[qep.lazy_covariance_matrix.unsqueeze(0) for qep in qeps], dim=0, output_device=mean.device
172
+ )
173
+ covar_lazy = BlockDiagLinearOperator(covar_blocks_lazy, block_dim=0)
174
+ return cls(mean=mean, covariance_matrix=covar_lazy, power=qeps[0].power, interleaved=False)
175
+
176
+ @classmethod
177
+ def from_repeated_qep(cls, qep, num_tasks):
178
+ """
179
+ Convert a single QEP into a :obj:`~qpytorch.distributions.MultitaskMultivariateQExponential`,
180
+ where each task shares the same mean and covariance.
181
+
182
+ :param ~qpytorch.distributions.MultivariateQExponential qep: The base QEP distribution.
183
+ :param int num_tasks: How many tasks to create.
184
+ :returns: the uncorrelated multitask distribution
185
+ :rtype: qpytorch.distributions.MultitaskMultivariateQExponential
186
+
187
+ Example:
188
+ >>> # model is a qpytorch.models.VariationalQEP
189
+ >>> # likelihood is a qpytorch.likelihoods.Likelihood
190
+ >>> mean = torch.randn(4, 3)
191
+ >>> covar_factor = torch.randn(4, 3, 3)
192
+ >>> covar = covar_factor @ covar_factor.transpose(-1, -2)
193
+ >>> qep = qpytorch.distributions.MultivariateQExponential(mean, covar)
194
+ >>> print(qep.event_shape, qep.batch_shape)
195
+ >>> # torch.Size([3]), torch.Size([4])
196
+ >>>
197
+ >>> mqep = MultitaskMultivariateQExponential.from_repeated_qep(qep, num_tasks=2)
198
+ >>> print(mqep.event_shape, mqep.batch_shape)
199
+ >>> # torch.Size([3, 2]), torch.Size([4])
200
+ """
201
+ return cls.from_batch_qep(qep.expand(torch.Size([num_tasks]) + qep.batch_shape), task_dim=0)
202
+
203
+ def expand(self, batch_size):
204
+ new_mean = self.mean.expand(torch.Size(batch_size) + self.mean.shape[-2:])
205
+ new_covar = self._covar.expand(torch.Size(batch_size) + self._covar.shape[-2:])
206
+ res = self.__class__(new_mean, new_covar, power=self.power, interleaved=self._interleaved)
207
+ return res
208
+
209
+ def get_base_samples(self, sample_shape=torch.Size(), **kwargs):
210
+ base_samples = super().get_base_samples(sample_shape, **kwargs)
211
+ if not self._interleaved:
212
+ # flip shape of last two dimensions
213
+ new_shape = sample_shape + self._output_shape[:-2] + self._output_shape[:-3:-1]
214
+ return base_samples.view(new_shape).transpose(-1, -2).contiguous()
215
+ return base_samples.view(*sample_shape, *self._output_shape)
216
+
217
+ def log_prob(self, value):
218
+ if not self._interleaved:
219
+ # flip shape of last two dimensions
220
+ new_shape = value.shape[:-2] + value.shape[:-3:-1]
221
+ value = value.view(new_shape).transpose(-1, -2).contiguous()
222
+ return super().log_prob(value.reshape(*value.shape[:-2], -1))
223
+
224
+ @property
225
+ def mean(self):
226
+ mean = super().mean
227
+ if not self._interleaved:
228
+ # flip shape of last two dimensions
229
+ new_shape = self._output_shape[:-2] + self._output_shape[:-3:-1]
230
+ return mean.view(new_shape).transpose(-1, -2).contiguous()
231
+ return mean.view(self._output_shape)
232
+
233
+ @property
234
+ def num_tasks(self):
235
+ return self._output_shape[-1]
236
+
237
+ def rsample(self, sample_shape=torch.Size(), base_samples=None, **kwargs):
238
+ if base_samples is not None:
239
+ # Make sure that the base samples agree with the distribution
240
+ mean_shape = self.mean.shape
241
+ base_sample_shape = base_samples.shape[-self.mean.ndimension() :]
242
+ if mean_shape != base_sample_shape:
243
+ raise RuntimeError(
244
+ "The shape of base_samples (minus sample shape dimensions) should agree with the shape "
245
+ "of self.mean. Expected ...{} but got {}".format(mean_shape, base_sample_shape)
246
+ )
247
+ sample_shape = base_samples.shape[: -self.mean.ndimension()]
248
+ base_samples = base_samples.view(*sample_shape, *self.loc.shape)
249
+
250
+ samples = super().rsample(sample_shape=sample_shape, base_samples=base_samples, **kwargs)
251
+ if not self._interleaved:
252
+ # flip shape of last two dimensions
253
+ new_shape = sample_shape + self._output_shape[:-2] + self._output_shape[:-3:-1]
254
+ return samples.view(new_shape).transpose(-1, -2).contiguous()
255
+ return samples.view(sample_shape + self._output_shape)
256
+
257
+ def to_data_uncorrelated_dist(self, jitter_val=1e-4):
258
+ """
259
+ Convert a multitask QEP into a batched (non-multitask) QEPs
260
+ The result retains the intertask covariances, but gets rid of the inter-data covariances.
261
+ The resulting distribution will have ``len(qeps)`` tasks, and the tasks will be uncorrelated.
262
+
263
+ :returns: the bached data-uncorrelated QEP
264
+ :rtype: qpytorch.distributions.MultivariateQExponential
265
+ """
266
+ # Create batch distribution where all data are independent, but the tasks are dependent
267
+ full_covar = self.lazy_covariance_matrix
268
+ num_data, num_tasks = self.mean.shape[-2:]
269
+ if self._interleaved:
270
+ data_indices = torch.arange(0, num_data * num_tasks, num_tasks, device=full_covar.device).view(-1, 1, 1)
271
+ task_indices = torch.arange(num_tasks, device=full_covar.device)
272
+ else:
273
+ data_indices = torch.arange(num_data, device=full_covar.device).view(-1, 1, 1)
274
+ task_indices = torch.arange(0, num_data * num_tasks, num_data, device=full_covar.device)
275
+ task_covars = full_covar[
276
+ ..., data_indices + task_indices.unsqueeze(-2), data_indices + task_indices.unsqueeze(-1)
277
+ ]
278
+ return MultivariateQExponential(self.mean, to_linear_operator(task_covars).add_jitter(jitter_val=jitter_val), self.power)
279
+
280
+ # to_data_independent_dist = to_data_uncorrelated_dist # alias to the same function with a more appropriate name
281
+
282
+ @property
283
+ def variance(self):
284
+ var = super().variance
285
+ if not self._interleaved:
286
+ # flip shape of last two dimensions
287
+ new_shape = self._output_shape[:-2] + self._output_shape[:-3:-1]
288
+ return var.view(new_shape).transpose(-1, -2).contiguous()
289
+ return var.view(self._output_shape)
290
+
291
+ def __getitem__(self, idx) -> MultivariateQExponential:
292
+ """
293
+ Constructs a new MultivariateQExponential that represents a random variable
294
+ modified by an indexing operation.
295
+
296
+ The mean and covariance matrix arguments are indexed accordingly.
297
+
298
+ :param Any idx: Index to apply to the mean. The covariance matrix is indexed accordingly.
299
+ :returns: If indices specify a slice for samples and tasks, returns a
300
+ MultitaskMultivariateQExponential, else returns a MultivariateQExponential.
301
+ """
302
+
303
+ # Normalize index to a tuple
304
+ if not isinstance(idx, tuple):
305
+ idx = (idx,)
306
+
307
+ if ... in idx:
308
+ # Replace ellipsis '...' with explicit indices
309
+ ellipsis_location = idx.index(...)
310
+ if ... in idx[ellipsis_location + 1 :]:
311
+ raise IndexError("Only one ellipsis '...' is supported!")
312
+ prefix = idx[:ellipsis_location]
313
+ suffix = idx[ellipsis_location + 1 :]
314
+ infix_length = self.mean.dim() - len(prefix) - len(suffix)
315
+ if infix_length < 0:
316
+ raise IndexError(f"Index {idx} has too many dimensions")
317
+ idx = prefix + (slice(None),) * infix_length + suffix
318
+ elif len(idx) == self.mean.dim() - 1:
319
+ # Normalize indices ignoring the task-index to include it
320
+ idx = idx + (slice(None),)
321
+
322
+ new_mean = self.mean[idx]
323
+
324
+ # We now create a covariance matrix appropriate for new_mean
325
+ if len(idx) <= self.mean.dim() - 2:
326
+ # We are only indexing the batch dimensions in this case
327
+ return MultitaskMultivariateQExponential(
328
+ mean=new_mean,
329
+ covariance_matrix=self.lazy_covariance_matrix[idx],
330
+ power=self.power,
331
+ interleaved=self._interleaved,
332
+ )
333
+ elif len(idx) > self.mean.dim():
334
+ raise IndexError(f"Index {idx} has too many dimensions")
335
+ else:
336
+ # We have an index that extends over all dimensions
337
+ batch_idx = idx[:-2]
338
+ if self._interleaved:
339
+ row_idx = idx[-2]
340
+ col_idx = idx[-1]
341
+ num_rows = self._output_shape[-2]
342
+ num_cols = self._output_shape[-1]
343
+ else:
344
+ row_idx = idx[-1]
345
+ col_idx = idx[-2]
346
+ num_rows = self._output_shape[-1]
347
+ num_cols = self._output_shape[-2]
348
+
349
+ if isinstance(row_idx, int) and isinstance(col_idx, int):
350
+ # Single sample with single task
351
+ row_idx = _normalize_index(row_idx, num_rows)
352
+ col_idx = _normalize_index(col_idx, num_cols)
353
+ new_cov = DiagLinearOperator(
354
+ self.lazy_covariance_matrix.diagonal()[batch_idx + (row_idx * num_cols + col_idx,)]
355
+ )
356
+ return MultivariateQExponential(mean=new_mean, covariance_matrix=new_cov, power=self.power)
357
+ elif isinstance(row_idx, int) and isinstance(col_idx, slice):
358
+ # A block of the covariance matrix
359
+ row_idx = _normalize_index(row_idx, num_rows)
360
+ col_idx = _normalize_slice(col_idx, num_cols)
361
+ new_slice = slice(
362
+ col_idx.start + row_idx * num_cols,
363
+ col_idx.stop + row_idx * num_cols,
364
+ col_idx.step,
365
+ )
366
+ new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)]
367
+ return MultivariateQExponential(mean=new_mean, covariance_matrix=new_cov, power=self.power)
368
+ elif isinstance(row_idx, slice) and isinstance(col_idx, int):
369
+ # A block of the reversely interleaved covariance matrix
370
+ row_idx = _normalize_slice(row_idx, num_rows)
371
+ col_idx = _normalize_index(col_idx, num_cols)
372
+ new_slice = slice(row_idx.start + col_idx, row_idx.stop * num_cols + col_idx, row_idx.step * num_cols)
373
+ new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)]
374
+ return MultivariateQExponential(mean=new_mean, covariance_matrix=new_cov, power=self.power)
375
+ elif (
376
+ isinstance(row_idx, slice)
377
+ and isinstance(col_idx, slice)
378
+ and row_idx == col_idx == slice(None, None, None)
379
+ ):
380
+ new_cov = self.lazy_covariance_matrix[batch_idx]
381
+ return MultitaskMultivariateQExponential(
382
+ mean=new_mean,
383
+ covariance_matrix=new_cov,
384
+ power=self.power,
385
+ interleaved=self._interleaved,
386
+ validate_args=False,
387
+ )
388
+ elif isinstance(row_idx, slice) or isinstance(col_idx, slice):
389
+ # slice x slice or indices x slice or slice x indices
390
+ if isinstance(row_idx, slice):
391
+ row_idx = torch.arange(num_rows)[row_idx]
392
+ if isinstance(col_idx, slice):
393
+ col_idx = torch.arange(num_cols)[col_idx]
394
+ row_grid, col_grid = torch.meshgrid(row_idx, col_idx, indexing="ij")
395
+ indices = (row_grid * num_cols + col_grid).reshape(-1)
396
+ new_cov = self.lazy_covariance_matrix[batch_idx + (indices,)][..., indices]
397
+ return MultitaskMultivariateQExponential(
398
+ mean=new_mean, covariance_matrix=new_cov, power=self.power, interleaved=self._interleaved, validate_args=False
399
+ )
400
+ else:
401
+ # row_idx and col_idx have pairs of indices
402
+ indices = row_idx * num_cols + col_idx
403
+ new_cov = self.lazy_covariance_matrix[batch_idx + (indices,)][..., indices]
404
+ return MultivariateQExponential(
405
+ mean=new_mean,
406
+ covariance_matrix=new_cov,
407
+ power=self.power
408
+ )
409
+
410
+ def __repr__(self) -> str:
411
+ return f"MultitaskMultivariateQExponential(mean shape: {self._output_shape})"
412
+
413
+
414
+ def _normalize_index(i: int, dim_size: int) -> int:
415
+ if i < 0:
416
+ return dim_size + i
417
+ else:
418
+ return i
419
+
420
+
421
+ def _normalize_slice(s: slice, dim_size: int) -> slice:
422
+ start = s.start
423
+ if start is None:
424
+ start = 0
425
+ elif start < 0:
426
+ start = dim_size + start
427
+ stop = s.stop
428
+ if stop is None:
429
+ stop = dim_size
430
+ elif stop < 0:
431
+ stop = dim_size + stop
432
+ step = s.step
433
+ if step is None:
434
+ step = 1
435
+ return slice(start, stop, step)