gpjax 0.13.2__py3-none-any.whl → 0.13.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +3 -3
- gpjax/citation.py +13 -0
- gpjax/dataset.py +1 -1
- gpjax/fit.py +1 -1
- gpjax/gps.py +273 -63
- gpjax/kernels/__init__.py +1 -1
- gpjax/kernels/base.py +2 -2
- gpjax/kernels/computations/__init__.py +1 -1
- gpjax/kernels/computations/base.py +1 -1
- gpjax/kernels/computations/constant_diagonal.py +1 -1
- gpjax/kernels/computations/dense.py +1 -1
- gpjax/kernels/computations/diagonal.py +1 -1
- gpjax/kernels/computations/eigen.py +1 -1
- gpjax/kernels/non_euclidean/__init__.py +1 -1
- gpjax/kernels/non_euclidean/graph.py +18 -6
- gpjax/kernels/non_euclidean/utils.py +1 -1
- gpjax/kernels/nonstationary/__init__.py +1 -1
- gpjax/kernels/nonstationary/arccosine.py +1 -1
- gpjax/kernels/nonstationary/linear.py +1 -1
- gpjax/kernels/nonstationary/polynomial.py +1 -1
- gpjax/kernels/stationary/__init__.py +1 -1
- gpjax/kernels/stationary/base.py +1 -1
- gpjax/kernels/stationary/matern12.py +1 -1
- gpjax/kernels/stationary/matern32.py +1 -1
- gpjax/kernels/stationary/matern52.py +1 -1
- gpjax/kernels/stationary/periodic.py +1 -1
- gpjax/kernels/stationary/powered_exponential.py +1 -1
- gpjax/kernels/stationary/rational_quadratic.py +1 -1
- gpjax/kernels/stationary/rbf.py +1 -1
- gpjax/kernels/stationary/utils.py +1 -1
- gpjax/kernels/stationary/white.py +1 -1
- gpjax/likelihoods.py +234 -0
- gpjax/mean_functions.py +2 -2
- gpjax/objectives.py +56 -1
- gpjax/parameters.py +8 -1
- gpjax/scan.py +1 -1
- gpjax/variational_families.py +129 -0
- {gpjax-0.13.2.dist-info → gpjax-0.13.4.dist-info}/METADATA +13 -13
- gpjax-0.13.4.dist-info/RECORD +52 -0
- gpjax-0.13.2.dist-info/RECORD +0 -52
- {gpjax-0.13.2.dist-info → gpjax-0.13.4.dist-info}/WHEEL +0 -0
- {gpjax-0.13.2.dist-info → gpjax-0.13.4.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/__init__.py
CHANGED
|
@@ -38,9 +38,9 @@ from gpjax.fit import (
|
|
|
38
38
|
|
|
39
39
|
__license__ = "MIT"
|
|
40
40
|
__description__ = "Gaussian processes in JAX and Flax"
|
|
41
|
-
__url__ = "https://github.com/
|
|
42
|
-
__contributors__ = "https://github.com/
|
|
43
|
-
__version__ = "0.13.
|
|
41
|
+
__url__ = "https://github.com/thomaspinder/GPJax"
|
|
42
|
+
__contributors__ = "https://github.com/thomaspinder/GPJax/graphs/contributors"
|
|
43
|
+
__version__ = "0.13.4"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
46
|
"gps",
|
gpjax/citation.py
CHANGED
|
@@ -23,6 +23,7 @@ from gpjax.kernels import (
|
|
|
23
23
|
Matern32,
|
|
24
24
|
Matern52,
|
|
25
25
|
)
|
|
26
|
+
from gpjax.likelihoods import HeteroscedasticGaussian
|
|
26
27
|
|
|
27
28
|
CitationType = Union[None, str, Dict[str, str]]
|
|
28
29
|
|
|
@@ -149,3 +150,15 @@ def _(tree) -> PaperCitation:
|
|
|
149
150
|
booktitle="Advances in neural information processing systems",
|
|
150
151
|
citation_type="article",
|
|
151
152
|
)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
@cite.register(HeteroscedasticGaussian)
|
|
156
|
+
def _(tree) -> PaperCitation:
|
|
157
|
+
return PaperCitation(
|
|
158
|
+
citation_key="lazaro2011variational",
|
|
159
|
+
authors="Lázaro-Gredilla, Miguel and Titsias, Michalis",
|
|
160
|
+
title="Variational heteroscedastic Gaussian process regression",
|
|
161
|
+
year="2011",
|
|
162
|
+
booktitle="Proceedings of the 28th International Conference on Machine Learning (ICML)",
|
|
163
|
+
citation_type="inproceedings",
|
|
164
|
+
)
|
gpjax/dataset.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
gpjax/fit.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2023 The
|
|
1
|
+
# Copyright 2023 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
gpjax/gps.py
CHANGED
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
5
5
|
# You may obtain a copy of the License at
|
|
6
6
|
#
|
|
7
|
-
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
8
|
#
|
|
9
9
|
# Unless required by applicable law or agreed to in writing, software
|
|
10
10
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
@@ -13,10 +13,13 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
# from __future__ import annotations
|
|
16
|
+
|
|
16
17
|
from abc import abstractmethod
|
|
18
|
+
from typing import Literal
|
|
17
19
|
|
|
18
20
|
import beartype.typing as tp
|
|
19
21
|
from flax import nnx
|
|
22
|
+
import jax
|
|
20
23
|
import jax.numpy as jnp
|
|
21
24
|
import jax.random as jr
|
|
22
25
|
from jaxtyping import (
|
|
@@ -29,16 +32,21 @@ from gpjax.distributions import GaussianDistribution
|
|
|
29
32
|
from gpjax.kernels import RFF
|
|
30
33
|
from gpjax.kernels.base import AbstractKernel
|
|
31
34
|
from gpjax.likelihoods import (
|
|
35
|
+
AbstractHeteroscedasticLikelihood,
|
|
32
36
|
AbstractLikelihood,
|
|
33
37
|
Gaussian,
|
|
38
|
+
HeteroscedasticGaussian,
|
|
34
39
|
NonGaussian,
|
|
35
40
|
)
|
|
36
41
|
from gpjax.linalg import (
|
|
37
42
|
Dense,
|
|
43
|
+
Diagonal,
|
|
38
44
|
psd,
|
|
39
45
|
solve,
|
|
40
46
|
)
|
|
41
|
-
from gpjax.linalg.operations import
|
|
47
|
+
from gpjax.linalg.operations import (
|
|
48
|
+
lower_cholesky,
|
|
49
|
+
)
|
|
42
50
|
from gpjax.linalg.utils import add_jitter
|
|
43
51
|
from gpjax.mean_functions import AbstractMeanFunction
|
|
44
52
|
from gpjax.parameters import (
|
|
@@ -56,6 +64,7 @@ M = tp.TypeVar("M", bound=AbstractMeanFunction)
|
|
|
56
64
|
L = tp.TypeVar("L", bound=AbstractLikelihood)
|
|
57
65
|
NGL = tp.TypeVar("NGL", bound=NonGaussian)
|
|
58
66
|
GL = tp.TypeVar("GL", bound=Gaussian)
|
|
67
|
+
HL = tp.TypeVar("HL", bound=AbstractHeteroscedasticLikelihood)
|
|
59
68
|
|
|
60
69
|
|
|
61
70
|
class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
@@ -77,7 +86,12 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
|
77
86
|
self.mean_function = mean_function
|
|
78
87
|
self.jitter = jitter
|
|
79
88
|
|
|
80
|
-
def __call__(
|
|
89
|
+
def __call__(
|
|
90
|
+
self,
|
|
91
|
+
test_inputs: Num[Array, "N D"],
|
|
92
|
+
*,
|
|
93
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
94
|
+
) -> GaussianDistribution:
|
|
81
95
|
r"""Evaluate the Gaussian process at the given points.
|
|
82
96
|
|
|
83
97
|
The output of this function is a
|
|
@@ -91,15 +105,27 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
|
91
105
|
|
|
92
106
|
Args:
|
|
93
107
|
test_inputs: Input locations where the GP should be evaluated.
|
|
108
|
+
return_covariance_type: Literal denoting whether to return the full covariance
|
|
109
|
+
of the joint predictive distribution at the test_inputs (dense)
|
|
110
|
+
or just the the standard-deviation of the predictive distribution at
|
|
111
|
+
the test_inputs.
|
|
94
112
|
|
|
95
113
|
Returns:
|
|
96
114
|
GaussianDistribution: A multivariate normal random variable representation
|
|
97
115
|
of the Gaussian process.
|
|
98
116
|
"""
|
|
99
|
-
return self.predict(
|
|
117
|
+
return self.predict(
|
|
118
|
+
test_inputs,
|
|
119
|
+
return_covariance_type=return_covariance_type,
|
|
120
|
+
)
|
|
100
121
|
|
|
101
122
|
@abstractmethod
|
|
102
|
-
def predict(
|
|
123
|
+
def predict(
|
|
124
|
+
self,
|
|
125
|
+
test_inputs: Num[Array, "N D"],
|
|
126
|
+
*,
|
|
127
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
128
|
+
) -> GaussianDistribution:
|
|
103
129
|
r"""Evaluate the predictive distribution.
|
|
104
130
|
|
|
105
131
|
Compute the latent function's multivariate normal distribution for a
|
|
@@ -108,6 +134,10 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
|
108
134
|
|
|
109
135
|
Args:
|
|
110
136
|
test_inputs: Input locations where the GP should be evaluated.
|
|
137
|
+
return_covariance_type: Literal denoting whether to return the full covariance
|
|
138
|
+
of the joint predictive distribution at the test_inputs (dense)
|
|
139
|
+
or just the the standard-deviation of the predictive distribution at
|
|
140
|
+
the test_inputs.
|
|
111
141
|
|
|
112
142
|
Returns:
|
|
113
143
|
GaussianDistribution: A multivariate normal random variable representation
|
|
@@ -123,8 +153,8 @@ class Prior(AbstractPrior[M, K]):
|
|
|
123
153
|
r"""A Gaussian process prior object.
|
|
124
154
|
|
|
125
155
|
The GP is parameterised by a
|
|
126
|
-
[mean](https://docs.
|
|
127
|
-
and [kernel](https://docs.
|
|
156
|
+
[mean](https://docs.thomaspinder.com/api/mean_functions/)
|
|
157
|
+
and [kernel](https://docs.thomaspinder.com/api/kernels/base/)
|
|
128
158
|
function.
|
|
129
159
|
|
|
130
160
|
A Gaussian process prior parameterised by a mean function $m(\cdot)$ and a kernel
|
|
@@ -220,7 +250,12 @@ class Prior(AbstractPrior[M, K]):
|
|
|
220
250
|
"""
|
|
221
251
|
return self.__mul__(other)
|
|
222
252
|
|
|
223
|
-
def predict(
|
|
253
|
+
def predict(
|
|
254
|
+
self,
|
|
255
|
+
test_inputs: Num[Array, "N D"],
|
|
256
|
+
*,
|
|
257
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
258
|
+
) -> GaussianDistribution:
|
|
224
259
|
r"""Compute the predictive prior distribution for a given set of
|
|
225
260
|
parameters. The output of this function is a function that computes
|
|
226
261
|
a TFP distribution for a given set of inputs.
|
|
@@ -241,17 +276,43 @@ class Prior(AbstractPrior[M, K]):
|
|
|
241
276
|
Args:
|
|
242
277
|
test_inputs (Float[Array, "N D"]): The inputs at which to evaluate the
|
|
243
278
|
prior distribution.
|
|
279
|
+
return_covariance_type: Literal denoting whether to return the full covariance
|
|
280
|
+
of the joint predictive distribution at the test_inputs (dense)
|
|
281
|
+
or just the the standard-deviation of the predictive distribution at
|
|
282
|
+
the test_inputs.
|
|
244
283
|
|
|
245
284
|
Returns:
|
|
246
285
|
GaussianDistribution: A multivariate normal random variable representation
|
|
247
286
|
of the Gaussian process.
|
|
248
287
|
"""
|
|
288
|
+
|
|
289
|
+
def _return_full_covariance(
|
|
290
|
+
t: Num[Array, "N D"],
|
|
291
|
+
) -> Dense:
|
|
292
|
+
Kxx = self.kernel.gram(t)
|
|
293
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
|
|
294
|
+
Kxx = psd(Dense(Kxx_dense))
|
|
295
|
+
return Kxx
|
|
296
|
+
|
|
297
|
+
def _return_diagonal_covariance(
|
|
298
|
+
t: Num[Array, "N D"],
|
|
299
|
+
) -> Dense:
|
|
300
|
+
Kxx = self.kernel.diagonal(t).diagonal
|
|
301
|
+
Kxx += self.jitter
|
|
302
|
+
Kxx = psd(Dense(Diagonal(Kxx).to_dense()))
|
|
303
|
+
return Kxx
|
|
304
|
+
|
|
249
305
|
mean_at_test = self.mean_function(test_inputs)
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
306
|
+
cov = jax.lax.cond(
|
|
307
|
+
return_covariance_type == "dense",
|
|
308
|
+
_return_full_covariance,
|
|
309
|
+
_return_diagonal_covariance,
|
|
310
|
+
test_inputs,
|
|
311
|
+
)
|
|
253
312
|
|
|
254
|
-
return GaussianDistribution(
|
|
313
|
+
return GaussianDistribution(
|
|
314
|
+
loc=jnp.atleast_1d(mean_at_test.squeeze()), scale=cov
|
|
315
|
+
)
|
|
255
316
|
|
|
256
317
|
def sample_approx(
|
|
257
318
|
self,
|
|
@@ -329,7 +390,7 @@ P = tp.TypeVar("P", bound=AbstractPrior)
|
|
|
329
390
|
|
|
330
391
|
#######################
|
|
331
392
|
# GP Posteriors
|
|
332
|
-
#######################
|
|
393
|
+
#######################from gpjax.linalg.operators import LinearOperator
|
|
333
394
|
class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
334
395
|
r"""Abstract Gaussian process posterior.
|
|
335
396
|
|
|
@@ -356,7 +417,11 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
|
356
417
|
self.jitter = jitter
|
|
357
418
|
|
|
358
419
|
def __call__(
|
|
359
|
-
self,
|
|
420
|
+
self,
|
|
421
|
+
test_inputs: Num[Array, "N D"],
|
|
422
|
+
train_data: Dataset,
|
|
423
|
+
*,
|
|
424
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
360
425
|
) -> GaussianDistribution:
|
|
361
426
|
r"""Evaluate the Gaussian process posterior at the given points.
|
|
362
427
|
|
|
@@ -372,16 +437,28 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
|
372
437
|
Args:
|
|
373
438
|
test_inputs: Input locations where the GP should be evaluated.
|
|
374
439
|
train_data: Training dataset to condition on.
|
|
440
|
+
return_covariance_type: Literal denoting whether to return the full covariance
|
|
441
|
+
of the joint predictive distribution at the test_inputs (dense)
|
|
442
|
+
or just the the standard-deviation of the predictive distribution at
|
|
443
|
+
the test_inputs.
|
|
375
444
|
|
|
376
445
|
Returns:
|
|
377
446
|
GaussianDistribution: A multivariate normal random variable representation
|
|
378
447
|
of the Gaussian process.
|
|
379
448
|
"""
|
|
380
|
-
return self.predict(
|
|
449
|
+
return self.predict(
|
|
450
|
+
test_inputs,
|
|
451
|
+
train_data,
|
|
452
|
+
return_covariance_type=return_covariance_type,
|
|
453
|
+
)
|
|
381
454
|
|
|
382
455
|
@abstractmethod
|
|
383
456
|
def predict(
|
|
384
|
-
self,
|
|
457
|
+
self,
|
|
458
|
+
test_inputs: Num[Array, "N D"],
|
|
459
|
+
train_data: Dataset,
|
|
460
|
+
*,
|
|
461
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
385
462
|
) -> GaussianDistribution:
|
|
386
463
|
r"""Compute the latent function's multivariate normal distribution for a
|
|
387
464
|
given set of parameters. For any class inheriting the `AbstractPosterior` class,
|
|
@@ -390,6 +467,10 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
|
390
467
|
Args:
|
|
391
468
|
test_inputs: Input locations where the GP should be evaluated.
|
|
392
469
|
train_data: Training dataset to condition on.
|
|
470
|
+
return_covariance_type: Literal denoting whether to return the full covariance
|
|
471
|
+
of the joint predictive distribution at the test_inputs (dense)
|
|
472
|
+
or just the the standard-deviation of the predictive distribution at
|
|
473
|
+
the test_inputs.
|
|
393
474
|
|
|
394
475
|
Returns:
|
|
395
476
|
GaussianDistribution: A multivariate normal random variable representation
|
|
@@ -398,6 +479,22 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
|
398
479
|
raise NotImplementedError
|
|
399
480
|
|
|
400
481
|
|
|
482
|
+
class LatentPosterior(AbstractPosterior[P, L]):
|
|
483
|
+
r"""A posterior shell used to expose prior structure without inference."""
|
|
484
|
+
|
|
485
|
+
def predict(
|
|
486
|
+
self,
|
|
487
|
+
test_inputs: Num[Array, "N D"],
|
|
488
|
+
train_data: Dataset,
|
|
489
|
+
*,
|
|
490
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
491
|
+
) -> GaussianDistribution:
|
|
492
|
+
raise NotImplementedError(
|
|
493
|
+
"LatentPosteriors are a lightweight wrapper for priors and do not "
|
|
494
|
+
"implement predictive distributions. Use a variational family for inference."
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
|
|
401
498
|
class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
402
499
|
r"""A Conjuate Gaussian process posterior object.
|
|
403
500
|
|
|
@@ -442,8 +539,10 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
442
539
|
|
|
443
540
|
def predict(
|
|
444
541
|
self,
|
|
445
|
-
test_inputs: Num[Array, "
|
|
542
|
+
test_inputs: Num[Array, "M D"],
|
|
446
543
|
train_data: Dataset,
|
|
544
|
+
*,
|
|
545
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
447
546
|
) -> GaussianDistribution:
|
|
448
547
|
r"""Query the predictive posterior distribution.
|
|
449
548
|
|
|
@@ -454,13 +553,13 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
454
553
|
|
|
455
554
|
The predictive distribution of a conjugate GP is given by
|
|
456
555
|
$$
|
|
457
|
-
|
|
458
|
-
|
|
556
|
+
p(\mathbf{f}^{\star}\mid \mathbf{y}) & = \int p(\mathbf{f}^{\star} \mathbf{f} \mid \mathbf{y})\\
|
|
557
|
+
& =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}_{\mid \mathbf{y}}, \boldsymbol{\Sigma}_{\mid \mathbf{y}}
|
|
459
558
|
$$
|
|
460
559
|
where
|
|
461
560
|
$$
|
|
462
|
-
|
|
463
|
-
|
|
561
|
+
\boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \\
|
|
562
|
+
\boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\mathbf{x}^{\star}, \mathbf{x}^{\star\prime}) -k(\mathbf{x}^{\star}, \mathbf{x})\left( k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_n \right)^{-1}k(\mathbf{x}, \mathbf{x}^{\star}).
|
|
464
563
|
$$
|
|
465
564
|
|
|
466
565
|
The conditioning set is a GPJax `Dataset` object, whilst predictions
|
|
@@ -486,44 +585,65 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
486
585
|
predictive distribution is evaluated.
|
|
487
586
|
train_data (Dataset): A `gpx.Dataset` object that contains the input and
|
|
488
587
|
output data used for training dataset.
|
|
588
|
+
return_covariance_type: Literal denoting whether to return the full covariance
|
|
589
|
+
of the joint predictive distribution at the test_inputs (dense)
|
|
590
|
+
or just the the standard-deviation of the predictive distribution at
|
|
591
|
+
the test_inputs.
|
|
489
592
|
|
|
490
593
|
Returns:
|
|
491
594
|
GaussianDistribution: A function that accepts an input array and
|
|
492
595
|
returns the predictive distribution as a `GaussianDistribution`.
|
|
493
596
|
"""
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
# Unpack test inputs
|
|
498
|
-
t = test_inputs
|
|
499
|
-
|
|
597
|
+
x = train_data.X
|
|
598
|
+
y = train_data.y
|
|
500
599
|
# Observation noise o²
|
|
501
|
-
obs_noise = self.likelihood.obs_stddev.value
|
|
600
|
+
obs_noise = jnp.square(self.likelihood.obs_stddev.value)
|
|
502
601
|
mx = self.prior.mean_function(x)
|
|
503
|
-
|
|
504
602
|
# Precompute Gram matrix, Kxx, at training inputs, x
|
|
505
603
|
Kxx = self.prior.kernel.gram(x)
|
|
506
|
-
|
|
507
|
-
Kxx = Dense(Kxx_dense)
|
|
604
|
+
Kxx = add_jitter(Kxx.to_dense(), self.jitter)
|
|
508
605
|
|
|
509
|
-
Sigma_dense = Kxx
|
|
606
|
+
Sigma_dense = Kxx + jnp.eye(Kxx.shape[0]) * obs_noise
|
|
510
607
|
Sigma = psd(Dense(Sigma_dense))
|
|
511
608
|
L_sigma = lower_cholesky(Sigma)
|
|
512
609
|
|
|
513
|
-
|
|
514
|
-
Ktt = self.prior.kernel.gram(t)
|
|
515
|
-
Kxt = self.prior.kernel.cross_covariance(x, t)
|
|
610
|
+
Kxt = self.prior.kernel.cross_covariance(x, test_inputs)
|
|
516
611
|
|
|
517
612
|
L_inv_Kxt = solve(L_sigma, Kxt)
|
|
518
613
|
L_inv_y_diff = solve(L_sigma, y - mx)
|
|
519
614
|
|
|
615
|
+
mean_t = self.prior.mean_function(test_inputs)
|
|
520
616
|
mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff)
|
|
521
617
|
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
618
|
+
def _return_full_covariance(
|
|
619
|
+
L_inv_Kxt: Num[Array, "N M"],
|
|
620
|
+
t: Num[Array, "M D"],
|
|
621
|
+
) -> Dense:
|
|
622
|
+
Ktt = self.prior.kernel.gram(t)
|
|
623
|
+
covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt)
|
|
624
|
+
covariance = add_jitter(covariance, self.prior.jitter)
|
|
625
|
+
covariance = psd(Dense(covariance))
|
|
626
|
+
return covariance
|
|
627
|
+
|
|
628
|
+
def _return_diagonal_covariance(
|
|
629
|
+
L_inv_Kxt: Num[Array, "N M"],
|
|
630
|
+
t: Num[Array, "M D"],
|
|
631
|
+
) -> Dense:
|
|
632
|
+
Ktt = self.prior.kernel.diagonal(t).diagonal
|
|
633
|
+
covariance = Ktt - jnp.einsum("ij, ji->i", L_inv_Kxt.T, L_inv_Kxt)
|
|
634
|
+
covariance += self.prior.jitter
|
|
635
|
+
covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze()))))
|
|
636
|
+
return covariance
|
|
637
|
+
|
|
638
|
+
cov = jax.lax.cond(
|
|
639
|
+
return_covariance_type == "dense",
|
|
640
|
+
_return_full_covariance,
|
|
641
|
+
_return_diagonal_covariance,
|
|
642
|
+
L_inv_Kxt,
|
|
643
|
+
test_inputs,
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
return GaussianDistribution(loc=jnp.atleast_1d(mean.squeeze()), scale=cov)
|
|
527
647
|
|
|
528
648
|
def sample_approx(
|
|
529
649
|
self,
|
|
@@ -567,7 +687,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
567
687
|
|
|
568
688
|
Returns:
|
|
569
689
|
FunctionalSample: A function representing an approximate sample from the Gaussian
|
|
570
|
-
|
|
690
|
+
process prior.
|
|
571
691
|
"""
|
|
572
692
|
if (not isinstance(num_samples, int)) or num_samples <= 0:
|
|
573
693
|
raise ValueError("num_samples must be a positive integer")
|
|
@@ -586,7 +706,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
586
706
|
canonical_weights = solve(
|
|
587
707
|
Sigma,
|
|
588
708
|
y + eps - jnp.inner(Phi, fourier_weights),
|
|
589
|
-
) #
|
|
709
|
+
) # [N, B]
|
|
590
710
|
|
|
591
711
|
def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
|
|
592
712
|
fourier_features = fourier_feature_fn(test_inputs)
|
|
@@ -648,7 +768,11 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
648
768
|
self.key = key
|
|
649
769
|
|
|
650
770
|
def predict(
|
|
651
|
-
self,
|
|
771
|
+
self,
|
|
772
|
+
test_inputs: Num[Array, "M D"],
|
|
773
|
+
train_data: Dataset,
|
|
774
|
+
*,
|
|
775
|
+
return_covariance_type: Literal["dense", "diagonal"] = "dense",
|
|
652
776
|
) -> GaussianDistribution:
|
|
653
777
|
r"""Query the predictive posterior distribution.
|
|
654
778
|
|
|
@@ -660,50 +784,112 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
660
784
|
transformed through the likelihood function's inverse link function.
|
|
661
785
|
|
|
662
786
|
Args:
|
|
663
|
-
|
|
664
|
-
|
|
787
|
+
test_inputs (Num[Array, "N D"]): A Jax array of test inputs at which the
|
|
788
|
+
predictive distribution is evaluated.
|
|
789
|
+
train_data (Dataset): A `gpx.Dataset` object that contains the input
|
|
790
|
+
and output data used for training dataset.
|
|
791
|
+
return_covariance_type: Literal denoting whether to return the full covariance
|
|
792
|
+
of the joint predictive distribution at the test_inputs (dense)
|
|
793
|
+
or just the the standard-deviation of the predictive distribution at
|
|
794
|
+
the test_inputs.
|
|
665
795
|
|
|
666
796
|
Returns:
|
|
667
797
|
GaussianDistribution: A function that accepts an
|
|
668
798
|
input array and returns the predictive distribution as
|
|
669
799
|
a `dx.Distribution`.
|
|
670
800
|
"""
|
|
671
|
-
# Unpack training data
|
|
672
801
|
x = train_data.X
|
|
673
|
-
|
|
674
|
-
# Unpack mean function and kernel
|
|
802
|
+
t = test_inputs
|
|
675
803
|
mean_function = self.prior.mean_function
|
|
676
804
|
kernel = self.prior.kernel
|
|
677
805
|
|
|
678
|
-
# Precompute lower triangular of Gram matrix
|
|
806
|
+
# Precompute lower triangular of Gram matrix
|
|
679
807
|
Kxx = kernel.gram(x)
|
|
680
808
|
Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter)
|
|
681
809
|
Kxx = psd(Dense(Kxx_dense))
|
|
682
810
|
Lx = lower_cholesky(Kxx)
|
|
683
811
|
|
|
684
|
-
|
|
685
|
-
t = test_inputs
|
|
686
|
-
|
|
687
|
-
# Compute terms of the posterior predictive distribution
|
|
688
|
-
Ktx = kernel.cross_covariance(t, x)
|
|
689
|
-
Ktt = kernel.gram(t)
|
|
690
|
-
mean_t = mean_function(t)
|
|
691
|
-
|
|
812
|
+
Kxt = kernel.cross_covariance(x, t)
|
|
692
813
|
# Lx⁻¹ Kxt
|
|
693
|
-
Lx_inv_Kxt = solve(Lx,
|
|
814
|
+
Lx_inv_Kxt = solve(Lx, Kxt)
|
|
694
815
|
|
|
816
|
+
mean_t = mean_function(t)
|
|
695
817
|
# Whitened function values, wx, corresponding to the inputs, x
|
|
696
818
|
wx = self.latent.value
|
|
697
819
|
|
|
698
820
|
# μt + Ktx Lx⁻¹ wx
|
|
699
821
|
mean = mean_t + jnp.matmul(Lx_inv_Kxt.T, wx)
|
|
700
822
|
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
823
|
+
def _return_full_covariance(
|
|
824
|
+
Lx_inv_Kxt: Num[Array, "N M"],
|
|
825
|
+
t: Num[Array, "M D"],
|
|
826
|
+
) -> Dense:
|
|
827
|
+
Ktt = kernel.gram(t)
|
|
828
|
+
covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
|
|
829
|
+
covariance = add_jitter(covariance, self.prior.jitter)
|
|
830
|
+
covariance = psd(Dense(covariance))
|
|
831
|
+
|
|
832
|
+
return covariance
|
|
833
|
+
|
|
834
|
+
def _return_diagonal_covariance(
|
|
835
|
+
Lx_inv_Kxt: Num[Array, "N M"],
|
|
836
|
+
t: Num[Array, "M D"],
|
|
837
|
+
) -> Dense:
|
|
838
|
+
Ktt = kernel.diagonal(t).diagonal
|
|
839
|
+
covariance = Ktt - jnp.einsum("ij, ji->i", Lx_inv_Kxt.T, Lx_inv_Kxt)
|
|
840
|
+
covariance += self.prior.jitter
|
|
841
|
+
# It would be nice to return a Diagonal here, but the pytree needs
|
|
842
|
+
# to be the same for both cond branches and the other branch needs
|
|
843
|
+
# to return a Dense.
|
|
844
|
+
# They are both LinearOperators, but they inherit from that class
|
|
845
|
+
# and hence are not the same pytree anymore.
|
|
846
|
+
covariance = psd(Dense(jnp.diag(jnp.atleast_1d(covariance.squeeze()))))
|
|
847
|
+
|
|
848
|
+
return covariance
|
|
849
|
+
|
|
850
|
+
cov = jax.lax.cond(
|
|
851
|
+
return_covariance_type == "dense",
|
|
852
|
+
_return_full_covariance,
|
|
853
|
+
_return_diagonal_covariance,
|
|
854
|
+
Lx_inv_Kxt,
|
|
855
|
+
test_inputs,
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), cov)
|
|
859
|
+
|
|
860
|
+
|
|
861
|
+
class HeteroscedasticPosterior(LatentPosterior[P, HL]):
|
|
862
|
+
r"""Posterior shell for heteroscedastic likelihoods.
|
|
863
|
+
|
|
864
|
+
The posterior retains both the signal and noise priors; inference is delegated
|
|
865
|
+
to variational families and specialised objectives.
|
|
866
|
+
"""
|
|
867
|
+
|
|
868
|
+
def __init__(
|
|
869
|
+
self,
|
|
870
|
+
prior: AbstractPrior[M, K],
|
|
871
|
+
likelihood: HL,
|
|
872
|
+
jitter: float = 1e-6,
|
|
873
|
+
):
|
|
874
|
+
if likelihood.noise_prior is None:
|
|
875
|
+
raise ValueError("Heteroscedastic likelihoods require a noise_prior.")
|
|
876
|
+
super().__init__(prior=prior, likelihood=likelihood, jitter=jitter)
|
|
877
|
+
self.noise_prior = likelihood.noise_prior
|
|
878
|
+
self.noise_posterior = LatentPosterior(
|
|
879
|
+
prior=self.noise_prior, likelihood=likelihood, jitter=jitter
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
class ChainedPosterior(HeteroscedasticPosterior[P, HL]):
|
|
884
|
+
r"""Posterior routed for heteroscedastic likelihoods using chained bounds."""
|
|
705
885
|
|
|
706
|
-
|
|
886
|
+
def __init__(
|
|
887
|
+
self,
|
|
888
|
+
prior: AbstractPrior[M, K],
|
|
889
|
+
likelihood: HL,
|
|
890
|
+
jitter: float = 1e-6,
|
|
891
|
+
):
|
|
892
|
+
super().__init__(prior=prior, likelihood=likelihood, jitter=jitter)
|
|
707
893
|
|
|
708
894
|
|
|
709
895
|
#######################
|
|
@@ -721,6 +907,18 @@ def construct_posterior( # noqa: F811
|
|
|
721
907
|
) -> NonConjugatePosterior[P, NGL]: ...
|
|
722
908
|
|
|
723
909
|
|
|
910
|
+
@tp.overload
|
|
911
|
+
def construct_posterior( # noqa: F811
|
|
912
|
+
prior: P, likelihood: HeteroscedasticGaussian
|
|
913
|
+
) -> HeteroscedasticPosterior[P, HeteroscedasticGaussian]: ...
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
@tp.overload
|
|
917
|
+
def construct_posterior( # noqa: F811
|
|
918
|
+
prior: P, likelihood: AbstractHeteroscedasticLikelihood
|
|
919
|
+
) -> ChainedPosterior[P, AbstractHeteroscedasticLikelihood]: ...
|
|
920
|
+
|
|
921
|
+
|
|
724
922
|
def construct_posterior(prior, likelihood): # noqa: F811
|
|
725
923
|
r"""Utility function for constructing a posterior object from a prior and
|
|
726
924
|
likelihood. The function will automatically select the correct posterior
|
|
@@ -740,6 +938,15 @@ def construct_posterior(prior, likelihood): # noqa: F811
|
|
|
740
938
|
if isinstance(likelihood, Gaussian):
|
|
741
939
|
return ConjugatePosterior(prior=prior, likelihood=likelihood)
|
|
742
940
|
|
|
941
|
+
if (
|
|
942
|
+
isinstance(likelihood, HeteroscedasticGaussian)
|
|
943
|
+
and likelihood.supports_tight_bound()
|
|
944
|
+
):
|
|
945
|
+
return HeteroscedasticPosterior(prior=prior, likelihood=likelihood)
|
|
946
|
+
|
|
947
|
+
if isinstance(likelihood, AbstractHeteroscedasticLikelihood):
|
|
948
|
+
return ChainedPosterior(prior=prior, likelihood=likelihood)
|
|
949
|
+
|
|
743
950
|
return NonConjugatePosterior(prior=prior, likelihood=likelihood)
|
|
744
951
|
|
|
745
952
|
|
|
@@ -778,7 +985,10 @@ __all__ = [
|
|
|
778
985
|
"AbstractPrior",
|
|
779
986
|
"Prior",
|
|
780
987
|
"AbstractPosterior",
|
|
988
|
+
"LatentPosterior",
|
|
781
989
|
"ConjugatePosterior",
|
|
782
990
|
"NonConjugatePosterior",
|
|
991
|
+
"HeteroscedasticPosterior",
|
|
992
|
+
"ChainedPosterior",
|
|
783
993
|
"construct_posterior",
|
|
784
994
|
]
|
gpjax/kernels/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
gpjax/kernels/base.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -123,7 +123,7 @@ class AbstractKernel(nnx.Module):
|
|
|
123
123
|
"""
|
|
124
124
|
return self.compute_engine.gram(self, x)
|
|
125
125
|
|
|
126
|
-
def diagonal(self, x: Num[Array, "N D"]) ->
|
|
126
|
+
def diagonal(self, x: Num[Array, "N D"]) -> LinearOperator:
|
|
127
127
|
r"""Compute the diagonal of the gram matrix of the kernel.
|
|
128
128
|
|
|
129
129
|
Args:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 The
|
|
1
|
+
# Copyright 2022 The thomaspinder Contributors. All Rights Reserved.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|