gpjax 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -1
- gpjax/distributions.py +101 -111
- gpjax/fit.py +2 -2
- gpjax/kernels/approximations/rff.py +1 -1
- gpjax/kernels/stationary/base.py +2 -2
- gpjax/kernels/stationary/matern12.py +2 -2
- gpjax/kernels/stationary/matern32.py +2 -2
- gpjax/kernels/stationary/matern52.py +2 -2
- gpjax/kernels/stationary/rbf.py +3 -3
- gpjax/kernels/stationary/utils.py +3 -5
- gpjax/likelihoods.py +36 -35
- gpjax/mean_functions.py +3 -2
- gpjax/numpyro_extras.py +106 -0
- gpjax/objectives.py +4 -6
- gpjax/parameters.py +15 -13
- gpjax/variational_families.py +5 -1
- {gpjax-0.10.2.dist-info → gpjax-0.11.0.dist-info}/METADATA +2 -61
- {gpjax-0.10.2.dist-info → gpjax-0.11.0.dist-info}/RECORD +20 -19
- {gpjax-0.10.2.dist-info → gpjax-0.11.0.dist-info}/WHEEL +0 -0
- {gpjax-0.10.2.dist-info → gpjax-0.11.0.dist-info}/licenses/LICENSE.txt +0 -0
gpjax/__init__.py
CHANGED
|
@@ -39,7 +39,7 @@ __license__ = "MIT"
|
|
|
39
39
|
__description__ = "Didactic Gaussian processes in JAX"
|
|
40
40
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
41
41
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
42
|
-
__version__ = "0.
|
|
42
|
+
__version__ = "0.11.0"
|
|
43
43
|
|
|
44
44
|
__all__ = [
|
|
45
45
|
"base",
|
gpjax/distributions.py
CHANGED
|
@@ -15,77 +15,76 @@
|
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
from beartype.typing import (
|
|
18
|
-
Any,
|
|
19
18
|
Optional,
|
|
20
|
-
Tuple,
|
|
21
|
-
TypeVar,
|
|
22
19
|
)
|
|
23
20
|
import cola
|
|
21
|
+
from cola.linalg.decompositions import Cholesky
|
|
24
22
|
from cola.ops import (
|
|
25
|
-
Identity,
|
|
26
23
|
LinearOperator,
|
|
27
24
|
)
|
|
28
25
|
from jax import vmap
|
|
29
26
|
import jax.numpy as jnp
|
|
30
27
|
import jax.random as jr
|
|
31
28
|
from jaxtyping import Float
|
|
32
|
-
|
|
29
|
+
from numpyro.distributions import constraints
|
|
30
|
+
from numpyro.distributions.distribution import Distribution
|
|
31
|
+
from numpyro.distributions.util import is_prng_key
|
|
33
32
|
|
|
34
33
|
from gpjax.lower_cholesky import lower_cholesky
|
|
35
34
|
from gpjax.typing import (
|
|
36
35
|
Array,
|
|
37
|
-
KeyArray,
|
|
38
36
|
ScalarFloat,
|
|
39
37
|
)
|
|
40
38
|
|
|
41
|
-
tfd = tfp.distributions
|
|
42
|
-
|
|
43
|
-
from cola.linalg.decompositions import Cholesky
|
|
44
|
-
|
|
45
39
|
|
|
46
|
-
class GaussianDistribution(
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
# TODO: Consider `distrax.transformed.Transformed` object. Can we create a LinearOperator to `distrax.bijector` representation
|
|
50
|
-
# and modify `distrax.MultivariateNormalFromBijector`?
|
|
51
|
-
# TODO: Consider natural and expectation parameterisations in future work.
|
|
52
|
-
# TODO: we don't really need to inherit from `tfd.Distribution` here
|
|
40
|
+
class GaussianDistribution(Distribution):
|
|
41
|
+
support = constraints.real_vector
|
|
53
42
|
|
|
54
43
|
def __init__(
|
|
55
44
|
self,
|
|
56
|
-
loc: Optional[Float[Array, " N"]]
|
|
57
|
-
scale: Optional[LinearOperator]
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
_check_loc_scale(loc, scale)
|
|
45
|
+
loc: Optional[Float[Array, " N"]],
|
|
46
|
+
scale: Optional[LinearOperator],
|
|
47
|
+
validate_args=None,
|
|
48
|
+
):
|
|
49
|
+
self.loc = loc
|
|
50
|
+
self.scale = cola.PSD(scale)
|
|
51
|
+
batch_shape = ()
|
|
52
|
+
event_shape = jnp.shape(self.loc)
|
|
53
|
+
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
|
66
54
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
55
|
+
def sample(self, key, sample_shape=()):
|
|
56
|
+
assert is_prng_key(key)
|
|
57
|
+
# Obtain covariance root.
|
|
58
|
+
covariance_root = lower_cholesky(self.scale)
|
|
70
59
|
|
|
71
|
-
|
|
72
|
-
|
|
60
|
+
# Gather n samples from standard normal distribution Z = [z₁, ..., zₙ]ᵀ.
|
|
61
|
+
white_noise = jr.normal(
|
|
62
|
+
key, shape=sample_shape + self.batch_shape + self.event_shape
|
|
63
|
+
)
|
|
73
64
|
|
|
74
|
-
#
|
|
75
|
-
|
|
76
|
-
loc
|
|
65
|
+
# xᵢ ~ N(loc, cov) <=> xᵢ = loc + sqrt zᵢ, where zᵢ ~ N(0, I).
|
|
66
|
+
def affine_transformation(_x):
|
|
67
|
+
return self.loc + covariance_root @ _x
|
|
77
68
|
|
|
78
|
-
|
|
79
|
-
if scale is None:
|
|
80
|
-
scale = Identity(shape=(num_dims, num_dims), dtype=loc.dtype)
|
|
81
|
-
|
|
82
|
-
self.loc = loc
|
|
83
|
-
self.scale = cola.PSD(scale)
|
|
69
|
+
return vmap(affine_transformation)(white_noise)
|
|
84
70
|
|
|
71
|
+
@property
|
|
85
72
|
def mean(self) -> Float[Array, " N"]:
|
|
86
73
|
r"""Calculates the mean."""
|
|
87
74
|
return self.loc
|
|
88
75
|
|
|
76
|
+
@property
|
|
77
|
+
def variance(self) -> Float[Array, " N"]:
|
|
78
|
+
r"""Calculates the variance."""
|
|
79
|
+
return cola.diag(self.scale)
|
|
80
|
+
|
|
81
|
+
def entropy(self) -> ScalarFloat:
|
|
82
|
+
r"""Calculates the entropy of the distribution."""
|
|
83
|
+
return 0.5 * (
|
|
84
|
+
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
|
|
85
|
+
+ cola.logdet(self.scale, Cholesky(), Cholesky())
|
|
86
|
+
)
|
|
87
|
+
|
|
89
88
|
def median(self) -> Float[Array, " N"]:
|
|
90
89
|
r"""Calculates the median."""
|
|
91
90
|
return self.loc
|
|
@@ -98,25 +97,19 @@ class GaussianDistribution(tfd.Distribution):
|
|
|
98
97
|
r"""Calculates the covariance matrix."""
|
|
99
98
|
return self.scale.to_dense()
|
|
100
99
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
100
|
+
@property
|
|
101
|
+
def covariance_matrix(self) -> Float[Array, "N N"]:
|
|
102
|
+
r"""Calculates the covariance matrix."""
|
|
103
|
+
return self.covariance()
|
|
104
104
|
|
|
105
105
|
def stddev(self) -> Float[Array, " N"]:
|
|
106
106
|
r"""Calculates the standard deviation."""
|
|
107
107
|
return jnp.sqrt(cola.diag(self.scale))
|
|
108
108
|
|
|
109
|
-
@property
|
|
110
|
-
def event_shape(self) -> Tuple:
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def entropy(self) -> ScalarFloat:
|
|
115
|
-
r"""Calculates the entropy of the distribution."""
|
|
116
|
-
return 0.5 * (
|
|
117
|
-
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
|
|
118
|
-
+ cola.logdet(self.scale, Cholesky(), Cholesky())
|
|
119
|
-
)
|
|
109
|
+
# @property
|
|
110
|
+
# def event_shape(self) -> Tuple:
|
|
111
|
+
# r"""Returns the event shape."""
|
|
112
|
+
# return self.loc.shape[-1:]
|
|
120
113
|
|
|
121
114
|
def log_prob(self, y: Float[Array, " N"]) -> ScalarFloat:
|
|
122
115
|
r"""Calculates the log pdf of the multivariate Gaussian.
|
|
@@ -141,42 +134,39 @@ class GaussianDistribution(tfd.Distribution):
|
|
|
141
134
|
+ diff.T @ cola.solve(sigma, diff, Cholesky())
|
|
142
135
|
)
|
|
143
136
|
|
|
144
|
-
def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
|
|
145
|
-
|
|
137
|
+
# def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
|
|
138
|
+
# r"""Samples from the distribution.
|
|
146
139
|
|
|
147
|
-
|
|
148
|
-
|
|
140
|
+
# Args:
|
|
141
|
+
# key (KeyArray): The key to use for sampling.
|
|
149
142
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
143
|
+
# Returns:
|
|
144
|
+
# The samples as an array of shape (n_samples, n_points).
|
|
145
|
+
# """
|
|
146
|
+
# # Obtain covariance root.
|
|
147
|
+
# sqrt = lower_cholesky(self.scale)
|
|
155
148
|
|
|
156
|
-
|
|
157
|
-
|
|
149
|
+
# # Gather n samples from standard normal distribution Z = [z₁, ..., zₙ]ᵀ.
|
|
150
|
+
# Z = jr.normal(key, shape=(n, *self.event_shape))
|
|
158
151
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
152
|
+
# # xᵢ ~ N(loc, cov) <=> xᵢ = loc + sqrt zᵢ, where zᵢ ~ N(0, I).
|
|
153
|
+
# def affine_transformation(x):
|
|
154
|
+
# return self.loc + sqrt @ x
|
|
162
155
|
|
|
163
|
-
|
|
156
|
+
# return vmap(affine_transformation)(Z)
|
|
164
157
|
|
|
165
|
-
def sample(
|
|
166
|
-
|
|
167
|
-
): # pylint: disable=useless-super-delegation
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
158
|
+
# def sample(
|
|
159
|
+
# self, seed: KeyArray, sample_shape: Tuple[int, ...]
|
|
160
|
+
# ): # pylint: disable=useless-super-delegation
|
|
161
|
+
# r"""See `Distribution.sample`."""
|
|
162
|
+
# return self._sample_n(
|
|
163
|
+
# seed, sample_shape[0]
|
|
164
|
+
# ) # TODO this looks weird, why ignore the second entry?
|
|
172
165
|
|
|
173
166
|
def kl_divergence(self, other: "GaussianDistribution") -> ScalarFloat:
|
|
174
167
|
return _kl_divergence(self, other)
|
|
175
168
|
|
|
176
169
|
|
|
177
|
-
DistrT = TypeVar("DistrT", bound=tfd.Distribution)
|
|
178
|
-
|
|
179
|
-
|
|
180
170
|
def _check_and_return_dimension(
|
|
181
171
|
q: GaussianDistribution, p: GaussianDistribution
|
|
182
172
|
) -> int:
|
|
@@ -245,37 +235,37 @@ def _kl_divergence(q: GaussianDistribution, p: GaussianDistribution) -> ScalarFl
|
|
|
245
235
|
) / 2.0
|
|
246
236
|
|
|
247
237
|
|
|
248
|
-
def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
238
|
+
# def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
|
|
239
|
+
# r"""Checks that the inputs are correct."""
|
|
240
|
+
# if loc is None and scale is None:
|
|
241
|
+
# raise ValueError("At least one of `loc` or `scale` must be specified.")
|
|
242
|
+
|
|
243
|
+
# if loc is not None and loc.ndim < 1:
|
|
244
|
+
# raise ValueError("The parameter `loc` must have at least one dimension.")
|
|
245
|
+
|
|
246
|
+
# if scale is not None and len(scale.shape) < 2: # scale.ndim < 2:
|
|
247
|
+
# raise ValueError(
|
|
248
|
+
# "The `scale` must have at least two dimensions, but "
|
|
249
|
+
# f"`scale.shape = {scale.shape}`."
|
|
250
|
+
# )
|
|
251
|
+
|
|
252
|
+
# if scale is not None and not isinstance(scale, LinearOperator):
|
|
253
|
+
# raise ValueError(
|
|
254
|
+
# f"The `scale` must be a CoLA LinearOperator but got {type(scale)}"
|
|
255
|
+
# )
|
|
256
|
+
|
|
257
|
+
# if scale is not None and (scale.shape[-1] != scale.shape[-2]):
|
|
258
|
+
# raise ValueError(
|
|
259
|
+
# f"The `scale` must be a square matrix, but `scale.shape = {scale.shape}`."
|
|
260
|
+
# )
|
|
261
|
+
|
|
262
|
+
# if loc is not None:
|
|
263
|
+
# num_dims = loc.shape[-1]
|
|
264
|
+
# if scale is not None and (scale.shape[-1] != num_dims):
|
|
265
|
+
# raise ValueError(
|
|
266
|
+
# f"Shapes are not compatible: `loc.shape = {loc.shape}` and "
|
|
267
|
+
# f"`scale.shape = {scale.shape}`."
|
|
268
|
+
# )
|
|
279
269
|
|
|
280
270
|
|
|
281
271
|
__all__ = [
|
gpjax/fit.py
CHANGED
|
@@ -20,9 +20,9 @@ import jax
|
|
|
20
20
|
from jax.flatten_util import ravel_pytree
|
|
21
21
|
import jax.numpy as jnp
|
|
22
22
|
import jax.random as jr
|
|
23
|
+
from numpyro.distributions.transforms import Transform
|
|
23
24
|
import optax as ox
|
|
24
25
|
from scipy.optimize import minimize
|
|
25
|
-
from tensorflow_probability.substrates.jax.bijectors import Bijector
|
|
26
26
|
|
|
27
27
|
from gpjax.dataset import Dataset
|
|
28
28
|
from gpjax.objectives import Objective
|
|
@@ -47,7 +47,7 @@ def fit( # noqa: PLR0913
|
|
|
47
47
|
objective: Objective,
|
|
48
48
|
train_data: Dataset,
|
|
49
49
|
optim: ox.GradientTransformation,
|
|
50
|
-
params_bijection: tp.Union[dict[Parameter,
|
|
50
|
+
params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
|
|
51
51
|
key: KeyArray = jr.PRNGKey(42),
|
|
52
52
|
num_iters: int = 100,
|
|
53
53
|
batch_size: int = -1,
|
|
@@ -68,7 +68,7 @@ class RFF(AbstractKernel):
|
|
|
68
68
|
|
|
69
69
|
self.frequencies = Static(
|
|
70
70
|
self.base_kernel.spectral_density.sample(
|
|
71
|
-
|
|
71
|
+
key=key, sample_shape=(self.num_basis_fns, n_dims)
|
|
72
72
|
)
|
|
73
73
|
)
|
|
74
74
|
self.name = f"{self.base_kernel.name} (RFF)"
|
gpjax/kernels/stationary/base.py
CHANGED
|
@@ -18,7 +18,7 @@ import beartype.typing as tp
|
|
|
18
18
|
from flax import nnx
|
|
19
19
|
import jax.numpy as jnp
|
|
20
20
|
from jaxtyping import Float
|
|
21
|
-
import
|
|
21
|
+
import numpyro.distributions as npd
|
|
22
22
|
|
|
23
23
|
from gpjax.kernels.base import AbstractKernel
|
|
24
24
|
from gpjax.kernels.computations import (
|
|
@@ -92,7 +92,7 @@ class StationaryKernel(AbstractKernel):
|
|
|
92
92
|
self.variance = tp.cast(PositiveReal[ScalarFloat], self.variance)
|
|
93
93
|
|
|
94
94
|
@property
|
|
95
|
-
def spectral_density(self) ->
|
|
95
|
+
def spectral_density(self) -> npd.Normal | npd.StudentT:
|
|
96
96
|
r"""The spectral density of the kernel.
|
|
97
97
|
|
|
98
98
|
Returns:
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import jax.numpy as jnp
|
|
17
17
|
from jaxtyping import Float
|
|
18
|
-
import
|
|
18
|
+
import numpyro.distributions as npd
|
|
19
19
|
|
|
20
20
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
21
21
|
from gpjax.kernels.stationary.utils import (
|
|
@@ -48,5 +48,5 @@ class Matern12(StationaryKernel):
|
|
|
48
48
|
return K.squeeze()
|
|
49
49
|
|
|
50
50
|
@property
|
|
51
|
-
def spectral_density(self) ->
|
|
51
|
+
def spectral_density(self) -> npd.StudentT:
|
|
52
52
|
return build_student_t_distribution(nu=1)
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import jax.numpy as jnp
|
|
17
17
|
from jaxtyping import Float
|
|
18
|
-
import
|
|
18
|
+
import numpyro.distributions as npd
|
|
19
19
|
|
|
20
20
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
21
21
|
from gpjax.kernels.stationary.utils import (
|
|
@@ -54,5 +54,5 @@ class Matern32(StationaryKernel):
|
|
|
54
54
|
return K.squeeze()
|
|
55
55
|
|
|
56
56
|
@property
|
|
57
|
-
def spectral_density(self) ->
|
|
57
|
+
def spectral_density(self) -> npd.StudentT:
|
|
58
58
|
return build_student_t_distribution(nu=3)
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import jax.numpy as jnp
|
|
17
17
|
from jaxtyping import Float
|
|
18
|
-
import
|
|
18
|
+
import numpyro.distributions as npd
|
|
19
19
|
|
|
20
20
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
21
21
|
from gpjax.kernels.stationary.utils import (
|
|
@@ -53,5 +53,5 @@ class Matern52(StationaryKernel):
|
|
|
53
53
|
return K.squeeze()
|
|
54
54
|
|
|
55
55
|
@property
|
|
56
|
-
def spectral_density(self) ->
|
|
56
|
+
def spectral_density(self) -> npd.StudentT:
|
|
57
57
|
return build_student_t_distribution(nu=5)
|
gpjax/kernels/stationary/rbf.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
import jax.numpy as jnp
|
|
17
17
|
from jaxtyping import Float
|
|
18
|
-
import
|
|
18
|
+
import numpyro.distributions as npd
|
|
19
19
|
|
|
20
20
|
from gpjax.kernels.stationary.base import StationaryKernel
|
|
21
21
|
from gpjax.kernels.stationary.utils import squared_distance
|
|
@@ -44,5 +44,5 @@ class RBF(StationaryKernel):
|
|
|
44
44
|
return K.squeeze()
|
|
45
45
|
|
|
46
46
|
@property
|
|
47
|
-
def spectral_density(self) ->
|
|
48
|
-
return
|
|
47
|
+
def spectral_density(self) -> npd.Normal:
|
|
48
|
+
return npd.Normal(0.0, 1.0)
|
|
@@ -14,17 +14,15 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
import jax.numpy as jnp
|
|
16
16
|
from jaxtyping import Float
|
|
17
|
-
import
|
|
17
|
+
import numpyro.distributions as npd
|
|
18
18
|
|
|
19
19
|
from gpjax.typing import (
|
|
20
20
|
Array,
|
|
21
21
|
ScalarFloat,
|
|
22
22
|
)
|
|
23
23
|
|
|
24
|
-
tfd = tfp.distributions
|
|
25
24
|
|
|
26
|
-
|
|
27
|
-
def build_student_t_distribution(nu: int) -> tfd.Distribution:
|
|
25
|
+
def build_student_t_distribution(nu: int) -> npd.StudentT:
|
|
28
26
|
r"""Build a Student's t distribution with a fixed smoothness parameter.
|
|
29
27
|
|
|
30
28
|
For a fixed half-integer smoothness parameter, compute the spectral density of a
|
|
@@ -37,7 +35,7 @@ def build_student_t_distribution(nu: int) -> tfd.Distribution:
|
|
|
37
35
|
-------
|
|
38
36
|
tfp.Distribution: A Student's t distribution with the same smoothness parameter.
|
|
39
37
|
"""
|
|
40
|
-
dist =
|
|
38
|
+
dist = npd.StudentT(df=nu, loc=0.0, scale=1.0)
|
|
41
39
|
return dist
|
|
42
40
|
|
|
43
41
|
|
gpjax/likelihoods.py
CHANGED
|
@@ -19,7 +19,7 @@ from jax import vmap
|
|
|
19
19
|
import jax.numpy as jnp
|
|
20
20
|
import jax.scipy as jsp
|
|
21
21
|
from jaxtyping import Float
|
|
22
|
-
import
|
|
22
|
+
import numpyro.distributions as npd
|
|
23
23
|
|
|
24
24
|
from gpjax.distributions import GaussianDistribution
|
|
25
25
|
from gpjax.integrators import (
|
|
@@ -36,9 +36,6 @@ from gpjax.typing import (
|
|
|
36
36
|
ScalarFloat,
|
|
37
37
|
)
|
|
38
38
|
|
|
39
|
-
tfb = tfp.bijectors
|
|
40
|
-
tfd = tfp.distributions
|
|
41
|
-
|
|
42
39
|
|
|
43
40
|
class AbstractLikelihood(nnx.Module):
|
|
44
41
|
r"""Abstract base class for likelihoods.
|
|
@@ -62,7 +59,7 @@ class AbstractLikelihood(nnx.Module):
|
|
|
62
59
|
self.num_datapoints = num_datapoints
|
|
63
60
|
self.integrator = integrator
|
|
64
61
|
|
|
65
|
-
def __call__(self, *args: tp.Any, **kwargs: tp.Any) ->
|
|
62
|
+
def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> npd.Distribution:
|
|
66
63
|
r"""Evaluate the likelihood function at a given predictive distribution.
|
|
67
64
|
|
|
68
65
|
Args:
|
|
@@ -76,7 +73,7 @@ class AbstractLikelihood(nnx.Module):
|
|
|
76
73
|
return self.predict(*args, **kwargs)
|
|
77
74
|
|
|
78
75
|
@abc.abstractmethod
|
|
79
|
-
def predict(self, *args: tp.Any, **kwargs: tp.Any) ->
|
|
76
|
+
def predict(self, *args: tp.Any, **kwargs: tp.Any) -> npd.Distribution:
|
|
80
77
|
r"""Evaluate the likelihood function at a given predictive distribution.
|
|
81
78
|
|
|
82
79
|
Args:
|
|
@@ -85,19 +82,19 @@ class AbstractLikelihood(nnx.Module):
|
|
|
85
82
|
`predict` method.
|
|
86
83
|
|
|
87
84
|
Returns:
|
|
88
|
-
|
|
85
|
+
npd.Distribution: The predictive distribution.
|
|
89
86
|
"""
|
|
90
87
|
raise NotImplementedError
|
|
91
88
|
|
|
92
89
|
@abc.abstractmethod
|
|
93
|
-
def link_function(self, f: Float[Array, "..."]) ->
|
|
90
|
+
def link_function(self, f: Float[Array, "..."]) -> npd.Distribution:
|
|
94
91
|
r"""Return the link function of the likelihood function.
|
|
95
92
|
|
|
96
93
|
Args:
|
|
97
94
|
f (Float[Array, "..."]): the latent Gaussian process values.
|
|
98
95
|
|
|
99
96
|
Returns:
|
|
100
|
-
|
|
97
|
+
npd.Distribution: The distribution of observations, y, given values of the
|
|
101
98
|
Gaussian process, f.
|
|
102
99
|
"""
|
|
103
100
|
raise NotImplementedError
|
|
@@ -157,20 +154,20 @@ class Gaussian(AbstractLikelihood):
|
|
|
157
154
|
|
|
158
155
|
super().__init__(num_datapoints, integrator)
|
|
159
156
|
|
|
160
|
-
def link_function(self, f: Float[Array, "..."]) ->
|
|
157
|
+
def link_function(self, f: Float[Array, "..."]) -> npd.Normal:
|
|
161
158
|
r"""The link function of the Gaussian likelihood.
|
|
162
159
|
|
|
163
160
|
Args:
|
|
164
161
|
f (Float[Array, "..."]): Function values.
|
|
165
162
|
|
|
166
163
|
Returns:
|
|
167
|
-
|
|
164
|
+
npd.Normal: The likelihood function.
|
|
168
165
|
"""
|
|
169
|
-
return
|
|
166
|
+
return npd.Normal(loc=f, scale=self.obs_stddev.value.astype(f.dtype))
|
|
170
167
|
|
|
171
168
|
def predict(
|
|
172
|
-
self, dist: tp.Union[
|
|
173
|
-
) ->
|
|
169
|
+
self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
|
|
170
|
+
) -> npd.MultivariateNormal:
|
|
174
171
|
r"""Evaluate the Gaussian likelihood.
|
|
175
172
|
|
|
176
173
|
Evaluate the Gaussian likelihood function at a given predictive
|
|
@@ -179,75 +176,79 @@ class Gaussian(AbstractLikelihood):
|
|
|
179
176
|
distribution's covariance matrix.
|
|
180
177
|
|
|
181
178
|
Args:
|
|
182
|
-
dist (
|
|
179
|
+
dist (npd.Distribution): The Gaussian process posterior,
|
|
183
180
|
evaluated at a finite set of test points.
|
|
184
181
|
|
|
185
182
|
Returns:
|
|
186
|
-
|
|
183
|
+
npd.Distribution: The predictive distribution.
|
|
187
184
|
"""
|
|
188
185
|
n_data = dist.event_shape[0]
|
|
189
|
-
cov = dist.
|
|
186
|
+
cov = dist.covariance_matrix
|
|
190
187
|
noisy_cov = cov.at[jnp.diag_indices(n_data)].add(self.obs_stddev.value**2)
|
|
191
188
|
|
|
192
|
-
return
|
|
189
|
+
return npd.MultivariateNormal(dist.mean, noisy_cov)
|
|
193
190
|
|
|
194
191
|
|
|
195
192
|
class Bernoulli(AbstractLikelihood):
|
|
196
|
-
def link_function(self, f: Float[Array, "..."]) ->
|
|
193
|
+
def link_function(self, f: Float[Array, "..."]) -> npd.BernoulliProbs:
|
|
197
194
|
r"""The probit link function of the Bernoulli likelihood.
|
|
198
195
|
|
|
199
196
|
Args:
|
|
200
197
|
f (Float[Array, "..."]): Function values.
|
|
201
198
|
|
|
202
199
|
Returns:
|
|
203
|
-
|
|
200
|
+
npd.Bernoulli: The likelihood function.
|
|
204
201
|
"""
|
|
205
|
-
return
|
|
202
|
+
return npd.Bernoulli(probs=inv_probit(f))
|
|
206
203
|
|
|
207
|
-
def predict(
|
|
204
|
+
def predict(
|
|
205
|
+
self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
|
|
206
|
+
) -> npd.BernoulliProbs:
|
|
208
207
|
r"""Evaluate the pointwise predictive distribution.
|
|
209
208
|
|
|
210
209
|
Evaluate the pointwise predictive distribution, given a Gaussian
|
|
211
210
|
process posterior and likelihood parameters.
|
|
212
211
|
|
|
213
212
|
Args:
|
|
214
|
-
dist (
|
|
215
|
-
at a finite set of test points.
|
|
213
|
+
dist ([npd.MultivariateNormal, GaussianDistribution].): The Gaussian
|
|
214
|
+
process posterior, evaluated at a finite set of test points.
|
|
216
215
|
|
|
217
216
|
Returns:
|
|
218
|
-
|
|
217
|
+
npd.Bernoulli: The pointwise predictive distribution.
|
|
219
218
|
"""
|
|
220
|
-
variance = jnp.diag(dist.
|
|
221
|
-
mean = dist.mean
|
|
219
|
+
variance = jnp.diag(dist.covariance_matrix)
|
|
220
|
+
mean = dist.mean.ravel()
|
|
222
221
|
return self.link_function(mean / jnp.sqrt(1.0 + variance))
|
|
223
222
|
|
|
224
223
|
|
|
225
224
|
class Poisson(AbstractLikelihood):
|
|
226
|
-
def link_function(self, f: Float[Array, "..."]) ->
|
|
225
|
+
def link_function(self, f: Float[Array, "..."]) -> npd.Poisson:
|
|
227
226
|
r"""The link function of the Poisson likelihood.
|
|
228
227
|
|
|
229
228
|
Args:
|
|
230
229
|
f (Float[Array, "..."]): Function values.
|
|
231
230
|
|
|
232
231
|
Returns:
|
|
233
|
-
|
|
232
|
+
npd.Poisson: The likelihood function.
|
|
234
233
|
"""
|
|
235
|
-
return
|
|
234
|
+
return npd.Poisson(rate=jnp.exp(f))
|
|
236
235
|
|
|
237
|
-
def predict(
|
|
236
|
+
def predict(
|
|
237
|
+
self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
|
|
238
|
+
) -> npd.Poisson:
|
|
238
239
|
r"""Evaluate the pointwise predictive distribution.
|
|
239
240
|
|
|
240
241
|
Evaluate the pointwise predictive distribution, given a Gaussian
|
|
241
242
|
process posterior and likelihood parameters.
|
|
242
243
|
|
|
243
244
|
Args:
|
|
244
|
-
dist (
|
|
245
|
-
at a finite set of test points.
|
|
245
|
+
dist (tp.Union[npd.MultivariateNormal, GaussianDistribution]): The Gaussian
|
|
246
|
+
process posterior, evaluated at a finite set of test points.
|
|
246
247
|
|
|
247
248
|
Returns:
|
|
248
|
-
|
|
249
|
+
npd.Poisson: The pointwise predictive distribution.
|
|
249
250
|
"""
|
|
250
|
-
return self.link_function(dist.mean
|
|
251
|
+
return self.link_function(dist.mean)
|
|
251
252
|
|
|
252
253
|
|
|
253
254
|
def inv_probit(x: Float[Array, " *N"]) -> Float[Array, " *N"]:
|
gpjax/mean_functions.py
CHANGED
|
@@ -28,7 +28,7 @@ from jaxtyping import (
|
|
|
28
28
|
from gpjax.parameters import (
|
|
29
29
|
Parameter,
|
|
30
30
|
Real,
|
|
31
|
-
Static
|
|
31
|
+
Static,
|
|
32
32
|
)
|
|
33
33
|
from gpjax.typing import (
|
|
34
34
|
Array,
|
|
@@ -131,7 +131,8 @@ class Constant(AbstractMeanFunction):
|
|
|
131
131
|
"""
|
|
132
132
|
|
|
133
133
|
def __init__(
|
|
134
|
-
self,
|
|
134
|
+
self,
|
|
135
|
+
constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter, Static] = 0.0,
|
|
135
136
|
):
|
|
136
137
|
if isinstance(constant, Parameter) or isinstance(constant, Static):
|
|
137
138
|
self.constant = constant
|
gpjax/numpyro_extras.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from numpyro.distributions.transforms import Transform
|
|
6
|
+
|
|
7
|
+
# -----------------------------------------------------------------------------
|
|
8
|
+
# Implementation: FillTriangularTransform
|
|
9
|
+
# -----------------------------------------------------------------------------
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FillTriangularTransform(Transform):
|
|
13
|
+
"""
|
|
14
|
+
Transform that maps a vector of length n(n+1)/2 to an n x n lower triangular matrix.
|
|
15
|
+
The ordering is assumed to be:
|
|
16
|
+
(0,0), (1,0), (1,1), (2,0), (2,1), (2,2), ..., (n-1, n-1)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
# Note: The base class provides `inv` through _InverseTransform wrapping _inverse.
|
|
20
|
+
|
|
21
|
+
def __call__(self, x):
|
|
22
|
+
"""
|
|
23
|
+
Forward transformation.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
x : array_like, shape (..., L)
|
|
28
|
+
Input vector with L = n(n+1)/2 for some integer n.
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
y : array_like, shape (..., n, n)
|
|
33
|
+
Lower-triangular matrix (with zeros in the upper triangle) filled in
|
|
34
|
+
row-major order (i.e. [ (0,0), (1,0), (1,1), ... ]).
|
|
35
|
+
"""
|
|
36
|
+
L = x.shape[-1]
|
|
37
|
+
# Use static (Python) math.sqrt to compute n. This avoids tracer issues.
|
|
38
|
+
n = int((-1 + math.sqrt(1 + 8 * L)) // 2)
|
|
39
|
+
if n * (n + 1) // 2 != L:
|
|
40
|
+
raise ValueError("Last dimension must equal n(n+1)/2 for some integer n.")
|
|
41
|
+
|
|
42
|
+
def fill_single(vec):
|
|
43
|
+
out = jnp.zeros((n, n), dtype=vec.dtype)
|
|
44
|
+
row, col = jnp.tril_indices(n)
|
|
45
|
+
return out.at[row, col].set(vec)
|
|
46
|
+
|
|
47
|
+
if x.ndim == 1:
|
|
48
|
+
return fill_single(x)
|
|
49
|
+
else:
|
|
50
|
+
batch_shape = x.shape[:-1]
|
|
51
|
+
flat_x = x.reshape((-1, L))
|
|
52
|
+
out = jax.vmap(fill_single)(flat_x)
|
|
53
|
+
return out.reshape(batch_shape + (n, n))
|
|
54
|
+
|
|
55
|
+
def _inverse(self, y):
|
|
56
|
+
"""
|
|
57
|
+
Inverse transformation.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
y : array_like, shape (..., n, n)
|
|
62
|
+
Lower triangular matrix.
|
|
63
|
+
|
|
64
|
+
Returns
|
|
65
|
+
-------
|
|
66
|
+
x : array_like, shape (..., n(n+1)/2)
|
|
67
|
+
The vector containing the elements from the lower-triangular portion of y.
|
|
68
|
+
"""
|
|
69
|
+
if y.ndim < 2:
|
|
70
|
+
raise ValueError("Input to inverse must be at least two-dimensional.")
|
|
71
|
+
n = y.shape[-1]
|
|
72
|
+
if y.shape[-2] != n:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
"Input matrix must be square; got shape %s" % str(y.shape[-2:])
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
row, col = jnp.tril_indices(n)
|
|
78
|
+
|
|
79
|
+
def inv_single(mat):
|
|
80
|
+
return mat[row, col]
|
|
81
|
+
|
|
82
|
+
if y.ndim == 2:
|
|
83
|
+
return inv_single(y)
|
|
84
|
+
else:
|
|
85
|
+
batch_shape = y.shape[:-2]
|
|
86
|
+
flat_y = y.reshape((-1, n, n))
|
|
87
|
+
out = jax.vmap(inv_single)(flat_y)
|
|
88
|
+
return out.reshape(batch_shape + (n * (n + 1) // 2,))
|
|
89
|
+
|
|
90
|
+
def log_abs_det_jacobian(self, x, y, intermediates=None):
|
|
91
|
+
# Since the transform simply reorders the vector into a matrix, the Jacobian determinant is 1.
|
|
92
|
+
return jnp.zeros(x.shape[:-1])
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def sign(self):
|
|
96
|
+
# The reordering transformation has a positive derivative everywhere.
|
|
97
|
+
return 1.0
|
|
98
|
+
|
|
99
|
+
# Implement tree_flatten and tree_unflatten because base Transform expects them.
|
|
100
|
+
def tree_flatten(self):
|
|
101
|
+
# This transform is stateless.
|
|
102
|
+
return (), {}
|
|
103
|
+
|
|
104
|
+
@classmethod
|
|
105
|
+
def tree_unflatten(cls, aux_data, children):
|
|
106
|
+
return cls()
|
gpjax/objectives.py
CHANGED
|
@@ -13,7 +13,7 @@ from jax import vmap
|
|
|
13
13
|
import jax.numpy as jnp
|
|
14
14
|
import jax.scipy as jsp
|
|
15
15
|
from jaxtyping import Float
|
|
16
|
-
import
|
|
16
|
+
import numpyro.distributions as npd
|
|
17
17
|
import typing_extensions as tpe
|
|
18
18
|
|
|
19
19
|
from gpjax.dataset import Dataset
|
|
@@ -29,8 +29,6 @@ from gpjax.typing import (
|
|
|
29
29
|
)
|
|
30
30
|
from gpjax.variational_families import AbstractVariationalFamily
|
|
31
31
|
|
|
32
|
-
tfd = tfp.distributions
|
|
33
|
-
|
|
34
32
|
VF = TypeVar("VF", bound=AbstractVariationalFamily)
|
|
35
33
|
|
|
36
34
|
|
|
@@ -175,7 +173,7 @@ def conjugate_loocv(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat
|
|
|
175
173
|
loocv_means = mx + (y - mx) - Sigma_inv_y / Sigma_inv_diag
|
|
176
174
|
loocv_stds = jnp.sqrt(1.0 / Sigma_inv_diag)
|
|
177
175
|
|
|
178
|
-
loocv_posterior =
|
|
176
|
+
loocv_posterior = npd.Normal(loc=loocv_means, scale=loocv_stds)
|
|
179
177
|
return jnp.sum(loocv_posterior.log_prob(y))
|
|
180
178
|
|
|
181
179
|
|
|
@@ -232,7 +230,7 @@ def log_posterior_density(
|
|
|
232
230
|
likelihood = posterior.likelihood.link_function(fx)
|
|
233
231
|
|
|
234
232
|
# Whitened latent function values prior, p(wx | θ) = N(0, I)
|
|
235
|
-
latent_prior =
|
|
233
|
+
latent_prior = npd.Normal(loc=0.0, scale=1.0)
|
|
236
234
|
return likelihood.log_prob(y).sum() + latent_prior.log_prob(wx).sum()
|
|
237
235
|
|
|
238
236
|
|
|
@@ -305,7 +303,7 @@ def variational_expectation(
|
|
|
305
303
|
# inputs, x
|
|
306
304
|
def q_moments(x):
|
|
307
305
|
qx = q(x)
|
|
308
|
-
return qx.mean
|
|
306
|
+
return qx.mean.squeeze(), qx.covariance().squeeze()
|
|
309
307
|
|
|
310
308
|
mean, variance = vmap(q_moments)(x[:, None])
|
|
311
309
|
|
gpjax/parameters.py
CHANGED
|
@@ -5,7 +5,9 @@ from jax.experimental import checkify
|
|
|
5
5
|
import jax.numpy as jnp
|
|
6
6
|
import jax.tree_util as jtu
|
|
7
7
|
from jax.typing import ArrayLike
|
|
8
|
-
import
|
|
8
|
+
import numpyro.distributions.transforms as npt
|
|
9
|
+
|
|
10
|
+
from gpjax.numpyro_extras import FillTriangularTransform
|
|
9
11
|
|
|
10
12
|
T = tp.TypeVar("T", bound=tp.Union[ArrayLike, list[float]])
|
|
11
13
|
ParameterTag = str
|
|
@@ -13,7 +15,7 @@ ParameterTag = str
|
|
|
13
15
|
|
|
14
16
|
def transform(
|
|
15
17
|
params: nnx.State,
|
|
16
|
-
params_bijection: tp.Dict[str,
|
|
18
|
+
params_bijection: tp.Dict[str, npt.Transform],
|
|
17
19
|
inverse: bool = False,
|
|
18
20
|
) -> nnx.State:
|
|
19
21
|
r"""Transforms parameters using a bijector.
|
|
@@ -22,7 +24,7 @@ def transform(
|
|
|
22
24
|
```pycon
|
|
23
25
|
>>> from gpjax.parameters import PositiveReal, transform
|
|
24
26
|
>>> import jax.numpy as jnp
|
|
25
|
-
>>> import
|
|
27
|
+
>>> import numpyro.distributions.transforms as npt
|
|
26
28
|
>>> from flax import nnx
|
|
27
29
|
>>> params = nnx.State(
|
|
28
30
|
>>> {
|
|
@@ -30,7 +32,7 @@ def transform(
|
|
|
30
32
|
>>> "b": PositiveReal(jnp.array([2.0])),
|
|
31
33
|
>>> }
|
|
32
34
|
>>> )
|
|
33
|
-
>>> params_bijection = {'positive':
|
|
35
|
+
>>> params_bijection = {'positive': npt.SoftplusTransform()}
|
|
34
36
|
>>> transformed_params = transform(params, params_bijection)
|
|
35
37
|
>>> print(transformed_params["a"].value)
|
|
36
38
|
[1.3132617]
|
|
@@ -47,11 +49,11 @@ def transform(
|
|
|
47
49
|
"""
|
|
48
50
|
|
|
49
51
|
def _inner(param):
|
|
50
|
-
bijector = params_bijection.get(param._tag,
|
|
52
|
+
bijector = params_bijection.get(param._tag, npt.IdentityTransform())
|
|
51
53
|
if inverse:
|
|
52
|
-
transformed_value = bijector.
|
|
54
|
+
transformed_value = bijector.inv(param.value)
|
|
53
55
|
else:
|
|
54
|
-
transformed_value = bijector
|
|
56
|
+
transformed_value = bijector(param.value)
|
|
55
57
|
|
|
56
58
|
param = param.replace(transformed_value)
|
|
57
59
|
return param
|
|
@@ -104,7 +106,7 @@ class SigmoidBounded(Parameter[T]):
|
|
|
104
106
|
# Only perform validation in non-JIT contexts
|
|
105
107
|
if (
|
|
106
108
|
not isinstance(value, jnp.ndarray)
|
|
107
|
-
or
|
|
109
|
+
or getattr(value, "aval", None) is not None
|
|
108
110
|
):
|
|
109
111
|
_safe_assert(
|
|
110
112
|
_check_in_bounds,
|
|
@@ -133,17 +135,17 @@ class LowerTriangular(Parameter[T]):
|
|
|
133
135
|
# Only perform validation in non-JIT contexts
|
|
134
136
|
if (
|
|
135
137
|
not isinstance(value, jnp.ndarray)
|
|
136
|
-
or
|
|
138
|
+
or getattr(value, "aval", None) is not None
|
|
137
139
|
):
|
|
138
140
|
_safe_assert(_check_is_square, self.value)
|
|
139
141
|
_safe_assert(_check_is_lower_triangular, self.value)
|
|
140
142
|
|
|
141
143
|
|
|
142
144
|
DEFAULT_BIJECTION = {
|
|
143
|
-
"positive":
|
|
144
|
-
"real":
|
|
145
|
-
"sigmoid":
|
|
146
|
-
"lower_triangular":
|
|
145
|
+
"positive": npt.SoftplusTransform(),
|
|
146
|
+
"real": npt.IdentityTransform(),
|
|
147
|
+
"sigmoid": npt.SigmoidTransform(),
|
|
148
|
+
"lower_triangular": FillTriangularTransform(),
|
|
147
149
|
}
|
|
148
150
|
|
|
149
151
|
|
gpjax/variational_families.py
CHANGED
|
@@ -22,6 +22,7 @@ from cola.linalg.inverse.inv import solve
|
|
|
22
22
|
from cola.ops.operators import (
|
|
23
23
|
Dense,
|
|
24
24
|
I_like,
|
|
25
|
+
Identity,
|
|
25
26
|
Triangular,
|
|
26
27
|
)
|
|
27
28
|
from flax import nnx
|
|
@@ -296,7 +297,10 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
|
|
|
296
297
|
|
|
297
298
|
# Compute whitened KL divergence
|
|
298
299
|
qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
|
|
299
|
-
|
|
300
|
+
pu_S = Identity(shape=(self.num_inducing, self.num_inducing), dtype=mu.dtype)
|
|
301
|
+
pu = GaussianDistribution(
|
|
302
|
+
loc=jnp.zeros_like(jnp.atleast_1d(mu.squeeze())), scale=pu_S
|
|
303
|
+
)
|
|
300
304
|
return qu.kl_divergence(pu)
|
|
301
305
|
|
|
302
306
|
def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.11.0
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
6
6
|
Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
|
|
@@ -24,8 +24,8 @@ Requires-Dist: jax>=0.5.0
|
|
|
24
24
|
Requires-Dist: jaxlib>=0.5.0
|
|
25
25
|
Requires-Dist: jaxtyping>0.2.10
|
|
26
26
|
Requires-Dist: numpy>=2.0.0
|
|
27
|
+
Requires-Dist: numpyro
|
|
27
28
|
Requires-Dist: optax>0.2.1
|
|
28
|
-
Requires-Dist: tensorflow-probability>=0.24.0
|
|
29
29
|
Requires-Dist: tqdm>4.66.2
|
|
30
30
|
Description-Content-Type: text/markdown
|
|
31
31
|
|
|
@@ -138,65 +138,6 @@ jupytext --to notebook example.py
|
|
|
138
138
|
jupytext --to py:percent example.ipynb
|
|
139
139
|
```
|
|
140
140
|
|
|
141
|
-
# Simple example
|
|
142
|
-
|
|
143
|
-
Let us import some dependencies and simulate a toy dataset $\mathcal{D}$.
|
|
144
|
-
|
|
145
|
-
```python
|
|
146
|
-
from jax import config
|
|
147
|
-
|
|
148
|
-
config.update("jax_enable_x64", True)
|
|
149
|
-
|
|
150
|
-
import gpjax as gpx
|
|
151
|
-
from jax import grad, jit
|
|
152
|
-
import jax.numpy as jnp
|
|
153
|
-
import jax.random as jr
|
|
154
|
-
import optax as ox
|
|
155
|
-
|
|
156
|
-
key = jr.key(123)
|
|
157
|
-
|
|
158
|
-
f = lambda x: 10 * jnp.sin(x)
|
|
159
|
-
|
|
160
|
-
n = 50
|
|
161
|
-
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,1)).sort()
|
|
162
|
-
y = f(x) + jr.normal(key, shape=(n,1))
|
|
163
|
-
D = gpx.Dataset(X=x, y=y)
|
|
164
|
-
|
|
165
|
-
# Construct the prior
|
|
166
|
-
meanf = gpx.mean_functions.Zero()
|
|
167
|
-
kernel = gpx.kernels.RBF()
|
|
168
|
-
prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
|
|
169
|
-
|
|
170
|
-
# Define a likelihood
|
|
171
|
-
likelihood = gpx.likelihoods.Gaussian(num_datapoints = n)
|
|
172
|
-
|
|
173
|
-
# Construct the posterior
|
|
174
|
-
posterior = prior * likelihood
|
|
175
|
-
|
|
176
|
-
# Define an optimiser
|
|
177
|
-
optimiser = ox.adam(learning_rate=1e-2)
|
|
178
|
-
|
|
179
|
-
# Obtain Type 2 MLEs of the hyperparameters
|
|
180
|
-
opt_posterior, history = gpx.fit(
|
|
181
|
-
model=posterior,
|
|
182
|
-
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
|
|
183
|
-
train_data=D,
|
|
184
|
-
optim=optimiser,
|
|
185
|
-
num_iters=500,
|
|
186
|
-
safe=True,
|
|
187
|
-
key=key,
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
# Infer the predictive posterior distribution
|
|
191
|
-
xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)
|
|
192
|
-
latent_dist = opt_posterior(xtest, D)
|
|
193
|
-
predictive_dist = opt_posterior.likelihood(latent_dist)
|
|
194
|
-
|
|
195
|
-
# Obtain the predictive mean and standard deviation
|
|
196
|
-
pred_mean = predictive_dist.mean()
|
|
197
|
-
pred_std = predictive_dist.stddev()
|
|
198
|
-
```
|
|
199
|
-
|
|
200
141
|
# Installation
|
|
201
142
|
|
|
202
143
|
## Stable version
|
|
@@ -1,22 +1,23 @@
|
|
|
1
|
-
gpjax/__init__.py,sha256=
|
|
1
|
+
gpjax/__init__.py,sha256=wXJtQa_3W7wZEw_t1Dk0uHUzNQQDv8QzsVbnwXCMXcQ,1654
|
|
2
2
|
gpjax/citation.py,sha256=f2Hzj5MLyCE7l0hHAzsEQoTORZH5hgV_eis4uoBiWvE,3811
|
|
3
3
|
gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
|
|
4
|
-
gpjax/distributions.py,sha256=
|
|
5
|
-
gpjax/fit.py,sha256=
|
|
4
|
+
gpjax/distributions.py,sha256=8LWmfmRVHOX29Uy8PkKFi2UhcCiunuu-4TMI_5-krHc,9299
|
|
5
|
+
gpjax/fit.py,sha256=STwpeqSuu2pgT6uZU7xd7koPZbAjPDzhcZ8nHfozR7Q,11538
|
|
6
6
|
gpjax/gps.py,sha256=97lYGrsmsufQxKEd8qz5wPNvui6FKXTF_Ps-sMFIjnY,31246
|
|
7
7
|
gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
|
|
8
|
-
gpjax/likelihoods.py,sha256=
|
|
8
|
+
gpjax/likelihoods.py,sha256=VcCibgihaskmvNJT4kuPa7ehgjlnR9LgMz_2KJJvHY0,9296
|
|
9
9
|
gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
|
|
10
|
-
gpjax/mean_functions.py,sha256=
|
|
11
|
-
gpjax/
|
|
12
|
-
gpjax/
|
|
10
|
+
gpjax/mean_functions.py,sha256=gIPz7exEhish3yeJQxZp5Q_jlf2-gCE-KVAnL2Rumkc,6489
|
|
11
|
+
gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
|
|
12
|
+
gpjax/objectives.py,sha256=I_ZqnwTNYIAUAZ9KQNenIl0ish1jDOXb7KaNmjz3Su4,15340
|
|
13
|
+
gpjax/parameters.py,sha256=Vj1xzrziSLxfBSqyc-BacyKBwkbE9Sjq4b1HV5HZiOg,6507
|
|
13
14
|
gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
|
|
14
15
|
gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
|
|
15
|
-
gpjax/variational_families.py,sha256=
|
|
16
|
+
gpjax/variational_families.py,sha256=Y9J1H91tXPm_hMy3ri_PgjAxqc_3r-BqKV83HRvB_m4,28295
|
|
16
17
|
gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
|
|
17
18
|
gpjax/kernels/base.py,sha256=wXsrpm5ofy9S5MNgUkJk4lx2umcIJL6dDNhXY7cmTGk,11616
|
|
18
19
|
gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
|
|
19
|
-
gpjax/kernels/approximations/rff.py,sha256=
|
|
20
|
+
gpjax/kernels/approximations/rff.py,sha256=VbitjNuahFE5_IvCj1A0SxHhJXU0O0Qq0FMMVq8xA3E,4125
|
|
20
21
|
gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
|
|
21
22
|
gpjax/kernels/computations/base.py,sha256=zzabLN_-FkTWo6cBYjzjvUGYa7vrYyHxyrhQZxLgHBk,3651
|
|
22
23
|
gpjax/kernels/computations/basis_functions.py,sha256=zY4rUDZDLOYvQPY9xosRmCLPdiXfbzAN5GICjQhFOis,2528
|
|
@@ -32,17 +33,17 @@ gpjax/kernels/nonstationary/arccosine.py,sha256=UCTVJEhTZFQjARGFsYMImLnTDyTyxobI
|
|
|
32
33
|
gpjax/kernels/nonstationary/linear.py,sha256=UKDHFCQzKWDMYo76qcb5-ujjnP2_iL-1tcN017xjK48,2562
|
|
33
34
|
gpjax/kernels/nonstationary/polynomial.py,sha256=7SDMfEcBCqnRn9xyj4iGcYLNvYJZiveN3uLZ_h12p10,3257
|
|
34
35
|
gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
|
|
35
|
-
gpjax/kernels/stationary/base.py,sha256=
|
|
36
|
-
gpjax/kernels/stationary/matern12.py,sha256=
|
|
37
|
-
gpjax/kernels/stationary/matern32.py,sha256=
|
|
38
|
-
gpjax/kernels/stationary/matern52.py,sha256=
|
|
36
|
+
gpjax/kernels/stationary/base.py,sha256=FlsXMsXyZ5cI80jbsIo8Jv-H6gsV3C7v6plIhyCl-GI,7042
|
|
37
|
+
gpjax/kernels/stationary/matern12.py,sha256=DGjqw6VveYsyy0TrufyJJvCei7p9slnm2f0TgRGG7_U,1773
|
|
38
|
+
gpjax/kernels/stationary/matern32.py,sha256=laLsJWJozJzpYHBzlkPUq0rWxz1eWEwGC36P2nPJuaQ,1966
|
|
39
|
+
gpjax/kernels/stationary/matern52.py,sha256=VSByD2sb7k-DzRFjaz31P3Rtc4bPPhHvMshrxZNFnns,2019
|
|
39
40
|
gpjax/kernels/stationary/periodic.py,sha256=IAbCxURtJEHGdmYzbdrsqRZ3zJ8F8tGQF9O7sggafZk,3598
|
|
40
41
|
gpjax/kernels/stationary/powered_exponential.py,sha256=8qT91IWKJK7PpEtFcX4MVu1ahWMOFOZierPko4JCjKA,3776
|
|
41
42
|
gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UOl8uOFT3lCutleSvo4,3496
|
|
42
|
-
gpjax/kernels/stationary/rbf.py,sha256=
|
|
43
|
-
gpjax/kernels/stationary/utils.py,sha256=
|
|
43
|
+
gpjax/kernels/stationary/rbf.py,sha256=euHUs6FdfRICQcabAWE4MX-7GEDr2TxgZWdFQiXr9Bw,1690
|
|
44
|
+
gpjax/kernels/stationary/utils.py,sha256=6BI9EBcCzeeKx-XH-MfW1ORmtU__tPX5zyvfLhpkBsU,2180
|
|
44
45
|
gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
|
|
45
|
-
gpjax-0.
|
|
46
|
-
gpjax-0.
|
|
47
|
-
gpjax-0.
|
|
48
|
-
gpjax-0.
|
|
46
|
+
gpjax-0.11.0.dist-info/METADATA,sha256=eSWVc5y9WNrUmKpaOVq1CcHjrKjMwlmSvwovN9h9aCk,8558
|
|
47
|
+
gpjax-0.11.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
48
|
+
gpjax-0.11.0.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
|
|
49
|
+
gpjax-0.11.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|