gpjax 0.11.2__py3-none-any.whl → 0.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -4
- gpjax/distributions.py +16 -56
- gpjax/fit.py +11 -6
- gpjax/gps.py +61 -73
- gpjax/kernels/approximations/rff.py +2 -5
- gpjax/kernels/base.py +2 -5
- gpjax/kernels/computations/base.py +7 -7
- gpjax/kernels/computations/basis_functions.py +7 -6
- gpjax/kernels/computations/constant_diagonal.py +10 -12
- gpjax/kernels/computations/diagonal.py +6 -6
- gpjax/kernels/computations/eigen.py +1 -1
- gpjax/kernels/non_euclidean/graph.py +10 -11
- gpjax/kernels/nonstationary/arccosine.py +13 -21
- gpjax/kernels/nonstationary/polynomial.py +7 -8
- gpjax/kernels/stationary/periodic.py +3 -6
- gpjax/kernels/stationary/powered_exponential.py +3 -8
- gpjax/kernels/stationary/rational_quadratic.py +5 -8
- gpjax/likelihoods.py +11 -14
- gpjax/linalg/__init__.py +37 -0
- gpjax/linalg/operations.py +237 -0
- gpjax/linalg/operators.py +411 -0
- gpjax/linalg/utils.py +65 -0
- gpjax/mean_functions.py +8 -7
- gpjax/objectives.py +22 -21
- gpjax/parameters.py +11 -23
- gpjax/variational_families.py +93 -67
- {gpjax-0.11.2.dist-info → gpjax-0.12.2.dist-info}/METADATA +50 -18
- gpjax-0.12.2.dist-info/RECORD +52 -0
- gpjax/lower_cholesky.py +0 -69
- gpjax-0.11.2.dist-info/RECORD +0 -49
- {gpjax-0.11.2.dist-info → gpjax-0.12.2.dist-info}/WHEEL +0 -0
- {gpjax-0.11.2.dist-info → gpjax-0.12.2.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/__init__.py
CHANGED
|
@@ -40,10 +40,9 @@ __license__ = "MIT"
|
|
|
40
40
|
__description__ = "Gaussian processes in JAX and Flax"
|
|
41
41
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
42
42
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
43
|
-
__version__ = "0.
|
|
43
|
+
__version__ = "0.12.2"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
|
-
"base",
|
|
47
46
|
"gps",
|
|
48
47
|
"integrators",
|
|
49
48
|
"kernels",
|
|
@@ -55,8 +54,6 @@ __all__ = [
|
|
|
55
54
|
"Dataset",
|
|
56
55
|
"cite",
|
|
57
56
|
"fit",
|
|
58
|
-
"Module",
|
|
59
|
-
"param_field",
|
|
60
57
|
"fit_lbfgs",
|
|
61
58
|
"fit_scipy",
|
|
62
59
|
]
|
gpjax/distributions.py
CHANGED
|
@@ -17,11 +17,6 @@
|
|
|
17
17
|
from beartype.typing import (
|
|
18
18
|
Optional,
|
|
19
19
|
)
|
|
20
|
-
import cola
|
|
21
|
-
from cola.linalg.decompositions import Cholesky
|
|
22
|
-
from cola.ops import (
|
|
23
|
-
LinearOperator,
|
|
24
|
-
)
|
|
25
20
|
from jax import vmap
|
|
26
21
|
import jax.numpy as jnp
|
|
27
22
|
import jax.random as jr
|
|
@@ -30,7 +25,14 @@ from numpyro.distributions import constraints
|
|
|
30
25
|
from numpyro.distributions.distribution import Distribution
|
|
31
26
|
from numpyro.distributions.util import is_prng_key
|
|
32
27
|
|
|
33
|
-
from gpjax.
|
|
28
|
+
from gpjax.linalg.operations import (
|
|
29
|
+
diag,
|
|
30
|
+
logdet,
|
|
31
|
+
lower_cholesky,
|
|
32
|
+
solve,
|
|
33
|
+
)
|
|
34
|
+
from gpjax.linalg.operators import LinearOperator
|
|
35
|
+
from gpjax.linalg.utils import psd
|
|
34
36
|
from gpjax.typing import (
|
|
35
37
|
Array,
|
|
36
38
|
ScalarFloat,
|
|
@@ -47,7 +49,7 @@ class GaussianDistribution(Distribution):
|
|
|
47
49
|
validate_args=None,
|
|
48
50
|
):
|
|
49
51
|
self.loc = loc
|
|
50
|
-
self.scale =
|
|
52
|
+
self.scale = psd(scale)
|
|
51
53
|
batch_shape = ()
|
|
52
54
|
event_shape = jnp.shape(self.loc)
|
|
53
55
|
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
|
@@ -76,13 +78,12 @@ class GaussianDistribution(Distribution):
|
|
|
76
78
|
@property
|
|
77
79
|
def variance(self) -> Float[Array, " N"]:
|
|
78
80
|
r"""Calculates the variance."""
|
|
79
|
-
return
|
|
81
|
+
return diag(self.scale)
|
|
80
82
|
|
|
81
83
|
def entropy(self) -> ScalarFloat:
|
|
82
84
|
r"""Calculates the entropy of the distribution."""
|
|
83
85
|
return 0.5 * (
|
|
84
|
-
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
|
|
85
|
-
+ cola.logdet(self.scale, Cholesky(), Cholesky())
|
|
86
|
+
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi)) + logdet(self.scale)
|
|
86
87
|
)
|
|
87
88
|
|
|
88
89
|
def median(self) -> Float[Array, " N"]:
|
|
@@ -104,7 +105,7 @@ class GaussianDistribution(Distribution):
|
|
|
104
105
|
|
|
105
106
|
def stddev(self) -> Float[Array, " N"]:
|
|
106
107
|
r"""Calculates the standard deviation."""
|
|
107
|
-
return jnp.sqrt(
|
|
108
|
+
return jnp.sqrt(diag(self.scale))
|
|
108
109
|
|
|
109
110
|
# @property
|
|
110
111
|
# def event_shape(self) -> Tuple:
|
|
@@ -129,9 +130,7 @@ class GaussianDistribution(Distribution):
|
|
|
129
130
|
|
|
130
131
|
# compute the pdf, -1/2[ n log(2π) + log|Σ| + (y - µ)ᵀΣ⁻¹(y - µ) ]
|
|
131
132
|
return -0.5 * (
|
|
132
|
-
n * jnp.log(2.0 * jnp.pi)
|
|
133
|
-
+ cola.logdet(sigma, Cholesky(), Cholesky())
|
|
134
|
-
+ diff.T @ cola.solve(sigma, diff, Cholesky())
|
|
133
|
+
n * jnp.log(2.0 * jnp.pi) + logdet(sigma) + diff.T @ solve(sigma, diff)
|
|
135
134
|
)
|
|
136
135
|
|
|
137
136
|
# def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
|
|
@@ -219,53 +218,14 @@ def _kl_divergence(q: GaussianDistribution, p: GaussianDistribution) -> ScalarFl
|
|
|
219
218
|
|
|
220
219
|
# trace term, tr[Σp⁻¹ Σq] = tr[(LpLpᵀ)⁻¹(LqLqᵀ)] = tr[(Lp⁻¹Lq)(Lp⁻¹Lq)ᵀ] = (fr[LqLp⁻¹])²
|
|
221
220
|
trace = _frobenius_norm_squared(
|
|
222
|
-
|
|
221
|
+
solve(sqrt_p, sqrt_q.to_dense())
|
|
223
222
|
) # TODO: Not most efficient, given the `to_dense()` call (e.g., consider diagonal p and q). Need to abstract solving linear operator against another linear operator.
|
|
224
223
|
|
|
225
224
|
# Mahalanobis term, (μp - μq)ᵀ Σp⁻¹ (μp - μq) = tr [(μp - μq)ᵀ [LpLpᵀ]⁻¹ (μp - μq)] = (fr[Lp⁻¹(μp - μq)])²
|
|
226
|
-
mahalanobis = jnp.sum(jnp.square(
|
|
225
|
+
mahalanobis = jnp.sum(jnp.square(solve(sqrt_p, diff)))
|
|
227
226
|
|
|
228
227
|
# KL[q(x)||p(x)] = [ [(μp - μq)ᵀ Σp⁻¹ (μp - μq)] - n - log|Σq| + log|Σp| + tr[Σp⁻¹ Σq] ] / 2
|
|
229
|
-
return (
|
|
230
|
-
mahalanobis
|
|
231
|
-
- n_dim
|
|
232
|
-
- cola.logdet(sigma_q, Cholesky(), Cholesky())
|
|
233
|
-
+ cola.logdet(sigma_p, Cholesky(), Cholesky())
|
|
234
|
-
+ trace
|
|
235
|
-
) / 2.0
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
# def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
|
|
239
|
-
# r"""Checks that the inputs are correct."""
|
|
240
|
-
# if loc is None and scale is None:
|
|
241
|
-
# raise ValueError("At least one of `loc` or `scale` must be specified.")
|
|
242
|
-
|
|
243
|
-
# if loc is not None and loc.ndim < 1:
|
|
244
|
-
# raise ValueError("The parameter `loc` must have at least one dimension.")
|
|
245
|
-
|
|
246
|
-
# if scale is not None and len(scale.shape) < 2: # scale.ndim < 2:
|
|
247
|
-
# raise ValueError(
|
|
248
|
-
# "The `scale` must have at least two dimensions, but "
|
|
249
|
-
# f"`scale.shape = {scale.shape}`."
|
|
250
|
-
# )
|
|
251
|
-
|
|
252
|
-
# if scale is not None and not isinstance(scale, LinearOperator):
|
|
253
|
-
# raise ValueError(
|
|
254
|
-
# f"The `scale` must be a CoLA LinearOperator but got {type(scale)}"
|
|
255
|
-
# )
|
|
256
|
-
|
|
257
|
-
# if scale is not None and (scale.shape[-1] != scale.shape[-2]):
|
|
258
|
-
# raise ValueError(
|
|
259
|
-
# f"The `scale` must be a square matrix, but `scale.shape = {scale.shape}`."
|
|
260
|
-
# )
|
|
261
|
-
|
|
262
|
-
# if loc is not None:
|
|
263
|
-
# num_dims = loc.shape[-1]
|
|
264
|
-
# if scale is not None and (scale.shape[-1] != num_dims):
|
|
265
|
-
# raise ValueError(
|
|
266
|
-
# f"Shapes are not compatible: `loc.shape = {loc.shape}` and "
|
|
267
|
-
# f"`scale.shape = {scale.shape}`."
|
|
268
|
-
# )
|
|
228
|
+
return (mahalanobis - n_dim - logdet(sigma_q) + logdet(sigma_p) + trace) / 2.0
|
|
269
229
|
|
|
270
230
|
|
|
271
231
|
__all__ = [
|
gpjax/fit.py
CHANGED
|
@@ -48,6 +48,7 @@ def fit( # noqa: PLR0913
|
|
|
48
48
|
train_data: Dataset,
|
|
49
49
|
optim: ox.GradientTransformation,
|
|
50
50
|
params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
|
|
51
|
+
trainable: nnx.filterlib.Filter = Parameter,
|
|
51
52
|
key: KeyArray = jr.PRNGKey(42),
|
|
52
53
|
num_iters: int = 100,
|
|
53
54
|
batch_size: int = -1,
|
|
@@ -65,7 +66,7 @@ def fit( # noqa: PLR0913
|
|
|
65
66
|
>>> import jax.random as jr
|
|
66
67
|
>>> import optax as ox
|
|
67
68
|
>>> import gpjax as gpx
|
|
68
|
-
>>> from gpjax.parameters import PositiveReal
|
|
69
|
+
>>> from gpjax.parameters import PositiveReal
|
|
69
70
|
>>>
|
|
70
71
|
>>> # (1) Create a dataset:
|
|
71
72
|
>>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
|
|
@@ -75,10 +76,10 @@ def fit( # noqa: PLR0913
|
|
|
75
76
|
>>> class LinearModel(nnx.Module):
|
|
76
77
|
>>> def __init__(self, weight: float, bias: float):
|
|
77
78
|
>>> self.weight = PositiveReal(weight)
|
|
78
|
-
>>> self.bias =
|
|
79
|
+
>>> self.bias = bias
|
|
79
80
|
>>>
|
|
80
81
|
>>> def __call__(self, x):
|
|
81
|
-
>>> return self.weight.value * x + self.bias
|
|
82
|
+
>>> return self.weight.value * x + self.bias
|
|
82
83
|
>>>
|
|
83
84
|
>>> model = LinearModel(weight=1.0, bias=1.0)
|
|
84
85
|
>>>
|
|
@@ -100,6 +101,8 @@ def fit( # noqa: PLR0913
|
|
|
100
101
|
train_data (Dataset): The training data to be used for the optimisation.
|
|
101
102
|
optim (GradientTransformation): The Optax optimiser that is to be used for
|
|
102
103
|
learning a parameter set.
|
|
104
|
+
trainable (nnx.filterlib.Filter): Filter to determine which parameters are trainable.
|
|
105
|
+
Defaults to nnx.Param (all Parameter instances).
|
|
103
106
|
num_iters (int): The number of optimisation steps to run. Defaults
|
|
104
107
|
to 100.
|
|
105
108
|
batch_size (int): The size of the mini-batch to use. Defaults to -1
|
|
@@ -127,7 +130,7 @@ def fit( # noqa: PLR0913
|
|
|
127
130
|
_check_verbose(verbose)
|
|
128
131
|
|
|
129
132
|
# Model state filtering
|
|
130
|
-
graphdef, params, *static_state = nnx.split(model,
|
|
133
|
+
graphdef, params, *static_state = nnx.split(model, trainable, ...)
|
|
131
134
|
|
|
132
135
|
# Parameters bijection to unconstrained space
|
|
133
136
|
if params_bijection is not None:
|
|
@@ -182,6 +185,7 @@ def fit_scipy( # noqa: PLR0913
|
|
|
182
185
|
model: Model,
|
|
183
186
|
objective: Objective,
|
|
184
187
|
train_data: Dataset,
|
|
188
|
+
trainable: nnx.filterlib.Filter = Parameter,
|
|
185
189
|
max_iters: int = 500,
|
|
186
190
|
verbose: bool = True,
|
|
187
191
|
safe: bool = True,
|
|
@@ -210,7 +214,7 @@ def fit_scipy( # noqa: PLR0913
|
|
|
210
214
|
_check_verbose(verbose)
|
|
211
215
|
|
|
212
216
|
# Model state filtering
|
|
213
|
-
graphdef, params, *static_state = nnx.split(model,
|
|
217
|
+
graphdef, params, *static_state = nnx.split(model, trainable, ...)
|
|
214
218
|
|
|
215
219
|
# Parameters bijection to unconstrained space
|
|
216
220
|
params = transform(params, DEFAULT_BIJECTION, inverse=True)
|
|
@@ -258,6 +262,7 @@ def fit_lbfgs(
|
|
|
258
262
|
objective: Objective,
|
|
259
263
|
train_data: Dataset,
|
|
260
264
|
params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
|
|
265
|
+
trainable: nnx.filterlib.Filter = Parameter,
|
|
261
266
|
max_iters: int = 100,
|
|
262
267
|
safe: bool = True,
|
|
263
268
|
max_linesearch_steps: int = 32,
|
|
@@ -290,7 +295,7 @@ def fit_lbfgs(
|
|
|
290
295
|
_check_num_iters(max_iters)
|
|
291
296
|
|
|
292
297
|
# Model state filtering
|
|
293
|
-
graphdef, params, *static_state = nnx.split(model,
|
|
298
|
+
graphdef, params, *static_state = nnx.split(model, trainable, ...)
|
|
294
299
|
|
|
295
300
|
# Parameters bijection to unconstrained space
|
|
296
301
|
if params_bijection is not None:
|
gpjax/gps.py
CHANGED
|
@@ -16,14 +16,9 @@
|
|
|
16
16
|
from abc import abstractmethod
|
|
17
17
|
|
|
18
18
|
import beartype.typing as tp
|
|
19
|
-
from cola.annotations import PSD
|
|
20
|
-
from cola.linalg.algorithm_base import Algorithm
|
|
21
|
-
from cola.linalg.decompositions.decompositions import Cholesky
|
|
22
|
-
from cola.linalg.inverse.inv import solve
|
|
23
|
-
from cola.ops.operators import I_like
|
|
24
|
-
from flax import nnx
|
|
25
19
|
import jax.numpy as jnp
|
|
26
20
|
import jax.random as jr
|
|
21
|
+
from flax import nnx
|
|
27
22
|
from jaxtyping import (
|
|
28
23
|
Float,
|
|
29
24
|
Num,
|
|
@@ -38,12 +33,17 @@ from gpjax.likelihoods import (
|
|
|
38
33
|
Gaussian,
|
|
39
34
|
NonGaussian,
|
|
40
35
|
)
|
|
41
|
-
from gpjax.
|
|
36
|
+
from gpjax.linalg import (
|
|
37
|
+
Dense,
|
|
38
|
+
psd,
|
|
39
|
+
solve,
|
|
40
|
+
)
|
|
41
|
+
from gpjax.linalg.operations import lower_cholesky
|
|
42
|
+
from gpjax.linalg.utils import add_jitter
|
|
42
43
|
from gpjax.mean_functions import AbstractMeanFunction
|
|
43
44
|
from gpjax.parameters import (
|
|
44
45
|
Parameter,
|
|
45
46
|
Real,
|
|
46
|
-
Static,
|
|
47
47
|
)
|
|
48
48
|
from gpjax.typing import (
|
|
49
49
|
Array,
|
|
@@ -77,7 +77,7 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
|
77
77
|
self.mean_function = mean_function
|
|
78
78
|
self.jitter = jitter
|
|
79
79
|
|
|
80
|
-
def __call__(self,
|
|
80
|
+
def __call__(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
|
|
81
81
|
r"""Evaluate the Gaussian process at the given points.
|
|
82
82
|
|
|
83
83
|
The output of this function is a
|
|
@@ -90,17 +90,16 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
|
90
90
|
`__call__` method and should instead define a `predict` method.
|
|
91
91
|
|
|
92
92
|
Args:
|
|
93
|
-
|
|
94
|
-
**kwargs (Any): The keyword arguments to pass to the GP's `predict` method.
|
|
93
|
+
test_inputs: Input locations where the GP should be evaluated.
|
|
95
94
|
|
|
96
95
|
Returns:
|
|
97
96
|
GaussianDistribution: A multivariate normal random variable representation
|
|
98
97
|
of the Gaussian process.
|
|
99
98
|
"""
|
|
100
|
-
return self.predict(
|
|
99
|
+
return self.predict(test_inputs)
|
|
101
100
|
|
|
102
101
|
@abstractmethod
|
|
103
|
-
def predict(self,
|
|
102
|
+
def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
|
|
104
103
|
r"""Evaluate the predictive distribution.
|
|
105
104
|
|
|
106
105
|
Compute the latent function's multivariate normal distribution for a
|
|
@@ -108,8 +107,7 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
|
108
107
|
this method must be implemented.
|
|
109
108
|
|
|
110
109
|
Args:
|
|
111
|
-
|
|
112
|
-
**kwargs (Any): Keyword arguments to the predict method.
|
|
110
|
+
test_inputs: Input locations where the GP should be evaluated.
|
|
113
111
|
|
|
114
112
|
Returns:
|
|
115
113
|
GaussianDistribution: A multivariate normal random variable representation
|
|
@@ -248,13 +246,12 @@ class Prior(AbstractPrior[M, K]):
|
|
|
248
246
|
GaussianDistribution: A multivariate normal random variable representation
|
|
249
247
|
of the Gaussian process.
|
|
250
248
|
"""
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
Kxx
|
|
255
|
-
Kxx = PSD(Kxx)
|
|
249
|
+
mean_at_test = self.mean_function(test_inputs)
|
|
250
|
+
Kxx = self.kernel.gram(test_inputs)
|
|
251
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
|
|
252
|
+
Kxx = psd(Dense(Kxx_dense))
|
|
256
253
|
|
|
257
|
-
return GaussianDistribution(jnp.atleast_1d(
|
|
254
|
+
return GaussianDistribution(jnp.atleast_1d(mean_at_test.squeeze()), Kxx)
|
|
258
255
|
|
|
259
256
|
def sample_approx(
|
|
260
257
|
self,
|
|
@@ -315,15 +312,13 @@ class Prior(AbstractPrior[M, K]):
|
|
|
315
312
|
if (not isinstance(num_samples, int)) or num_samples <= 0:
|
|
316
313
|
raise ValueError("num_samples must be a positive integer")
|
|
317
314
|
|
|
318
|
-
# sample fourier features
|
|
319
315
|
fourier_feature_fn = _build_fourier_features_fn(self, num_features, key)
|
|
320
316
|
|
|
321
|
-
|
|
322
|
-
feature_weights = jr.normal(key, [num_samples, 2 * num_features]) # [B, L]
|
|
317
|
+
feature_weights = jr.normal(key, [num_samples, 2 * num_features])
|
|
323
318
|
|
|
324
319
|
def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]:
|
|
325
|
-
feature_evals = fourier_feature_fn(test_inputs)
|
|
326
|
-
evaluated_sample = jnp.inner(feature_evals, feature_weights)
|
|
320
|
+
feature_evals = fourier_feature_fn(test_inputs)
|
|
321
|
+
evaluated_sample = jnp.inner(feature_evals, feature_weights)
|
|
327
322
|
return self.mean_function(test_inputs) + evaluated_sample
|
|
328
323
|
|
|
329
324
|
return sample_fn
|
|
@@ -360,7 +355,9 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
|
360
355
|
self.likelihood = likelihood
|
|
361
356
|
self.jitter = jitter
|
|
362
357
|
|
|
363
|
-
def __call__(
|
|
358
|
+
def __call__(
|
|
359
|
+
self, test_inputs: Num[Array, "N D"], train_data: Dataset
|
|
360
|
+
) -> GaussianDistribution:
|
|
364
361
|
r"""Evaluate the Gaussian process posterior at the given points.
|
|
365
362
|
|
|
366
363
|
The output of this function is a
|
|
@@ -369,28 +366,30 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
|
369
366
|
evaluated and the distribution can be sampled.
|
|
370
367
|
|
|
371
368
|
Under the hood, `__call__` is calling the objects `predict` method. For this
|
|
372
|
-
reasons, classes inheriting the `
|
|
369
|
+
reasons, classes inheriting the `AbstractPosterior` class, should not overwrite the
|
|
373
370
|
`__call__` method and should instead define a `predict` method.
|
|
374
371
|
|
|
375
372
|
Args:
|
|
376
|
-
|
|
377
|
-
|
|
373
|
+
test_inputs: Input locations where the GP should be evaluated.
|
|
374
|
+
train_data: Training dataset to condition on.
|
|
378
375
|
|
|
379
376
|
Returns:
|
|
380
377
|
GaussianDistribution: A multivariate normal random variable representation
|
|
381
378
|
of the Gaussian process.
|
|
382
379
|
"""
|
|
383
|
-
return self.predict(
|
|
380
|
+
return self.predict(test_inputs, train_data)
|
|
384
381
|
|
|
385
382
|
@abstractmethod
|
|
386
|
-
def predict(
|
|
383
|
+
def predict(
|
|
384
|
+
self, test_inputs: Num[Array, "N D"], train_data: Dataset
|
|
385
|
+
) -> GaussianDistribution:
|
|
387
386
|
r"""Compute the latent function's multivariate normal distribution for a
|
|
388
|
-
given set of parameters. For any class inheriting the `
|
|
387
|
+
given set of parameters. For any class inheriting the `AbstractPosterior` class,
|
|
389
388
|
this method must be implemented.
|
|
390
389
|
|
|
391
390
|
Args:
|
|
392
|
-
|
|
393
|
-
|
|
391
|
+
test_inputs: Input locations where the GP should be evaluated.
|
|
392
|
+
train_data: Training dataset to condition on.
|
|
394
393
|
|
|
395
394
|
Returns:
|
|
396
395
|
GaussianDistribution: A multivariate normal random variable representation
|
|
@@ -504,24 +503,25 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
504
503
|
|
|
505
504
|
# Precompute Gram matrix, Kxx, at training inputs, x
|
|
506
505
|
Kxx = self.prior.kernel.gram(x)
|
|
507
|
-
|
|
506
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
|
|
507
|
+
Kxx = Dense(Kxx_dense)
|
|
508
508
|
|
|
509
|
-
|
|
510
|
-
Sigma =
|
|
511
|
-
|
|
509
|
+
Sigma_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * obs_noise
|
|
510
|
+
Sigma = psd(Dense(Sigma_dense))
|
|
511
|
+
L_sigma = lower_cholesky(Sigma)
|
|
512
512
|
|
|
513
513
|
mean_t = self.prior.mean_function(t)
|
|
514
514
|
Ktt = self.prior.kernel.gram(t)
|
|
515
515
|
Kxt = self.prior.kernel.cross_covariance(x, t)
|
|
516
|
-
Sigma_inv_Kxt = solve(Sigma, Kxt, Cholesky())
|
|
517
516
|
|
|
518
|
-
|
|
519
|
-
|
|
517
|
+
L_inv_Kxt = solve(L_sigma, Kxt)
|
|
518
|
+
L_inv_y_diff = solve(L_sigma, y - mx)
|
|
519
|
+
|
|
520
|
+
mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff)
|
|
520
521
|
|
|
521
|
-
|
|
522
|
-
covariance =
|
|
523
|
-
covariance
|
|
524
|
-
covariance = PSD(covariance)
|
|
522
|
+
covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt)
|
|
523
|
+
covariance = add_jitter(covariance, self.prior.jitter)
|
|
524
|
+
covariance = psd(Dense(covariance))
|
|
525
525
|
|
|
526
526
|
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
|
|
527
527
|
|
|
@@ -531,7 +531,6 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
531
531
|
train_data: Dataset,
|
|
532
532
|
key: KeyArray,
|
|
533
533
|
num_features: int | None = 100,
|
|
534
|
-
solver_algorithm: tp.Optional[Algorithm] = Cholesky(),
|
|
535
534
|
) -> FunctionalSample:
|
|
536
535
|
r"""Draw approximate samples from the Gaussian process posterior.
|
|
537
536
|
|
|
@@ -565,11 +564,6 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
565
564
|
key (KeyArray): The random seed used for the sample(s).
|
|
566
565
|
num_features (int): The number of features used when approximating the
|
|
567
566
|
kernel.
|
|
568
|
-
solver_algorithm (Optional[Algorithm], optional): The algorithm to use for the solves of
|
|
569
|
-
the inverse of the covariance matrix. See the
|
|
570
|
-
[CoLA documentation](https://cola.readthedocs.io/en/latest/package/cola.linalg.html#algorithms)
|
|
571
|
-
for which solver to pick. For PSD matrices, CoLA currently recommends Cholesky() for small
|
|
572
|
-
matrices and CG() for larger matrices. Select Auto() to let CoLA decide. Defaults to Cholesky().
|
|
573
567
|
|
|
574
568
|
Returns:
|
|
575
569
|
FunctionalSample: A function representing an approximate sample from the Gaussian
|
|
@@ -581,31 +575,25 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
581
575
|
# sample fourier features
|
|
582
576
|
fourier_feature_fn = _build_fourier_features_fn(self.prior, num_features, key)
|
|
583
577
|
|
|
584
|
-
|
|
585
|
-
fourier_weights = jr.normal(key, [num_samples, 2 * num_features]) # [B, L]
|
|
578
|
+
fourier_weights = jr.normal(key, [num_samples, 2 * num_features])
|
|
586
579
|
|
|
587
|
-
# sample weights v for canonical features
|
|
588
|
-
# v = Σ⁻¹ (y + ε - ɸ⍵) for Σ = Kxx + Io² and ε ᯈ N(0, o²)
|
|
589
580
|
obs_var = self.likelihood.obs_stddev.value**2
|
|
590
|
-
Kxx = self.prior.kernel.gram(train_data.X)
|
|
591
|
-
Sigma = Kxx
|
|
592
|
-
eps = jnp.sqrt(obs_var) * jr.normal(key, [train_data.n, num_samples])
|
|
593
|
-
y = train_data.y - self.prior.mean_function(train_data.X)
|
|
581
|
+
Kxx = self.prior.kernel.gram(train_data.X)
|
|
582
|
+
Sigma = Dense(add_jitter(Kxx.to_dense(), obs_var + self.jitter))
|
|
583
|
+
eps = jnp.sqrt(obs_var) * jr.normal(key, [train_data.n, num_samples])
|
|
584
|
+
y = train_data.y - self.prior.mean_function(train_data.X)
|
|
594
585
|
Phi = fourier_feature_fn(train_data.X)
|
|
595
586
|
canonical_weights = solve(
|
|
596
587
|
Sigma,
|
|
597
588
|
y + eps - jnp.inner(Phi, fourier_weights),
|
|
598
|
-
solver_algorithm,
|
|
599
589
|
) # [N, B]
|
|
600
590
|
|
|
601
591
|
def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
|
|
602
|
-
fourier_features = fourier_feature_fn(test_inputs)
|
|
603
|
-
weight_space_contribution = jnp.inner(
|
|
604
|
-
fourier_features, fourier_weights
|
|
605
|
-
) # [n, B]
|
|
592
|
+
fourier_features = fourier_feature_fn(test_inputs)
|
|
593
|
+
weight_space_contribution = jnp.inner(fourier_features, fourier_weights)
|
|
606
594
|
canonical_features = self.prior.kernel.cross_covariance(
|
|
607
595
|
test_inputs, train_data.X
|
|
608
|
-
)
|
|
596
|
+
)
|
|
609
597
|
function_space_contribution = jnp.matmul(
|
|
610
598
|
canonical_features, canonical_weights
|
|
611
599
|
)
|
|
@@ -657,7 +645,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
657
645
|
|
|
658
646
|
# TODO: static or intermediate?
|
|
659
647
|
self.latent = latent if isinstance(latent, Parameter) else Real(latent)
|
|
660
|
-
self.key =
|
|
648
|
+
self.key = key
|
|
661
649
|
|
|
662
650
|
def predict(
|
|
663
651
|
self, test_inputs: Num[Array, "N D"], train_data: Dataset
|
|
@@ -689,8 +677,8 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
689
677
|
|
|
690
678
|
# Precompute lower triangular of Gram matrix, Lx, at training inputs, x
|
|
691
679
|
Kxx = kernel.gram(x)
|
|
692
|
-
|
|
693
|
-
Kxx =
|
|
680
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter)
|
|
681
|
+
Kxx = psd(Dense(Kxx_dense))
|
|
694
682
|
Lx = lower_cholesky(Kxx)
|
|
695
683
|
|
|
696
684
|
# Unpack test inputs
|
|
@@ -702,7 +690,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
702
690
|
mean_t = mean_function(t)
|
|
703
691
|
|
|
704
692
|
# Lx⁻¹ Kxt
|
|
705
|
-
Lx_inv_Kxt = solve(Lx, Ktx.T
|
|
693
|
+
Lx_inv_Kxt = solve(Lx, Ktx.T)
|
|
706
694
|
|
|
707
695
|
# Whitened function values, wx, corresponding to the inputs, x
|
|
708
696
|
wx = self.latent.value
|
|
@@ -711,9 +699,9 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
711
699
|
mean = mean_t + jnp.matmul(Lx_inv_Kxt.T, wx)
|
|
712
700
|
|
|
713
701
|
# Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
|
|
714
|
-
covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
|
|
715
|
-
covariance
|
|
716
|
-
covariance =
|
|
702
|
+
covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
|
|
703
|
+
covariance = add_jitter(covariance, self.prior.jitter)
|
|
704
|
+
covariance = psd(Dense(covariance))
|
|
717
705
|
|
|
718
706
|
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
|
|
719
707
|
|
|
@@ -7,7 +7,6 @@ from jaxtyping import Float
|
|
|
7
7
|
from gpjax.kernels.base import AbstractKernel
|
|
8
8
|
from gpjax.kernels.computations import BasisFunctionComputation
|
|
9
9
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
10
|
-
from gpjax.parameters import Static
|
|
11
10
|
from gpjax.typing import (
|
|
12
11
|
Array,
|
|
13
12
|
KeyArray,
|
|
@@ -66,10 +65,8 @@ class RFF(AbstractKernel):
|
|
|
66
65
|
"Please specify the n_dims argument for the base kernel."
|
|
67
66
|
)
|
|
68
67
|
|
|
69
|
-
self.frequencies =
|
|
70
|
-
self.
|
|
71
|
-
key=key, sample_shape=(self.num_basis_fns, n_dims)
|
|
72
|
-
)
|
|
68
|
+
self.frequencies = self.base_kernel.spectral_density.sample(
|
|
69
|
+
key=key, sample_shape=(self.num_basis_fns, n_dims)
|
|
73
70
|
)
|
|
74
71
|
self.name = f"{self.base_kernel.name} (RFF)"
|
|
75
72
|
|
gpjax/kernels/base.py
CHANGED
|
@@ -17,7 +17,6 @@ import abc
|
|
|
17
17
|
import functools as ft
|
|
18
18
|
|
|
19
19
|
import beartype.typing as tp
|
|
20
|
-
from cola.ops.operator_base import LinearOperator
|
|
21
20
|
from flax import nnx
|
|
22
21
|
import jax.numpy as jnp
|
|
23
22
|
from jaxtyping import (
|
|
@@ -29,10 +28,10 @@ from gpjax.kernels.computations import (
|
|
|
29
28
|
AbstractKernelComputation,
|
|
30
29
|
DenseKernelComputation,
|
|
31
30
|
)
|
|
31
|
+
from gpjax.linalg import LinearOperator
|
|
32
32
|
from gpjax.parameters import (
|
|
33
33
|
Parameter,
|
|
34
34
|
Real,
|
|
35
|
-
Static,
|
|
36
35
|
)
|
|
37
36
|
from gpjax.typing import (
|
|
38
37
|
Array,
|
|
@@ -221,9 +220,7 @@ class Constant(AbstractKernel):
|
|
|
221
220
|
def __init__(
|
|
222
221
|
self,
|
|
223
222
|
active_dims: tp.Union[list[int], slice, None] = None,
|
|
224
|
-
constant: tp.Union[
|
|
225
|
-
ScalarFloat, Parameter[ScalarFloat], Static[ScalarFloat]
|
|
226
|
-
] = jnp.array(0.0),
|
|
223
|
+
constant: tp.Union[ScalarFloat, Parameter[ScalarFloat]] = jnp.array(0.0),
|
|
227
224
|
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
|
|
228
225
|
):
|
|
229
226
|
if isinstance(constant, Parameter):
|
|
@@ -16,11 +16,6 @@
|
|
|
16
16
|
import abc
|
|
17
17
|
import typing as tp
|
|
18
18
|
|
|
19
|
-
from cola.annotations import PSD
|
|
20
|
-
from cola.ops.operators import (
|
|
21
|
-
Dense,
|
|
22
|
-
Diagonal,
|
|
23
|
-
)
|
|
24
19
|
from jax import vmap
|
|
25
20
|
from jaxtyping import (
|
|
26
21
|
Float,
|
|
@@ -28,6 +23,11 @@ from jaxtyping import (
|
|
|
28
23
|
)
|
|
29
24
|
|
|
30
25
|
import gpjax
|
|
26
|
+
from gpjax.linalg import (
|
|
27
|
+
Dense,
|
|
28
|
+
Diagonal,
|
|
29
|
+
psd,
|
|
30
|
+
)
|
|
31
31
|
from gpjax.typing import Array
|
|
32
32
|
|
|
33
33
|
K = tp.TypeVar("K", bound="gpjax.kernels.base.AbstractKernel") # noqa: F821
|
|
@@ -69,7 +69,7 @@ class AbstractKernelComputation:
|
|
|
69
69
|
The Gram covariance of the kernel function as a linear operator.
|
|
70
70
|
"""
|
|
71
71
|
Kxx = self.cross_covariance(kernel, x, x)
|
|
72
|
-
return
|
|
72
|
+
return psd(Dense(Kxx))
|
|
73
73
|
|
|
74
74
|
@abc.abstractmethod
|
|
75
75
|
def _cross_covariance(
|
|
@@ -93,7 +93,7 @@ class AbstractKernelComputation:
|
|
|
93
93
|
return self._cross_covariance(kernel, x, y)
|
|
94
94
|
|
|
95
95
|
def _diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Diagonal:
|
|
96
|
-
return
|
|
96
|
+
return psd(Diagonal(vmap(lambda x: kernel(x, x))(inputs)))
|
|
97
97
|
|
|
98
98
|
def diagonal(self, kernel: K, inputs: Num[Array, "N D"]) -> Diagonal:
|
|
99
99
|
r"""For a given kernel, compute the elementwise diagonal of the
|
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
import typing as tp
|
|
2
2
|
|
|
3
|
-
from cola.annotations import PSD
|
|
4
|
-
from cola.ops.operators import Dense
|
|
5
3
|
import jax.numpy as jnp
|
|
6
4
|
from jaxtyping import Float
|
|
7
5
|
|
|
8
6
|
import gpjax
|
|
9
7
|
from gpjax.kernels.computations.base import AbstractKernelComputation
|
|
8
|
+
from gpjax.linalg import (
|
|
9
|
+
Dense,
|
|
10
|
+
Diagonal,
|
|
11
|
+
psd,
|
|
12
|
+
)
|
|
10
13
|
from gpjax.typing import Array
|
|
11
14
|
|
|
12
15
|
K = tp.TypeVar("K", bound="gpjax.kernels.approximations.RFF") # noqa: F821
|
|
13
16
|
|
|
14
|
-
from cola.ops import Diagonal
|
|
15
|
-
|
|
16
17
|
# TODO: Use low rank linear operator!
|
|
17
18
|
|
|
18
19
|
|
|
@@ -28,7 +29,7 @@ class BasisFunctionComputation(AbstractKernelComputation):
|
|
|
28
29
|
|
|
29
30
|
def _gram(self, kernel: K, inputs: Float[Array, "N D"]) -> Dense:
|
|
30
31
|
z1 = self.compute_features(kernel, inputs)
|
|
31
|
-
return
|
|
32
|
+
return psd(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T)))
|
|
32
33
|
|
|
33
34
|
def diagonal(self, kernel: K, inputs: Float[Array, "N D"]) -> Diagonal:
|
|
34
35
|
r"""For a given kernel, compute the elementwise diagonal of the
|
|
@@ -56,7 +57,7 @@ class BasisFunctionComputation(AbstractKernelComputation):
|
|
|
56
57
|
Returns:
|
|
57
58
|
A matrix of shape $N \times L$ representing the random fourier features where $L = 2M$.
|
|
58
59
|
"""
|
|
59
|
-
frequencies = kernel.frequencies
|
|
60
|
+
frequencies = kernel.frequencies
|
|
60
61
|
scaling_factor = kernel.base_kernel.lengthscale.value
|
|
61
62
|
z = jnp.matmul(x, (frequencies / scaling_factor).T)
|
|
62
63
|
z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1)
|