gpjax 0.12.0__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -4
- gpjax/fit.py +11 -6
- gpjax/gps.py +35 -33
- gpjax/kernels/approximations/rff.py +4 -6
- gpjax/kernels/base.py +2 -5
- gpjax/kernels/computations/basis_functions.py +1 -1
- gpjax/kernels/computations/eigen.py +1 -1
- gpjax/kernels/non_euclidean/graph.py +10 -11
- gpjax/kernels/nonstationary/arccosine.py +13 -21
- gpjax/kernels/nonstationary/polynomial.py +7 -8
- gpjax/kernels/stationary/base.py +1 -30
- gpjax/kernels/stationary/matern12.py +1 -1
- gpjax/kernels/stationary/matern32.py +1 -1
- gpjax/kernels/stationary/matern52.py +1 -1
- gpjax/kernels/stationary/periodic.py +3 -6
- gpjax/kernels/stationary/powered_exponential.py +3 -8
- gpjax/kernels/stationary/rational_quadratic.py +5 -8
- gpjax/likelihoods.py +11 -14
- gpjax/linalg/utils.py +32 -0
- gpjax/mean_functions.py +9 -8
- gpjax/objectives.py +4 -3
- gpjax/parameters.py +0 -10
- gpjax/variational_families.py +65 -45
- {gpjax-0.12.0.dist-info → gpjax-0.13.0.dist-info}/METADATA +21 -21
- gpjax-0.13.0.dist-info/RECORD +52 -0
- gpjax-0.12.0.dist-info/RECORD +0 -52
- {gpjax-0.12.0.dist-info → gpjax-0.13.0.dist-info}/WHEEL +0 -0
- {gpjax-0.12.0.dist-info → gpjax-0.13.0.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/__init__.py
CHANGED
|
@@ -40,10 +40,9 @@ __license__ = "MIT"
|
|
|
40
40
|
__description__ = "Gaussian processes in JAX and Flax"
|
|
41
41
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
42
42
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
43
|
-
__version__ = "0.
|
|
43
|
+
__version__ = "0.13.0"
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
|
-
"base",
|
|
47
46
|
"gps",
|
|
48
47
|
"integrators",
|
|
49
48
|
"kernels",
|
|
@@ -55,8 +54,6 @@ __all__ = [
|
|
|
55
54
|
"Dataset",
|
|
56
55
|
"cite",
|
|
57
56
|
"fit",
|
|
58
|
-
"Module",
|
|
59
|
-
"param_field",
|
|
60
57
|
"fit_lbfgs",
|
|
61
58
|
"fit_scipy",
|
|
62
59
|
]
|
gpjax/fit.py
CHANGED
|
@@ -48,6 +48,7 @@ def fit( # noqa: PLR0913
|
|
|
48
48
|
train_data: Dataset,
|
|
49
49
|
optim: ox.GradientTransformation,
|
|
50
50
|
params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
|
|
51
|
+
trainable: nnx.filterlib.Filter = Parameter,
|
|
51
52
|
key: KeyArray = jr.PRNGKey(42),
|
|
52
53
|
num_iters: int = 100,
|
|
53
54
|
batch_size: int = -1,
|
|
@@ -65,7 +66,7 @@ def fit( # noqa: PLR0913
|
|
|
65
66
|
>>> import jax.random as jr
|
|
66
67
|
>>> import optax as ox
|
|
67
68
|
>>> import gpjax as gpx
|
|
68
|
-
>>> from gpjax.parameters import PositiveReal
|
|
69
|
+
>>> from gpjax.parameters import PositiveReal
|
|
69
70
|
>>>
|
|
70
71
|
>>> # (1) Create a dataset:
|
|
71
72
|
>>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
|
|
@@ -75,10 +76,10 @@ def fit( # noqa: PLR0913
|
|
|
75
76
|
>>> class LinearModel(nnx.Module):
|
|
76
77
|
>>> def __init__(self, weight: float, bias: float):
|
|
77
78
|
>>> self.weight = PositiveReal(weight)
|
|
78
|
-
>>> self.bias =
|
|
79
|
+
>>> self.bias = bias
|
|
79
80
|
>>>
|
|
80
81
|
>>> def __call__(self, x):
|
|
81
|
-
>>> return self.weight.value * x + self.bias
|
|
82
|
+
>>> return self.weight.value * x + self.bias
|
|
82
83
|
>>>
|
|
83
84
|
>>> model = LinearModel(weight=1.0, bias=1.0)
|
|
84
85
|
>>>
|
|
@@ -100,6 +101,8 @@ def fit( # noqa: PLR0913
|
|
|
100
101
|
train_data (Dataset): The training data to be used for the optimisation.
|
|
101
102
|
optim (GradientTransformation): The Optax optimiser that is to be used for
|
|
102
103
|
learning a parameter set.
|
|
104
|
+
trainable (nnx.filterlib.Filter): Filter to determine which parameters are trainable.
|
|
105
|
+
Defaults to nnx.Param (all Parameter instances).
|
|
103
106
|
num_iters (int): The number of optimisation steps to run. Defaults
|
|
104
107
|
to 100.
|
|
105
108
|
batch_size (int): The size of the mini-batch to use. Defaults to -1
|
|
@@ -127,7 +130,7 @@ def fit( # noqa: PLR0913
|
|
|
127
130
|
_check_verbose(verbose)
|
|
128
131
|
|
|
129
132
|
# Model state filtering
|
|
130
|
-
graphdef, params, *static_state = nnx.split(model,
|
|
133
|
+
graphdef, params, *static_state = nnx.split(model, trainable, ...)
|
|
131
134
|
|
|
132
135
|
# Parameters bijection to unconstrained space
|
|
133
136
|
if params_bijection is not None:
|
|
@@ -182,6 +185,7 @@ def fit_scipy( # noqa: PLR0913
|
|
|
182
185
|
model: Model,
|
|
183
186
|
objective: Objective,
|
|
184
187
|
train_data: Dataset,
|
|
188
|
+
trainable: nnx.filterlib.Filter = Parameter,
|
|
185
189
|
max_iters: int = 500,
|
|
186
190
|
verbose: bool = True,
|
|
187
191
|
safe: bool = True,
|
|
@@ -210,7 +214,7 @@ def fit_scipy( # noqa: PLR0913
|
|
|
210
214
|
_check_verbose(verbose)
|
|
211
215
|
|
|
212
216
|
# Model state filtering
|
|
213
|
-
graphdef, params, *static_state = nnx.split(model,
|
|
217
|
+
graphdef, params, *static_state = nnx.split(model, trainable, ...)
|
|
214
218
|
|
|
215
219
|
# Parameters bijection to unconstrained space
|
|
216
220
|
params = transform(params, DEFAULT_BIJECTION, inverse=True)
|
|
@@ -258,6 +262,7 @@ def fit_lbfgs(
|
|
|
258
262
|
objective: Objective,
|
|
259
263
|
train_data: Dataset,
|
|
260
264
|
params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
|
|
265
|
+
trainable: nnx.filterlib.Filter = Parameter,
|
|
261
266
|
max_iters: int = 100,
|
|
262
267
|
safe: bool = True,
|
|
263
268
|
max_linesearch_steps: int = 32,
|
|
@@ -290,7 +295,7 @@ def fit_lbfgs(
|
|
|
290
295
|
_check_num_iters(max_iters)
|
|
291
296
|
|
|
292
297
|
# Model state filtering
|
|
293
|
-
graphdef, params, *static_state = nnx.split(model,
|
|
298
|
+
graphdef, params, *static_state = nnx.split(model, trainable, ...)
|
|
294
299
|
|
|
295
300
|
# Parameters bijection to unconstrained space
|
|
296
301
|
if params_bijection is not None:
|
gpjax/gps.py
CHANGED
|
@@ -35,16 +35,15 @@ from gpjax.likelihoods import (
|
|
|
35
35
|
)
|
|
36
36
|
from gpjax.linalg import (
|
|
37
37
|
Dense,
|
|
38
|
-
Identity,
|
|
39
38
|
psd,
|
|
40
39
|
solve,
|
|
41
40
|
)
|
|
42
41
|
from gpjax.linalg.operations import lower_cholesky
|
|
42
|
+
from gpjax.linalg.utils import add_jitter
|
|
43
43
|
from gpjax.mean_functions import AbstractMeanFunction
|
|
44
44
|
from gpjax.parameters import (
|
|
45
45
|
Parameter,
|
|
46
46
|
Real,
|
|
47
|
-
Static,
|
|
48
47
|
)
|
|
49
48
|
from gpjax.typing import (
|
|
50
49
|
Array,
|
|
@@ -78,7 +77,7 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
|
78
77
|
self.mean_function = mean_function
|
|
79
78
|
self.jitter = jitter
|
|
80
79
|
|
|
81
|
-
def __call__(self,
|
|
80
|
+
def __call__(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
|
|
82
81
|
r"""Evaluate the Gaussian process at the given points.
|
|
83
82
|
|
|
84
83
|
The output of this function is a
|
|
@@ -91,17 +90,16 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
|
91
90
|
`__call__` method and should instead define a `predict` method.
|
|
92
91
|
|
|
93
92
|
Args:
|
|
94
|
-
|
|
95
|
-
**kwargs (Any): The keyword arguments to pass to the GP's `predict` method.
|
|
93
|
+
test_inputs: Input locations where the GP should be evaluated.
|
|
96
94
|
|
|
97
95
|
Returns:
|
|
98
96
|
GaussianDistribution: A multivariate normal random variable representation
|
|
99
97
|
of the Gaussian process.
|
|
100
98
|
"""
|
|
101
|
-
return self.predict(
|
|
99
|
+
return self.predict(test_inputs)
|
|
102
100
|
|
|
103
101
|
@abstractmethod
|
|
104
|
-
def predict(self,
|
|
102
|
+
def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
|
|
105
103
|
r"""Evaluate the predictive distribution.
|
|
106
104
|
|
|
107
105
|
Compute the latent function's multivariate normal distribution for a
|
|
@@ -109,8 +107,7 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
|
|
|
109
107
|
this method must be implemented.
|
|
110
108
|
|
|
111
109
|
Args:
|
|
112
|
-
|
|
113
|
-
**kwargs (Any): Keyword arguments to the predict method.
|
|
110
|
+
test_inputs: Input locations where the GP should be evaluated.
|
|
114
111
|
|
|
115
112
|
Returns:
|
|
116
113
|
GaussianDistribution: A multivariate normal random variable representation
|
|
@@ -249,13 +246,12 @@ class Prior(AbstractPrior[M, K]):
|
|
|
249
246
|
GaussianDistribution: A multivariate normal random variable representation
|
|
250
247
|
of the Gaussian process.
|
|
251
248
|
"""
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
Kxx_dense = Kxx.to_dense() + Identity(Kxx.shape).to_dense() * self.jitter
|
|
249
|
+
mean_at_test = self.mean_function(test_inputs)
|
|
250
|
+
Kxx = self.kernel.gram(test_inputs)
|
|
251
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
|
|
256
252
|
Kxx = psd(Dense(Kxx_dense))
|
|
257
253
|
|
|
258
|
-
return GaussianDistribution(jnp.atleast_1d(
|
|
254
|
+
return GaussianDistribution(jnp.atleast_1d(mean_at_test.squeeze()), Kxx)
|
|
259
255
|
|
|
260
256
|
def sample_approx(
|
|
261
257
|
self,
|
|
@@ -359,7 +355,9 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
|
359
355
|
self.likelihood = likelihood
|
|
360
356
|
self.jitter = jitter
|
|
361
357
|
|
|
362
|
-
def __call__(
|
|
358
|
+
def __call__(
|
|
359
|
+
self, test_inputs: Num[Array, "N D"], train_data: Dataset
|
|
360
|
+
) -> GaussianDistribution:
|
|
363
361
|
r"""Evaluate the Gaussian process posterior at the given points.
|
|
364
362
|
|
|
365
363
|
The output of this function is a
|
|
@@ -368,28 +366,30 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
|
|
|
368
366
|
evaluated and the distribution can be sampled.
|
|
369
367
|
|
|
370
368
|
Under the hood, `__call__` is calling the objects `predict` method. For this
|
|
371
|
-
reasons, classes inheriting the `
|
|
369
|
+
reasons, classes inheriting the `AbstractPosterior` class, should not overwrite the
|
|
372
370
|
`__call__` method and should instead define a `predict` method.
|
|
373
371
|
|
|
374
372
|
Args:
|
|
375
|
-
|
|
376
|
-
|
|
373
|
+
test_inputs: Input locations where the GP should be evaluated.
|
|
374
|
+
train_data: Training dataset to condition on.
|
|
377
375
|
|
|
378
376
|
Returns:
|
|
379
377
|
GaussianDistribution: A multivariate normal random variable representation
|
|
380
378
|
of the Gaussian process.
|
|
381
379
|
"""
|
|
382
|
-
return self.predict(
|
|
380
|
+
return self.predict(test_inputs, train_data)
|
|
383
381
|
|
|
384
382
|
@abstractmethod
|
|
385
|
-
def predict(
|
|
383
|
+
def predict(
|
|
384
|
+
self, test_inputs: Num[Array, "N D"], train_data: Dataset
|
|
385
|
+
) -> GaussianDistribution:
|
|
386
386
|
r"""Compute the latent function's multivariate normal distribution for a
|
|
387
|
-
given set of parameters. For any class inheriting the `
|
|
387
|
+
given set of parameters. For any class inheriting the `AbstractPosterior` class,
|
|
388
388
|
this method must be implemented.
|
|
389
389
|
|
|
390
390
|
Args:
|
|
391
|
-
|
|
392
|
-
|
|
391
|
+
test_inputs: Input locations where the GP should be evaluated.
|
|
392
|
+
train_data: Training dataset to condition on.
|
|
393
393
|
|
|
394
394
|
Returns:
|
|
395
395
|
GaussianDistribution: A multivariate normal random variable representation
|
|
@@ -503,22 +503,24 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
503
503
|
|
|
504
504
|
# Precompute Gram matrix, Kxx, at training inputs, x
|
|
505
505
|
Kxx = self.prior.kernel.gram(x)
|
|
506
|
-
Kxx_dense =
|
|
506
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
|
|
507
507
|
Kxx = Dense(Kxx_dense)
|
|
508
508
|
|
|
509
509
|
Sigma_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * obs_noise
|
|
510
510
|
Sigma = psd(Dense(Sigma_dense))
|
|
511
|
+
L_sigma = lower_cholesky(Sigma)
|
|
511
512
|
|
|
512
513
|
mean_t = self.prior.mean_function(t)
|
|
513
514
|
Ktt = self.prior.kernel.gram(t)
|
|
514
515
|
Kxt = self.prior.kernel.cross_covariance(x, t)
|
|
515
|
-
Sigma_inv_Kxt = solve(Sigma, Kxt)
|
|
516
516
|
|
|
517
|
-
|
|
517
|
+
L_inv_Kxt = solve(L_sigma, Kxt)
|
|
518
|
+
L_inv_y_diff = solve(L_sigma, y - mx)
|
|
519
|
+
|
|
520
|
+
mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff)
|
|
518
521
|
|
|
519
|
-
|
|
520
|
-
covariance =
|
|
521
|
-
covariance += jnp.eye(covariance.shape[0]) * self.prior.jitter
|
|
522
|
+
covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt)
|
|
523
|
+
covariance = add_jitter(covariance, self.prior.jitter)
|
|
522
524
|
covariance = psd(Dense(covariance))
|
|
523
525
|
|
|
524
526
|
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
|
|
@@ -577,7 +579,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
|
|
|
577
579
|
|
|
578
580
|
obs_var = self.likelihood.obs_stddev.value**2
|
|
579
581
|
Kxx = self.prior.kernel.gram(train_data.X)
|
|
580
|
-
Sigma = Kxx
|
|
582
|
+
Sigma = Dense(add_jitter(Kxx.to_dense(), obs_var + self.jitter))
|
|
581
583
|
eps = jnp.sqrt(obs_var) * jr.normal(key, [train_data.n, num_samples])
|
|
582
584
|
y = train_data.y - self.prior.mean_function(train_data.X)
|
|
583
585
|
Phi = fourier_feature_fn(train_data.X)
|
|
@@ -643,7 +645,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
643
645
|
|
|
644
646
|
# TODO: static or intermediate?
|
|
645
647
|
self.latent = latent if isinstance(latent, Parameter) else Real(latent)
|
|
646
|
-
self.key =
|
|
648
|
+
self.key = key
|
|
647
649
|
|
|
648
650
|
def predict(
|
|
649
651
|
self, test_inputs: Num[Array, "N D"], train_data: Dataset
|
|
@@ -675,7 +677,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
675
677
|
|
|
676
678
|
# Precompute lower triangular of Gram matrix, Lx, at training inputs, x
|
|
677
679
|
Kxx = kernel.gram(x)
|
|
678
|
-
Kxx_dense = Kxx.to_dense()
|
|
680
|
+
Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter)
|
|
679
681
|
Kxx = psd(Dense(Kxx_dense))
|
|
680
682
|
Lx = lower_cholesky(Kxx)
|
|
681
683
|
|
|
@@ -698,7 +700,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
|
|
|
698
700
|
|
|
699
701
|
# Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
|
|
700
702
|
covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
|
|
701
|
-
covariance
|
|
703
|
+
covariance = add_jitter(covariance, self.prior.jitter)
|
|
702
704
|
covariance = psd(Dense(covariance))
|
|
703
705
|
|
|
704
706
|
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
"""Compute Random Fourier Feature (RFF) kernel approximations."""
|
|
2
2
|
|
|
3
3
|
import beartype.typing as tp
|
|
4
|
+
from flax import nnx
|
|
4
5
|
import jax.random as jr
|
|
5
6
|
from jaxtyping import Float
|
|
6
7
|
|
|
7
8
|
from gpjax.kernels.base import AbstractKernel
|
|
8
9
|
from gpjax.kernels.computations import BasisFunctionComputation
|
|
9
10
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
10
|
-
from gpjax.parameters import Static
|
|
11
11
|
from gpjax.typing import (
|
|
12
12
|
Array,
|
|
13
13
|
KeyArray,
|
|
@@ -55,7 +55,7 @@ class RFF(AbstractKernel):
|
|
|
55
55
|
self._check_valid_base_kernel(base_kernel)
|
|
56
56
|
self.base_kernel = base_kernel
|
|
57
57
|
self.num_basis_fns = num_basis_fns
|
|
58
|
-
self.frequencies = frequencies
|
|
58
|
+
self.frequencies = nnx.data(frequencies)
|
|
59
59
|
self.compute_engine = compute_engine
|
|
60
60
|
|
|
61
61
|
if self.frequencies is None:
|
|
@@ -66,10 +66,8 @@ class RFF(AbstractKernel):
|
|
|
66
66
|
"Please specify the n_dims argument for the base kernel."
|
|
67
67
|
)
|
|
68
68
|
|
|
69
|
-
self.frequencies =
|
|
70
|
-
self.
|
|
71
|
-
key=key, sample_shape=(self.num_basis_fns, n_dims)
|
|
72
|
-
)
|
|
69
|
+
self.frequencies = self.base_kernel.spectral_density.sample(
|
|
70
|
+
key=key, sample_shape=(self.num_basis_fns, n_dims)
|
|
73
71
|
)
|
|
74
72
|
self.name = f"{self.base_kernel.name} (RFF)"
|
|
75
73
|
|
gpjax/kernels/base.py
CHANGED
|
@@ -32,7 +32,6 @@ from gpjax.linalg import LinearOperator
|
|
|
32
32
|
from gpjax.parameters import (
|
|
33
33
|
Parameter,
|
|
34
34
|
Real,
|
|
35
|
-
Static,
|
|
36
35
|
)
|
|
37
36
|
from gpjax.typing import (
|
|
38
37
|
Array,
|
|
@@ -221,9 +220,7 @@ class Constant(AbstractKernel):
|
|
|
221
220
|
def __init__(
|
|
222
221
|
self,
|
|
223
222
|
active_dims: tp.Union[list[int], slice, None] = None,
|
|
224
|
-
constant: tp.Union[
|
|
225
|
-
ScalarFloat, Parameter[ScalarFloat], Static[ScalarFloat]
|
|
226
|
-
] = jnp.array(0.0),
|
|
223
|
+
constant: tp.Union[ScalarFloat, Parameter[ScalarFloat]] = jnp.array(0.0),
|
|
227
224
|
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
|
|
228
225
|
):
|
|
229
226
|
if isinstance(constant, Parameter):
|
|
@@ -256,7 +253,7 @@ class CombinationKernel(AbstractKernel):
|
|
|
256
253
|
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
|
|
257
254
|
):
|
|
258
255
|
# Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels.
|
|
259
|
-
kernels_list: list[AbstractKernel] = []
|
|
256
|
+
kernels_list: list[AbstractKernel] = nnx.List([])
|
|
260
257
|
for kernel in kernels:
|
|
261
258
|
if not isinstance(kernel, AbstractKernel):
|
|
262
259
|
raise TypeError("can only combine Kernel instances") # pragma: no cover
|
|
@@ -57,7 +57,7 @@ class BasisFunctionComputation(AbstractKernelComputation):
|
|
|
57
57
|
Returns:
|
|
58
58
|
A matrix of shape $N \times L$ representing the random fourier features where $L = 2M$.
|
|
59
59
|
"""
|
|
60
|
-
frequencies = kernel.frequencies
|
|
60
|
+
frequencies = kernel.frequencies
|
|
61
61
|
scaling_factor = kernel.base_kernel.lengthscale.value
|
|
62
62
|
z = jnp.matmul(x, (frequencies / scaling_factor).T)
|
|
63
63
|
z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1)
|
|
@@ -42,7 +42,7 @@ class EigenKernelComputation(AbstractKernelComputation):
|
|
|
42
42
|
# Transform the eigenvalues of the graph Laplacian according to the
|
|
43
43
|
# RBF kernel's SPDE form.
|
|
44
44
|
S = jnp.power(
|
|
45
|
-
kernel.eigenvalues
|
|
45
|
+
kernel.eigenvalues
|
|
46
46
|
+ 2
|
|
47
47
|
* kernel.smoothness.value
|
|
48
48
|
/ kernel.lengthscale.value
|
|
@@ -30,7 +30,6 @@ from gpjax.kernels.stationary.base import StationaryKernel
|
|
|
30
30
|
from gpjax.parameters import (
|
|
31
31
|
Parameter,
|
|
32
32
|
PositiveReal,
|
|
33
|
-
Static,
|
|
34
33
|
)
|
|
35
34
|
from gpjax.typing import (
|
|
36
35
|
Array,
|
|
@@ -55,9 +54,9 @@ class GraphKernel(StationaryKernel):
|
|
|
55
54
|
"""
|
|
56
55
|
|
|
57
56
|
num_vertex: tp.Union[ScalarInt, None]
|
|
58
|
-
laplacian:
|
|
59
|
-
eigenvalues:
|
|
60
|
-
eigenvectors:
|
|
57
|
+
laplacian: Float[Array, "N N"]
|
|
58
|
+
eigenvalues: Float[Array, "N 1"]
|
|
59
|
+
eigenvectors: Float[Array, "N N"]
|
|
61
60
|
name: str = "Graph Matérn"
|
|
62
61
|
|
|
63
62
|
def __init__(
|
|
@@ -91,11 +90,11 @@ class GraphKernel(StationaryKernel):
|
|
|
91
90
|
else:
|
|
92
91
|
self.smoothness = PositiveReal(smoothness)
|
|
93
92
|
|
|
94
|
-
self.laplacian =
|
|
95
|
-
evals, eigenvectors = jnp.linalg.eigh(self.laplacian
|
|
96
|
-
self.eigenvectors =
|
|
97
|
-
self.eigenvalues =
|
|
98
|
-
self.num_vertex = self.eigenvalues.
|
|
93
|
+
self.laplacian = laplacian
|
|
94
|
+
evals, eigenvectors = jnp.linalg.eigh(self.laplacian)
|
|
95
|
+
self.eigenvectors = eigenvectors
|
|
96
|
+
self.eigenvalues = evals.reshape(-1, 1)
|
|
97
|
+
self.num_vertex = self.eigenvalues.shape[0]
|
|
99
98
|
|
|
100
99
|
super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
|
|
101
100
|
|
|
@@ -107,7 +106,7 @@ class GraphKernel(StationaryKernel):
|
|
|
107
106
|
S,
|
|
108
107
|
**kwargs,
|
|
109
108
|
):
|
|
110
|
-
Kxx = (jax_gather_nd(self.eigenvectors
|
|
111
|
-
jax_gather_nd(self.eigenvectors
|
|
109
|
+
Kxx = (jax_gather_nd(self.eigenvectors, x) * S.squeeze()) @ jnp.transpose(
|
|
110
|
+
jax_gather_nd(self.eigenvectors, y)
|
|
112
111
|
) # shape (n,n)
|
|
113
112
|
return Kxx.squeeze()
|
|
@@ -25,7 +25,6 @@ from gpjax.kernels.computations import (
|
|
|
25
25
|
)
|
|
26
26
|
from gpjax.parameters import (
|
|
27
27
|
NonNegativeReal,
|
|
28
|
-
PositiveReal,
|
|
29
28
|
)
|
|
30
29
|
from gpjax.typing import (
|
|
31
30
|
Array,
|
|
@@ -82,30 +81,13 @@ class ArcCosine(AbstractKernel):
|
|
|
82
81
|
|
|
83
82
|
self.order = order
|
|
84
83
|
|
|
85
|
-
|
|
86
|
-
self.weight_variance = weight_variance
|
|
87
|
-
else:
|
|
88
|
-
self.weight_variance = PositiveReal(weight_variance)
|
|
89
|
-
if tp.TYPE_CHECKING:
|
|
90
|
-
self.weight_variance = tp.cast(
|
|
91
|
-
PositiveReal[WeightVariance], self.weight_variance
|
|
92
|
-
)
|
|
84
|
+
self.weight_variance = weight_variance
|
|
93
85
|
|
|
94
86
|
if isinstance(variance, nnx.Variable):
|
|
95
87
|
self.variance = variance
|
|
96
88
|
else:
|
|
97
89
|
self.variance = NonNegativeReal(variance)
|
|
98
|
-
|
|
99
|
-
self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
|
|
100
|
-
|
|
101
|
-
if isinstance(bias_variance, nnx.Variable):
|
|
102
|
-
self.bias_variance = bias_variance
|
|
103
|
-
else:
|
|
104
|
-
self.bias_variance = PositiveReal(bias_variance)
|
|
105
|
-
if tp.TYPE_CHECKING:
|
|
106
|
-
self.bias_variance = tp.cast(
|
|
107
|
-
PositiveReal[ScalarArray], self.bias_variance
|
|
108
|
-
)
|
|
90
|
+
self.bias_variance = bias_variance
|
|
109
91
|
|
|
110
92
|
self.name = f"ArcCosine (order {self.order})"
|
|
111
93
|
|
|
@@ -141,7 +123,17 @@ class ArcCosine(AbstractKernel):
|
|
|
141
123
|
Returns:
|
|
142
124
|
ScalarFloat: The value of the weighted product between the two arguments``.
|
|
143
125
|
"""
|
|
144
|
-
|
|
126
|
+
weight_var = (
|
|
127
|
+
self.weight_variance.value
|
|
128
|
+
if hasattr(self.weight_variance, "value")
|
|
129
|
+
else self.weight_variance
|
|
130
|
+
)
|
|
131
|
+
bias_var = (
|
|
132
|
+
self.bias_variance.value
|
|
133
|
+
if hasattr(self.bias_variance, "value")
|
|
134
|
+
else self.bias_variance
|
|
135
|
+
)
|
|
136
|
+
return jnp.inner(weight_var * x, y) + bias_var
|
|
145
137
|
|
|
146
138
|
def _J(self, theta: ScalarFloat) -> ScalarFloat:
|
|
147
139
|
r"""Evaluate the angular dependency function corresponding to the desired order.
|
|
@@ -69,12 +69,9 @@ class Polynomial(AbstractKernel):
|
|
|
69
69
|
|
|
70
70
|
self.degree = degree
|
|
71
71
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
self.shift = PositiveReal(shift)
|
|
76
|
-
if tp.TYPE_CHECKING:
|
|
77
|
-
self.shift = tp.cast(PositiveReal[ScalarArray], self.shift)
|
|
72
|
+
self.shift = shift
|
|
73
|
+
if tp.TYPE_CHECKING and not isinstance(shift, nnx.Variable):
|
|
74
|
+
self.shift = tp.cast(PositiveReal[ScalarArray], self.shift)
|
|
78
75
|
|
|
79
76
|
if isinstance(variance, nnx.Variable):
|
|
80
77
|
self.variance = variance
|
|
@@ -88,7 +85,9 @@ class Polynomial(AbstractKernel):
|
|
|
88
85
|
def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat:
|
|
89
86
|
x = self.slice_input(x)
|
|
90
87
|
y = self.slice_input(y)
|
|
91
|
-
|
|
92
|
-
|
|
88
|
+
shift_val = self.shift.value if hasattr(self.shift, "value") else self.shift
|
|
89
|
+
variance_val = (
|
|
90
|
+
self.variance.value if hasattr(self.variance, "value") else self.variance
|
|
93
91
|
)
|
|
92
|
+
K = jnp.power(shift_val + variance_val * jnp.dot(x, y), self.degree)
|
|
94
93
|
return K.squeeze()
|
gpjax/kernels/stationary/base.py
CHANGED
|
@@ -127,7 +127,7 @@ def _check_lengthscale_dims_compat(
|
|
|
127
127
|
"""
|
|
128
128
|
|
|
129
129
|
if isinstance(lengthscale, nnx.Variable):
|
|
130
|
-
return
|
|
130
|
+
return _check_lengthscale_dims_compat(lengthscale.value, n_dims)
|
|
131
131
|
|
|
132
132
|
lengthscale = jnp.asarray(lengthscale)
|
|
133
133
|
ls_shape = jnp.shape(lengthscale)
|
|
@@ -146,35 +146,6 @@ def _check_lengthscale_dims_compat(
|
|
|
146
146
|
return n_dims
|
|
147
147
|
|
|
148
148
|
|
|
149
|
-
def _check_lengthscale_dims_compat_old(
|
|
150
|
-
lengthscale: tp.Union[LengthscaleCompatible, nnx.Variable[Lengthscale]],
|
|
151
|
-
n_dims: tp.Union[int, None],
|
|
152
|
-
):
|
|
153
|
-
r"""Check that the lengthscale is compatible with n_dims.
|
|
154
|
-
|
|
155
|
-
If possible, infer the number of input dimensions from the lengthscale.
|
|
156
|
-
"""
|
|
157
|
-
|
|
158
|
-
if isinstance(lengthscale, nnx.Variable):
|
|
159
|
-
return _check_lengthscale_dims_compat_old(lengthscale.value, n_dims)
|
|
160
|
-
|
|
161
|
-
lengthscale = jnp.asarray(lengthscale)
|
|
162
|
-
ls_shape = jnp.shape(lengthscale)
|
|
163
|
-
|
|
164
|
-
if ls_shape == ():
|
|
165
|
-
return lengthscale, n_dims
|
|
166
|
-
elif ls_shape != () and n_dims is None:
|
|
167
|
-
return lengthscale, ls_shape[0]
|
|
168
|
-
elif ls_shape != () and n_dims is not None:
|
|
169
|
-
if ls_shape != (n_dims,):
|
|
170
|
-
raise ValueError(
|
|
171
|
-
"Expected `lengthscale` to be compatible with the number "
|
|
172
|
-
f"of input dimensions. Got `lengthscale` with shape {ls_shape}, "
|
|
173
|
-
f"but the number of input dimensions is {n_dims}."
|
|
174
|
-
)
|
|
175
|
-
return lengthscale, n_dims
|
|
176
|
-
|
|
177
|
-
|
|
178
149
|
def _check_lengthscale(lengthscale: tp.Any):
|
|
179
150
|
"""Check that the lengthscale is a valid value."""
|
|
180
151
|
|
|
@@ -32,7 +32,7 @@ class Matern32(StationaryKernel):
|
|
|
32
32
|
lengthscale parameter $\ell$ and variance $\sigma^2$.
|
|
33
33
|
|
|
34
34
|
$$
|
|
35
|
-
k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{3}\lvert x-y \rvert}{\ell
|
|
35
|
+
k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{3}\lvert x-y \rvert}{\ell} \ \Bigg)\exp\Bigg(-\frac{\sqrt{3}\lvert x-y\rvert}{\ell^2} \Bigg)
|
|
36
36
|
$$
|
|
37
37
|
"""
|
|
38
38
|
|
|
@@ -33,7 +33,7 @@ class Matern52(StationaryKernel):
|
|
|
33
33
|
lengthscale parameter $\ell$ and variance $\sigma^2$.
|
|
34
34
|
|
|
35
35
|
$$
|
|
36
|
-
k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{5}\lvert x-y \rvert}{\ell
|
|
36
|
+
k(x, y) = \sigma^2 \exp \Bigg(1+ \frac{\sqrt{5}\lvert x-y \rvert}{\ell} + \frac{5\lvert x - y \rvert^2}{3\ell^2} \Bigg)\exp\Bigg(-\frac{\sqrt{5}\lvert x-y\rvert}{\ell^2} \Bigg)
|
|
37
37
|
$$
|
|
38
38
|
"""
|
|
39
39
|
|
|
@@ -23,7 +23,6 @@ from gpjax.kernels.computations import (
|
|
|
23
23
|
DenseKernelComputation,
|
|
24
24
|
)
|
|
25
25
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
26
|
-
from gpjax.parameters import PositiveReal
|
|
27
26
|
from gpjax.typing import (
|
|
28
27
|
Array,
|
|
29
28
|
ScalarArray,
|
|
@@ -72,10 +71,7 @@ class Periodic(StationaryKernel):
|
|
|
72
71
|
covariance matrix.
|
|
73
72
|
"""
|
|
74
73
|
|
|
75
|
-
|
|
76
|
-
self.period = period
|
|
77
|
-
else:
|
|
78
|
-
self.period = PositiveReal(period)
|
|
74
|
+
self.period = period
|
|
79
75
|
|
|
80
76
|
super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
|
|
81
77
|
|
|
@@ -84,8 +80,9 @@ class Periodic(StationaryKernel):
|
|
|
84
80
|
) -> Float[Array, ""]:
|
|
85
81
|
x = self.slice_input(x)
|
|
86
82
|
y = self.slice_input(y)
|
|
83
|
+
period_val = self.period.value if hasattr(self.period, "value") else self.period
|
|
87
84
|
sine_squared = (
|
|
88
|
-
jnp.sin(jnp.pi * (x - y) /
|
|
85
|
+
jnp.sin(jnp.pi * (x - y) / period_val) / self.lengthscale.value
|
|
89
86
|
) ** 2
|
|
90
87
|
K = self.variance.value * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0))
|
|
91
88
|
return K.squeeze()
|
|
@@ -24,7 +24,6 @@ from gpjax.kernels.computations import (
|
|
|
24
24
|
)
|
|
25
25
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
26
26
|
from gpjax.kernels.stationary.utils import euclidean_distance
|
|
27
|
-
from gpjax.parameters import SigmoidBounded
|
|
28
27
|
from gpjax.typing import (
|
|
29
28
|
Array,
|
|
30
29
|
ScalarArray,
|
|
@@ -76,10 +75,7 @@ class PoweredExponential(StationaryKernel):
|
|
|
76
75
|
compute_engine: the computation engine that the kernel uses to compute the
|
|
77
76
|
covariance matrix.
|
|
78
77
|
"""
|
|
79
|
-
|
|
80
|
-
self.power = power
|
|
81
|
-
else:
|
|
82
|
-
self.power = SigmoidBounded(power)
|
|
78
|
+
self.power = power
|
|
83
79
|
|
|
84
80
|
super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
|
|
85
81
|
|
|
@@ -88,7 +84,6 @@ class PoweredExponential(StationaryKernel):
|
|
|
88
84
|
) -> Float[Array, ""]:
|
|
89
85
|
x = self.slice_input(x) / self.lengthscale.value
|
|
90
86
|
y = self.slice_input(y) / self.lengthscale.value
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
)
|
|
87
|
+
power_val = self.power.value if hasattr(self.power, "value") else self.power
|
|
88
|
+
K = self.variance.value * jnp.exp(-(euclidean_distance(x, y) ** power_val))
|
|
94
89
|
return K.squeeze()
|