gpjax 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -39,7 +39,7 @@ __license__ = "MIT"
39
39
  __description__ = "Didactic Gaussian processes in JAX"
40
40
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
41
41
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
42
- __version__ = "0.10.2"
42
+ __version__ = "0.11.0"
43
43
 
44
44
  __all__ = [
45
45
  "base",
gpjax/distributions.py CHANGED
@@ -15,77 +15,76 @@
15
15
 
16
16
 
17
17
  from beartype.typing import (
18
- Any,
19
18
  Optional,
20
- Tuple,
21
- TypeVar,
22
19
  )
23
20
  import cola
21
+ from cola.linalg.decompositions import Cholesky
24
22
  from cola.ops import (
25
- Identity,
26
23
  LinearOperator,
27
24
  )
28
25
  from jax import vmap
29
26
  import jax.numpy as jnp
30
27
  import jax.random as jr
31
28
  from jaxtyping import Float
32
- import tensorflow_probability.substrates.jax as tfp
29
+ from numpyro.distributions import constraints
30
+ from numpyro.distributions.distribution import Distribution
31
+ from numpyro.distributions.util import is_prng_key
33
32
 
34
33
  from gpjax.lower_cholesky import lower_cholesky
35
34
  from gpjax.typing import (
36
35
  Array,
37
- KeyArray,
38
36
  ScalarFloat,
39
37
  )
40
38
 
41
- tfd = tfp.distributions
42
-
43
- from cola.linalg.decompositions import Cholesky
44
-
45
39
 
46
- class GaussianDistribution(tfd.Distribution):
47
- r"""Multivariate Gaussian distribution with a linear operator scale matrix."""
48
-
49
- # TODO: Consider `distrax.transformed.Transformed` object. Can we create a LinearOperator to `distrax.bijector` representation
50
- # and modify `distrax.MultivariateNormalFromBijector`?
51
- # TODO: Consider natural and expectation parameterisations in future work.
52
- # TODO: we don't really need to inherit from `tfd.Distribution` here
40
+ class GaussianDistribution(Distribution):
41
+ support = constraints.real_vector
53
42
 
54
43
  def __init__(
55
44
  self,
56
- loc: Optional[Float[Array, " N"]] = None,
57
- scale: Optional[LinearOperator] = None,
58
- ) -> None:
59
- r"""Initialises the distribution.
60
-
61
- Args:
62
- loc: the mean of the distribution as an array of shape (n_points,).
63
- scale: the scale matrix of the distribution as a LinearOperator object.
64
- """
65
- _check_loc_scale(loc, scale)
45
+ loc: Optional[Float[Array, " N"]],
46
+ scale: Optional[LinearOperator],
47
+ validate_args=None,
48
+ ):
49
+ self.loc = loc
50
+ self.scale = cola.PSD(scale)
51
+ batch_shape = ()
52
+ event_shape = jnp.shape(self.loc)
53
+ super().__init__(batch_shape, event_shape, validate_args=validate_args)
66
54
 
67
- # Find dimensionality of the distribution.
68
- if loc is not None:
69
- num_dims = loc.shape[-1]
55
+ def sample(self, key, sample_shape=()):
56
+ assert is_prng_key(key)
57
+ # Obtain covariance root.
58
+ covariance_root = lower_cholesky(self.scale)
70
59
 
71
- elif scale is not None:
72
- num_dims = scale.shape[-1]
60
+ # Gather n samples from standard normal distribution Z = [z₁, ..., zₙ]ᵀ.
61
+ white_noise = jr.normal(
62
+ key, shape=sample_shape + self.batch_shape + self.event_shape
63
+ )
73
64
 
74
- # Set the location to zero vector if unspecified.
75
- if loc is None:
76
- loc = jnp.zeros((num_dims,))
65
+ # xᵢ ~ N(loc, cov) <=> xᵢ = loc + sqrt zᵢ, where zᵢ ~ N(0, I).
66
+ def affine_transformation(_x):
67
+ return self.loc + covariance_root @ _x
77
68
 
78
- # If not specified, set the scale to the identity matrix.
79
- if scale is None:
80
- scale = Identity(shape=(num_dims, num_dims), dtype=loc.dtype)
81
-
82
- self.loc = loc
83
- self.scale = cola.PSD(scale)
69
+ return vmap(affine_transformation)(white_noise)
84
70
 
71
+ @property
85
72
  def mean(self) -> Float[Array, " N"]:
86
73
  r"""Calculates the mean."""
87
74
  return self.loc
88
75
 
76
+ @property
77
+ def variance(self) -> Float[Array, " N"]:
78
+ r"""Calculates the variance."""
79
+ return cola.diag(self.scale)
80
+
81
+ def entropy(self) -> ScalarFloat:
82
+ r"""Calculates the entropy of the distribution."""
83
+ return 0.5 * (
84
+ self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
85
+ + cola.logdet(self.scale, Cholesky(), Cholesky())
86
+ )
87
+
89
88
  def median(self) -> Float[Array, " N"]:
90
89
  r"""Calculates the median."""
91
90
  return self.loc
@@ -98,25 +97,19 @@ class GaussianDistribution(tfd.Distribution):
98
97
  r"""Calculates the covariance matrix."""
99
98
  return self.scale.to_dense()
100
99
 
101
- def variance(self) -> Float[Array, " N"]:
102
- r"""Calculates the variance."""
103
- return cola.diag(self.scale)
100
+ @property
101
+ def covariance_matrix(self) -> Float[Array, "N N"]:
102
+ r"""Calculates the covariance matrix."""
103
+ return self.covariance()
104
104
 
105
105
  def stddev(self) -> Float[Array, " N"]:
106
106
  r"""Calculates the standard deviation."""
107
107
  return jnp.sqrt(cola.diag(self.scale))
108
108
 
109
- @property
110
- def event_shape(self) -> Tuple:
111
- r"""Returns the event shape."""
112
- return self.loc.shape[-1:]
113
-
114
- def entropy(self) -> ScalarFloat:
115
- r"""Calculates the entropy of the distribution."""
116
- return 0.5 * (
117
- self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
118
- + cola.logdet(self.scale, Cholesky(), Cholesky())
119
- )
109
+ # @property
110
+ # def event_shape(self) -> Tuple:
111
+ # r"""Returns the event shape."""
112
+ # return self.loc.shape[-1:]
120
113
 
121
114
  def log_prob(self, y: Float[Array, " N"]) -> ScalarFloat:
122
115
  r"""Calculates the log pdf of the multivariate Gaussian.
@@ -141,42 +134,39 @@ class GaussianDistribution(tfd.Distribution):
141
134
  + diff.T @ cola.solve(sigma, diff, Cholesky())
142
135
  )
143
136
 
144
- def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
145
- r"""Samples from the distribution.
137
+ # def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
138
+ # r"""Samples from the distribution.
146
139
 
147
- Args:
148
- key (KeyArray): The key to use for sampling.
140
+ # Args:
141
+ # key (KeyArray): The key to use for sampling.
149
142
 
150
- Returns:
151
- The samples as an array of shape (n_samples, n_points).
152
- """
153
- # Obtain covariance root.
154
- sqrt = lower_cholesky(self.scale)
143
+ # Returns:
144
+ # The samples as an array of shape (n_samples, n_points).
145
+ # """
146
+ # # Obtain covariance root.
147
+ # sqrt = lower_cholesky(self.scale)
155
148
 
156
- # Gather n samples from standard normal distribution Z = [z₁, ..., zₙ]ᵀ.
157
- Z = jr.normal(key, shape=(n, *self.event_shape))
149
+ # # Gather n samples from standard normal distribution Z = [z₁, ..., zₙ]ᵀ.
150
+ # Z = jr.normal(key, shape=(n, *self.event_shape))
158
151
 
159
- # xᵢ ~ N(loc, cov) <=> xᵢ = loc + sqrt zᵢ, where zᵢ ~ N(0, I).
160
- def affine_transformation(x):
161
- return self.loc + sqrt @ x
152
+ # # xᵢ ~ N(loc, cov) <=> xᵢ = loc + sqrt zᵢ, where zᵢ ~ N(0, I).
153
+ # def affine_transformation(x):
154
+ # return self.loc + sqrt @ x
162
155
 
163
- return vmap(affine_transformation)(Z)
156
+ # return vmap(affine_transformation)(Z)
164
157
 
165
- def sample(
166
- self, seed: KeyArray, sample_shape: Tuple[int, ...]
167
- ): # pylint: disable=useless-super-delegation
168
- r"""See `Distribution.sample`."""
169
- return self._sample_n(
170
- seed, sample_shape[0]
171
- ) # TODO this looks weird, why ignore the second entry?
158
+ # def sample(
159
+ # self, seed: KeyArray, sample_shape: Tuple[int, ...]
160
+ # ): # pylint: disable=useless-super-delegation
161
+ # r"""See `Distribution.sample`."""
162
+ # return self._sample_n(
163
+ # seed, sample_shape[0]
164
+ # ) # TODO this looks weird, why ignore the second entry?
172
165
 
173
166
  def kl_divergence(self, other: "GaussianDistribution") -> ScalarFloat:
174
167
  return _kl_divergence(self, other)
175
168
 
176
169
 
177
- DistrT = TypeVar("DistrT", bound=tfd.Distribution)
178
-
179
-
180
170
  def _check_and_return_dimension(
181
171
  q: GaussianDistribution, p: GaussianDistribution
182
172
  ) -> int:
@@ -245,37 +235,37 @@ def _kl_divergence(q: GaussianDistribution, p: GaussianDistribution) -> ScalarFl
245
235
  ) / 2.0
246
236
 
247
237
 
248
- def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
249
- r"""Checks that the inputs are correct."""
250
- if loc is None and scale is None:
251
- raise ValueError("At least one of `loc` or `scale` must be specified.")
252
-
253
- if loc is not None and loc.ndim < 1:
254
- raise ValueError("The parameter `loc` must have at least one dimension.")
255
-
256
- if scale is not None and len(scale.shape) < 2: # scale.ndim < 2:
257
- raise ValueError(
258
- "The `scale` must have at least two dimensions, but "
259
- f"`scale.shape = {scale.shape}`."
260
- )
261
-
262
- if scale is not None and not isinstance(scale, LinearOperator):
263
- raise ValueError(
264
- f"The `scale` must be a CoLA LinearOperator but got {type(scale)}"
265
- )
266
-
267
- if scale is not None and (scale.shape[-1] != scale.shape[-2]):
268
- raise ValueError(
269
- f"The `scale` must be a square matrix, but `scale.shape = {scale.shape}`."
270
- )
271
-
272
- if loc is not None:
273
- num_dims = loc.shape[-1]
274
- if scale is not None and (scale.shape[-1] != num_dims):
275
- raise ValueError(
276
- f"Shapes are not compatible: `loc.shape = {loc.shape}` and "
277
- f"`scale.shape = {scale.shape}`."
278
- )
238
+ # def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
239
+ # r"""Checks that the inputs are correct."""
240
+ # if loc is None and scale is None:
241
+ # raise ValueError("At least one of `loc` or `scale` must be specified.")
242
+
243
+ # if loc is not None and loc.ndim < 1:
244
+ # raise ValueError("The parameter `loc` must have at least one dimension.")
245
+
246
+ # if scale is not None and len(scale.shape) < 2: # scale.ndim < 2:
247
+ # raise ValueError(
248
+ # "The `scale` must have at least two dimensions, but "
249
+ # f"`scale.shape = {scale.shape}`."
250
+ # )
251
+
252
+ # if scale is not None and not isinstance(scale, LinearOperator):
253
+ # raise ValueError(
254
+ # f"The `scale` must be a CoLA LinearOperator but got {type(scale)}"
255
+ # )
256
+
257
+ # if scale is not None and (scale.shape[-1] != scale.shape[-2]):
258
+ # raise ValueError(
259
+ # f"The `scale` must be a square matrix, but `scale.shape = {scale.shape}`."
260
+ # )
261
+
262
+ # if loc is not None:
263
+ # num_dims = loc.shape[-1]
264
+ # if scale is not None and (scale.shape[-1] != num_dims):
265
+ # raise ValueError(
266
+ # f"Shapes are not compatible: `loc.shape = {loc.shape}` and "
267
+ # f"`scale.shape = {scale.shape}`."
268
+ # )
279
269
 
280
270
 
281
271
  __all__ = [
gpjax/fit.py CHANGED
@@ -20,9 +20,9 @@ import jax
20
20
  from jax.flatten_util import ravel_pytree
21
21
  import jax.numpy as jnp
22
22
  import jax.random as jr
23
+ from numpyro.distributions.transforms import Transform
23
24
  import optax as ox
24
25
  from scipy.optimize import minimize
25
- from tensorflow_probability.substrates.jax.bijectors import Bijector
26
26
 
27
27
  from gpjax.dataset import Dataset
28
28
  from gpjax.objectives import Objective
@@ -47,7 +47,7 @@ def fit( # noqa: PLR0913
47
47
  objective: Objective,
48
48
  train_data: Dataset,
49
49
  optim: ox.GradientTransformation,
50
- params_bijection: tp.Union[dict[Parameter, Bijector], None] = DEFAULT_BIJECTION,
50
+ params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
51
51
  key: KeyArray = jr.PRNGKey(42),
52
52
  num_iters: int = 100,
53
53
  batch_size: int = -1,
@@ -68,7 +68,7 @@ class RFF(AbstractKernel):
68
68
 
69
69
  self.frequencies = Static(
70
70
  self.base_kernel.spectral_density.sample(
71
- seed=key, sample_shape=(self.num_basis_fns, n_dims)
71
+ key=key, sample_shape=(self.num_basis_fns, n_dims)
72
72
  )
73
73
  )
74
74
  self.name = f"{self.base_kernel.name} (RFF)"
@@ -18,7 +18,7 @@ import beartype.typing as tp
18
18
  from flax import nnx
19
19
  import jax.numpy as jnp
20
20
  from jaxtyping import Float
21
- import tensorflow_probability.substrates.jax.distributions as tfd
21
+ import numpyro.distributions as npd
22
22
 
23
23
  from gpjax.kernels.base import AbstractKernel
24
24
  from gpjax.kernels.computations import (
@@ -92,7 +92,7 @@ class StationaryKernel(AbstractKernel):
92
92
  self.variance = tp.cast(PositiveReal[ScalarFloat], self.variance)
93
93
 
94
94
  @property
95
- def spectral_density(self) -> tfd.Distribution:
95
+ def spectral_density(self) -> npd.Normal | npd.StudentT:
96
96
  r"""The spectral density of the kernel.
97
97
 
98
98
  Returns:
@@ -15,7 +15,7 @@
15
15
 
16
16
  import jax.numpy as jnp
17
17
  from jaxtyping import Float
18
- import tensorflow_probability.substrates.jax.distributions as tfd
18
+ import numpyro.distributions as npd
19
19
 
20
20
  from gpjax.kernels.stationary.base import StationaryKernel
21
21
  from gpjax.kernels.stationary.utils import (
@@ -48,5 +48,5 @@ class Matern12(StationaryKernel):
48
48
  return K.squeeze()
49
49
 
50
50
  @property
51
- def spectral_density(self) -> tfd.Distribution:
51
+ def spectral_density(self) -> npd.StudentT:
52
52
  return build_student_t_distribution(nu=1)
@@ -15,7 +15,7 @@
15
15
 
16
16
  import jax.numpy as jnp
17
17
  from jaxtyping import Float
18
- import tensorflow_probability.substrates.jax.distributions as tfd
18
+ import numpyro.distributions as npd
19
19
 
20
20
  from gpjax.kernels.stationary.base import StationaryKernel
21
21
  from gpjax.kernels.stationary.utils import (
@@ -54,5 +54,5 @@ class Matern32(StationaryKernel):
54
54
  return K.squeeze()
55
55
 
56
56
  @property
57
- def spectral_density(self) -> tfd.Distribution:
57
+ def spectral_density(self) -> npd.StudentT:
58
58
  return build_student_t_distribution(nu=3)
@@ -15,7 +15,7 @@
15
15
 
16
16
  import jax.numpy as jnp
17
17
  from jaxtyping import Float
18
- import tensorflow_probability.substrates.jax.distributions as tfd
18
+ import numpyro.distributions as npd
19
19
 
20
20
  from gpjax.kernels.stationary.base import StationaryKernel
21
21
  from gpjax.kernels.stationary.utils import (
@@ -53,5 +53,5 @@ class Matern52(StationaryKernel):
53
53
  return K.squeeze()
54
54
 
55
55
  @property
56
- def spectral_density(self) -> tfd.Distribution:
56
+ def spectral_density(self) -> npd.StudentT:
57
57
  return build_student_t_distribution(nu=5)
@@ -15,7 +15,7 @@
15
15
 
16
16
  import jax.numpy as jnp
17
17
  from jaxtyping import Float
18
- import tensorflow_probability.substrates.jax as tfp
18
+ import numpyro.distributions as npd
19
19
 
20
20
  from gpjax.kernels.stationary.base import StationaryKernel
21
21
  from gpjax.kernels.stationary.utils import squared_distance
@@ -44,5 +44,5 @@ class RBF(StationaryKernel):
44
44
  return K.squeeze()
45
45
 
46
46
  @property
47
- def spectral_density(self) -> tfp.distributions.Normal:
48
- return tfp.distributions.Normal(0.0, 1.0)
47
+ def spectral_density(self) -> npd.Normal:
48
+ return npd.Normal(0.0, 1.0)
@@ -14,17 +14,15 @@
14
14
  # ==============================================================================
15
15
  import jax.numpy as jnp
16
16
  from jaxtyping import Float
17
- import tensorflow_probability.substrates.jax as tfp
17
+ import numpyro.distributions as npd
18
18
 
19
19
  from gpjax.typing import (
20
20
  Array,
21
21
  ScalarFloat,
22
22
  )
23
23
 
24
- tfd = tfp.distributions
25
24
 
26
-
27
- def build_student_t_distribution(nu: int) -> tfd.Distribution:
25
+ def build_student_t_distribution(nu: int) -> npd.StudentT:
28
26
  r"""Build a Student's t distribution with a fixed smoothness parameter.
29
27
 
30
28
  For a fixed half-integer smoothness parameter, compute the spectral density of a
@@ -37,7 +35,7 @@ def build_student_t_distribution(nu: int) -> tfd.Distribution:
37
35
  -------
38
36
  tfp.Distribution: A Student's t distribution with the same smoothness parameter.
39
37
  """
40
- dist = tfd.StudentT(df=nu, loc=0.0, scale=1.0)
38
+ dist = npd.StudentT(df=nu, loc=0.0, scale=1.0)
41
39
  return dist
42
40
 
43
41
 
gpjax/likelihoods.py CHANGED
@@ -19,7 +19,7 @@ from jax import vmap
19
19
  import jax.numpy as jnp
20
20
  import jax.scipy as jsp
21
21
  from jaxtyping import Float
22
- import tensorflow_probability.substrates.jax as tfp
22
+ import numpyro.distributions as npd
23
23
 
24
24
  from gpjax.distributions import GaussianDistribution
25
25
  from gpjax.integrators import (
@@ -36,9 +36,6 @@ from gpjax.typing import (
36
36
  ScalarFloat,
37
37
  )
38
38
 
39
- tfb = tfp.bijectors
40
- tfd = tfp.distributions
41
-
42
39
 
43
40
  class AbstractLikelihood(nnx.Module):
44
41
  r"""Abstract base class for likelihoods.
@@ -62,7 +59,7 @@ class AbstractLikelihood(nnx.Module):
62
59
  self.num_datapoints = num_datapoints
63
60
  self.integrator = integrator
64
61
 
65
- def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tfd.Distribution:
62
+ def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> npd.Distribution:
66
63
  r"""Evaluate the likelihood function at a given predictive distribution.
67
64
 
68
65
  Args:
@@ -76,7 +73,7 @@ class AbstractLikelihood(nnx.Module):
76
73
  return self.predict(*args, **kwargs)
77
74
 
78
75
  @abc.abstractmethod
79
- def predict(self, *args: tp.Any, **kwargs: tp.Any) -> tfd.Distribution:
76
+ def predict(self, *args: tp.Any, **kwargs: tp.Any) -> npd.Distribution:
80
77
  r"""Evaluate the likelihood function at a given predictive distribution.
81
78
 
82
79
  Args:
@@ -85,19 +82,19 @@ class AbstractLikelihood(nnx.Module):
85
82
  `predict` method.
86
83
 
87
84
  Returns:
88
- tfd.Distribution: The predictive distribution.
85
+ npd.Distribution: The predictive distribution.
89
86
  """
90
87
  raise NotImplementedError
91
88
 
92
89
  @abc.abstractmethod
93
- def link_function(self, f: Float[Array, "..."]) -> tfd.Distribution:
90
+ def link_function(self, f: Float[Array, "..."]) -> npd.Distribution:
94
91
  r"""Return the link function of the likelihood function.
95
92
 
96
93
  Args:
97
94
  f (Float[Array, "..."]): the latent Gaussian process values.
98
95
 
99
96
  Returns:
100
- tfd.Distribution: The distribution of observations, y, given values of the
97
+ npd.Distribution: The distribution of observations, y, given values of the
101
98
  Gaussian process, f.
102
99
  """
103
100
  raise NotImplementedError
@@ -157,20 +154,20 @@ class Gaussian(AbstractLikelihood):
157
154
 
158
155
  super().__init__(num_datapoints, integrator)
159
156
 
160
- def link_function(self, f: Float[Array, "..."]) -> tfd.Normal:
157
+ def link_function(self, f: Float[Array, "..."]) -> npd.Normal:
161
158
  r"""The link function of the Gaussian likelihood.
162
159
 
163
160
  Args:
164
161
  f (Float[Array, "..."]): Function values.
165
162
 
166
163
  Returns:
167
- tfd.Normal: The likelihood function.
164
+ npd.Normal: The likelihood function.
168
165
  """
169
- return tfd.Normal(loc=f, scale=self.obs_stddev.value.astype(f.dtype))
166
+ return npd.Normal(loc=f, scale=self.obs_stddev.value.astype(f.dtype))
170
167
 
171
168
  def predict(
172
- self, dist: tp.Union[tfd.MultivariateNormalTriL, GaussianDistribution]
173
- ) -> tfd.MultivariateNormalFullCovariance:
169
+ self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
170
+ ) -> npd.MultivariateNormal:
174
171
  r"""Evaluate the Gaussian likelihood.
175
172
 
176
173
  Evaluate the Gaussian likelihood function at a given predictive
@@ -179,75 +176,79 @@ class Gaussian(AbstractLikelihood):
179
176
  distribution's covariance matrix.
180
177
 
181
178
  Args:
182
- dist (tfd.Distribution): The Gaussian process posterior,
179
+ dist (npd.Distribution): The Gaussian process posterior,
183
180
  evaluated at a finite set of test points.
184
181
 
185
182
  Returns:
186
- tfd.Distribution: The predictive distribution.
183
+ npd.Distribution: The predictive distribution.
187
184
  """
188
185
  n_data = dist.event_shape[0]
189
- cov = dist.covariance()
186
+ cov = dist.covariance_matrix
190
187
  noisy_cov = cov.at[jnp.diag_indices(n_data)].add(self.obs_stddev.value**2)
191
188
 
192
- return tfd.MultivariateNormalFullCovariance(dist.mean(), noisy_cov)
189
+ return npd.MultivariateNormal(dist.mean, noisy_cov)
193
190
 
194
191
 
195
192
  class Bernoulli(AbstractLikelihood):
196
- def link_function(self, f: Float[Array, "..."]) -> tfd.Distribution:
193
+ def link_function(self, f: Float[Array, "..."]) -> npd.BernoulliProbs:
197
194
  r"""The probit link function of the Bernoulli likelihood.
198
195
 
199
196
  Args:
200
197
  f (Float[Array, "..."]): Function values.
201
198
 
202
199
  Returns:
203
- tfd.Distribution: The likelihood function.
200
+ npd.Bernoulli: The likelihood function.
204
201
  """
205
- return tfd.Bernoulli(probs=inv_probit(f))
202
+ return npd.Bernoulli(probs=inv_probit(f))
206
203
 
207
- def predict(self, dist: tfd.Distribution) -> tfd.Distribution:
204
+ def predict(
205
+ self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
206
+ ) -> npd.BernoulliProbs:
208
207
  r"""Evaluate the pointwise predictive distribution.
209
208
 
210
209
  Evaluate the pointwise predictive distribution, given a Gaussian
211
210
  process posterior and likelihood parameters.
212
211
 
213
212
  Args:
214
- dist (tfd.Distribution): The Gaussian process posterior, evaluated
215
- at a finite set of test points.
213
+ dist ([npd.MultivariateNormal, GaussianDistribution].): The Gaussian
214
+ process posterior, evaluated at a finite set of test points.
216
215
 
217
216
  Returns:
218
- tfd.Distribution: The pointwise predictive distribution.
217
+ npd.Bernoulli: The pointwise predictive distribution.
219
218
  """
220
- variance = jnp.diag(dist.covariance())
221
- mean = dist.mean().ravel()
219
+ variance = jnp.diag(dist.covariance_matrix)
220
+ mean = dist.mean.ravel()
222
221
  return self.link_function(mean / jnp.sqrt(1.0 + variance))
223
222
 
224
223
 
225
224
  class Poisson(AbstractLikelihood):
226
- def link_function(self, f: Float[Array, "..."]) -> tfd.Distribution:
225
+ def link_function(self, f: Float[Array, "..."]) -> npd.Poisson:
227
226
  r"""The link function of the Poisson likelihood.
228
227
 
229
228
  Args:
230
229
  f (Float[Array, "..."]): Function values.
231
230
 
232
231
  Returns:
233
- tfd.Distribution: The likelihood function.
232
+ npd.Poisson: The likelihood function.
234
233
  """
235
- return tfd.Poisson(rate=jnp.exp(f))
234
+ return npd.Poisson(rate=jnp.exp(f))
236
235
 
237
- def predict(self, dist: tfd.Distribution) -> tfd.Distribution:
236
+ def predict(
237
+ self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
238
+ ) -> npd.Poisson:
238
239
  r"""Evaluate the pointwise predictive distribution.
239
240
 
240
241
  Evaluate the pointwise predictive distribution, given a Gaussian
241
242
  process posterior and likelihood parameters.
242
243
 
243
244
  Args:
244
- dist (tfd.Distribution): The Gaussian process posterior, evaluated
245
- at a finite set of test points.
245
+ dist (tp.Union[npd.MultivariateNormal, GaussianDistribution]): The Gaussian
246
+ process posterior, evaluated at a finite set of test points.
246
247
 
247
248
  Returns:
248
- tfd.Distribution: The pointwise predictive distribution.
249
+ npd.Poisson: The pointwise predictive distribution.
249
250
  """
250
- return self.link_function(dist.mean())
251
+ return self.link_function(dist.mean)
251
252
 
252
253
 
253
254
  def inv_probit(x: Float[Array, " *N"]) -> Float[Array, " *N"]:
gpjax/mean_functions.py CHANGED
@@ -28,7 +28,7 @@ from jaxtyping import (
28
28
  from gpjax.parameters import (
29
29
  Parameter,
30
30
  Real,
31
- Static
31
+ Static,
32
32
  )
33
33
  from gpjax.typing import (
34
34
  Array,
@@ -131,7 +131,8 @@ class Constant(AbstractMeanFunction):
131
131
  """
132
132
 
133
133
  def __init__(
134
- self, constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter, Static] = 0.0
134
+ self,
135
+ constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter, Static] = 0.0,
135
136
  ):
136
137
  if isinstance(constant, Parameter) or isinstance(constant, Static):
137
138
  self.constant = constant
@@ -0,0 +1,106 @@
1
+ import math
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from numpyro.distributions.transforms import Transform
6
+
7
+ # -----------------------------------------------------------------------------
8
+ # Implementation: FillTriangularTransform
9
+ # -----------------------------------------------------------------------------
10
+
11
+
12
+ class FillTriangularTransform(Transform):
13
+ """
14
+ Transform that maps a vector of length n(n+1)/2 to an n x n lower triangular matrix.
15
+ The ordering is assumed to be:
16
+ (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), ..., (n-1, n-1)
17
+ """
18
+
19
+ # Note: The base class provides `inv` through _InverseTransform wrapping _inverse.
20
+
21
+ def __call__(self, x):
22
+ """
23
+ Forward transformation.
24
+
25
+ Parameters
26
+ ----------
27
+ x : array_like, shape (..., L)
28
+ Input vector with L = n(n+1)/2 for some integer n.
29
+
30
+ Returns
31
+ -------
32
+ y : array_like, shape (..., n, n)
33
+ Lower-triangular matrix (with zeros in the upper triangle) filled in
34
+ row-major order (i.e. [ (0,0), (1,0), (1,1), ... ]).
35
+ """
36
+ L = x.shape[-1]
37
+ # Use static (Python) math.sqrt to compute n. This avoids tracer issues.
38
+ n = int((-1 + math.sqrt(1 + 8 * L)) // 2)
39
+ if n * (n + 1) // 2 != L:
40
+ raise ValueError("Last dimension must equal n(n+1)/2 for some integer n.")
41
+
42
+ def fill_single(vec):
43
+ out = jnp.zeros((n, n), dtype=vec.dtype)
44
+ row, col = jnp.tril_indices(n)
45
+ return out.at[row, col].set(vec)
46
+
47
+ if x.ndim == 1:
48
+ return fill_single(x)
49
+ else:
50
+ batch_shape = x.shape[:-1]
51
+ flat_x = x.reshape((-1, L))
52
+ out = jax.vmap(fill_single)(flat_x)
53
+ return out.reshape(batch_shape + (n, n))
54
+
55
+ def _inverse(self, y):
56
+ """
57
+ Inverse transformation.
58
+
59
+ Parameters
60
+ ----------
61
+ y : array_like, shape (..., n, n)
62
+ Lower triangular matrix.
63
+
64
+ Returns
65
+ -------
66
+ x : array_like, shape (..., n(n+1)/2)
67
+ The vector containing the elements from the lower-triangular portion of y.
68
+ """
69
+ if y.ndim < 2:
70
+ raise ValueError("Input to inverse must be at least two-dimensional.")
71
+ n = y.shape[-1]
72
+ if y.shape[-2] != n:
73
+ raise ValueError(
74
+ "Input matrix must be square; got shape %s" % str(y.shape[-2:])
75
+ )
76
+
77
+ row, col = jnp.tril_indices(n)
78
+
79
+ def inv_single(mat):
80
+ return mat[row, col]
81
+
82
+ if y.ndim == 2:
83
+ return inv_single(y)
84
+ else:
85
+ batch_shape = y.shape[:-2]
86
+ flat_y = y.reshape((-1, n, n))
87
+ out = jax.vmap(inv_single)(flat_y)
88
+ return out.reshape(batch_shape + (n * (n + 1) // 2,))
89
+
90
+ def log_abs_det_jacobian(self, x, y, intermediates=None):
91
+ # Since the transform simply reorders the vector into a matrix, the Jacobian determinant is 1.
92
+ return jnp.zeros(x.shape[:-1])
93
+
94
+ @property
95
+ def sign(self):
96
+ # The reordering transformation has a positive derivative everywhere.
97
+ return 1.0
98
+
99
+ # Implement tree_flatten and tree_unflatten because base Transform expects them.
100
+ def tree_flatten(self):
101
+ # This transform is stateless.
102
+ return (), {}
103
+
104
+ @classmethod
105
+ def tree_unflatten(cls, aux_data, children):
106
+ return cls()
gpjax/objectives.py CHANGED
@@ -13,7 +13,7 @@ from jax import vmap
13
13
  import jax.numpy as jnp
14
14
  import jax.scipy as jsp
15
15
  from jaxtyping import Float
16
- import tensorflow_probability.substrates.jax as tfp
16
+ import numpyro.distributions as npd
17
17
  import typing_extensions as tpe
18
18
 
19
19
  from gpjax.dataset import Dataset
@@ -29,8 +29,6 @@ from gpjax.typing import (
29
29
  )
30
30
  from gpjax.variational_families import AbstractVariationalFamily
31
31
 
32
- tfd = tfp.distributions
33
-
34
32
  VF = TypeVar("VF", bound=AbstractVariationalFamily)
35
33
 
36
34
 
@@ -175,7 +173,7 @@ def conjugate_loocv(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat
175
173
  loocv_means = mx + (y - mx) - Sigma_inv_y / Sigma_inv_diag
176
174
  loocv_stds = jnp.sqrt(1.0 / Sigma_inv_diag)
177
175
 
178
- loocv_posterior = tfd.Normal(loc=loocv_means, scale=loocv_stds)
176
+ loocv_posterior = npd.Normal(loc=loocv_means, scale=loocv_stds)
179
177
  return jnp.sum(loocv_posterior.log_prob(y))
180
178
 
181
179
 
@@ -232,7 +230,7 @@ def log_posterior_density(
232
230
  likelihood = posterior.likelihood.link_function(fx)
233
231
 
234
232
  # Whitened latent function values prior, p(wx | θ) = N(0, I)
235
- latent_prior = tfd.Normal(loc=0.0, scale=1.0)
233
+ latent_prior = npd.Normal(loc=0.0, scale=1.0)
236
234
  return likelihood.log_prob(y).sum() + latent_prior.log_prob(wx).sum()
237
235
 
238
236
 
@@ -305,7 +303,7 @@ def variational_expectation(
305
303
  # inputs, x
306
304
  def q_moments(x):
307
305
  qx = q(x)
308
- return qx.mean().squeeze(), qx.covariance().squeeze()
306
+ return qx.mean.squeeze(), qx.covariance().squeeze()
309
307
 
310
308
  mean, variance = vmap(q_moments)(x[:, None])
311
309
 
gpjax/parameters.py CHANGED
@@ -5,7 +5,9 @@ from jax.experimental import checkify
5
5
  import jax.numpy as jnp
6
6
  import jax.tree_util as jtu
7
7
  from jax.typing import ArrayLike
8
- import tensorflow_probability.substrates.jax.bijectors as tfb
8
+ import numpyro.distributions.transforms as npt
9
+
10
+ from gpjax.numpyro_extras import FillTriangularTransform
9
11
 
10
12
  T = tp.TypeVar("T", bound=tp.Union[ArrayLike, list[float]])
11
13
  ParameterTag = str
@@ -13,7 +15,7 @@ ParameterTag = str
13
15
 
14
16
  def transform(
15
17
  params: nnx.State,
16
- params_bijection: tp.Dict[str, tfb.Bijector],
18
+ params_bijection: tp.Dict[str, npt.Transform],
17
19
  inverse: bool = False,
18
20
  ) -> nnx.State:
19
21
  r"""Transforms parameters using a bijector.
@@ -22,7 +24,7 @@ def transform(
22
24
  ```pycon
23
25
  >>> from gpjax.parameters import PositiveReal, transform
24
26
  >>> import jax.numpy as jnp
25
- >>> import tensorflow_probability.substrates.jax.bijectors as tfb
27
+ >>> import numpyro.distributions.transforms as npt
26
28
  >>> from flax import nnx
27
29
  >>> params = nnx.State(
28
30
  >>> {
@@ -30,7 +32,7 @@ def transform(
30
32
  >>> "b": PositiveReal(jnp.array([2.0])),
31
33
  >>> }
32
34
  >>> )
33
- >>> params_bijection = {'positive': tfb.Softplus()}
35
+ >>> params_bijection = {'positive': npt.SoftplusTransform()}
34
36
  >>> transformed_params = transform(params, params_bijection)
35
37
  >>> print(transformed_params["a"].value)
36
38
  [1.3132617]
@@ -47,11 +49,11 @@ def transform(
47
49
  """
48
50
 
49
51
  def _inner(param):
50
- bijector = params_bijection.get(param._tag, tfb.Identity())
52
+ bijector = params_bijection.get(param._tag, npt.IdentityTransform())
51
53
  if inverse:
52
- transformed_value = bijector.inverse(param.value)
54
+ transformed_value = bijector.inv(param.value)
53
55
  else:
54
- transformed_value = bijector.forward(param.value)
56
+ transformed_value = bijector(param.value)
55
57
 
56
58
  param = param.replace(transformed_value)
57
59
  return param
@@ -104,7 +106,7 @@ class SigmoidBounded(Parameter[T]):
104
106
  # Only perform validation in non-JIT contexts
105
107
  if (
106
108
  not isinstance(value, jnp.ndarray)
107
- or not getattr(value, "aval", None) is None
109
+ or getattr(value, "aval", None) is not None
108
110
  ):
109
111
  _safe_assert(
110
112
  _check_in_bounds,
@@ -133,17 +135,17 @@ class LowerTriangular(Parameter[T]):
133
135
  # Only perform validation in non-JIT contexts
134
136
  if (
135
137
  not isinstance(value, jnp.ndarray)
136
- or not getattr(value, "aval", None) is None
138
+ or getattr(value, "aval", None) is not None
137
139
  ):
138
140
  _safe_assert(_check_is_square, self.value)
139
141
  _safe_assert(_check_is_lower_triangular, self.value)
140
142
 
141
143
 
142
144
  DEFAULT_BIJECTION = {
143
- "positive": tfb.Softplus(),
144
- "real": tfb.Identity(),
145
- "sigmoid": tfb.Sigmoid(low=0.0, high=1.0),
146
- "lower_triangular": tfb.FillTriangular(),
145
+ "positive": npt.SoftplusTransform(),
146
+ "real": npt.IdentityTransform(),
147
+ "sigmoid": npt.SigmoidTransform(),
148
+ "lower_triangular": FillTriangularTransform(),
147
149
  }
148
150
 
149
151
 
@@ -22,6 +22,7 @@ from cola.linalg.inverse.inv import solve
22
22
  from cola.ops.operators import (
23
23
  Dense,
24
24
  I_like,
25
+ Identity,
25
26
  Triangular,
26
27
  )
27
28
  from flax import nnx
@@ -296,7 +297,10 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
296
297
 
297
298
  # Compute whitened KL divergence
298
299
  qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
299
- pu = GaussianDistribution(loc=jnp.zeros_like(jnp.atleast_1d(mu.squeeze())))
300
+ pu_S = Identity(shape=(self.num_inducing, self.num_inducing), dtype=mu.dtype)
301
+ pu = GaussianDistribution(
302
+ loc=jnp.zeros_like(jnp.atleast_1d(mu.squeeze())), scale=pu_S
303
+ )
300
304
  return qu.kl_divergence(pu)
301
305
 
302
306
  def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.10.2
3
+ Version: 0.11.0
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -24,8 +24,8 @@ Requires-Dist: jax>=0.5.0
24
24
  Requires-Dist: jaxlib>=0.5.0
25
25
  Requires-Dist: jaxtyping>0.2.10
26
26
  Requires-Dist: numpy>=2.0.0
27
+ Requires-Dist: numpyro
27
28
  Requires-Dist: optax>0.2.1
28
- Requires-Dist: tensorflow-probability>=0.24.0
29
29
  Requires-Dist: tqdm>4.66.2
30
30
  Description-Content-Type: text/markdown
31
31
 
@@ -138,65 +138,6 @@ jupytext --to notebook example.py
138
138
  jupytext --to py:percent example.ipynb
139
139
  ```
140
140
 
141
- # Simple example
142
-
143
- Let us import some dependencies and simulate a toy dataset $\mathcal{D}$.
144
-
145
- ```python
146
- from jax import config
147
-
148
- config.update("jax_enable_x64", True)
149
-
150
- import gpjax as gpx
151
- from jax import grad, jit
152
- import jax.numpy as jnp
153
- import jax.random as jr
154
- import optax as ox
155
-
156
- key = jr.key(123)
157
-
158
- f = lambda x: 10 * jnp.sin(x)
159
-
160
- n = 50
161
- x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,1)).sort()
162
- y = f(x) + jr.normal(key, shape=(n,1))
163
- D = gpx.Dataset(X=x, y=y)
164
-
165
- # Construct the prior
166
- meanf = gpx.mean_functions.Zero()
167
- kernel = gpx.kernels.RBF()
168
- prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
169
-
170
- # Define a likelihood
171
- likelihood = gpx.likelihoods.Gaussian(num_datapoints = n)
172
-
173
- # Construct the posterior
174
- posterior = prior * likelihood
175
-
176
- # Define an optimiser
177
- optimiser = ox.adam(learning_rate=1e-2)
178
-
179
- # Obtain Type 2 MLEs of the hyperparameters
180
- opt_posterior, history = gpx.fit(
181
- model=posterior,
182
- objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
183
- train_data=D,
184
- optim=optimiser,
185
- num_iters=500,
186
- safe=True,
187
- key=key,
188
- )
189
-
190
- # Infer the predictive posterior distribution
191
- xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)
192
- latent_dist = opt_posterior(xtest, D)
193
- predictive_dist = opt_posterior.likelihood(latent_dist)
194
-
195
- # Obtain the predictive mean and standard deviation
196
- pred_mean = predictive_dist.mean()
197
- pred_std = predictive_dist.stddev()
198
- ```
199
-
200
141
  # Installation
201
142
 
202
143
  ## Stable version
@@ -1,22 +1,23 @@
1
- gpjax/__init__.py,sha256=F9GVk18tdmvwiDEHZNo_4Wr0TkmPhWIEwl3KzEWQcaQ,1654
1
+ gpjax/__init__.py,sha256=wXJtQa_3W7wZEw_t1Dk0uHUzNQQDv8QzsVbnwXCMXcQ,1654
2
2
  gpjax/citation.py,sha256=f2Hzj5MLyCE7l0hHAzsEQoTORZH5hgV_eis4uoBiWvE,3811
3
3
  gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
- gpjax/distributions.py,sha256=X48FJr3reop9maherdMVt7-XZOm2f26T8AJt_IKM_oE,9339
5
- gpjax/fit.py,sha256=OHv8jUHxa1ndpqMERSDRtYtUDzubk9rMPVIhfCiIH5Q,11551
4
+ gpjax/distributions.py,sha256=8LWmfmRVHOX29Uy8PkKFi2UhcCiunuu-4TMI_5-krHc,9299
5
+ gpjax/fit.py,sha256=STwpeqSuu2pgT6uZU7xd7koPZbAjPDzhcZ8nHfozR7Q,11538
6
6
  gpjax/gps.py,sha256=97lYGrsmsufQxKEd8qz5wPNvui6FKXTF_Ps-sMFIjnY,31246
7
7
  gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
- gpjax/likelihoods.py,sha256=DOyV1L0ompkpeImMTiOOiWLJfqSqvDX_acOumuFqPEc,9234
8
+ gpjax/likelihoods.py,sha256=VcCibgihaskmvNJT4kuPa7ehgjlnR9LgMz_2KJJvHY0,9296
9
9
  gpjax/lower_cholesky.py,sha256=3pnHaBrlGckFsrfYJ9Lsbd0pGmO7NIXdyY4aGm48MpY,1952
10
- gpjax/mean_functions.py,sha256=BpeFkR3Eqa3O_FGp9BtSu9HKNSYZ8M08VtyfPfWbwRg,6479
11
- gpjax/objectives.py,sha256=XwkPyL_iovTNKpKGVNt0Lt2_OMTJitSPhuyCtUrJpbc,15383
12
- gpjax/parameters.py,sha256=6VKq6wBzEUtx-GXniC8fEqjTNrTC1YwIOw66QguW6UM,6457
10
+ gpjax/mean_functions.py,sha256=gIPz7exEhish3yeJQxZp5Q_jlf2-gCE-KVAnL2Rumkc,6489
11
+ gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
12
+ gpjax/objectives.py,sha256=I_ZqnwTNYIAUAZ9KQNenIl0ish1jDOXb7KaNmjz3Su4,15340
13
+ gpjax/parameters.py,sha256=Vj1xzrziSLxfBSqyc-BacyKBwkbE9Sjq4b1HV5HZiOg,6507
13
14
  gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
14
15
  gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
15
- gpjax/variational_families.py,sha256=s1rk7PtNTjQPabmVu-jBsuJBoqsxAAXwKFZJOEswkNQ,28161
16
+ gpjax/variational_families.py,sha256=Y9J1H91tXPm_hMy3ri_PgjAxqc_3r-BqKV83HRvB_m4,28295
16
17
  gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
17
18
  gpjax/kernels/base.py,sha256=wXsrpm5ofy9S5MNgUkJk4lx2umcIJL6dDNhXY7cmTGk,11616
18
19
  gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
19
- gpjax/kernels/approximations/rff.py,sha256=4kD1uocjHmxkLgvf4DxB4_Gy7iefdPgnWiZB3jDiExI,4126
20
+ gpjax/kernels/approximations/rff.py,sha256=VbitjNuahFE5_IvCj1A0SxHhJXU0O0Qq0FMMVq8xA3E,4125
20
21
  gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
21
22
  gpjax/kernels/computations/base.py,sha256=zzabLN_-FkTWo6cBYjzjvUGYa7vrYyHxyrhQZxLgHBk,3651
22
23
  gpjax/kernels/computations/basis_functions.py,sha256=zY4rUDZDLOYvQPY9xosRmCLPdiXfbzAN5GICjQhFOis,2528
@@ -32,17 +33,17 @@ gpjax/kernels/nonstationary/arccosine.py,sha256=UCTVJEhTZFQjARGFsYMImLnTDyTyxobI
32
33
  gpjax/kernels/nonstationary/linear.py,sha256=UKDHFCQzKWDMYo76qcb5-ujjnP2_iL-1tcN017xjK48,2562
33
34
  gpjax/kernels/nonstationary/polynomial.py,sha256=7SDMfEcBCqnRn9xyj4iGcYLNvYJZiveN3uLZ_h12p10,3257
34
35
  gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
35
- gpjax/kernels/stationary/base.py,sha256=pQNkMo-E4bIT4tNfb7JvFJZC6fIIXNErsT1iQopFlAA,7063
36
- gpjax/kernels/stationary/matern12.py,sha256=b2vQCUqhd9NJK84L2RYjpI597uxy_7xgwsjS35Gc958,1807
37
- gpjax/kernels/stationary/matern32.py,sha256=ZVYbUIQhvKpriC7abH8wV6Pk-mRoxtl3e2YYwH-KM5Y,2000
38
- gpjax/kernels/stationary/matern52.py,sha256=xfMYbY7MXxgMECtA2qVT5I8HoDGzGxygUvduGT3_Gvs,2053
36
+ gpjax/kernels/stationary/base.py,sha256=FlsXMsXyZ5cI80jbsIo8Jv-H6gsV3C7v6plIhyCl-GI,7042
37
+ gpjax/kernels/stationary/matern12.py,sha256=DGjqw6VveYsyy0TrufyJJvCei7p9slnm2f0TgRGG7_U,1773
38
+ gpjax/kernels/stationary/matern32.py,sha256=laLsJWJozJzpYHBzlkPUq0rWxz1eWEwGC36P2nPJuaQ,1966
39
+ gpjax/kernels/stationary/matern52.py,sha256=VSByD2sb7k-DzRFjaz31P3Rtc4bPPhHvMshrxZNFnns,2019
39
40
  gpjax/kernels/stationary/periodic.py,sha256=IAbCxURtJEHGdmYzbdrsqRZ3zJ8F8tGQF9O7sggafZk,3598
40
41
  gpjax/kernels/stationary/powered_exponential.py,sha256=8qT91IWKJK7PpEtFcX4MVu1ahWMOFOZierPko4JCjKA,3776
41
42
  gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UOl8uOFT3lCutleSvo4,3496
42
- gpjax/kernels/stationary/rbf.py,sha256=G13gg5phO7ite7D9QgoCy7gB2_y0FM6GZhgFW4RL6Xw,1734
43
- gpjax/kernels/stationary/utils.py,sha256=Xa9EEnxgFqEi08ZSFAZYYHhJ85_3Ac-ZUyUk18B63M4,2225
43
+ gpjax/kernels/stationary/rbf.py,sha256=euHUs6FdfRICQcabAWE4MX-7GEDr2TxgZWdFQiXr9Bw,1690
44
+ gpjax/kernels/stationary/utils.py,sha256=6BI9EBcCzeeKx-XH-MfW1ORmtU__tPX5zyvfLhpkBsU,2180
44
45
  gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
45
- gpjax-0.10.2.dist-info/METADATA,sha256=mqIBMOMKKiI9qkM_uFHSuPEXY17Jd6bOL5EM2hPiaok,9970
46
- gpjax-0.10.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
47
- gpjax-0.10.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
48
- gpjax-0.10.2.dist-info/RECORD,,
46
+ gpjax-0.11.0.dist-info/METADATA,sha256=eSWVc5y9WNrUmKpaOVq1CcHjrKjMwlmSvwovN9h9aCk,8558
47
+ gpjax-0.11.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
48
+ gpjax-0.11.0.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
49
+ gpjax-0.11.0.dist-info/RECORD,,
File without changes