gpjax 0.10.2__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/likelihoods.py CHANGED
@@ -19,7 +19,7 @@ from jax import vmap
19
19
  import jax.numpy as jnp
20
20
  import jax.scipy as jsp
21
21
  from jaxtyping import Float
22
- import tensorflow_probability.substrates.jax as tfp
22
+ import numpyro.distributions as npd
23
23
 
24
24
  from gpjax.distributions import GaussianDistribution
25
25
  from gpjax.integrators import (
@@ -28,7 +28,7 @@ from gpjax.integrators import (
28
28
  GHQuadratureIntegrator,
29
29
  )
30
30
  from gpjax.parameters import (
31
- PositiveReal,
31
+ NonNegativeReal,
32
32
  Static,
33
33
  )
34
34
  from gpjax.typing import (
@@ -36,9 +36,6 @@ from gpjax.typing import (
36
36
  ScalarFloat,
37
37
  )
38
38
 
39
- tfb = tfp.bijectors
40
- tfd = tfp.distributions
41
-
42
39
 
43
40
  class AbstractLikelihood(nnx.Module):
44
41
  r"""Abstract base class for likelihoods.
@@ -62,7 +59,7 @@ class AbstractLikelihood(nnx.Module):
62
59
  self.num_datapoints = num_datapoints
63
60
  self.integrator = integrator
64
61
 
65
- def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tfd.Distribution:
62
+ def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> npd.Distribution:
66
63
  r"""Evaluate the likelihood function at a given predictive distribution.
67
64
 
68
65
  Args:
@@ -76,7 +73,7 @@ class AbstractLikelihood(nnx.Module):
76
73
  return self.predict(*args, **kwargs)
77
74
 
78
75
  @abc.abstractmethod
79
- def predict(self, *args: tp.Any, **kwargs: tp.Any) -> tfd.Distribution:
76
+ def predict(self, *args: tp.Any, **kwargs: tp.Any) -> npd.Distribution:
80
77
  r"""Evaluate the likelihood function at a given predictive distribution.
81
78
 
82
79
  Args:
@@ -85,19 +82,19 @@ class AbstractLikelihood(nnx.Module):
85
82
  `predict` method.
86
83
 
87
84
  Returns:
88
- tfd.Distribution: The predictive distribution.
85
+ npd.Distribution: The predictive distribution.
89
86
  """
90
87
  raise NotImplementedError
91
88
 
92
89
  @abc.abstractmethod
93
- def link_function(self, f: Float[Array, "..."]) -> tfd.Distribution:
90
+ def link_function(self, f: Float[Array, "..."]) -> npd.Distribution:
94
91
  r"""Return the link function of the likelihood function.
95
92
 
96
93
  Args:
97
94
  f (Float[Array, "..."]): the latent Gaussian process values.
98
95
 
99
96
  Returns:
100
- tfd.Distribution: The distribution of observations, y, given values of the
97
+ npd.Distribution: The distribution of observations, y, given values of the
101
98
  Gaussian process, f.
102
99
  """
103
100
  raise NotImplementedError
@@ -137,7 +134,7 @@ class Gaussian(AbstractLikelihood):
137
134
  self,
138
135
  num_datapoints: int,
139
136
  obs_stddev: tp.Union[
140
- ScalarFloat, Float[Array, "#N"], PositiveReal, Static
137
+ ScalarFloat, Float[Array, "#N"], NonNegativeReal, Static
141
138
  ] = 1.0,
142
139
  integrator: AbstractIntegrator = AnalyticalGaussianIntegrator(),
143
140
  ):
@@ -151,26 +148,26 @@ class Gaussian(AbstractLikelihood):
151
148
  likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
152
149
  the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
153
150
  """
154
- if not isinstance(obs_stddev, (PositiveReal, Static)):
155
- obs_stddev = PositiveReal(jnp.asarray(obs_stddev))
151
+ if not isinstance(obs_stddev, (NonNegativeReal, Static)):
152
+ obs_stddev = NonNegativeReal(jnp.asarray(obs_stddev))
156
153
  self.obs_stddev = obs_stddev
157
154
 
158
155
  super().__init__(num_datapoints, integrator)
159
156
 
160
- def link_function(self, f: Float[Array, "..."]) -> tfd.Normal:
157
+ def link_function(self, f: Float[Array, "..."]) -> npd.Normal:
161
158
  r"""The link function of the Gaussian likelihood.
162
159
 
163
160
  Args:
164
161
  f (Float[Array, "..."]): Function values.
165
162
 
166
163
  Returns:
167
- tfd.Normal: The likelihood function.
164
+ npd.Normal: The likelihood function.
168
165
  """
169
- return tfd.Normal(loc=f, scale=self.obs_stddev.value.astype(f.dtype))
166
+ return npd.Normal(loc=f, scale=self.obs_stddev.value.astype(f.dtype))
170
167
 
171
168
  def predict(
172
- self, dist: tp.Union[tfd.MultivariateNormalTriL, GaussianDistribution]
173
- ) -> tfd.MultivariateNormalFullCovariance:
169
+ self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
170
+ ) -> npd.MultivariateNormal:
174
171
  r"""Evaluate the Gaussian likelihood.
175
172
 
176
173
  Evaluate the Gaussian likelihood function at a given predictive
@@ -179,75 +176,79 @@ class Gaussian(AbstractLikelihood):
179
176
  distribution's covariance matrix.
180
177
 
181
178
  Args:
182
- dist (tfd.Distribution): The Gaussian process posterior,
179
+ dist (npd.Distribution): The Gaussian process posterior,
183
180
  evaluated at a finite set of test points.
184
181
 
185
182
  Returns:
186
- tfd.Distribution: The predictive distribution.
183
+ npd.Distribution: The predictive distribution.
187
184
  """
188
185
  n_data = dist.event_shape[0]
189
- cov = dist.covariance()
186
+ cov = dist.covariance_matrix
190
187
  noisy_cov = cov.at[jnp.diag_indices(n_data)].add(self.obs_stddev.value**2)
191
188
 
192
- return tfd.MultivariateNormalFullCovariance(dist.mean(), noisy_cov)
189
+ return npd.MultivariateNormal(dist.mean, noisy_cov)
193
190
 
194
191
 
195
192
  class Bernoulli(AbstractLikelihood):
196
- def link_function(self, f: Float[Array, "..."]) -> tfd.Distribution:
193
+ def link_function(self, f: Float[Array, "..."]) -> npd.BernoulliProbs:
197
194
  r"""The probit link function of the Bernoulli likelihood.
198
195
 
199
196
  Args:
200
197
  f (Float[Array, "..."]): Function values.
201
198
 
202
199
  Returns:
203
- tfd.Distribution: The likelihood function.
200
+ npd.Bernoulli: The likelihood function.
204
201
  """
205
- return tfd.Bernoulli(probs=inv_probit(f))
202
+ return npd.Bernoulli(probs=inv_probit(f))
206
203
 
207
- def predict(self, dist: tfd.Distribution) -> tfd.Distribution:
204
+ def predict(
205
+ self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
206
+ ) -> npd.BernoulliProbs:
208
207
  r"""Evaluate the pointwise predictive distribution.
209
208
 
210
209
  Evaluate the pointwise predictive distribution, given a Gaussian
211
210
  process posterior and likelihood parameters.
212
211
 
213
212
  Args:
214
- dist (tfd.Distribution): The Gaussian process posterior, evaluated
215
- at a finite set of test points.
213
+ dist ([npd.MultivariateNormal, GaussianDistribution].): The Gaussian
214
+ process posterior, evaluated at a finite set of test points.
216
215
 
217
216
  Returns:
218
- tfd.Distribution: The pointwise predictive distribution.
217
+ npd.Bernoulli: The pointwise predictive distribution.
219
218
  """
220
- variance = jnp.diag(dist.covariance())
221
- mean = dist.mean().ravel()
219
+ variance = jnp.diag(dist.covariance_matrix)
220
+ mean = dist.mean.ravel()
222
221
  return self.link_function(mean / jnp.sqrt(1.0 + variance))
223
222
 
224
223
 
225
224
  class Poisson(AbstractLikelihood):
226
- def link_function(self, f: Float[Array, "..."]) -> tfd.Distribution:
225
+ def link_function(self, f: Float[Array, "..."]) -> npd.Poisson:
227
226
  r"""The link function of the Poisson likelihood.
228
227
 
229
228
  Args:
230
229
  f (Float[Array, "..."]): Function values.
231
230
 
232
231
  Returns:
233
- tfd.Distribution: The likelihood function.
232
+ npd.Poisson: The likelihood function.
234
233
  """
235
- return tfd.Poisson(rate=jnp.exp(f))
234
+ return npd.Poisson(rate=jnp.exp(f))
236
235
 
237
- def predict(self, dist: tfd.Distribution) -> tfd.Distribution:
236
+ def predict(
237
+ self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
238
+ ) -> npd.Poisson:
238
239
  r"""Evaluate the pointwise predictive distribution.
239
240
 
240
241
  Evaluate the pointwise predictive distribution, given a Gaussian
241
242
  process posterior and likelihood parameters.
242
243
 
243
244
  Args:
244
- dist (tfd.Distribution): The Gaussian process posterior, evaluated
245
- at a finite set of test points.
245
+ dist (tp.Union[npd.MultivariateNormal, GaussianDistribution]): The Gaussian
246
+ process posterior, evaluated at a finite set of test points.
246
247
 
247
248
  Returns:
248
- tfd.Distribution: The pointwise predictive distribution.
249
+ npd.Poisson: The pointwise predictive distribution.
249
250
  """
250
- return self.link_function(dist.mean())
251
+ return self.link_function(dist.mean)
251
252
 
252
253
 
253
254
  def inv_probit(x: Float[Array, " *N"]) -> Float[Array, " *N"]:
gpjax/mean_functions.py CHANGED
@@ -28,7 +28,7 @@ from jaxtyping import (
28
28
  from gpjax.parameters import (
29
29
  Parameter,
30
30
  Real,
31
- Static
31
+ Static,
32
32
  )
33
33
  from gpjax.typing import (
34
34
  Array,
@@ -131,7 +131,8 @@ class Constant(AbstractMeanFunction):
131
131
  """
132
132
 
133
133
  def __init__(
134
- self, constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter, Static] = 0.0
134
+ self,
135
+ constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter, Static] = 0.0,
135
136
  ):
136
137
  if isinstance(constant, Parameter) or isinstance(constant, Static):
137
138
  self.constant = constant
@@ -206,5 +207,5 @@ SumMeanFunction = ft.partial(
206
207
  CombinationMeanFunction, operator=ft.partial(jnp.sum, axis=0)
207
208
  )
208
209
  ProductMeanFunction = ft.partial(
209
- CombinationMeanFunction, operator=ft.partial(jnp.sum, axis=0)
210
+ CombinationMeanFunction, operator=ft.partial(jnp.prod, axis=0)
210
211
  )
@@ -0,0 +1,106 @@
1
+ import math
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from numpyro.distributions.transforms import Transform
6
+
7
+ # -----------------------------------------------------------------------------
8
+ # Implementation: FillTriangularTransform
9
+ # -----------------------------------------------------------------------------
10
+
11
+
12
+ class FillTriangularTransform(Transform):
13
+ """
14
+ Transform that maps a vector of length n(n+1)/2 to an n x n lower triangular matrix.
15
+ The ordering is assumed to be:
16
+ (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), ..., (n-1, n-1)
17
+ """
18
+
19
+ # Note: The base class provides `inv` through _InverseTransform wrapping _inverse.
20
+
21
+ def __call__(self, x):
22
+ """
23
+ Forward transformation.
24
+
25
+ Parameters
26
+ ----------
27
+ x : array_like, shape (..., L)
28
+ Input vector with L = n(n+1)/2 for some integer n.
29
+
30
+ Returns
31
+ -------
32
+ y : array_like, shape (..., n, n)
33
+ Lower-triangular matrix (with zeros in the upper triangle) filled in
34
+ row-major order (i.e. [ (0,0), (1,0), (1,1), ... ]).
35
+ """
36
+ L = x.shape[-1]
37
+ # Use static (Python) math.sqrt to compute n. This avoids tracer issues.
38
+ n = int((-1 + math.sqrt(1 + 8 * L)) // 2)
39
+ if n * (n + 1) // 2 != L:
40
+ raise ValueError("Last dimension must equal n(n+1)/2 for some integer n.")
41
+
42
+ def fill_single(vec):
43
+ out = jnp.zeros((n, n), dtype=vec.dtype)
44
+ row, col = jnp.tril_indices(n)
45
+ return out.at[row, col].set(vec)
46
+
47
+ if x.ndim == 1:
48
+ return fill_single(x)
49
+ else:
50
+ batch_shape = x.shape[:-1]
51
+ flat_x = x.reshape((-1, L))
52
+ out = jax.vmap(fill_single)(flat_x)
53
+ return out.reshape(batch_shape + (n, n))
54
+
55
+ def _inverse(self, y):
56
+ """
57
+ Inverse transformation.
58
+
59
+ Parameters
60
+ ----------
61
+ y : array_like, shape (..., n, n)
62
+ Lower triangular matrix.
63
+
64
+ Returns
65
+ -------
66
+ x : array_like, shape (..., n(n+1)/2)
67
+ The vector containing the elements from the lower-triangular portion of y.
68
+ """
69
+ if y.ndim < 2:
70
+ raise ValueError("Input to inverse must be at least two-dimensional.")
71
+ n = y.shape[-1]
72
+ if y.shape[-2] != n:
73
+ raise ValueError(
74
+ "Input matrix must be square; got shape %s" % str(y.shape[-2:])
75
+ )
76
+
77
+ row, col = jnp.tril_indices(n)
78
+
79
+ def inv_single(mat):
80
+ return mat[row, col]
81
+
82
+ if y.ndim == 2:
83
+ return inv_single(y)
84
+ else:
85
+ batch_shape = y.shape[:-2]
86
+ flat_y = y.reshape((-1, n, n))
87
+ out = jax.vmap(inv_single)(flat_y)
88
+ return out.reshape(batch_shape + (n * (n + 1) // 2,))
89
+
90
+ def log_abs_det_jacobian(self, x, y, intermediates=None):
91
+ # Since the transform simply reorders the vector into a matrix, the Jacobian determinant is 1.
92
+ return jnp.zeros(x.shape[:-1])
93
+
94
+ @property
95
+ def sign(self):
96
+ # The reordering transformation has a positive derivative everywhere.
97
+ return 1.0
98
+
99
+ # Implement tree_flatten and tree_unflatten because base Transform expects them.
100
+ def tree_flatten(self):
101
+ # This transform is stateless.
102
+ return (), {}
103
+
104
+ @classmethod
105
+ def tree_unflatten(cls, aux_data, children):
106
+ return cls()
gpjax/objectives.py CHANGED
@@ -13,7 +13,7 @@ from jax import vmap
13
13
  import jax.numpy as jnp
14
14
  import jax.scipy as jsp
15
15
  from jaxtyping import Float
16
- import tensorflow_probability.substrates.jax as tfp
16
+ import numpyro.distributions as npd
17
17
  import typing_extensions as tpe
18
18
 
19
19
  from gpjax.dataset import Dataset
@@ -29,8 +29,6 @@ from gpjax.typing import (
29
29
  )
30
30
  from gpjax.variational_families import AbstractVariationalFamily
31
31
 
32
- tfd = tfp.distributions
33
-
34
32
  VF = TypeVar("VF", bound=AbstractVariationalFamily)
35
33
 
36
34
 
@@ -175,7 +173,7 @@ def conjugate_loocv(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat
175
173
  loocv_means = mx + (y - mx) - Sigma_inv_y / Sigma_inv_diag
176
174
  loocv_stds = jnp.sqrt(1.0 / Sigma_inv_diag)
177
175
 
178
- loocv_posterior = tfd.Normal(loc=loocv_means, scale=loocv_stds)
176
+ loocv_posterior = npd.Normal(loc=loocv_means, scale=loocv_stds)
179
177
  return jnp.sum(loocv_posterior.log_prob(y))
180
178
 
181
179
 
@@ -232,7 +230,7 @@ def log_posterior_density(
232
230
  likelihood = posterior.likelihood.link_function(fx)
233
231
 
234
232
  # Whitened latent function values prior, p(wx | θ) = N(0, I)
235
- latent_prior = tfd.Normal(loc=0.0, scale=1.0)
233
+ latent_prior = npd.Normal(loc=0.0, scale=1.0)
236
234
  return likelihood.log_prob(y).sum() + latent_prior.log_prob(wx).sum()
237
235
 
238
236
 
@@ -305,7 +303,7 @@ def variational_expectation(
305
303
  # inputs, x
306
304
  def q_moments(x):
307
305
  qx = q(x)
308
- return qx.mean().squeeze(), qx.covariance().squeeze()
306
+ return qx.mean.squeeze(), qx.covariance().squeeze()
309
307
 
310
308
  mean, variance = vmap(q_moments)(x[:, None])
311
309
 
gpjax/parameters.py CHANGED
@@ -5,7 +5,9 @@ from jax.experimental import checkify
5
5
  import jax.numpy as jnp
6
6
  import jax.tree_util as jtu
7
7
  from jax.typing import ArrayLike
8
- import tensorflow_probability.substrates.jax.bijectors as tfb
8
+ import numpyro.distributions.transforms as npt
9
+
10
+ from gpjax.numpyro_extras import FillTriangularTransform
9
11
 
10
12
  T = tp.TypeVar("T", bound=tp.Union[ArrayLike, list[float]])
11
13
  ParameterTag = str
@@ -13,7 +15,7 @@ ParameterTag = str
13
15
 
14
16
  def transform(
15
17
  params: nnx.State,
16
- params_bijection: tp.Dict[str, tfb.Bijector],
18
+ params_bijection: tp.Dict[str, npt.Transform],
17
19
  inverse: bool = False,
18
20
  ) -> nnx.State:
19
21
  r"""Transforms parameters using a bijector.
@@ -22,7 +24,7 @@ def transform(
22
24
  ```pycon
23
25
  >>> from gpjax.parameters import PositiveReal, transform
24
26
  >>> import jax.numpy as jnp
25
- >>> import tensorflow_probability.substrates.jax.bijectors as tfb
27
+ >>> import numpyro.distributions.transforms as npt
26
28
  >>> from flax import nnx
27
29
  >>> params = nnx.State(
28
30
  >>> {
@@ -30,7 +32,7 @@ def transform(
30
32
  >>> "b": PositiveReal(jnp.array([2.0])),
31
33
  >>> }
32
34
  >>> )
33
- >>> params_bijection = {'positive': tfb.Softplus()}
35
+ >>> params_bijection = {'positive': npt.SoftplusTransform()}
34
36
  >>> transformed_params = transform(params, params_bijection)
35
37
  >>> print(transformed_params["a"].value)
36
38
  [1.3132617]
@@ -47,11 +49,11 @@ def transform(
47
49
  """
48
50
 
49
51
  def _inner(param):
50
- bijector = params_bijection.get(param._tag, tfb.Identity())
52
+ bijector = params_bijection.get(param._tag, npt.IdentityTransform())
51
53
  if inverse:
52
- transformed_value = bijector.inverse(param.value)
54
+ transformed_value = bijector.inv(param.value)
53
55
  else:
54
- transformed_value = bijector.forward(param.value)
56
+ transformed_value = bijector(param.value)
55
57
 
56
58
  param = param.replace(transformed_value)
57
59
  return param
@@ -80,6 +82,14 @@ class Parameter(nnx.Variable[T]):
80
82
  self._tag = tag
81
83
 
82
84
 
85
+ class NonNegativeReal(Parameter[T]):
86
+ """Parameter that is non-negative."""
87
+
88
+ def __init__(self, value: T, tag: ParameterTag = "non_negative", **kwargs):
89
+ super().__init__(value=value, tag=tag, **kwargs)
90
+ _safe_assert(_check_is_non_negative, self.value)
91
+
92
+
83
93
  class PositiveReal(Parameter[T]):
84
94
  """Parameter that is strictly positive."""
85
95
 
@@ -104,7 +114,7 @@ class SigmoidBounded(Parameter[T]):
104
114
  # Only perform validation in non-JIT contexts
105
115
  if (
106
116
  not isinstance(value, jnp.ndarray)
107
- or not getattr(value, "aval", None) is None
117
+ or getattr(value, "aval", None) is not None
108
118
  ):
109
119
  _safe_assert(
110
120
  _check_in_bounds,
@@ -133,17 +143,18 @@ class LowerTriangular(Parameter[T]):
133
143
  # Only perform validation in non-JIT contexts
134
144
  if (
135
145
  not isinstance(value, jnp.ndarray)
136
- or not getattr(value, "aval", None) is None
146
+ or getattr(value, "aval", None) is not None
137
147
  ):
138
148
  _safe_assert(_check_is_square, self.value)
139
149
  _safe_assert(_check_is_lower_triangular, self.value)
140
150
 
141
151
 
142
152
  DEFAULT_BIJECTION = {
143
- "positive": tfb.Softplus(),
144
- "real": tfb.Identity(),
145
- "sigmoid": tfb.Sigmoid(low=0.0, high=1.0),
146
- "lower_triangular": tfb.FillTriangular(),
153
+ "positive": npt.SoftplusTransform(),
154
+ "non_negative": npt.SoftplusTransform(),
155
+ "real": npt.IdentityTransform(),
156
+ "sigmoid": npt.SigmoidTransform(),
157
+ "lower_triangular": FillTriangularTransform(),
147
158
  }
148
159
 
149
160
 
@@ -162,6 +173,13 @@ def _check_is_arraylike(value: T) -> None:
162
173
  )
163
174
 
164
175
 
176
+ @checkify.checkify
177
+ def _check_is_non_negative(value):
178
+ checkify.check(
179
+ jnp.all(value >= 0), "value needs to be non-negative, got {value}", value=value
180
+ )
181
+
182
+
165
183
  @checkify.checkify
166
184
  def _check_is_positive(value):
167
185
  checkify.check(
@@ -22,6 +22,7 @@ from cola.linalg.inverse.inv import solve
22
22
  from cola.ops.operators import (
23
23
  Dense,
24
24
  I_like,
25
+ Identity,
25
26
  Triangular,
26
27
  )
27
28
  from flax import nnx
@@ -296,7 +297,10 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
296
297
 
297
298
  # Compute whitened KL divergence
298
299
  qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
299
- pu = GaussianDistribution(loc=jnp.zeros_like(jnp.atleast_1d(mu.squeeze())))
300
+ pu_S = Identity(shape=(self.num_inducing, self.num_inducing), dtype=mu.dtype)
301
+ pu = GaussianDistribution(
302
+ loc=jnp.zeros_like(jnp.atleast_1d(mu.squeeze())), scale=pu_S
303
+ )
300
304
  return qu.kl_divergence(pu)
301
305
 
302
306
  def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.10.2
3
+ Version: 0.11.1
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -24,8 +24,8 @@ Requires-Dist: jax>=0.5.0
24
24
  Requires-Dist: jaxlib>=0.5.0
25
25
  Requires-Dist: jaxtyping>0.2.10
26
26
  Requires-Dist: numpy>=2.0.0
27
+ Requires-Dist: numpyro
27
28
  Requires-Dist: optax>0.2.1
28
- Requires-Dist: tensorflow-probability>=0.24.0
29
29
  Requires-Dist: tqdm>4.66.2
30
30
  Description-Content-Type: text/markdown
31
31
 
@@ -138,65 +138,6 @@ jupytext --to notebook example.py
138
138
  jupytext --to py:percent example.ipynb
139
139
  ```
140
140
 
141
- # Simple example
142
-
143
- Let us import some dependencies and simulate a toy dataset $\mathcal{D}$.
144
-
145
- ```python
146
- from jax import config
147
-
148
- config.update("jax_enable_x64", True)
149
-
150
- import gpjax as gpx
151
- from jax import grad, jit
152
- import jax.numpy as jnp
153
- import jax.random as jr
154
- import optax as ox
155
-
156
- key = jr.key(123)
157
-
158
- f = lambda x: 10 * jnp.sin(x)
159
-
160
- n = 50
161
- x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,1)).sort()
162
- y = f(x) + jr.normal(key, shape=(n,1))
163
- D = gpx.Dataset(X=x, y=y)
164
-
165
- # Construct the prior
166
- meanf = gpx.mean_functions.Zero()
167
- kernel = gpx.kernels.RBF()
168
- prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
169
-
170
- # Define a likelihood
171
- likelihood = gpx.likelihoods.Gaussian(num_datapoints = n)
172
-
173
- # Construct the posterior
174
- posterior = prior * likelihood
175
-
176
- # Define an optimiser
177
- optimiser = ox.adam(learning_rate=1e-2)
178
-
179
- # Obtain Type 2 MLEs of the hyperparameters
180
- opt_posterior, history = gpx.fit(
181
- model=posterior,
182
- objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
183
- train_data=D,
184
- optim=optimiser,
185
- num_iters=500,
186
- safe=True,
187
- key=key,
188
- )
189
-
190
- # Infer the predictive posterior distribution
191
- xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)
192
- latent_dist = opt_posterior(xtest, D)
193
- predictive_dist = opt_posterior.likelihood(latent_dist)
194
-
195
- # Obtain the predictive mean and standard deviation
196
- pred_mean = predictive_dist.mean()
197
- pred_std = predictive_dist.stddev()
198
- ```
199
-
200
141
  # Installation
201
142
 
202
143
  ## Stable version
@@ -1,22 +1,23 @@
1
- gpjax/__init__.py,sha256=F9GVk18tdmvwiDEHZNo_4Wr0TkmPhWIEwl3KzEWQcaQ,1654
1
+ gpjax/__init__.py,sha256=TjAAfeZTCEl_zsibA8pV76M1jcHkeFhNfWk_SllfgHY,1686
2
2
  gpjax/citation.py,sha256=f2Hzj5MLyCE7l0hHAzsEQoTORZH5hgV_eis4uoBiWvE,3811
3
3
  gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
- gpjax/distributions.py,sha256=X48FJr3reop9maherdMVt7-XZOm2f26T8AJt_IKM_oE,9339
5
- gpjax/fit.py,sha256=OHv8jUHxa1ndpqMERSDRtYtUDzubk9rMPVIhfCiIH5Q,11551
4
+ gpjax/distributions.py,sha256=8LWmfmRVHOX29Uy8PkKFi2UhcCiunuu-4TMI_5-krHc,9299
5
+ gpjax/fit.py,sha256=7L2veA6aRNiozZD8fWa-MVDoYFUKjGJahmvjz8Wp-P0,15046
6
6
  gpjax/gps.py,sha256=97lYGrsmsufQxKEd8qz5wPNvui6FKXTF_Ps-sMFIjnY,31246
7
7
  gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
- gpjax/likelihoods.py,sha256=DOyV1L0ompkpeImMTiOOiWLJfqSqvDX_acOumuFqPEc,9234
8
+ gpjax/likelihoods.py,sha256=99oTZoWld1M7vxgGM0pNY5Hnt2Ajd2lQNqawzrLmwtk,9308
9
9
  gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
10
- gpjax/mean_functions.py,sha256=BpeFkR3Eqa3O_FGp9BtSu9HKNSYZ8M08VtyfPfWbwRg,6479
11
- gpjax/objectives.py,sha256=XwkPyL_iovTNKpKGVNt0Lt2_OMTJitSPhuyCtUrJpbc,15383
12
- gpjax/parameters.py,sha256=6VKq6wBzEUtx-GXniC8fEqjTNrTC1YwIOw66QguW6UM,6457
10
+ gpjax/mean_functions.py,sha256=-sVYO1_LWE8f34rllUOuaT5sgGGAdxo99v5kRo2d4oM,6490
11
+ gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
12
+ gpjax/objectives.py,sha256=I_ZqnwTNYIAUAZ9KQNenIl0ish1jDOXb7KaNmjz3Su4,15340
13
+ gpjax/parameters.py,sha256=H-DiXmotdBZCbf-GOjRaJoS_isk3GgFrpKHTq5GpnoA,6998
13
14
  gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
14
15
  gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
15
- gpjax/variational_families.py,sha256=s1rk7PtNTjQPabmVu-jBsuJBoqsxAAXwKFZJOEswkNQ,28161
16
+ gpjax/variational_families.py,sha256=Y9J1H91tXPm_hMy3ri_PgjAxqc_3r-BqKV83HRvB_m4,28295
16
17
  gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
17
18
  gpjax/kernels/base.py,sha256=wXsrpm5ofy9S5MNgUkJk4lx2umcIJL6dDNhXY7cmTGk,11616
18
19
  gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
19
- gpjax/kernels/approximations/rff.py,sha256=4kD1uocjHmxkLgvf4DxB4_Gy7iefdPgnWiZB3jDiExI,4126
20
+ gpjax/kernels/approximations/rff.py,sha256=VbitjNuahFE5_IvCj1A0SxHhJXU0O0Qq0FMMVq8xA3E,4125
20
21
  gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
21
22
  gpjax/kernels/computations/base.py,sha256=zzabLN_-FkTWo6cBYjzjvUGYa7vrYyHxyrhQZxLgHBk,3651
22
23
  gpjax/kernels/computations/basis_functions.py,sha256=zY4rUDZDLOYvQPY9xosRmCLPdiXfbzAN5GICjQhFOis,2528
@@ -28,21 +29,21 @@ gpjax/kernels/non_euclidean/__init__.py,sha256=RT7puRPqCTpyxZ16q596EuOQEQi1LK1v3
28
29
  gpjax/kernels/non_euclidean/graph.py,sha256=K4WIdX-dx1SsWuNHZnNjHFw8ElKZxGcReUiA3w4aCOI,4204
29
30
  gpjax/kernels/non_euclidean/utils.py,sha256=z42aw8ga0zuREzHawemR9okttgrAUPmq-aN5HMt4SuY,1578
30
31
  gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
31
- gpjax/kernels/nonstationary/arccosine.py,sha256=UCTVJEhTZFQjARGFsYMImLnTDyTyxobIL5f2LiAHkPI,5822
32
- gpjax/kernels/nonstationary/linear.py,sha256=UKDHFCQzKWDMYo76qcb5-ujjnP2_iL-1tcN017xjK48,2562
33
- gpjax/kernels/nonstationary/polynomial.py,sha256=7SDMfEcBCqnRn9xyj4iGcYLNvYJZiveN3uLZ_h12p10,3257
32
+ gpjax/kernels/nonstationary/arccosine.py,sha256=2WV6aM0Z3-xXZnoPw-77n2CW62n-AZuJy-7AQ9xrMco,5858
33
+ gpjax/kernels/nonstationary/linear.py,sha256=UIMoCq2hg6dQKr4J5UGiiPqotBleQuYfy00Ia1NaMOo,2571
34
+ gpjax/kernels/nonstationary/polynomial.py,sha256=arP8DK0jnBOaayDWcFvHF0pdu9FVhwzXdqjnHUAL2VI,3293
34
35
  gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
35
- gpjax/kernels/stationary/base.py,sha256=pQNkMo-E4bIT4tNfb7JvFJZC6fIIXNErsT1iQopFlAA,7063
36
- gpjax/kernels/stationary/matern12.py,sha256=b2vQCUqhd9NJK84L2RYjpI597uxy_7xgwsjS35Gc958,1807
37
- gpjax/kernels/stationary/matern32.py,sha256=ZVYbUIQhvKpriC7abH8wV6Pk-mRoxtl3e2YYwH-KM5Y,2000
38
- gpjax/kernels/stationary/matern52.py,sha256=xfMYbY7MXxgMECtA2qVT5I8HoDGzGxygUvduGT3_Gvs,2053
36
+ gpjax/kernels/stationary/base.py,sha256=25qDqpZP4gNtzbyzDCW-6u7rJfMqkg0dW88XUmTTupU,7078
37
+ gpjax/kernels/stationary/matern12.py,sha256=DGjqw6VveYsyy0TrufyJJvCei7p9slnm2f0TgRGG7_U,1773
38
+ gpjax/kernels/stationary/matern32.py,sha256=laLsJWJozJzpYHBzlkPUq0rWxz1eWEwGC36P2nPJuaQ,1966
39
+ gpjax/kernels/stationary/matern52.py,sha256=VSByD2sb7k-DzRFjaz31P3Rtc4bPPhHvMshrxZNFnns,2019
39
40
  gpjax/kernels/stationary/periodic.py,sha256=IAbCxURtJEHGdmYzbdrsqRZ3zJ8F8tGQF9O7sggafZk,3598
40
41
  gpjax/kernels/stationary/powered_exponential.py,sha256=8qT91IWKJK7PpEtFcX4MVu1ahWMOFOZierPko4JCjKA,3776
41
42
  gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UOl8uOFT3lCutleSvo4,3496
42
- gpjax/kernels/stationary/rbf.py,sha256=G13gg5phO7ite7D9QgoCy7gB2_y0FM6GZhgFW4RL6Xw,1734
43
- gpjax/kernels/stationary/utils.py,sha256=Xa9EEnxgFqEi08ZSFAZYYHhJ85_3Ac-ZUyUk18B63M4,2225
43
+ gpjax/kernels/stationary/rbf.py,sha256=euHUs6FdfRICQcabAWE4MX-7GEDr2TxgZWdFQiXr9Bw,1690
44
+ gpjax/kernels/stationary/utils.py,sha256=6BI9EBcCzeeKx-XH-MfW1ORmtU__tPX5zyvfLhpkBsU,2180
44
45
  gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
45
- gpjax-0.10.2.dist-info/METADATA,sha256=mqIBMOMKKiI9qkM_uFHSuPEXY17Jd6bOL5EM2hPiaok,9970
46
- gpjax-0.10.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
47
- gpjax-0.10.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
48
- gpjax-0.10.2.dist-info/RECORD,,
46
+ gpjax-0.11.1.dist-info/METADATA,sha256=02crI6D0dsht6XJ8N1ZqNj5ZktmS5NymVfY45pPmEgM,8558
47
+ gpjax-0.11.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
48
+ gpjax-0.11.1.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
49
+ gpjax-0.11.1.dist-info/RECORD,,