gpjax 0.13.2__py3-none-any.whl → 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. gpjax/__init__.py +3 -3
  2. gpjax/citation.py +13 -0
  3. gpjax/dataset.py +1 -1
  4. gpjax/fit.py +1 -1
  5. gpjax/gps.py +273 -63
  6. gpjax/kernels/__init__.py +1 -1
  7. gpjax/kernels/base.py +2 -2
  8. gpjax/kernels/computations/__init__.py +1 -1
  9. gpjax/kernels/computations/base.py +1 -1
  10. gpjax/kernels/computations/constant_diagonal.py +1 -1
  11. gpjax/kernels/computations/dense.py +1 -1
  12. gpjax/kernels/computations/diagonal.py +1 -1
  13. gpjax/kernels/computations/eigen.py +1 -1
  14. gpjax/kernels/non_euclidean/__init__.py +1 -1
  15. gpjax/kernels/non_euclidean/graph.py +18 -6
  16. gpjax/kernels/non_euclidean/utils.py +1 -1
  17. gpjax/kernels/nonstationary/__init__.py +1 -1
  18. gpjax/kernels/nonstationary/arccosine.py +1 -1
  19. gpjax/kernels/nonstationary/linear.py +1 -1
  20. gpjax/kernels/nonstationary/polynomial.py +1 -1
  21. gpjax/kernels/stationary/__init__.py +1 -1
  22. gpjax/kernels/stationary/base.py +1 -1
  23. gpjax/kernels/stationary/matern12.py +1 -1
  24. gpjax/kernels/stationary/matern32.py +1 -1
  25. gpjax/kernels/stationary/matern52.py +1 -1
  26. gpjax/kernels/stationary/periodic.py +1 -1
  27. gpjax/kernels/stationary/powered_exponential.py +1 -1
  28. gpjax/kernels/stationary/rational_quadratic.py +1 -1
  29. gpjax/kernels/stationary/rbf.py +1 -1
  30. gpjax/kernels/stationary/utils.py +1 -1
  31. gpjax/kernels/stationary/white.py +1 -1
  32. gpjax/likelihoods.py +234 -0
  33. gpjax/mean_functions.py +2 -2
  34. gpjax/objectives.py +56 -1
  35. gpjax/parameters.py +8 -1
  36. gpjax/scan.py +1 -1
  37. gpjax/variational_families.py +129 -0
  38. {gpjax-0.13.2.dist-info → gpjax-0.13.4.dist-info}/METADATA +13 -13
  39. gpjax-0.13.4.dist-info/RECORD +52 -0
  40. gpjax-0.13.2.dist-info/RECORD +0 -52
  41. {gpjax-0.13.2.dist-info → gpjax-0.13.4.dist-info}/WHEEL +0 -0
  42. {gpjax-0.13.2.dist-info → gpjax-0.13.4.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/__init__.py CHANGED
@@ -38,9 +38,9 @@ from gpjax.fit import (
38
38
 
39
39
  __license__ = "MIT"
40
40
  __description__ = "Gaussian processes in JAX and Flax"
41
- __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
- __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.13.2"
41
+ __url__ = "https://github.com/thomaspinder/GPJax"
42
+ __contributors__ = "https://github.com/thomaspinder/GPJax/graphs/contributors"
43
+ __version__ = "0.13.4"
44
44
 
45
45
  __all__ = [
46
46
  "gps",
gpjax/citation.py CHANGED
@@ -23,6 +23,7 @@ from gpjax.kernels import (
23
23
  Matern32,
24
24
  Matern52,
25
25
  )
26
+ from gpjax.likelihoods import HeteroscedasticGaussian
26
27
 
27
28
  CitationType = Union[None, str, Dict[str, str]]
28
29
 
@@ -149,3 +150,15 @@ def _(tree) -> PaperCitation:
149
150
  booktitle="Advances in neural information processing systems",
150
151
  citation_type="article",
151
152
  )
153
+
154
+
155
+ @cite.register(HeteroscedasticGaussian)
156
+ def _(tree) -> PaperCitation:
157
+ return PaperCitation(
158
+ citation_key="lazaro2011variational",
159
+ authors="Lázaro-Gredilla, Miguel and Titsias, Michalis",
160
+ title="Variational heteroscedastic Gaussian process regression",
161
+ year="2011",
162
+ booktitle="Proceedings of the 28th International Conference on Machine Learning (ICML)",
163
+ citation_type="inproceedings",
164
+ )
gpjax/dataset.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
gpjax/fit.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2023 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
gpjax/gps.py CHANGED
@@ -4,7 +4,7 @@
4
4
  # you may not use this file except in compliance with the License.
5
5
  # You may obtain a copy of the License at
6
6
  #
7
- # http://www.apache.org/licenses/LICENSE-2.0
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
8
  #
9
9
  # Unless required by applicable law or agreed to in writing, software
10
10
  # distributed under the License is distributed on an "AS IS" BASIS,
@@ -13,10 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  # from __future__ import annotations
16
+
16
17
  from abc import abstractmethod
18
+ from typing import Literal
17
19
 
18
20
  import beartype.typing as tp
19
21
  from flax import nnx
22
+ import jax
20
23
  import jax.numpy as jnp
21
24
  import jax.random as jr
22
25
  from jaxtyping import (
@@ -29,16 +32,21 @@ from gpjax.distributions import GaussianDistribution
29
32
  from gpjax.kernels import RFF
30
33
  from gpjax.kernels.base import AbstractKernel
31
34
  from gpjax.likelihoods import (
35
+ AbstractHeteroscedasticLikelihood,
32
36
  AbstractLikelihood,
33
37
  Gaussian,
38
+ HeteroscedasticGaussian,
34
39
  NonGaussian,
35
40
  )
36
41
  from gpjax.linalg import (
37
42
  Dense,
43
+ Diagonal,
38
44
  psd,
39
45
  solve,
40
46
  )
41
- from gpjax.linalg.operations import lower_cholesky
47
+ from gpjax.linalg.operations import (
48
+ lower_cholesky,
49
+ )
42
50
  from gpjax.linalg.utils import add_jitter
43
51
  from gpjax.mean_functions import AbstractMeanFunction
44
52
  from gpjax.parameters import (
@@ -56,6 +64,7 @@ M = tp.TypeVar("M", bound=AbstractMeanFunction)
56
64
  L = tp.TypeVar("L", bound=AbstractLikelihood)
57
65
  NGL = tp.TypeVar("NGL", bound=NonGaussian)
58
66
  GL = tp.TypeVar("GL", bound=Gaussian)
67
+ HL = tp.TypeVar("HL", bound=AbstractHeteroscedasticLikelihood)
59
68
 
60
69
 
61
70
  class AbstractPrior(nnx.Module, tp.Generic[M, K]):
@@ -77,7 +86,12 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
77
86
  self.mean_function = mean_function
78
87
  self.jitter = jitter
79
88
 
80
- def __call__(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
89
+ def __call__(
90
+ self,
91
+ test_inputs: Num[Array, "N D"],
92
+ *,
93
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
94
+ ) -> GaussianDistribution:
81
95
  r"""Evaluate the Gaussian process at the given points.
82
96
 
83
97
  The output of this function is a
@@ -91,15 +105,27 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
91
105
 
92
106
  Args:
93
107
  test_inputs: Input locations where the GP should be evaluated.
108
+ return_covariance_type: Literal denoting whether to return the full covariance
109
+ of the joint predictive distribution at the test_inputs (dense)
110
+ or just the the standard-deviation of the predictive distribution at
111
+ the test_inputs.
94
112
 
95
113
  Returns:
96
114
  GaussianDistribution: A multivariate normal random variable representation
97
115
  of the Gaussian process.
98
116
  """
99
- return self.predict(test_inputs)
117
+ return self.predict(
118
+ test_inputs,
119
+ return_covariance_type=return_covariance_type,
120
+ )
100
121
 
101
122
  @abstractmethod
102
- def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
123
+ def predict(
124
+ self,
125
+ test_inputs: Num[Array, "N D"],
126
+ *,
127
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
128
+ ) -> GaussianDistribution:
103
129
  r"""Evaluate the predictive distribution.
104
130
 
105
131
  Compute the latent function's multivariate normal distribution for a
@@ -108,6 +134,10 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
108
134
 
109
135
  Args:
110
136
  test_inputs: Input locations where the GP should be evaluated.
137
+ return_covariance_type: Literal denoting whether to return the full covariance
138
+ of the joint predictive distribution at the test_inputs (dense)
139
+ or just the the standard-deviation of the predictive distribution at
140
+ the test_inputs.
111
141
 
112
142
  Returns:
113
143
  GaussianDistribution: A multivariate normal random variable representation
@@ -123,8 +153,8 @@ class Prior(AbstractPrior[M, K]):
123
153
  r"""A Gaussian process prior object.
124
154
 
125
155
  The GP is parameterised by a
126
- [mean](https://docs.jaxgaussianprocesses.com/api/mean_functions/)
127
- and [kernel](https://docs.jaxgaussianprocesses.com/api/kernels/base/)
156
+ [mean](https://docs.thomaspinder.com/api/mean_functions/)
157
+ and [kernel](https://docs.thomaspinder.com/api/kernels/base/)
128
158
  function.
129
159
 
130
160
  A Gaussian process prior parameterised by a mean function $m(\cdot)$ and a kernel
@@ -220,7 +250,12 @@ class Prior(AbstractPrior[M, K]):
220
250
  """
221
251
  return self.__mul__(other)
222
252
 
223
- def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
253
+ def predict(
254
+ self,
255
+ test_inputs: Num[Array, "N D"],
256
+ *,
257
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
258
+ ) -> GaussianDistribution:
224
259
  r"""Compute the predictive prior distribution for a given set of
225
260
  parameters. The output of this function is a function that computes
226
261
  a TFP distribution for a given set of inputs.
@@ -241,17 +276,43 @@ class Prior(AbstractPrior[M, K]):
241
276
  Args:
242
277
  test_inputs (Float[Array, "N D"]): The inputs at which to evaluate the
243
278
  prior distribution.
279
+ return_covariance_type: Literal denoting whether to return the full covariance
280
+ of the joint predictive distribution at the test_inputs (dense)
281
+ or just the the standard-deviation of the predictive distribution at
282
+ the test_inputs.
244
283
 
245
284
  Returns:
246
285
  GaussianDistribution: A multivariate normal random variable representation
247
286
  of the Gaussian process.
248
287
  """
288
+
289
+ def _return_full_covariance(
290
+ t: Num[Array, "N D"],
291
+ ) -> Dense:
292
+ Kxx = self.kernel.gram(t)
293
+ Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
294
+ Kxx = psd(Dense(Kxx_dense))
295
+ return Kxx
296
+
297
+ def _return_diagonal_covariance(
298
+ t: Num[Array, "N D"],
299
+ ) -> Dense:
300
+ Kxx = self.kernel.diagonal(t).diagonal
301
+ Kxx += self.jitter
302
+ Kxx = psd(Dense(Diagonal(Kxx).to_dense()))
303
+ return Kxx
304
+
249
305
  mean_at_test = self.mean_function(test_inputs)
250
- Kxx = self.kernel.gram(test_inputs)
251
- Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
252
- Kxx = psd(Dense(Kxx_dense))
306
+ cov = jax.lax.cond(
307
+ return_covariance_type == "dense",
308
+ _return_full_covariance,
309
+ _return_diagonal_covariance,
310
+ test_inputs,
311
+ )
253
312
 
254
- return GaussianDistribution(jnp.atleast_1d(mean_at_test.squeeze()), Kxx)
313
+ return GaussianDistribution(
314
+ loc=jnp.atleast_1d(mean_at_test.squeeze()), scale=cov
315
+ )
255
316
 
256
317
  def sample_approx(
257
318
  self,
@@ -329,7 +390,7 @@ P = tp.TypeVar("P", bound=AbstractPrior)
329
390
 
330
391
  #######################
331
392
  # GP Posteriors
332
- #######################
393
+ #######################from gpjax.linalg.operators import LinearOperator
333
394
  class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
334
395
  r"""Abstract Gaussian process posterior.
335
396
 
@@ -356,7 +417,11 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
356
417
  self.jitter = jitter
357
418
 
358
419
  def __call__(
359
- self, test_inputs: Num[Array, "N D"], train_data: Dataset
420
+ self,
421
+ test_inputs: Num[Array, "N D"],
422
+ train_data: Dataset,
423
+ *,
424
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
360
425
  ) -> GaussianDistribution:
361
426
  r"""Evaluate the Gaussian process posterior at the given points.
362
427
 
@@ -372,16 +437,28 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
372
437
  Args:
373
438
  test_inputs: Input locations where the GP should be evaluated.
374
439
  train_data: Training dataset to condition on.
440
+ return_covariance_type: Literal denoting whether to return the full covariance
441
+ of the joint predictive distribution at the test_inputs (dense)
442
+ or just the the standard-deviation of the predictive distribution at
443
+ the test_inputs.
375
444
 
376
445
  Returns:
377
446
  GaussianDistribution: A multivariate normal random variable representation
378
447
  of the Gaussian process.
379
448
  """
380
- return self.predict(test_inputs, train_data)
449
+ return self.predict(
450
+ test_inputs,
451
+ train_data,
452
+ return_covariance_type=return_covariance_type,
453
+ )
381
454
 
382
455
  @abstractmethod
383
456
  def predict(
384
- self, test_inputs: Num[Array, "N D"], train_data: Dataset
457
+ self,
458
+ test_inputs: Num[Array, "N D"],
459
+ train_data: Dataset,
460
+ *,
461
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
385
462
  ) -> GaussianDistribution:
386
463
  r"""Compute the latent function's multivariate normal distribution for a
387
464
  given set of parameters. For any class inheriting the `AbstractPosterior` class,
@@ -390,6 +467,10 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
390
467
  Args:
391
468
  test_inputs: Input locations where the GP should be evaluated.
392
469
  train_data: Training dataset to condition on.
470
+ return_covariance_type: Literal denoting whether to return the full covariance
471
+ of the joint predictive distribution at the test_inputs (dense)
472
+ or just the the standard-deviation of the predictive distribution at
473
+ the test_inputs.
393
474
 
394
475
  Returns:
395
476
  GaussianDistribution: A multivariate normal random variable representation
@@ -398,6 +479,22 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
398
479
  raise NotImplementedError
399
480
 
400
481
 
482
+ class LatentPosterior(AbstractPosterior[P, L]):
483
+ r"""A posterior shell used to expose prior structure without inference."""
484
+
485
+ def predict(
486
+ self,
487
+ test_inputs: Num[Array, "N D"],
488
+ train_data: Dataset,
489
+ *,
490
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
491
+ ) -> GaussianDistribution:
492
+ raise NotImplementedError(
493
+ "LatentPosteriors are a lightweight wrapper for priors and do not "
494
+ "implement predictive distributions. Use a variational family for inference."
495
+ )
496
+
497
+
401
498
  class ConjugatePosterior(AbstractPosterior[P, GL]):
402
499
  r"""A Conjuate Gaussian process posterior object.
403
500
 
@@ -442,8 +539,10 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
442
539
 
443
540
  def predict(
444
541
  self,
445
- test_inputs: Num[Array, "N D"],
542
+ test_inputs: Num[Array, "M D"],
446
543
  train_data: Dataset,
544
+ *,
545
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
447
546
  ) -> GaussianDistribution:
448
547
  r"""Query the predictive posterior distribution.
449
548
 
@@ -454,13 +553,13 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
454
553
 
455
554
  The predictive distribution of a conjugate GP is given by
456
555
  $$
457
- p(\mathbf{f}^{\star}\mid \mathbf{y}) & = \int p(\mathbf{f}^{\star} \mathbf{f} \mid \mathbf{y})\\
458
- & =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}_{\mid \mathbf{y}}, \boldsymbol{\Sigma}_{\mid \mathbf{y}}
556
+ p(\mathbf{f}^{\star}\mid \mathbf{y}) & = \int p(\mathbf{f}^{\star} \mathbf{f} \mid \mathbf{y})\\
557
+ & =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}_{\mid \mathbf{y}}, \boldsymbol{\Sigma}_{\mid \mathbf{y}}
459
558
  $$
460
559
  where
461
560
  $$
462
- \boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \\
463
- \boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\mathbf{x}^{\star}, \mathbf{x}^{\star\prime}) -k(\mathbf{x}^{\star}, \mathbf{x})\left( k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_n \right)^{-1}k(\mathbf{x}, \mathbf{x}^{\star}).
561
+ \boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \\
562
+ \boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\mathbf{x}^{\star}, \mathbf{x}^{\star\prime}) -k(\mathbf{x}^{\star}, \mathbf{x})\left( k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_n \right)^{-1}k(\mathbf{x}, \mathbf{x}^{\star}).
464
563
  $$
465
564
 
466
565
  The conditioning set is a GPJax `Dataset` object, whilst predictions
@@ -486,44 +585,65 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
486
585
  predictive distribution is evaluated.
487
586
  train_data (Dataset): A `gpx.Dataset` object that contains the input and
488
587
  output data used for training dataset.
588
+ return_covariance_type: Literal denoting whether to return the full covariance
589
+ of the joint predictive distribution at the test_inputs (dense)
590
+ or just the the standard-deviation of the predictive distribution at
591
+ the test_inputs.
489
592
 
490
593
  Returns:
491
594
  GaussianDistribution: A function that accepts an input array and
492
595
  returns the predictive distribution as a `GaussianDistribution`.
493
596
  """
494
- # Unpack training data
495
- x, y = train_data.X, train_data.y
496
-
497
- # Unpack test inputs
498
- t = test_inputs
499
-
597
+ x = train_data.X
598
+ y = train_data.y
500
599
  # Observation noise o²
501
- obs_noise = self.likelihood.obs_stddev.value**2
600
+ obs_noise = jnp.square(self.likelihood.obs_stddev.value)
502
601
  mx = self.prior.mean_function(x)
503
-
504
602
  # Precompute Gram matrix, Kxx, at training inputs, x
505
603
  Kxx = self.prior.kernel.gram(x)
506
- Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
507
- Kxx = Dense(Kxx_dense)
604
+ Kxx = add_jitter(Kxx.to_dense(), self.jitter)
508
605
 
509
- Sigma_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * obs_noise
606
+ Sigma_dense = Kxx + jnp.eye(Kxx.shape[0]) * obs_noise
510
607
  Sigma = psd(Dense(Sigma_dense))
511
608
  L_sigma = lower_cholesky(Sigma)
512
609
 
513
- mean_t = self.prior.mean_function(t)
514
- Ktt = self.prior.kernel.gram(t)
515
- Kxt = self.prior.kernel.cross_covariance(x, t)
610
+ Kxt = self.prior.kernel.cross_covariance(x, test_inputs)
516
611
 
517
612
  L_inv_Kxt = solve(L_sigma, Kxt)
518
613
  L_inv_y_diff = solve(L_sigma, y - mx)
519
614
 
615
+ mean_t = self.prior.mean_function(test_inputs)
520
616
  mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff)
521
617
 
522
- covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt)
523
- covariance = add_jitter(covariance, self.prior.jitter)
524
- covariance = psd(Dense(covariance))
525
-
526
- return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
618
+ def _return_full_covariance(
619
+ L_inv_Kxt: Num[Array, "N M"],
620
+ t: Num[Array, "M D"],
621
+ ) -> Dense:
622
+ Ktt = self.prior.kernel.gram(t)
623
+ covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt)
624
+ covariance = add_jitter(covariance, self.prior.jitter)
625
+ covariance = psd(Dense(covariance))
626
+ return covariance
627
+
628
+ def _return_diagonal_covariance(
629
+ L_inv_Kxt: Num[Array, "N M"],
630
+ t: Num[Array, "M D"],
631
+ ) -> Dense:
632
+ Ktt = self.prior.kernel.diagonal(t).diagonal
633
+ covariance = Ktt - jnp.einsum("ij, ji->i", L_inv_Kxt.T, L_inv_Kxt)
634
+ covariance += self.prior.jitter
635
+ covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze()))))
636
+ return covariance
637
+
638
+ cov = jax.lax.cond(
639
+ return_covariance_type == "dense",
640
+ _return_full_covariance,
641
+ _return_diagonal_covariance,
642
+ L_inv_Kxt,
643
+ test_inputs,
644
+ )
645
+
646
+ return GaussianDistribution(loc=jnp.atleast_1d(mean.squeeze()), scale=cov)
527
647
 
528
648
  def sample_approx(
529
649
  self,
@@ -567,7 +687,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
567
687
 
568
688
  Returns:
569
689
  FunctionalSample: A function representing an approximate sample from the Gaussian
570
- process prior.
690
+ process prior.
571
691
  """
572
692
  if (not isinstance(num_samples, int)) or num_samples <= 0:
573
693
  raise ValueError("num_samples must be a positive integer")
@@ -586,7 +706,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
586
706
  canonical_weights = solve(
587
707
  Sigma,
588
708
  y + eps - jnp.inner(Phi, fourier_weights),
589
- ) # [N, B]
709
+ ) # [N, B]
590
710
 
591
711
  def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
592
712
  fourier_features = fourier_feature_fn(test_inputs)
@@ -648,7 +768,11 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
648
768
  self.key = key
649
769
 
650
770
  def predict(
651
- self, test_inputs: Num[Array, "N D"], train_data: Dataset
771
+ self,
772
+ test_inputs: Num[Array, "M D"],
773
+ train_data: Dataset,
774
+ *,
775
+ return_covariance_type: Literal["dense", "diagonal"] = "dense",
652
776
  ) -> GaussianDistribution:
653
777
  r"""Query the predictive posterior distribution.
654
778
 
@@ -660,50 +784,112 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
660
784
  transformed through the likelihood function's inverse link function.
661
785
 
662
786
  Args:
663
- train_data (Dataset): A `gpx.Dataset` object that contains the input
664
- and output data used for training dataset.
787
+ test_inputs (Num[Array, "N D"]): A Jax array of test inputs at which the
788
+ predictive distribution is evaluated.
789
+ train_data (Dataset): A `gpx.Dataset` object that contains the input
790
+ and output data used for training dataset.
791
+ return_covariance_type: Literal denoting whether to return the full covariance
792
+ of the joint predictive distribution at the test_inputs (dense)
793
+ or just the the standard-deviation of the predictive distribution at
794
+ the test_inputs.
665
795
 
666
796
  Returns:
667
797
  GaussianDistribution: A function that accepts an
668
798
  input array and returns the predictive distribution as
669
799
  a `dx.Distribution`.
670
800
  """
671
- # Unpack training data
672
801
  x = train_data.X
673
-
674
- # Unpack mean function and kernel
802
+ t = test_inputs
675
803
  mean_function = self.prior.mean_function
676
804
  kernel = self.prior.kernel
677
805
 
678
- # Precompute lower triangular of Gram matrix, Lx, at training inputs, x
806
+ # Precompute lower triangular of Gram matrix
679
807
  Kxx = kernel.gram(x)
680
808
  Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter)
681
809
  Kxx = psd(Dense(Kxx_dense))
682
810
  Lx = lower_cholesky(Kxx)
683
811
 
684
- # Unpack test inputs
685
- t = test_inputs
686
-
687
- # Compute terms of the posterior predictive distribution
688
- Ktx = kernel.cross_covariance(t, x)
689
- Ktt = kernel.gram(t)
690
- mean_t = mean_function(t)
691
-
812
+ Kxt = kernel.cross_covariance(x, t)
692
813
  # Lx⁻¹ Kxt
693
- Lx_inv_Kxt = solve(Lx, Ktx.T)
814
+ Lx_inv_Kxt = solve(Lx, Kxt)
694
815
 
816
+ mean_t = mean_function(t)
695
817
  # Whitened function values, wx, corresponding to the inputs, x
696
818
  wx = self.latent.value
697
819
 
698
820
  # μt + Ktx Lx⁻¹ wx
699
821
  mean = mean_t + jnp.matmul(Lx_inv_Kxt.T, wx)
700
822
 
701
- # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
702
- covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
703
- covariance = add_jitter(covariance, self.prior.jitter)
704
- covariance = psd(Dense(covariance))
823
+ def _return_full_covariance(
824
+ Lx_inv_Kxt: Num[Array, "N M"],
825
+ t: Num[Array, "M D"],
826
+ ) -> Dense:
827
+ Ktt = kernel.gram(t)
828
+ covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
829
+ covariance = add_jitter(covariance, self.prior.jitter)
830
+ covariance = psd(Dense(covariance))
831
+
832
+ return covariance
833
+
834
+ def _return_diagonal_covariance(
835
+ Lx_inv_Kxt: Num[Array, "N M"],
836
+ t: Num[Array, "M D"],
837
+ ) -> Dense:
838
+ Ktt = kernel.diagonal(t).diagonal
839
+ covariance = Ktt - jnp.einsum("ij, ji->i", Lx_inv_Kxt.T, Lx_inv_Kxt)
840
+ covariance += self.prior.jitter
841
+ # It would be nice to return a Diagonal here, but the pytree needs
842
+ # to be the same for both cond branches and the other branch needs
843
+ # to return a Dense.
844
+ # They are both LinearOperators, but they inherit from that class
845
+ # and hence are not the same pytree anymore.
846
+ covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze()))))
847
+
848
+ return covariance
849
+
850
+ cov = jax.lax.cond(
851
+ return_covariance_type == "dense",
852
+ _return_full_covariance,
853
+ _return_diagonal_covariance,
854
+ Lx_inv_Kxt,
855
+ test_inputs,
856
+ )
857
+
858
+ return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), cov)
859
+
860
+
861
+ class HeteroscedasticPosterior(LatentPosterior[P, HL]):
862
+ r"""Posterior shell for heteroscedastic likelihoods.
863
+
864
+ The posterior retains both the signal and noise priors; inference is delegated
865
+ to variational families and specialised objectives.
866
+ """
867
+
868
+ def __init__(
869
+ self,
870
+ prior: AbstractPrior[M, K],
871
+ likelihood: HL,
872
+ jitter: float = 1e-6,
873
+ ):
874
+ if likelihood.noise_prior is None:
875
+ raise ValueError("Heteroscedastic likelihoods require a noise_prior.")
876
+ super().__init__(prior=prior, likelihood=likelihood, jitter=jitter)
877
+ self.noise_prior = likelihood.noise_prior
878
+ self.noise_posterior = LatentPosterior(
879
+ prior=self.noise_prior, likelihood=likelihood, jitter=jitter
880
+ )
881
+
882
+
883
+ class ChainedPosterior(HeteroscedasticPosterior[P, HL]):
884
+ r"""Posterior routed for heteroscedastic likelihoods using chained bounds."""
705
885
 
706
- return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
886
+ def __init__(
887
+ self,
888
+ prior: AbstractPrior[M, K],
889
+ likelihood: HL,
890
+ jitter: float = 1e-6,
891
+ ):
892
+ super().__init__(prior=prior, likelihood=likelihood, jitter=jitter)
707
893
 
708
894
 
709
895
  #######################
@@ -721,6 +907,18 @@ def construct_posterior( # noqa: F811
721
907
  ) -> NonConjugatePosterior[P, NGL]: ...
722
908
 
723
909
 
910
+ @tp.overload
911
+ def construct_posterior( # noqa: F811
912
+ prior: P, likelihood: HeteroscedasticGaussian
913
+ ) -> HeteroscedasticPosterior[P, HeteroscedasticGaussian]: ...
914
+
915
+
916
+ @tp.overload
917
+ def construct_posterior( # noqa: F811
918
+ prior: P, likelihood: AbstractHeteroscedasticLikelihood
919
+ ) -> ChainedPosterior[P, AbstractHeteroscedasticLikelihood]: ...
920
+
921
+
724
922
  def construct_posterior(prior, likelihood): # noqa: F811
725
923
  r"""Utility function for constructing a posterior object from a prior and
726
924
  likelihood. The function will automatically select the correct posterior
@@ -740,6 +938,15 @@ def construct_posterior(prior, likelihood): # noqa: F811
740
938
  if isinstance(likelihood, Gaussian):
741
939
  return ConjugatePosterior(prior=prior, likelihood=likelihood)
742
940
 
941
+ if (
942
+ isinstance(likelihood, HeteroscedasticGaussian)
943
+ and likelihood.supports_tight_bound()
944
+ ):
945
+ return HeteroscedasticPosterior(prior=prior, likelihood=likelihood)
946
+
947
+ if isinstance(likelihood, AbstractHeteroscedasticLikelihood):
948
+ return ChainedPosterior(prior=prior, likelihood=likelihood)
949
+
743
950
  return NonConjugatePosterior(prior=prior, likelihood=likelihood)
744
951
 
745
952
 
@@ -778,7 +985,10 @@ __all__ = [
778
985
  "AbstractPrior",
779
986
  "Prior",
780
987
  "AbstractPosterior",
988
+ "LatentPosterior",
781
989
  "ConjugatePosterior",
782
990
  "NonConjugatePosterior",
991
+ "HeteroscedasticPosterior",
992
+ "ChainedPosterior",
783
993
  "construct_posterior",
784
994
  ]
gpjax/kernels/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
gpjax/kernels/base.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -123,7 +123,7 @@ class AbstractKernel(nnx.Module):
123
123
  """
124
124
  return self.compute_engine.gram(self, x)
125
125
 
126
- def diagonal(self, x: Num[Array, "N D"]) -> Float[Array, " N"]:
126
+ def diagonal(self, x: Num[Array, "N D"]) -> LinearOperator:
127
127
  r"""Compute the diagonal of the gram matrix of the kernel.
128
128
 
129
129
  Args:
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
1
+ # Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.