gpjax 0.10.2__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -32,14 +32,15 @@ from gpjax.citation import cite
32
32
  from gpjax.dataset import Dataset
33
33
  from gpjax.fit import (
34
34
  fit,
35
+ fit_lbfgs,
35
36
  fit_scipy,
36
37
  )
37
38
 
38
39
  __license__ = "MIT"
39
- __description__ = "Didactic Gaussian processes in JAX"
40
+ __description__ = "Gaussian processes in JAX and Flax"
40
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
41
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
42
- __version__ = "0.10.2"
43
+ __version__ = "0.11.1"
43
44
 
44
45
  __all__ = [
45
46
  "base",
@@ -56,5 +57,6 @@ __all__ = [
56
57
  "fit",
57
58
  "Module",
58
59
  "param_field",
60
+ "fit_lbfgs",
59
61
  "fit_scipy",
60
62
  ]
gpjax/distributions.py CHANGED
@@ -15,77 +15,76 @@
15
15
 
16
16
 
17
17
  from beartype.typing import (
18
- Any,
19
18
  Optional,
20
- Tuple,
21
- TypeVar,
22
19
  )
23
20
  import cola
21
+ from cola.linalg.decompositions import Cholesky
24
22
  from cola.ops import (
25
- Identity,
26
23
  LinearOperator,
27
24
  )
28
25
  from jax import vmap
29
26
  import jax.numpy as jnp
30
27
  import jax.random as jr
31
28
  from jaxtyping import Float
32
- import tensorflow_probability.substrates.jax as tfp
29
+ from numpyro.distributions import constraints
30
+ from numpyro.distributions.distribution import Distribution
31
+ from numpyro.distributions.util import is_prng_key
33
32
 
34
33
  from gpjax.lower_cholesky import lower_cholesky
35
34
  from gpjax.typing import (
36
35
  Array,
37
- KeyArray,
38
36
  ScalarFloat,
39
37
  )
40
38
 
41
- tfd = tfp.distributions
42
-
43
- from cola.linalg.decompositions import Cholesky
44
-
45
39
 
46
- class GaussianDistribution(tfd.Distribution):
47
- r"""Multivariate Gaussian distribution with a linear operator scale matrix."""
48
-
49
- # TODO: Consider `distrax.transformed.Transformed` object. Can we create a LinearOperator to `distrax.bijector` representation
50
- # and modify `distrax.MultivariateNormalFromBijector`?
51
- # TODO: Consider natural and expectation parameterisations in future work.
52
- # TODO: we don't really need to inherit from `tfd.Distribution` here
40
+ class GaussianDistribution(Distribution):
41
+ support = constraints.real_vector
53
42
 
54
43
  def __init__(
55
44
  self,
56
- loc: Optional[Float[Array, " N"]] = None,
57
- scale: Optional[LinearOperator] = None,
58
- ) -> None:
59
- r"""Initialises the distribution.
60
-
61
- Args:
62
- loc: the mean of the distribution as an array of shape (n_points,).
63
- scale: the scale matrix of the distribution as a LinearOperator object.
64
- """
65
- _check_loc_scale(loc, scale)
45
+ loc: Optional[Float[Array, " N"]],
46
+ scale: Optional[LinearOperator],
47
+ validate_args=None,
48
+ ):
49
+ self.loc = loc
50
+ self.scale = cola.PSD(scale)
51
+ batch_shape = ()
52
+ event_shape = jnp.shape(self.loc)
53
+ super().__init__(batch_shape, event_shape, validate_args=validate_args)
66
54
 
67
- # Find dimensionality of the distribution.
68
- if loc is not None:
69
- num_dims = loc.shape[-1]
55
+ def sample(self, key, sample_shape=()):
56
+ assert is_prng_key(key)
57
+ # Obtain covariance root.
58
+ covariance_root = lower_cholesky(self.scale)
70
59
 
71
- elif scale is not None:
72
- num_dims = scale.shape[-1]
60
+ # Gather n samples from standard normal distribution Z = [z₁, ..., zₙ]ᵀ.
61
+ white_noise = jr.normal(
62
+ key, shape=sample_shape + self.batch_shape + self.event_shape
63
+ )
73
64
 
74
- # Set the location to zero vector if unspecified.
75
- if loc is None:
76
- loc = jnp.zeros((num_dims,))
65
+ # xᵢ ~ N(loc, cov) <=> xᵢ = loc + sqrt zᵢ, where zᵢ ~ N(0, I).
66
+ def affine_transformation(_x):
67
+ return self.loc + covariance_root @ _x
77
68
 
78
- # If not specified, set the scale to the identity matrix.
79
- if scale is None:
80
- scale = Identity(shape=(num_dims, num_dims), dtype=loc.dtype)
81
-
82
- self.loc = loc
83
- self.scale = cola.PSD(scale)
69
+ return vmap(affine_transformation)(white_noise)
84
70
 
71
+ @property
85
72
  def mean(self) -> Float[Array, " N"]:
86
73
  r"""Calculates the mean."""
87
74
  return self.loc
88
75
 
76
+ @property
77
+ def variance(self) -> Float[Array, " N"]:
78
+ r"""Calculates the variance."""
79
+ return cola.diag(self.scale)
80
+
81
+ def entropy(self) -> ScalarFloat:
82
+ r"""Calculates the entropy of the distribution."""
83
+ return 0.5 * (
84
+ self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
85
+ + cola.logdet(self.scale, Cholesky(), Cholesky())
86
+ )
87
+
89
88
  def median(self) -> Float[Array, " N"]:
90
89
  r"""Calculates the median."""
91
90
  return self.loc
@@ -98,25 +97,19 @@ class GaussianDistribution(tfd.Distribution):
98
97
  r"""Calculates the covariance matrix."""
99
98
  return self.scale.to_dense()
100
99
 
101
- def variance(self) -> Float[Array, " N"]:
102
- r"""Calculates the variance."""
103
- return cola.diag(self.scale)
100
+ @property
101
+ def covariance_matrix(self) -> Float[Array, "N N"]:
102
+ r"""Calculates the covariance matrix."""
103
+ return self.covariance()
104
104
 
105
105
  def stddev(self) -> Float[Array, " N"]:
106
106
  r"""Calculates the standard deviation."""
107
107
  return jnp.sqrt(cola.diag(self.scale))
108
108
 
109
- @property
110
- def event_shape(self) -> Tuple:
111
- r"""Returns the event shape."""
112
- return self.loc.shape[-1:]
113
-
114
- def entropy(self) -> ScalarFloat:
115
- r"""Calculates the entropy of the distribution."""
116
- return 0.5 * (
117
- self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
118
- + cola.logdet(self.scale, Cholesky(), Cholesky())
119
- )
109
+ # @property
110
+ # def event_shape(self) -> Tuple:
111
+ # r"""Returns the event shape."""
112
+ # return self.loc.shape[-1:]
120
113
 
121
114
  def log_prob(self, y: Float[Array, " N"]) -> ScalarFloat:
122
115
  r"""Calculates the log pdf of the multivariate Gaussian.
@@ -141,42 +134,39 @@ class GaussianDistribution(tfd.Distribution):
141
134
  + diff.T @ cola.solve(sigma, diff, Cholesky())
142
135
  )
143
136
 
144
- def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
145
- r"""Samples from the distribution.
137
+ # def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
138
+ # r"""Samples from the distribution.
146
139
 
147
- Args:
148
- key (KeyArray): The key to use for sampling.
140
+ # Args:
141
+ # key (KeyArray): The key to use for sampling.
149
142
 
150
- Returns:
151
- The samples as an array of shape (n_samples, n_points).
152
- """
153
- # Obtain covariance root.
154
- sqrt = lower_cholesky(self.scale)
143
+ # Returns:
144
+ # The samples as an array of shape (n_samples, n_points).
145
+ # """
146
+ # # Obtain covariance root.
147
+ # sqrt = lower_cholesky(self.scale)
155
148
 
156
- # Gather n samples from standard normal distribution Z = [z₁, ..., zₙ]ᵀ.
157
- Z = jr.normal(key, shape=(n, *self.event_shape))
149
+ # # Gather n samples from standard normal distribution Z = [z₁, ..., zₙ]ᵀ.
150
+ # Z = jr.normal(key, shape=(n, *self.event_shape))
158
151
 
159
- # xᵢ ~ N(loc, cov) <=> xᵢ = loc + sqrt zᵢ, where zᵢ ~ N(0, I).
160
- def affine_transformation(x):
161
- return self.loc + sqrt @ x
152
+ # # xᵢ ~ N(loc, cov) <=> xᵢ = loc + sqrt zᵢ, where zᵢ ~ N(0, I).
153
+ # def affine_transformation(x):
154
+ # return self.loc + sqrt @ x
162
155
 
163
- return vmap(affine_transformation)(Z)
156
+ # return vmap(affine_transformation)(Z)
164
157
 
165
- def sample(
166
- self, seed: KeyArray, sample_shape: Tuple[int, ...]
167
- ): # pylint: disable=useless-super-delegation
168
- r"""See `Distribution.sample`."""
169
- return self._sample_n(
170
- seed, sample_shape[0]
171
- ) # TODO this looks weird, why ignore the second entry?
158
+ # def sample(
159
+ # self, seed: KeyArray, sample_shape: Tuple[int, ...]
160
+ # ): # pylint: disable=useless-super-delegation
161
+ # r"""See `Distribution.sample`."""
162
+ # return self._sample_n(
163
+ # seed, sample_shape[0]
164
+ # ) # TODO this looks weird, why ignore the second entry?
172
165
 
173
166
  def kl_divergence(self, other: "GaussianDistribution") -> ScalarFloat:
174
167
  return _kl_divergence(self, other)
175
168
 
176
169
 
177
- DistrT = TypeVar("DistrT", bound=tfd.Distribution)
178
-
179
-
180
170
  def _check_and_return_dimension(
181
171
  q: GaussianDistribution, p: GaussianDistribution
182
172
  ) -> int:
@@ -245,37 +235,37 @@ def _kl_divergence(q: GaussianDistribution, p: GaussianDistribution) -> ScalarFl
245
235
  ) / 2.0
246
236
 
247
237
 
248
- def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
249
- r"""Checks that the inputs are correct."""
250
- if loc is None and scale is None:
251
- raise ValueError("At least one of `loc` or `scale` must be specified.")
252
-
253
- if loc is not None and loc.ndim < 1:
254
- raise ValueError("The parameter `loc` must have at least one dimension.")
255
-
256
- if scale is not None and len(scale.shape) < 2: # scale.ndim < 2:
257
- raise ValueError(
258
- "The `scale` must have at least two dimensions, but "
259
- f"`scale.shape = {scale.shape}`."
260
- )
261
-
262
- if scale is not None and not isinstance(scale, LinearOperator):
263
- raise ValueError(
264
- f"The `scale` must be a CoLA LinearOperator but got {type(scale)}"
265
- )
266
-
267
- if scale is not None and (scale.shape[-1] != scale.shape[-2]):
268
- raise ValueError(
269
- f"The `scale` must be a square matrix, but `scale.shape = {scale.shape}`."
270
- )
271
-
272
- if loc is not None:
273
- num_dims = loc.shape[-1]
274
- if scale is not None and (scale.shape[-1] != num_dims):
275
- raise ValueError(
276
- f"Shapes are not compatible: `loc.shape = {loc.shape}` and "
277
- f"`scale.shape = {scale.shape}`."
278
- )
238
+ # def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
239
+ # r"""Checks that the inputs are correct."""
240
+ # if loc is None and scale is None:
241
+ # raise ValueError("At least one of `loc` or `scale` must be specified.")
242
+
243
+ # if loc is not None and loc.ndim < 1:
244
+ # raise ValueError("The parameter `loc` must have at least one dimension.")
245
+
246
+ # if scale is not None and len(scale.shape) < 2: # scale.ndim < 2:
247
+ # raise ValueError(
248
+ # "The `scale` must have at least two dimensions, but "
249
+ # f"`scale.shape = {scale.shape}`."
250
+ # )
251
+
252
+ # if scale is not None and not isinstance(scale, LinearOperator):
253
+ # raise ValueError(
254
+ # f"The `scale` must be a CoLA LinearOperator but got {type(scale)}"
255
+ # )
256
+
257
+ # if scale is not None and (scale.shape[-1] != scale.shape[-2]):
258
+ # raise ValueError(
259
+ # f"The `scale` must be a square matrix, but `scale.shape = {scale.shape}`."
260
+ # )
261
+
262
+ # if loc is not None:
263
+ # num_dims = loc.shape[-1]
264
+ # if scale is not None and (scale.shape[-1] != num_dims):
265
+ # raise ValueError(
266
+ # f"Shapes are not compatible: `loc.shape = {loc.shape}` and "
267
+ # f"`scale.shape = {scale.shape}`."
268
+ # )
279
269
 
280
270
 
281
271
  __all__ = [
gpjax/fit.py CHANGED
@@ -15,14 +15,14 @@
15
15
 
16
16
  import typing as tp
17
17
 
18
- from flax import nnx
19
18
  import jax
20
- from jax.flatten_util import ravel_pytree
21
19
  import jax.numpy as jnp
22
20
  import jax.random as jr
23
21
  import optax as ox
22
+ from flax import nnx
23
+ from jax.flatten_util import ravel_pytree
24
+ from numpyro.distributions.transforms import Transform
24
25
  from scipy.optimize import minimize
25
- from tensorflow_probability.substrates.jax.bijectors import Bijector
26
26
 
27
27
  from gpjax.dataset import Dataset
28
28
  from gpjax.objectives import Objective
@@ -47,7 +47,7 @@ def fit( # noqa: PLR0913
47
47
  objective: Objective,
48
48
  train_data: Dataset,
49
49
  optim: ox.GradientTransformation,
50
- params_bijection: tp.Union[dict[Parameter, Bijector], None] = DEFAULT_BIJECTION,
50
+ params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
51
51
  key: KeyArray = jr.PRNGKey(42),
52
52
  num_iters: int = 100,
53
53
  batch_size: int = -1,
@@ -127,7 +127,6 @@ def fit( # noqa: PLR0913
127
127
  _check_verbose(verbose)
128
128
 
129
129
  # Model state filtering
130
-
131
130
  graphdef, params, *static_state = nnx.split(model, Parameter, ...)
132
131
 
133
132
  # Parameters bijection to unconstrained space
@@ -253,6 +252,110 @@ def fit_scipy( # noqa: PLR0913
253
252
  return model, history
254
253
 
255
254
 
255
+ def fit_lbfgs(
256
+ *,
257
+ model: Model,
258
+ objective: Objective,
259
+ train_data: Dataset,
260
+ params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
261
+ max_iters: int = 100,
262
+ safe: bool = True,
263
+ max_linesearch_steps: int = 32,
264
+ gtol: float = 1e-5,
265
+ ) -> tuple[Model, jax.Array]:
266
+ r"""Train a Module model with respect to a supplied Objective function.
267
+
268
+ Uses Optax's LBFGS implementation and a jax.lax.while loop.
269
+
270
+ Args:
271
+ model: the model Module to be optimised.
272
+ objective: The objective function that we are optimising with
273
+ respect to.
274
+ train_data (Dataset): The training data to be used for the optimisation.
275
+ max_iters (int): The maximum number of optimisation steps to run. Defaults
276
+ to 500.
277
+ safe (bool): Whether to check the types of the inputs.
278
+ max_linesearch_steps (int): The maximum number of linesearch steps to use
279
+ for finding the stepsize.
280
+ gtol (float): Terminate the optimisation if the L2 norm of the gradient is
281
+ below this threshold.
282
+
283
+ Returns:
284
+ A tuple comprising the optimised model and final loss.
285
+ """
286
+ if safe:
287
+ # Check inputs
288
+ _check_model(model)
289
+ _check_train_data(train_data)
290
+ _check_num_iters(max_iters)
291
+
292
+ # Model state filtering
293
+ graphdef, params, *static_state = nnx.split(model, Parameter, ...)
294
+
295
+ # Parameters bijection to unconstrained space
296
+ if params_bijection is not None:
297
+ params = transform(params, params_bijection, inverse=True)
298
+
299
+ # Loss definition
300
+ def loss(params: nnx.State) -> ScalarFloat:
301
+ params = transform(params, params_bijection)
302
+ model = nnx.merge(graphdef, params, *static_state)
303
+ return objective(model, train_data)
304
+
305
+ # Initialise optimiser
306
+ optim = ox.lbfgs(
307
+ linesearch=ox.scale_by_zoom_linesearch(
308
+ max_linesearch_steps=max_linesearch_steps,
309
+ initial_guess_strategy="one",
310
+ )
311
+ )
312
+ opt_state = optim.init(params)
313
+ loss_value_and_grad = ox.value_and_grad_from_state(loss)
314
+
315
+ # Optimisation step.
316
+ def step(carry):
317
+ params, opt_state = carry
318
+
319
+ # Using optax's value_and_grad_from_state is more efficient given LBFGS uses a linesearch
320
+ # See https://optax.readthedocs.io/en/latest/api/utilities.html#optax.value_and_grad_from_state
321
+ loss_val, loss_gradient = loss_value_and_grad(params, state=opt_state)
322
+ updates, opt_state = optim.update(
323
+ loss_gradient,
324
+ opt_state,
325
+ params,
326
+ value=loss_val,
327
+ grad=loss_gradient,
328
+ value_fn=loss,
329
+ )
330
+ params = ox.apply_updates(params, updates)
331
+
332
+ return params, opt_state
333
+
334
+ def continue_fn(carry):
335
+ _, opt_state = carry
336
+ n = ox.tree_utils.tree_get(opt_state, "count")
337
+ g = ox.tree_utils.tree_get(opt_state, "grad")
338
+ g_l2_norm = ox.tree_utils.tree_l2_norm(g)
339
+ return (n == 0) | ((n < max_iters) & (g_l2_norm >= gtol))
340
+
341
+ # Optimisation loop
342
+ params, opt_state = jax.lax.while_loop(
343
+ continue_fn,
344
+ step,
345
+ (params, opt_state),
346
+ )
347
+ final_loss = ox.tree_utils.tree_get(opt_state, "value")
348
+
349
+ # Parameters bijection to constrained space
350
+ if params_bijection is not None:
351
+ params = transform(params, params_bijection)
352
+
353
+ # Reconstruct model
354
+ model = nnx.merge(graphdef, params, *static_state)
355
+
356
+ return model, final_loss
357
+
358
+
256
359
  def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset:
257
360
  """Batch the data into mini-batches. Sampling is done with replacement.
258
361
 
@@ -68,7 +68,7 @@ class RFF(AbstractKernel):
68
68
 
69
69
  self.frequencies = Static(
70
70
  self.base_kernel.spectral_density.sample(
71
- seed=key, sample_shape=(self.num_basis_fns, n_dims)
71
+ key=key, sample_shape=(self.num_basis_fns, n_dims)
72
72
  )
73
73
  )
74
74
  self.name = f"{self.base_kernel.name} (RFF)"
@@ -23,7 +23,10 @@ from gpjax.kernels.computations import (
23
23
  AbstractKernelComputation,
24
24
  DenseKernelComputation,
25
25
  )
26
- from gpjax.parameters import PositiveReal
26
+ from gpjax.parameters import (
27
+ NonNegativeReal,
28
+ PositiveReal,
29
+ )
27
30
  from gpjax.typing import (
28
31
  Array,
29
32
  ScalarArray,
@@ -91,9 +94,9 @@ class ArcCosine(AbstractKernel):
91
94
  if isinstance(variance, nnx.Variable):
92
95
  self.variance = variance
93
96
  else:
94
- self.variance = PositiveReal(variance)
97
+ self.variance = NonNegativeReal(variance)
95
98
  if tp.TYPE_CHECKING:
96
- self.variance = tp.cast(PositiveReal[ScalarArray], self.variance)
99
+ self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
97
100
 
98
101
  if isinstance(bias_variance, nnx.Variable):
99
102
  self.bias_variance = bias_variance
@@ -23,7 +23,7 @@ from gpjax.kernels.computations import (
23
23
  AbstractKernelComputation,
24
24
  DenseKernelComputation,
25
25
  )
26
- from gpjax.parameters import PositiveReal
26
+ from gpjax.parameters import NonNegativeReal
27
27
  from gpjax.typing import (
28
28
  Array,
29
29
  ScalarArray,
@@ -64,9 +64,9 @@ class Linear(AbstractKernel):
64
64
  if isinstance(variance, nnx.Variable):
65
65
  self.variance = variance
66
66
  else:
67
- self.variance = PositiveReal(variance)
67
+ self.variance = NonNegativeReal(variance)
68
68
  if tp.TYPE_CHECKING:
69
- self.variance = tp.cast(PositiveReal[ScalarArray], self.variance)
69
+ self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
70
70
 
71
71
  def __call__(
72
72
  self,
@@ -23,7 +23,10 @@ from gpjax.kernels.computations import (
23
23
  AbstractKernelComputation,
24
24
  DenseKernelComputation,
25
25
  )
26
- from gpjax.parameters import PositiveReal
26
+ from gpjax.parameters import (
27
+ NonNegativeReal,
28
+ PositiveReal,
29
+ )
27
30
  from gpjax.typing import (
28
31
  Array,
29
32
  ScalarArray,
@@ -76,9 +79,9 @@ class Polynomial(AbstractKernel):
76
79
  if isinstance(variance, nnx.Variable):
77
80
  self.variance = variance
78
81
  else:
79
- self.variance = PositiveReal(variance)
82
+ self.variance = NonNegativeReal(variance)
80
83
  if tp.TYPE_CHECKING:
81
- self.variance = tp.cast(PositiveReal[ScalarArray], self.variance)
84
+ self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
82
85
 
83
86
  self.name = f"Polynomial (degree {self.degree})"
84
87
 
@@ -18,14 +18,17 @@ import beartype.typing as tp
18
18
  from flax import nnx
19
19
  import jax.numpy as jnp
20
20
  from jaxtyping import Float
21
- import tensorflow_probability.substrates.jax.distributions as tfd
21
+ import numpyro.distributions as npd
22
22
 
23
23
  from gpjax.kernels.base import AbstractKernel
24
24
  from gpjax.kernels.computations import (
25
25
  AbstractKernelComputation,
26
26
  DenseKernelComputation,
27
27
  )
28
- from gpjax.parameters import PositiveReal
28
+ from gpjax.parameters import (
29
+ NonNegativeReal,
30
+ PositiveReal,
31
+ )
29
32
  from gpjax.typing import (
30
33
  Array,
31
34
  ScalarArray,
@@ -85,14 +88,14 @@ class StationaryKernel(AbstractKernel):
85
88
  if isinstance(variance, nnx.Variable):
86
89
  self.variance = variance
87
90
  else:
88
- self.variance = PositiveReal(variance)
91
+ self.variance = NonNegativeReal(variance)
89
92
 
90
93
  # static typing
91
94
  if tp.TYPE_CHECKING:
92
- self.variance = tp.cast(PositiveReal[ScalarFloat], self.variance)
95
+ self.variance = tp.cast(NonNegativeReal[ScalarFloat], self.variance)
93
96
 
94
97
  @property
95
- def spectral_density(self) -> tfd.Distribution:
98
+ def spectral_density(self) -> npd.Normal | npd.StudentT:
96
99
  r"""The spectral density of the kernel.
97
100
 
98
101
  Returns:
@@ -15,7 +15,7 @@
15
15
 
16
16
  import jax.numpy as jnp
17
17
  from jaxtyping import Float
18
- import tensorflow_probability.substrates.jax.distributions as tfd
18
+ import numpyro.distributions as npd
19
19
 
20
20
  from gpjax.kernels.stationary.base import StationaryKernel
21
21
  from gpjax.kernels.stationary.utils import (
@@ -48,5 +48,5 @@ class Matern12(StationaryKernel):
48
48
  return K.squeeze()
49
49
 
50
50
  @property
51
- def spectral_density(self) -> tfd.Distribution:
51
+ def spectral_density(self) -> npd.StudentT:
52
52
  return build_student_t_distribution(nu=1)
@@ -15,7 +15,7 @@
15
15
 
16
16
  import jax.numpy as jnp
17
17
  from jaxtyping import Float
18
- import tensorflow_probability.substrates.jax.distributions as tfd
18
+ import numpyro.distributions as npd
19
19
 
20
20
  from gpjax.kernels.stationary.base import StationaryKernel
21
21
  from gpjax.kernels.stationary.utils import (
@@ -54,5 +54,5 @@ class Matern32(StationaryKernel):
54
54
  return K.squeeze()
55
55
 
56
56
  @property
57
- def spectral_density(self) -> tfd.Distribution:
57
+ def spectral_density(self) -> npd.StudentT:
58
58
  return build_student_t_distribution(nu=3)
@@ -15,7 +15,7 @@
15
15
 
16
16
  import jax.numpy as jnp
17
17
  from jaxtyping import Float
18
- import tensorflow_probability.substrates.jax.distributions as tfd
18
+ import numpyro.distributions as npd
19
19
 
20
20
  from gpjax.kernels.stationary.base import StationaryKernel
21
21
  from gpjax.kernels.stationary.utils import (
@@ -53,5 +53,5 @@ class Matern52(StationaryKernel):
53
53
  return K.squeeze()
54
54
 
55
55
  @property
56
- def spectral_density(self) -> tfd.Distribution:
56
+ def spectral_density(self) -> npd.StudentT:
57
57
  return build_student_t_distribution(nu=5)
@@ -15,7 +15,7 @@
15
15
 
16
16
  import jax.numpy as jnp
17
17
  from jaxtyping import Float
18
- import tensorflow_probability.substrates.jax as tfp
18
+ import numpyro.distributions as npd
19
19
 
20
20
  from gpjax.kernels.stationary.base import StationaryKernel
21
21
  from gpjax.kernels.stationary.utils import squared_distance
@@ -44,5 +44,5 @@ class RBF(StationaryKernel):
44
44
  return K.squeeze()
45
45
 
46
46
  @property
47
- def spectral_density(self) -> tfp.distributions.Normal:
48
- return tfp.distributions.Normal(0.0, 1.0)
47
+ def spectral_density(self) -> npd.Normal:
48
+ return npd.Normal(0.0, 1.0)
@@ -14,17 +14,15 @@
14
14
  # ==============================================================================
15
15
  import jax.numpy as jnp
16
16
  from jaxtyping import Float
17
- import tensorflow_probability.substrates.jax as tfp
17
+ import numpyro.distributions as npd
18
18
 
19
19
  from gpjax.typing import (
20
20
  Array,
21
21
  ScalarFloat,
22
22
  )
23
23
 
24
- tfd = tfp.distributions
25
24
 
26
-
27
- def build_student_t_distribution(nu: int) -> tfd.Distribution:
25
+ def build_student_t_distribution(nu: int) -> npd.StudentT:
28
26
  r"""Build a Student's t distribution with a fixed smoothness parameter.
29
27
 
30
28
  For a fixed half-integer smoothness parameter, compute the spectral density of a
@@ -37,7 +35,7 @@ def build_student_t_distribution(nu: int) -> tfd.Distribution:
37
35
  -------
38
36
  tfp.Distribution: A Student's t distribution with the same smoothness parameter.
39
37
  """
40
- dist = tfd.StudentT(df=nu, loc=0.0, scale=1.0)
38
+ dist = npd.StudentT(df=nu, loc=0.0, scale=1.0)
41
39
  return dist
42
40
 
43
41