gpjax 0.10.2__py3-none-any.whl → 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +4 -2
- gpjax/distributions.py +101 -111
- gpjax/fit.py +108 -5
- gpjax/kernels/approximations/rff.py +1 -1
- gpjax/kernels/nonstationary/arccosine.py +6 -3
- gpjax/kernels/nonstationary/linear.py +3 -3
- gpjax/kernels/nonstationary/polynomial.py +6 -3
- gpjax/kernels/stationary/base.py +8 -5
- gpjax/kernels/stationary/matern12.py +2 -2
- gpjax/kernels/stationary/matern32.py +2 -2
- gpjax/kernels/stationary/matern52.py +2 -2
- gpjax/kernels/stationary/rbf.py +3 -3
- gpjax/kernels/stationary/utils.py +3 -5
- gpjax/likelihoods.py +40 -39
- gpjax/mean_functions.py +4 -3
- gpjax/numpyro_extras.py +106 -0
- gpjax/objectives.py +4 -6
- gpjax/parameters.py +31 -13
- gpjax/variational_families.py +5 -1
- {gpjax-0.10.2.dist-info → gpjax-0.11.1.dist-info}/METADATA +2 -61
- {gpjax-0.10.2.dist-info → gpjax-0.11.1.dist-info}/RECORD +23 -22
- {gpjax-0.10.2.dist-info → gpjax-0.11.1.dist-info}/WHEEL +0 -0
- {gpjax-0.10.2.dist-info → gpjax-0.11.1.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/__init__.py
CHANGED
|
@@ -32,14 +32,15 @@ from gpjax.citation import cite
|
|
|
32
32
|
from gpjax.dataset import Dataset
|
|
33
33
|
from gpjax.fit import (
|
|
34
34
|
fit,
|
|
35
|
+
fit_lbfgs,
|
|
35
36
|
fit_scipy,
|
|
36
37
|
)
|
|
37
38
|
|
|
38
39
|
__license__ = "MIT"
|
|
39
|
-
__description__ = "
|
|
40
|
+
__description__ = "Gaussian processes in JAX and Flax"
|
|
40
41
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
41
42
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
42
|
-
__version__ = "0.
|
|
43
|
+
__version__ = "0.11.1"
|
|
43
44
|
|
|
44
45
|
__all__ = [
|
|
45
46
|
"base",
|
|
@@ -56,5 +57,6 @@ __all__ = [
|
|
|
56
57
|
"fit",
|
|
57
58
|
"Module",
|
|
58
59
|
"param_field",
|
|
60
|
+
"fit_lbfgs",
|
|
59
61
|
"fit_scipy",
|
|
60
62
|
]
|
gpjax/distributions.py
CHANGED
|
@@ -15,77 +15,76 @@
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
from beartype.typing import (
|
|
18
|
-
Any,
|
|
19
18
|
Optional,
|
|
20
|
-
Tuple,
|
|
21
|
-
TypeVar,
|
|
22
19
|
)
|
|
23
20
|
import cola
|
|
21
|
+
from cola.linalg.decompositions import Cholesky
|
|
24
22
|
from cola.ops import (
|
|
25
|
-
Identity,
|
|
26
23
|
LinearOperator,
|
|
27
24
|
)
|
|
28
25
|
from jax import vmap
|
|
29
26
|
import jax.numpy as jnp
|
|
30
27
|
import jax.random as jr
|
|
31
28
|
from jaxtyping import Float
|
|
32
|
-
|
|
29
|
+
from numpyro.distributions import constraints
|
|
30
|
+
from numpyro.distributions.distribution import Distribution
|
|
31
|
+
from numpyro.distributions.util import is_prng_key
|
|
33
32
|
|
|
34
33
|
from gpjax.lower_cholesky import lower_cholesky
|
|
35
34
|
from gpjax.typing import (
|
|
36
35
|
Array,
|
|
37
|
-
KeyArray,
|
|
38
36
|
ScalarFloat,
|
|
39
37
|
)
|
|
40
38
|
|
|
41
|
-
tfd = tfp.distributions
|
|
42
|
-
|
|
43
|
-
from cola.linalg.decompositions import Cholesky
|
|
44
|
-
|
|
45
39
|
|
|
46
|
-
class GaussianDistribution(
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
# TODO: Consider `distrax.transformed.Transformed` object. Can we create a LinearOperator to `distrax.bijector` representation
|
|
50
|
-
# and modify `distrax.MultivariateNormalFromBijector`?
|
|
51
|
-
# TODO: Consider natural and expectation parameterisations in future work.
|
|
52
|
-
# TODO: we don't really need to inherit from `tfd.Distribution` here
|
|
40
|
+
class GaussianDistribution(Distribution):
|
|
41
|
+
support = constraints.real_vector
|
|
53
42
|
|
|
54
43
|
def __init__(
|
|
55
44
|
self,
|
|
56
|
-
loc: Optional[Float[Array, " N"]]
|
|
57
|
-
scale: Optional[LinearOperator]
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
_check_loc_scale(loc, scale)
|
|
45
|
+
loc: Optional[Float[Array, " N"]],
|
|
46
|
+
scale: Optional[LinearOperator],
|
|
47
|
+
validate_args=None,
|
|
48
|
+
):
|
|
49
|
+
self.loc = loc
|
|
50
|
+
self.scale = cola.PSD(scale)
|
|
51
|
+
batch_shape = ()
|
|
52
|
+
event_shape = jnp.shape(self.loc)
|
|
53
|
+
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
|
66
54
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
55
|
+
def sample(self, key, sample_shape=()):
|
|
56
|
+
assert is_prng_key(key)
|
|
57
|
+
# Obtain covariance root.
|
|
58
|
+
covariance_root = lower_cholesky(self.scale)
|
|
70
59
|
|
|
71
|
-
|
|
72
|
-
|
|
60
|
+
# Gather n samples from standard normal distribution Z = [z₁, ..., zₙ]ᵀ.
|
|
61
|
+
white_noise = jr.normal(
|
|
62
|
+
key, shape=sample_shape + self.batch_shape + self.event_shape
|
|
63
|
+
)
|
|
73
64
|
|
|
74
|
-
#
|
|
75
|
-
|
|
76
|
-
loc
|
|
65
|
+
# xᵢ ~ N(loc, cov) <=> xᵢ = loc + sqrt zᵢ, where zᵢ ~ N(0, I).
|
|
66
|
+
def affine_transformation(_x):
|
|
67
|
+
return self.loc + covariance_root @ _x
|
|
77
68
|
|
|
78
|
-
|
|
79
|
-
if scale is None:
|
|
80
|
-
scale = Identity(shape=(num_dims, num_dims), dtype=loc.dtype)
|
|
81
|
-
|
|
82
|
-
self.loc = loc
|
|
83
|
-
self.scale = cola.PSD(scale)
|
|
69
|
+
return vmap(affine_transformation)(white_noise)
|
|
84
70
|
|
|
71
|
+
@property
|
|
85
72
|
def mean(self) -> Float[Array, " N"]:
|
|
86
73
|
r"""Calculates the mean."""
|
|
87
74
|
return self.loc
|
|
88
75
|
|
|
76
|
+
@property
|
|
77
|
+
def variance(self) -> Float[Array, " N"]:
|
|
78
|
+
r"""Calculates the variance."""
|
|
79
|
+
return cola.diag(self.scale)
|
|
80
|
+
|
|
81
|
+
def entropy(self) -> ScalarFloat:
|
|
82
|
+
r"""Calculates the entropy of the distribution."""
|
|
83
|
+
return 0.5 * (
|
|
84
|
+
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
|
|
85
|
+
+ cola.logdet(self.scale, Cholesky(), Cholesky())
|
|
86
|
+
)
|
|
87
|
+
|
|
89
88
|
def median(self) -> Float[Array, " N"]:
|
|
90
89
|
r"""Calculates the median."""
|
|
91
90
|
return self.loc
|
|
@@ -98,25 +97,19 @@ class GaussianDistribution(tfd.Distribution):
|
|
|
98
97
|
r"""Calculates the covariance matrix."""
|
|
99
98
|
return self.scale.to_dense()
|
|
100
99
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
100
|
+
@property
|
|
101
|
+
def covariance_matrix(self) -> Float[Array, "N N"]:
|
|
102
|
+
r"""Calculates the covariance matrix."""
|
|
103
|
+
return self.covariance()
|
|
104
104
|
|
|
105
105
|
def stddev(self) -> Float[Array, " N"]:
|
|
106
106
|
r"""Calculates the standard deviation."""
|
|
107
107
|
return jnp.sqrt(cola.diag(self.scale))
|
|
108
108
|
|
|
109
|
-
@property
|
|
110
|
-
def event_shape(self) -> Tuple:
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def entropy(self) -> ScalarFloat:
|
|
115
|
-
r"""Calculates the entropy of the distribution."""
|
|
116
|
-
return 0.5 * (
|
|
117
|
-
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
|
|
118
|
-
+ cola.logdet(self.scale, Cholesky(), Cholesky())
|
|
119
|
-
)
|
|
109
|
+
# @property
|
|
110
|
+
# def event_shape(self) -> Tuple:
|
|
111
|
+
# r"""Returns the event shape."""
|
|
112
|
+
# return self.loc.shape[-1:]
|
|
120
113
|
|
|
121
114
|
def log_prob(self, y: Float[Array, " N"]) -> ScalarFloat:
|
|
122
115
|
r"""Calculates the log pdf of the multivariate Gaussian.
|
|
@@ -141,42 +134,39 @@ class GaussianDistribution(tfd.Distribution):
|
|
|
141
134
|
+ diff.T @ cola.solve(sigma, diff, Cholesky())
|
|
142
135
|
)
|
|
143
136
|
|
|
144
|
-
def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
|
|
145
|
-
|
|
137
|
+
# def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
|
|
138
|
+
# r"""Samples from the distribution.
|
|
146
139
|
|
|
147
|
-
|
|
148
|
-
|
|
140
|
+
# Args:
|
|
141
|
+
# key (KeyArray): The key to use for sampling.
|
|
149
142
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
143
|
+
# Returns:
|
|
144
|
+
# The samples as an array of shape (n_samples, n_points).
|
|
145
|
+
# """
|
|
146
|
+
# # Obtain covariance root.
|
|
147
|
+
# sqrt = lower_cholesky(self.scale)
|
|
155
148
|
|
|
156
|
-
|
|
157
|
-
|
|
149
|
+
# # Gather n samples from standard normal distribution Z = [z₁, ..., zₙ]ᵀ.
|
|
150
|
+
# Z = jr.normal(key, shape=(n, *self.event_shape))
|
|
158
151
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
152
|
+
# # xᵢ ~ N(loc, cov) <=> xᵢ = loc + sqrt zᵢ, where zᵢ ~ N(0, I).
|
|
153
|
+
# def affine_transformation(x):
|
|
154
|
+
# return self.loc + sqrt @ x
|
|
162
155
|
|
|
163
|
-
|
|
156
|
+
# return vmap(affine_transformation)(Z)
|
|
164
157
|
|
|
165
|
-
def sample(
|
|
166
|
-
|
|
167
|
-
): # pylint: disable=useless-super-delegation
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
158
|
+
# def sample(
|
|
159
|
+
# self, seed: KeyArray, sample_shape: Tuple[int, ...]
|
|
160
|
+
# ): # pylint: disable=useless-super-delegation
|
|
161
|
+
# r"""See `Distribution.sample`."""
|
|
162
|
+
# return self._sample_n(
|
|
163
|
+
# seed, sample_shape[0]
|
|
164
|
+
# ) # TODO this looks weird, why ignore the second entry?
|
|
172
165
|
|
|
173
166
|
def kl_divergence(self, other: "GaussianDistribution") -> ScalarFloat:
|
|
174
167
|
return _kl_divergence(self, other)
|
|
175
168
|
|
|
176
169
|
|
|
177
|
-
DistrT = TypeVar("DistrT", bound=tfd.Distribution)
|
|
178
|
-
|
|
179
|
-
|
|
180
170
|
def _check_and_return_dimension(
|
|
181
171
|
q: GaussianDistribution, p: GaussianDistribution
|
|
182
172
|
) -> int:
|
|
@@ -245,37 +235,37 @@ def _kl_divergence(q: GaussianDistribution, p: GaussianDistribution) -> ScalarFl
|
|
|
245
235
|
) / 2.0
|
|
246
236
|
|
|
247
237
|
|
|
248
|
-
def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
238
|
+
# def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
|
|
239
|
+
# r"""Checks that the inputs are correct."""
|
|
240
|
+
# if loc is None and scale is None:
|
|
241
|
+
# raise ValueError("At least one of `loc` or `scale` must be specified.")
|
|
242
|
+
|
|
243
|
+
# if loc is not None and loc.ndim < 1:
|
|
244
|
+
# raise ValueError("The parameter `loc` must have at least one dimension.")
|
|
245
|
+
|
|
246
|
+
# if scale is not None and len(scale.shape) < 2: # scale.ndim < 2:
|
|
247
|
+
# raise ValueError(
|
|
248
|
+
# "The `scale` must have at least two dimensions, but "
|
|
249
|
+
# f"`scale.shape = {scale.shape}`."
|
|
250
|
+
# )
|
|
251
|
+
|
|
252
|
+
# if scale is not None and not isinstance(scale, LinearOperator):
|
|
253
|
+
# raise ValueError(
|
|
254
|
+
# f"The `scale` must be a CoLA LinearOperator but got {type(scale)}"
|
|
255
|
+
# )
|
|
256
|
+
|
|
257
|
+
# if scale is not None and (scale.shape[-1] != scale.shape[-2]):
|
|
258
|
+
# raise ValueError(
|
|
259
|
+
# f"The `scale` must be a square matrix, but `scale.shape = {scale.shape}`."
|
|
260
|
+
# )
|
|
261
|
+
|
|
262
|
+
# if loc is not None:
|
|
263
|
+
# num_dims = loc.shape[-1]
|
|
264
|
+
# if scale is not None and (scale.shape[-1] != num_dims):
|
|
265
|
+
# raise ValueError(
|
|
266
|
+
# f"Shapes are not compatible: `loc.shape = {loc.shape}` and "
|
|
267
|
+
# f"`scale.shape = {scale.shape}`."
|
|
268
|
+
# )
|
|
279
269
|
|
|
280
270
|
|
|
281
271
|
__all__ = [
|
gpjax/fit.py
CHANGED
|
@@ -15,14 +15,14 @@
|
|
|
15
15
|
|
|
16
16
|
import typing as tp
|
|
17
17
|
|
|
18
|
-
from flax import nnx
|
|
19
18
|
import jax
|
|
20
|
-
from jax.flatten_util import ravel_pytree
|
|
21
19
|
import jax.numpy as jnp
|
|
22
20
|
import jax.random as jr
|
|
23
21
|
import optax as ox
|
|
22
|
+
from flax import nnx
|
|
23
|
+
from jax.flatten_util import ravel_pytree
|
|
24
|
+
from numpyro.distributions.transforms import Transform
|
|
24
25
|
from scipy.optimize import minimize
|
|
25
|
-
from tensorflow_probability.substrates.jax.bijectors import Bijector
|
|
26
26
|
|
|
27
27
|
from gpjax.dataset import Dataset
|
|
28
28
|
from gpjax.objectives import Objective
|
|
@@ -47,7 +47,7 @@ def fit( # noqa: PLR0913
|
|
|
47
47
|
objective: Objective,
|
|
48
48
|
train_data: Dataset,
|
|
49
49
|
optim: ox.GradientTransformation,
|
|
50
|
-
params_bijection: tp.Union[dict[Parameter,
|
|
50
|
+
params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
|
|
51
51
|
key: KeyArray = jr.PRNGKey(42),
|
|
52
52
|
num_iters: int = 100,
|
|
53
53
|
batch_size: int = -1,
|
|
@@ -127,7 +127,6 @@ def fit( # noqa: PLR0913
|
|
|
127
127
|
_check_verbose(verbose)
|
|
128
128
|
|
|
129
129
|
# Model state filtering
|
|
130
|
-
|
|
131
130
|
graphdef, params, *static_state = nnx.split(model, Parameter, ...)
|
|
132
131
|
|
|
133
132
|
# Parameters bijection to unconstrained space
|
|
@@ -253,6 +252,110 @@ def fit_scipy( # noqa: PLR0913
|
|
|
253
252
|
return model, history
|
|
254
253
|
|
|
255
254
|
|
|
255
|
+
def fit_lbfgs(
|
|
256
|
+
*,
|
|
257
|
+
model: Model,
|
|
258
|
+
objective: Objective,
|
|
259
|
+
train_data: Dataset,
|
|
260
|
+
params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
|
|
261
|
+
max_iters: int = 100,
|
|
262
|
+
safe: bool = True,
|
|
263
|
+
max_linesearch_steps: int = 32,
|
|
264
|
+
gtol: float = 1e-5,
|
|
265
|
+
) -> tuple[Model, jax.Array]:
|
|
266
|
+
r"""Train a Module model with respect to a supplied Objective function.
|
|
267
|
+
|
|
268
|
+
Uses Optax's LBFGS implementation and a jax.lax.while loop.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
model: the model Module to be optimised.
|
|
272
|
+
objective: The objective function that we are optimising with
|
|
273
|
+
respect to.
|
|
274
|
+
train_data (Dataset): The training data to be used for the optimisation.
|
|
275
|
+
max_iters (int): The maximum number of optimisation steps to run. Defaults
|
|
276
|
+
to 500.
|
|
277
|
+
safe (bool): Whether to check the types of the inputs.
|
|
278
|
+
max_linesearch_steps (int): The maximum number of linesearch steps to use
|
|
279
|
+
for finding the stepsize.
|
|
280
|
+
gtol (float): Terminate the optimisation if the L2 norm of the gradient is
|
|
281
|
+
below this threshold.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
A tuple comprising the optimised model and final loss.
|
|
285
|
+
"""
|
|
286
|
+
if safe:
|
|
287
|
+
# Check inputs
|
|
288
|
+
_check_model(model)
|
|
289
|
+
_check_train_data(train_data)
|
|
290
|
+
_check_num_iters(max_iters)
|
|
291
|
+
|
|
292
|
+
# Model state filtering
|
|
293
|
+
graphdef, params, *static_state = nnx.split(model, Parameter, ...)
|
|
294
|
+
|
|
295
|
+
# Parameters bijection to unconstrained space
|
|
296
|
+
if params_bijection is not None:
|
|
297
|
+
params = transform(params, params_bijection, inverse=True)
|
|
298
|
+
|
|
299
|
+
# Loss definition
|
|
300
|
+
def loss(params: nnx.State) -> ScalarFloat:
|
|
301
|
+
params = transform(params, params_bijection)
|
|
302
|
+
model = nnx.merge(graphdef, params, *static_state)
|
|
303
|
+
return objective(model, train_data)
|
|
304
|
+
|
|
305
|
+
# Initialise optimiser
|
|
306
|
+
optim = ox.lbfgs(
|
|
307
|
+
linesearch=ox.scale_by_zoom_linesearch(
|
|
308
|
+
max_linesearch_steps=max_linesearch_steps,
|
|
309
|
+
initial_guess_strategy="one",
|
|
310
|
+
)
|
|
311
|
+
)
|
|
312
|
+
opt_state = optim.init(params)
|
|
313
|
+
loss_value_and_grad = ox.value_and_grad_from_state(loss)
|
|
314
|
+
|
|
315
|
+
# Optimisation step.
|
|
316
|
+
def step(carry):
|
|
317
|
+
params, opt_state = carry
|
|
318
|
+
|
|
319
|
+
# Using optax's value_and_grad_from_state is more efficient given LBFGS uses a linesearch
|
|
320
|
+
# See https://optax.readthedocs.io/en/latest/api/utilities.html#optax.value_and_grad_from_state
|
|
321
|
+
loss_val, loss_gradient = loss_value_and_grad(params, state=opt_state)
|
|
322
|
+
updates, opt_state = optim.update(
|
|
323
|
+
loss_gradient,
|
|
324
|
+
opt_state,
|
|
325
|
+
params,
|
|
326
|
+
value=loss_val,
|
|
327
|
+
grad=loss_gradient,
|
|
328
|
+
value_fn=loss,
|
|
329
|
+
)
|
|
330
|
+
params = ox.apply_updates(params, updates)
|
|
331
|
+
|
|
332
|
+
return params, opt_state
|
|
333
|
+
|
|
334
|
+
def continue_fn(carry):
|
|
335
|
+
_, opt_state = carry
|
|
336
|
+
n = ox.tree_utils.tree_get(opt_state, "count")
|
|
337
|
+
g = ox.tree_utils.tree_get(opt_state, "grad")
|
|
338
|
+
g_l2_norm = ox.tree_utils.tree_l2_norm(g)
|
|
339
|
+
return (n == 0) | ((n < max_iters) & (g_l2_norm >= gtol))
|
|
340
|
+
|
|
341
|
+
# Optimisation loop
|
|
342
|
+
params, opt_state = jax.lax.while_loop(
|
|
343
|
+
continue_fn,
|
|
344
|
+
step,
|
|
345
|
+
(params, opt_state),
|
|
346
|
+
)
|
|
347
|
+
final_loss = ox.tree_utils.tree_get(opt_state, "value")
|
|
348
|
+
|
|
349
|
+
# Parameters bijection to constrained space
|
|
350
|
+
if params_bijection is not None:
|
|
351
|
+
params = transform(params, params_bijection)
|
|
352
|
+
|
|
353
|
+
# Reconstruct model
|
|
354
|
+
model = nnx.merge(graphdef, params, *static_state)
|
|
355
|
+
|
|
356
|
+
return model, final_loss
|
|
357
|
+
|
|
358
|
+
|
|
256
359
|
def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset:
|
|
257
360
|
"""Batch the data into mini-batches. Sampling is done with replacement.
|
|
258
361
|
|
|
@@ -68,7 +68,7 @@ class RFF(AbstractKernel):
|
|
|
68
68
|
|
|
69
69
|
self.frequencies = Static(
|
|
70
70
|
self.base_kernel.spectral_density.sample(
|
|
71
|
-
|
|
71
|
+
key=key, sample_shape=(self.num_basis_fns, n_dims)
|
|
72
72
|
)
|
|
73
73
|
)
|
|
74
74
|
self.name = f"{self.base_kernel.name} (RFF)"
|
|
@@ -23,7 +23,10 @@ from gpjax.kernels.computations import (
|
|
|
23
23
|
AbstractKernelComputation,
|
|
24
24
|
DenseKernelComputation,
|
|
25
25
|
)
|
|
26
|
-
from gpjax.parameters import
|
|
26
|
+
from gpjax.parameters import (
|
|
27
|
+
NonNegativeReal,
|
|
28
|
+
PositiveReal,
|
|
29
|
+
)
|
|
27
30
|
from gpjax.typing import (
|
|
28
31
|
Array,
|
|
29
32
|
ScalarArray,
|
|
@@ -91,9 +94,9 @@ class ArcCosine(AbstractKernel):
|
|
|
91
94
|
if isinstance(variance, nnx.Variable):
|
|
92
95
|
self.variance = variance
|
|
93
96
|
else:
|
|
94
|
-
self.variance =
|
|
97
|
+
self.variance = NonNegativeReal(variance)
|
|
95
98
|
if tp.TYPE_CHECKING:
|
|
96
|
-
self.variance = tp.cast(
|
|
99
|
+
self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
|
|
97
100
|
|
|
98
101
|
if isinstance(bias_variance, nnx.Variable):
|
|
99
102
|
self.bias_variance = bias_variance
|
|
@@ -23,7 +23,7 @@ from gpjax.kernels.computations import (
|
|
|
23
23
|
AbstractKernelComputation,
|
|
24
24
|
DenseKernelComputation,
|
|
25
25
|
)
|
|
26
|
-
from gpjax.parameters import
|
|
26
|
+
from gpjax.parameters import NonNegativeReal
|
|
27
27
|
from gpjax.typing import (
|
|
28
28
|
Array,
|
|
29
29
|
ScalarArray,
|
|
@@ -64,9 +64,9 @@ class Linear(AbstractKernel):
|
|
|
64
64
|
if isinstance(variance, nnx.Variable):
|
|
65
65
|
self.variance = variance
|
|
66
66
|
else:
|
|
67
|
-
self.variance =
|
|
67
|
+
self.variance = NonNegativeReal(variance)
|
|
68
68
|
if tp.TYPE_CHECKING:
|
|
69
|
-
self.variance = tp.cast(
|
|
69
|
+
self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
|
|
70
70
|
|
|
71
71
|
def __call__(
|
|
72
72
|
self,
|
|
@@ -23,7 +23,10 @@ from gpjax.kernels.computations import (
|
|
|
23
23
|
AbstractKernelComputation,
|
|
24
24
|
DenseKernelComputation,
|
|
25
25
|
)
|
|
26
|
-
from gpjax.parameters import
|
|
26
|
+
from gpjax.parameters import (
|
|
27
|
+
NonNegativeReal,
|
|
28
|
+
PositiveReal,
|
|
29
|
+
)
|
|
27
30
|
from gpjax.typing import (
|
|
28
31
|
Array,
|
|
29
32
|
ScalarArray,
|
|
@@ -76,9 +79,9 @@ class Polynomial(AbstractKernel):
|
|
|
76
79
|
if isinstance(variance, nnx.Variable):
|
|
77
80
|
self.variance = variance
|
|
78
81
|
else:
|
|
79
|
-
self.variance =
|
|
82
|
+
self.variance = NonNegativeReal(variance)
|
|
80
83
|
if tp.TYPE_CHECKING:
|
|
81
|
-
self.variance = tp.cast(
|
|
84
|
+
self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
|
|
82
85
|
|
|
83
86
|
self.name = f"Polynomial (degree {self.degree})"
|
|
84
87
|
|
gpjax/kernels/stationary/base.py
CHANGED
|
@@ -18,14 +18,17 @@ import beartype.typing as tp
|
|
|
18
18
|
from flax import nnx
|
|
19
19
|
import jax.numpy as jnp
|
|
20
20
|
from jaxtyping import Float
|
|
21
|
-
import
|
|
21
|
+
import numpyro.distributions as npd
|
|
22
22
|
|
|
23
23
|
from gpjax.kernels.base import AbstractKernel
|
|
24
24
|
from gpjax.kernels.computations import (
|
|
25
25
|
AbstractKernelComputation,
|
|
26
26
|
DenseKernelComputation,
|
|
27
27
|
)
|
|
28
|
-
from gpjax.parameters import
|
|
28
|
+
from gpjax.parameters import (
|
|
29
|
+
NonNegativeReal,
|
|
30
|
+
PositiveReal,
|
|
31
|
+
)
|
|
29
32
|
from gpjax.typing import (
|
|
30
33
|
Array,
|
|
31
34
|
ScalarArray,
|
|
@@ -85,14 +88,14 @@ class StationaryKernel(AbstractKernel):
|
|
|
85
88
|
if isinstance(variance, nnx.Variable):
|
|
86
89
|
self.variance = variance
|
|
87
90
|
else:
|
|
88
|
-
self.variance =
|
|
91
|
+
self.variance = NonNegativeReal(variance)
|
|
89
92
|
|
|
90
93
|
# static typing
|
|
91
94
|
if tp.TYPE_CHECKING:
|
|
92
|
-
self.variance = tp.cast(
|
|
95
|
+
self.variance = tp.cast(NonNegativeReal[ScalarFloat], self.variance)
|
|
93
96
|
|
|
94
97
|
@property
|
|
95
|
-
def spectral_density(self) ->
|
|
98
|
+
def spectral_density(self) -> npd.Normal | npd.StudentT:
|
|
96
99
|
r"""The spectral density of the kernel.
|
|
97
100
|
|
|
98
101
|
Returns:
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import jax.numpy as jnp
|
|
17
17
|
from jaxtyping import Float
|
|
18
|
-
import
|
|
18
|
+
import numpyro.distributions as npd
|
|
19
19
|
|
|
20
20
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
21
21
|
from gpjax.kernels.stationary.utils import (
|
|
@@ -48,5 +48,5 @@ class Matern12(StationaryKernel):
|
|
|
48
48
|
return K.squeeze()
|
|
49
49
|
|
|
50
50
|
@property
|
|
51
|
-
def spectral_density(self) ->
|
|
51
|
+
def spectral_density(self) -> npd.StudentT:
|
|
52
52
|
return build_student_t_distribution(nu=1)
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import jax.numpy as jnp
|
|
17
17
|
from jaxtyping import Float
|
|
18
|
-
import
|
|
18
|
+
import numpyro.distributions as npd
|
|
19
19
|
|
|
20
20
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
21
21
|
from gpjax.kernels.stationary.utils import (
|
|
@@ -54,5 +54,5 @@ class Matern32(StationaryKernel):
|
|
|
54
54
|
return K.squeeze()
|
|
55
55
|
|
|
56
56
|
@property
|
|
57
|
-
def spectral_density(self) ->
|
|
57
|
+
def spectral_density(self) -> npd.StudentT:
|
|
58
58
|
return build_student_t_distribution(nu=3)
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import jax.numpy as jnp
|
|
17
17
|
from jaxtyping import Float
|
|
18
|
-
import
|
|
18
|
+
import numpyro.distributions as npd
|
|
19
19
|
|
|
20
20
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
21
21
|
from gpjax.kernels.stationary.utils import (
|
|
@@ -53,5 +53,5 @@ class Matern52(StationaryKernel):
|
|
|
53
53
|
return K.squeeze()
|
|
54
54
|
|
|
55
55
|
@property
|
|
56
|
-
def spectral_density(self) ->
|
|
56
|
+
def spectral_density(self) -> npd.StudentT:
|
|
57
57
|
return build_student_t_distribution(nu=5)
|
gpjax/kernels/stationary/rbf.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import jax.numpy as jnp
|
|
17
17
|
from jaxtyping import Float
|
|
18
|
-
import
|
|
18
|
+
import numpyro.distributions as npd
|
|
19
19
|
|
|
20
20
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
21
21
|
from gpjax.kernels.stationary.utils import squared_distance
|
|
@@ -44,5 +44,5 @@ class RBF(StationaryKernel):
|
|
|
44
44
|
return K.squeeze()
|
|
45
45
|
|
|
46
46
|
@property
|
|
47
|
-
def spectral_density(self) ->
|
|
48
|
-
return
|
|
47
|
+
def spectral_density(self) -> npd.Normal:
|
|
48
|
+
return npd.Normal(0.0, 1.0)
|
|
@@ -14,17 +14,15 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
import jax.numpy as jnp
|
|
16
16
|
from jaxtyping import Float
|
|
17
|
-
import
|
|
17
|
+
import numpyro.distributions as npd
|
|
18
18
|
|
|
19
19
|
from gpjax.typing import (
|
|
20
20
|
Array,
|
|
21
21
|
ScalarFloat,
|
|
22
22
|
)
|
|
23
23
|
|
|
24
|
-
tfd = tfp.distributions
|
|
25
24
|
|
|
26
|
-
|
|
27
|
-
def build_student_t_distribution(nu: int) -> tfd.Distribution:
|
|
25
|
+
def build_student_t_distribution(nu: int) -> npd.StudentT:
|
|
28
26
|
r"""Build a Student's t distribution with a fixed smoothness parameter.
|
|
29
27
|
|
|
30
28
|
For a fixed half-integer smoothness parameter, compute the spectral density of a
|
|
@@ -37,7 +35,7 @@ def build_student_t_distribution(nu: int) -> tfd.Distribution:
|
|
|
37
35
|
-------
|
|
38
36
|
tfp.Distribution: A Student's t distribution with the same smoothness parameter.
|
|
39
37
|
"""
|
|
40
|
-
dist =
|
|
38
|
+
dist = npd.StudentT(df=nu, loc=0.0, scale=1.0)
|
|
41
39
|
return dist
|
|
42
40
|
|
|
43
41
|
|