gpjax 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -40,10 +40,9 @@ __license__ = "MIT"
40
40
  __description__ = "Gaussian processes in JAX and Flax"
41
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.12.0"
43
+ __version__ = "0.12.2"
44
44
 
45
45
  __all__ = [
46
- "base",
47
46
  "gps",
48
47
  "integrators",
49
48
  "kernels",
@@ -55,8 +54,6 @@ __all__ = [
55
54
  "Dataset",
56
55
  "cite",
57
56
  "fit",
58
- "Module",
59
- "param_field",
60
57
  "fit_lbfgs",
61
58
  "fit_scipy",
62
59
  ]
gpjax/fit.py CHANGED
@@ -48,6 +48,7 @@ def fit( # noqa: PLR0913
48
48
  train_data: Dataset,
49
49
  optim: ox.GradientTransformation,
50
50
  params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
51
+ trainable: nnx.filterlib.Filter = Parameter,
51
52
  key: KeyArray = jr.PRNGKey(42),
52
53
  num_iters: int = 100,
53
54
  batch_size: int = -1,
@@ -65,7 +66,7 @@ def fit( # noqa: PLR0913
65
66
  >>> import jax.random as jr
66
67
  >>> import optax as ox
67
68
  >>> import gpjax as gpx
68
- >>> from gpjax.parameters import PositiveReal, Static
69
+ >>> from gpjax.parameters import PositiveReal
69
70
  >>>
70
71
  >>> # (1) Create a dataset:
71
72
  >>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
@@ -75,10 +76,10 @@ def fit( # noqa: PLR0913
75
76
  >>> class LinearModel(nnx.Module):
76
77
  >>> def __init__(self, weight: float, bias: float):
77
78
  >>> self.weight = PositiveReal(weight)
78
- >>> self.bias = Static(bias)
79
+ >>> self.bias = bias
79
80
  >>>
80
81
  >>> def __call__(self, x):
81
- >>> return self.weight.value * x + self.bias.value
82
+ >>> return self.weight.value * x + self.bias
82
83
  >>>
83
84
  >>> model = LinearModel(weight=1.0, bias=1.0)
84
85
  >>>
@@ -100,6 +101,8 @@ def fit( # noqa: PLR0913
100
101
  train_data (Dataset): The training data to be used for the optimisation.
101
102
  optim (GradientTransformation): The Optax optimiser that is to be used for
102
103
  learning a parameter set.
104
+ trainable (nnx.filterlib.Filter): Filter to determine which parameters are trainable.
105
+ Defaults to nnx.Param (all Parameter instances).
103
106
  num_iters (int): The number of optimisation steps to run. Defaults
104
107
  to 100.
105
108
  batch_size (int): The size of the mini-batch to use. Defaults to -1
@@ -127,7 +130,7 @@ def fit( # noqa: PLR0913
127
130
  _check_verbose(verbose)
128
131
 
129
132
  # Model state filtering
130
- graphdef, params, *static_state = nnx.split(model, Parameter, ...)
133
+ graphdef, params, *static_state = nnx.split(model, trainable, ...)
131
134
 
132
135
  # Parameters bijection to unconstrained space
133
136
  if params_bijection is not None:
@@ -182,6 +185,7 @@ def fit_scipy( # noqa: PLR0913
182
185
  model: Model,
183
186
  objective: Objective,
184
187
  train_data: Dataset,
188
+ trainable: nnx.filterlib.Filter = Parameter,
185
189
  max_iters: int = 500,
186
190
  verbose: bool = True,
187
191
  safe: bool = True,
@@ -210,7 +214,7 @@ def fit_scipy( # noqa: PLR0913
210
214
  _check_verbose(verbose)
211
215
 
212
216
  # Model state filtering
213
- graphdef, params, *static_state = nnx.split(model, Parameter, ...)
217
+ graphdef, params, *static_state = nnx.split(model, trainable, ...)
214
218
 
215
219
  # Parameters bijection to unconstrained space
216
220
  params = transform(params, DEFAULT_BIJECTION, inverse=True)
@@ -258,6 +262,7 @@ def fit_lbfgs(
258
262
  objective: Objective,
259
263
  train_data: Dataset,
260
264
  params_bijection: tp.Union[dict[Parameter, Transform], None] = DEFAULT_BIJECTION,
265
+ trainable: nnx.filterlib.Filter = Parameter,
261
266
  max_iters: int = 100,
262
267
  safe: bool = True,
263
268
  max_linesearch_steps: int = 32,
@@ -290,7 +295,7 @@ def fit_lbfgs(
290
295
  _check_num_iters(max_iters)
291
296
 
292
297
  # Model state filtering
293
- graphdef, params, *static_state = nnx.split(model, Parameter, ...)
298
+ graphdef, params, *static_state = nnx.split(model, trainable, ...)
294
299
 
295
300
  # Parameters bijection to unconstrained space
296
301
  if params_bijection is not None:
gpjax/gps.py CHANGED
@@ -16,9 +16,9 @@
16
16
  from abc import abstractmethod
17
17
 
18
18
  import beartype.typing as tp
19
- from flax import nnx
20
19
  import jax.numpy as jnp
21
20
  import jax.random as jr
21
+ from flax import nnx
22
22
  from jaxtyping import (
23
23
  Float,
24
24
  Num,
@@ -35,16 +35,15 @@ from gpjax.likelihoods import (
35
35
  )
36
36
  from gpjax.linalg import (
37
37
  Dense,
38
- Identity,
39
38
  psd,
40
39
  solve,
41
40
  )
42
41
  from gpjax.linalg.operations import lower_cholesky
42
+ from gpjax.linalg.utils import add_jitter
43
43
  from gpjax.mean_functions import AbstractMeanFunction
44
44
  from gpjax.parameters import (
45
45
  Parameter,
46
46
  Real,
47
- Static,
48
47
  )
49
48
  from gpjax.typing import (
50
49
  Array,
@@ -78,7 +77,7 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
78
77
  self.mean_function = mean_function
79
78
  self.jitter = jitter
80
79
 
81
- def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> GaussianDistribution:
80
+ def __call__(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
82
81
  r"""Evaluate the Gaussian process at the given points.
83
82
 
84
83
  The output of this function is a
@@ -91,17 +90,16 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
91
90
  `__call__` method and should instead define a `predict` method.
92
91
 
93
92
  Args:
94
- *args (Any): The arguments to pass to the GP's `predict` method.
95
- **kwargs (Any): The keyword arguments to pass to the GP's `predict` method.
93
+ test_inputs: Input locations where the GP should be evaluated.
96
94
 
97
95
  Returns:
98
96
  GaussianDistribution: A multivariate normal random variable representation
99
97
  of the Gaussian process.
100
98
  """
101
- return self.predict(*args, **kwargs)
99
+ return self.predict(test_inputs)
102
100
 
103
101
  @abstractmethod
104
- def predict(self, *args: tp.Any, **kwargs: tp.Any) -> GaussianDistribution:
102
+ def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
105
103
  r"""Evaluate the predictive distribution.
106
104
 
107
105
  Compute the latent function's multivariate normal distribution for a
@@ -109,8 +107,7 @@ class AbstractPrior(nnx.Module, tp.Generic[M, K]):
109
107
  this method must be implemented.
110
108
 
111
109
  Args:
112
- *args (Any): Arguments to the predict method.
113
- **kwargs (Any): Keyword arguments to the predict method.
110
+ test_inputs: Input locations where the GP should be evaluated.
114
111
 
115
112
  Returns:
116
113
  GaussianDistribution: A multivariate normal random variable representation
@@ -249,13 +246,12 @@ class Prior(AbstractPrior[M, K]):
249
246
  GaussianDistribution: A multivariate normal random variable representation
250
247
  of the Gaussian process.
251
248
  """
252
- x = test_inputs
253
- mx = self.mean_function(x)
254
- Kxx = self.kernel.gram(x)
255
- Kxx_dense = Kxx.to_dense() + Identity(Kxx.shape).to_dense() * self.jitter
249
+ mean_at_test = self.mean_function(test_inputs)
250
+ Kxx = self.kernel.gram(test_inputs)
251
+ Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
256
252
  Kxx = psd(Dense(Kxx_dense))
257
253
 
258
- return GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Kxx)
254
+ return GaussianDistribution(jnp.atleast_1d(mean_at_test.squeeze()), Kxx)
259
255
 
260
256
  def sample_approx(
261
257
  self,
@@ -359,7 +355,9 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
359
355
  self.likelihood = likelihood
360
356
  self.jitter = jitter
361
357
 
362
- def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> GaussianDistribution:
358
+ def __call__(
359
+ self, test_inputs: Num[Array, "N D"], train_data: Dataset
360
+ ) -> GaussianDistribution:
363
361
  r"""Evaluate the Gaussian process posterior at the given points.
364
362
 
365
363
  The output of this function is a
@@ -368,28 +366,30 @@ class AbstractPosterior(nnx.Module, tp.Generic[P, L]):
368
366
  evaluated and the distribution can be sampled.
369
367
 
370
368
  Under the hood, `__call__` is calling the objects `predict` method. For this
371
- reasons, classes inheriting the `AbstractPrior` class, should not overwrite the
369
+ reasons, classes inheriting the `AbstractPosterior` class, should not overwrite the
372
370
  `__call__` method and should instead define a `predict` method.
373
371
 
374
372
  Args:
375
- *args (Any): The arguments to pass to the GP's `predict` method.
376
- **kwargs (Any): The keyword arguments to pass to the GP's `predict` method.
373
+ test_inputs: Input locations where the GP should be evaluated.
374
+ train_data: Training dataset to condition on.
377
375
 
378
376
  Returns:
379
377
  GaussianDistribution: A multivariate normal random variable representation
380
378
  of the Gaussian process.
381
379
  """
382
- return self.predict(*args, **kwargs)
380
+ return self.predict(test_inputs, train_data)
383
381
 
384
382
  @abstractmethod
385
- def predict(self, *args: tp.Any, **kwargs: tp.Any) -> GaussianDistribution:
383
+ def predict(
384
+ self, test_inputs: Num[Array, "N D"], train_data: Dataset
385
+ ) -> GaussianDistribution:
386
386
  r"""Compute the latent function's multivariate normal distribution for a
387
- given set of parameters. For any class inheriting the `AbstractPrior` class,
387
+ given set of parameters. For any class inheriting the `AbstractPosterior` class,
388
388
  this method must be implemented.
389
389
 
390
390
  Args:
391
- *args (Any): Arguments to the predict method.
392
- **kwargs (Any): Keyword arguments to the predict method.
391
+ test_inputs: Input locations where the GP should be evaluated.
392
+ train_data: Training dataset to condition on.
393
393
 
394
394
  Returns:
395
395
  GaussianDistribution: A multivariate normal random variable representation
@@ -503,22 +503,24 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
503
503
 
504
504
  # Precompute Gram matrix, Kxx, at training inputs, x
505
505
  Kxx = self.prior.kernel.gram(x)
506
- Kxx_dense = Kxx.to_dense() + Identity(Kxx.shape).to_dense() * self.jitter
506
+ Kxx_dense = add_jitter(Kxx.to_dense(), self.jitter)
507
507
  Kxx = Dense(Kxx_dense)
508
508
 
509
509
  Sigma_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * obs_noise
510
510
  Sigma = psd(Dense(Sigma_dense))
511
+ L_sigma = lower_cholesky(Sigma)
511
512
 
512
513
  mean_t = self.prior.mean_function(t)
513
514
  Ktt = self.prior.kernel.gram(t)
514
515
  Kxt = self.prior.kernel.cross_covariance(x, t)
515
- Sigma_inv_Kxt = solve(Sigma, Kxt)
516
516
 
517
- mean = mean_t + jnp.matmul(Sigma_inv_Kxt.T, y - mx)
517
+ L_inv_Kxt = solve(L_sigma, Kxt)
518
+ L_inv_y_diff = solve(L_sigma, y - mx)
519
+
520
+ mean = mean_t + jnp.matmul(L_inv_Kxt.T, L_inv_y_diff)
518
521
 
519
- # Ktt - Ktx (Kxx + Io²)⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
520
- covariance = Ktt.to_dense() - jnp.matmul(Kxt.T, Sigma_inv_Kxt)
521
- covariance += jnp.eye(covariance.shape[0]) * self.prior.jitter
522
+ covariance = Ktt.to_dense() - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt)
523
+ covariance = add_jitter(covariance, self.prior.jitter)
522
524
  covariance = psd(Dense(covariance))
523
525
 
524
526
  return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
@@ -577,7 +579,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
577
579
 
578
580
  obs_var = self.likelihood.obs_stddev.value**2
579
581
  Kxx = self.prior.kernel.gram(train_data.X)
580
- Sigma = Kxx + jnp.eye(Kxx.shape[0]) * (obs_var + self.jitter)
582
+ Sigma = Dense(add_jitter(Kxx.to_dense(), obs_var + self.jitter))
581
583
  eps = jnp.sqrt(obs_var) * jr.normal(key, [train_data.n, num_samples])
582
584
  y = train_data.y - self.prior.mean_function(train_data.X)
583
585
  Phi = fourier_feature_fn(train_data.X)
@@ -643,7 +645,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
643
645
 
644
646
  # TODO: static or intermediate?
645
647
  self.latent = latent if isinstance(latent, Parameter) else Real(latent)
646
- self.key = Static(key)
648
+ self.key = key
647
649
 
648
650
  def predict(
649
651
  self, test_inputs: Num[Array, "N D"], train_data: Dataset
@@ -675,7 +677,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
675
677
 
676
678
  # Precompute lower triangular of Gram matrix, Lx, at training inputs, x
677
679
  Kxx = kernel.gram(x)
678
- Kxx_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * self.prior.jitter
680
+ Kxx_dense = add_jitter(Kxx.to_dense(), self.prior.jitter)
679
681
  Kxx = psd(Dense(Kxx_dense))
680
682
  Lx = lower_cholesky(Kxx)
681
683
 
@@ -698,7 +700,7 @@ class NonConjugatePosterior(AbstractPosterior[P, NGL]):
698
700
 
699
701
  # Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
700
702
  covariance = Ktt.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
701
- covariance += jnp.eye(covariance.shape[0]) * self.prior.jitter
703
+ covariance = add_jitter(covariance, self.prior.jitter)
702
704
  covariance = psd(Dense(covariance))
703
705
 
704
706
  return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)
@@ -7,7 +7,6 @@ from jaxtyping import Float
7
7
  from gpjax.kernels.base import AbstractKernel
8
8
  from gpjax.kernels.computations import BasisFunctionComputation
9
9
  from gpjax.kernels.stationary.base import StationaryKernel
10
- from gpjax.parameters import Static
11
10
  from gpjax.typing import (
12
11
  Array,
13
12
  KeyArray,
@@ -66,10 +65,8 @@ class RFF(AbstractKernel):
66
65
  "Please specify the n_dims argument for the base kernel."
67
66
  )
68
67
 
69
- self.frequencies = Static(
70
- self.base_kernel.spectral_density.sample(
71
- key=key, sample_shape=(self.num_basis_fns, n_dims)
72
- )
68
+ self.frequencies = self.base_kernel.spectral_density.sample(
69
+ key=key, sample_shape=(self.num_basis_fns, n_dims)
73
70
  )
74
71
  self.name = f"{self.base_kernel.name} (RFF)"
75
72
 
gpjax/kernels/base.py CHANGED
@@ -32,7 +32,6 @@ from gpjax.linalg import LinearOperator
32
32
  from gpjax.parameters import (
33
33
  Parameter,
34
34
  Real,
35
- Static,
36
35
  )
37
36
  from gpjax.typing import (
38
37
  Array,
@@ -221,9 +220,7 @@ class Constant(AbstractKernel):
221
220
  def __init__(
222
221
  self,
223
222
  active_dims: tp.Union[list[int], slice, None] = None,
224
- constant: tp.Union[
225
- ScalarFloat, Parameter[ScalarFloat], Static[ScalarFloat]
226
- ] = jnp.array(0.0),
223
+ constant: tp.Union[ScalarFloat, Parameter[ScalarFloat]] = jnp.array(0.0),
227
224
  compute_engine: AbstractKernelComputation = DenseKernelComputation(),
228
225
  ):
229
226
  if isinstance(constant, Parameter):
@@ -57,7 +57,7 @@ class BasisFunctionComputation(AbstractKernelComputation):
57
57
  Returns:
58
58
  A matrix of shape $N \times L$ representing the random fourier features where $L = 2M$.
59
59
  """
60
- frequencies = kernel.frequencies.value
60
+ frequencies = kernel.frequencies
61
61
  scaling_factor = kernel.base_kernel.lengthscale.value
62
62
  z = jnp.matmul(x, (frequencies / scaling_factor).T)
63
63
  z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1)
@@ -42,7 +42,7 @@ class EigenKernelComputation(AbstractKernelComputation):
42
42
  # Transform the eigenvalues of the graph Laplacian according to the
43
43
  # RBF kernel's SPDE form.
44
44
  S = jnp.power(
45
- kernel.eigenvalues.value
45
+ kernel.eigenvalues
46
46
  + 2
47
47
  * kernel.smoothness.value
48
48
  / kernel.lengthscale.value
@@ -30,7 +30,6 @@ from gpjax.kernels.stationary.base import StationaryKernel
30
30
  from gpjax.parameters import (
31
31
  Parameter,
32
32
  PositiveReal,
33
- Static,
34
33
  )
35
34
  from gpjax.typing import (
36
35
  Array,
@@ -55,9 +54,9 @@ class GraphKernel(StationaryKernel):
55
54
  """
56
55
 
57
56
  num_vertex: tp.Union[ScalarInt, None]
58
- laplacian: Static[Float[Array, "N N"]]
59
- eigenvalues: Static[Float[Array, "N 1"]]
60
- eigenvectors: Static[Float[Array, "N N"]]
57
+ laplacian: Float[Array, "N N"]
58
+ eigenvalues: Float[Array, "N 1"]
59
+ eigenvectors: Float[Array, "N N"]
61
60
  name: str = "Graph Matérn"
62
61
 
63
62
  def __init__(
@@ -91,11 +90,11 @@ class GraphKernel(StationaryKernel):
91
90
  else:
92
91
  self.smoothness = PositiveReal(smoothness)
93
92
 
94
- self.laplacian = Static(laplacian)
95
- evals, eigenvectors = jnp.linalg.eigh(self.laplacian.value)
96
- self.eigenvectors = Static(eigenvectors)
97
- self.eigenvalues = Static(evals.reshape(-1, 1))
98
- self.num_vertex = self.eigenvalues.value.shape[0]
93
+ self.laplacian = laplacian
94
+ evals, eigenvectors = jnp.linalg.eigh(self.laplacian)
95
+ self.eigenvectors = eigenvectors
96
+ self.eigenvalues = evals.reshape(-1, 1)
97
+ self.num_vertex = self.eigenvalues.shape[0]
99
98
 
100
99
  super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
101
100
 
@@ -107,7 +106,7 @@ class GraphKernel(StationaryKernel):
107
106
  S,
108
107
  **kwargs,
109
108
  ):
110
- Kxx = (jax_gather_nd(self.eigenvectors.value, x) * S.squeeze()) @ jnp.transpose(
111
- jax_gather_nd(self.eigenvectors.value, y)
109
+ Kxx = (jax_gather_nd(self.eigenvectors, x) * S.squeeze()) @ jnp.transpose(
110
+ jax_gather_nd(self.eigenvectors, y)
112
111
  ) # shape (n,n)
113
112
  return Kxx.squeeze()
@@ -25,7 +25,6 @@ from gpjax.kernels.computations import (
25
25
  )
26
26
  from gpjax.parameters import (
27
27
  NonNegativeReal,
28
- PositiveReal,
29
28
  )
30
29
  from gpjax.typing import (
31
30
  Array,
@@ -82,30 +81,13 @@ class ArcCosine(AbstractKernel):
82
81
 
83
82
  self.order = order
84
83
 
85
- if isinstance(weight_variance, nnx.Variable):
86
- self.weight_variance = weight_variance
87
- else:
88
- self.weight_variance = PositiveReal(weight_variance)
89
- if tp.TYPE_CHECKING:
90
- self.weight_variance = tp.cast(
91
- PositiveReal[WeightVariance], self.weight_variance
92
- )
84
+ self.weight_variance = weight_variance
93
85
 
94
86
  if isinstance(variance, nnx.Variable):
95
87
  self.variance = variance
96
88
  else:
97
89
  self.variance = NonNegativeReal(variance)
98
- if tp.TYPE_CHECKING:
99
- self.variance = tp.cast(NonNegativeReal[ScalarArray], self.variance)
100
-
101
- if isinstance(bias_variance, nnx.Variable):
102
- self.bias_variance = bias_variance
103
- else:
104
- self.bias_variance = PositiveReal(bias_variance)
105
- if tp.TYPE_CHECKING:
106
- self.bias_variance = tp.cast(
107
- PositiveReal[ScalarArray], self.bias_variance
108
- )
90
+ self.bias_variance = bias_variance
109
91
 
110
92
  self.name = f"ArcCosine (order {self.order})"
111
93
 
@@ -141,7 +123,17 @@ class ArcCosine(AbstractKernel):
141
123
  Returns:
142
124
  ScalarFloat: The value of the weighted product between the two arguments``.
143
125
  """
144
- return jnp.inner(self.weight_variance.value * x, y) + self.bias_variance.value
126
+ weight_var = (
127
+ self.weight_variance.value
128
+ if hasattr(self.weight_variance, "value")
129
+ else self.weight_variance
130
+ )
131
+ bias_var = (
132
+ self.bias_variance.value
133
+ if hasattr(self.bias_variance, "value")
134
+ else self.bias_variance
135
+ )
136
+ return jnp.inner(weight_var * x, y) + bias_var
145
137
 
146
138
  def _J(self, theta: ScalarFloat) -> ScalarFloat:
147
139
  r"""Evaluate the angular dependency function corresponding to the desired order.
@@ -69,12 +69,9 @@ class Polynomial(AbstractKernel):
69
69
 
70
70
  self.degree = degree
71
71
 
72
- if isinstance(shift, nnx.Variable):
73
- self.shift = shift
74
- else:
75
- self.shift = PositiveReal(shift)
76
- if tp.TYPE_CHECKING:
77
- self.shift = tp.cast(PositiveReal[ScalarArray], self.shift)
72
+ self.shift = shift
73
+ if tp.TYPE_CHECKING and not isinstance(shift, nnx.Variable):
74
+ self.shift = tp.cast(PositiveReal[ScalarArray], self.shift)
78
75
 
79
76
  if isinstance(variance, nnx.Variable):
80
77
  self.variance = variance
@@ -88,7 +85,9 @@ class Polynomial(AbstractKernel):
88
85
  def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat:
89
86
  x = self.slice_input(x)
90
87
  y = self.slice_input(y)
91
- K = jnp.power(
92
- self.shift.value + self.variance.value * jnp.dot(x, y), self.degree
88
+ shift_val = self.shift.value if hasattr(self.shift, "value") else self.shift
89
+ variance_val = (
90
+ self.variance.value if hasattr(self.variance, "value") else self.variance
93
91
  )
92
+ K = jnp.power(shift_val + variance_val * jnp.dot(x, y), self.degree)
94
93
  return K.squeeze()
@@ -23,7 +23,6 @@ from gpjax.kernels.computations import (
23
23
  DenseKernelComputation,
24
24
  )
25
25
  from gpjax.kernels.stationary.base import StationaryKernel
26
- from gpjax.parameters import PositiveReal
27
26
  from gpjax.typing import (
28
27
  Array,
29
28
  ScalarArray,
@@ -72,10 +71,7 @@ class Periodic(StationaryKernel):
72
71
  covariance matrix.
73
72
  """
74
73
 
75
- if isinstance(period, nnx.Variable):
76
- self.period = period
77
- else:
78
- self.period = PositiveReal(period)
74
+ self.period = period
79
75
 
80
76
  super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
81
77
 
@@ -84,8 +80,9 @@ class Periodic(StationaryKernel):
84
80
  ) -> Float[Array, ""]:
85
81
  x = self.slice_input(x)
86
82
  y = self.slice_input(y)
83
+ period_val = self.period.value if hasattr(self.period, "value") else self.period
87
84
  sine_squared = (
88
- jnp.sin(jnp.pi * (x - y) / self.period.value) / self.lengthscale.value
85
+ jnp.sin(jnp.pi * (x - y) / period_val) / self.lengthscale.value
89
86
  ) ** 2
90
87
  K = self.variance.value * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0))
91
88
  return K.squeeze()
@@ -24,7 +24,6 @@ from gpjax.kernels.computations import (
24
24
  )
25
25
  from gpjax.kernels.stationary.base import StationaryKernel
26
26
  from gpjax.kernels.stationary.utils import euclidean_distance
27
- from gpjax.parameters import SigmoidBounded
28
27
  from gpjax.typing import (
29
28
  Array,
30
29
  ScalarArray,
@@ -76,10 +75,7 @@ class PoweredExponential(StationaryKernel):
76
75
  compute_engine: the computation engine that the kernel uses to compute the
77
76
  covariance matrix.
78
77
  """
79
- if isinstance(power, nnx.Variable):
80
- self.power = power
81
- else:
82
- self.power = SigmoidBounded(power)
78
+ self.power = power
83
79
 
84
80
  super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
85
81
 
@@ -88,7 +84,6 @@ class PoweredExponential(StationaryKernel):
88
84
  ) -> Float[Array, ""]:
89
85
  x = self.slice_input(x) / self.lengthscale.value
90
86
  y = self.slice_input(y) / self.lengthscale.value
91
- K = self.variance.value * jnp.exp(
92
- -(euclidean_distance(x, y) ** self.power.value)
93
- )
87
+ power_val = self.power.value if hasattr(self.power, "value") else self.power
88
+ K = self.variance.value * jnp.exp(-(euclidean_distance(x, y) ** power_val))
94
89
  return K.squeeze()
@@ -23,7 +23,6 @@ from gpjax.kernels.computations import (
23
23
  )
24
24
  from gpjax.kernels.stationary.base import StationaryKernel
25
25
  from gpjax.kernels.stationary.utils import squared_distance
26
- from gpjax.parameters import PositiveReal
27
26
  from gpjax.typing import (
28
27
  Array,
29
28
  ScalarArray,
@@ -70,17 +69,15 @@ class RationalQuadratic(StationaryKernel):
70
69
  compute_engine: The computation engine that the kernel uses to compute the
71
70
  covariance matrix.
72
71
  """
73
- if isinstance(alpha, nnx.Variable):
74
- self.alpha = alpha
75
- else:
76
- self.alpha = PositiveReal(alpha)
72
+ self.alpha = alpha
77
73
 
78
74
  super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
79
75
 
80
76
  def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarFloat:
81
77
  x = self.slice_input(x) / self.lengthscale.value
82
78
  y = self.slice_input(y) / self.lengthscale.value
83
- K = self.variance.value * (
84
- 1 + 0.5 * squared_distance(x, y) / self.alpha.value
85
- ) ** (-self.alpha.value)
79
+ alpha_val = self.alpha.value if hasattr(self.alpha, "value") else self.alpha
80
+ K = self.variance.value * (1 + 0.5 * squared_distance(x, y) / alpha_val) ** (
81
+ -alpha_val
82
+ )
86
83
  return K.squeeze()
gpjax/likelihoods.py CHANGED
@@ -29,7 +29,6 @@ from gpjax.integrators import (
29
29
  )
30
30
  from gpjax.parameters import (
31
31
  NonNegativeReal,
32
- Static,
33
32
  )
34
33
  from gpjax.typing import (
35
34
  Array,
@@ -59,27 +58,27 @@ class AbstractLikelihood(nnx.Module):
59
58
  self.num_datapoints = num_datapoints
60
59
  self.integrator = integrator
61
60
 
62
- def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> npd.Distribution:
61
+ def __call__(
62
+ self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
63
+ ) -> npd.Distribution:
63
64
  r"""Evaluate the likelihood function at a given predictive distribution.
64
65
 
65
66
  Args:
66
- *args (Any): Arguments to be passed to the likelihood's `predict` method.
67
- **kwargs (Any): Keyword arguments to be passed to the likelihood's
68
- `predict` method.
67
+ dist: The predictive distribution to evaluate the likelihood at.
69
68
 
70
69
  Returns:
71
70
  The predictive distribution.
72
71
  """
73
- return self.predict(*args, **kwargs)
72
+ return self.predict(dist)
74
73
 
75
74
  @abc.abstractmethod
76
- def predict(self, *args: tp.Any, **kwargs: tp.Any) -> npd.Distribution:
75
+ def predict(
76
+ self, dist: tp.Union[npd.MultivariateNormal, GaussianDistribution]
77
+ ) -> npd.Distribution:
77
78
  r"""Evaluate the likelihood function at a given predictive distribution.
78
79
 
79
80
  Args:
80
- *args (Any): Arguments to be passed to the likelihood's `predict` method.
81
- **kwargs (Any): Keyword arguments to be passed to the likelihood's
82
- `predict` method.
81
+ dist: The predictive distribution to evaluate the likelihood at.
83
82
 
84
83
  Returns:
85
84
  npd.Distribution: The predictive distribution.
@@ -133,9 +132,7 @@ class Gaussian(AbstractLikelihood):
133
132
  def __init__(
134
133
  self,
135
134
  num_datapoints: int,
136
- obs_stddev: tp.Union[
137
- ScalarFloat, Float[Array, "#N"], NonNegativeReal, Static
138
- ] = 1.0,
135
+ obs_stddev: tp.Union[ScalarFloat, Float[Array, "#N"], NonNegativeReal] = 1.0,
139
136
  integrator: AbstractIntegrator = AnalyticalGaussianIntegrator(),
140
137
  ):
141
138
  r"""Initializes the Gaussian likelihood.
@@ -148,7 +145,7 @@ class Gaussian(AbstractLikelihood):
148
145
  likelihoods. Must be an instance of `AbstractIntegrator`. For the Gaussian likelihood, this defaults to
149
146
  the `AnalyticalGaussianIntegrator`, as the expected log likelihood can be computed analytically.
150
147
  """
151
- if not isinstance(obs_stddev, (NonNegativeReal, Static)):
148
+ if not isinstance(obs_stddev, NonNegativeReal):
152
149
  obs_stddev = NonNegativeReal(jnp.asarray(obs_stddev))
153
150
  self.obs_stddev = obs_stddev
154
151
 
gpjax/linalg/utils.py CHANGED
@@ -1,5 +1,8 @@
1
1
  """Utility functions for the linear algebra module."""
2
2
 
3
+ import jax.numpy as jnp
4
+ from jaxtyping import Array
5
+
3
6
  from gpjax.linalg.operators import LinearOperator
4
7
 
5
8
 
@@ -31,3 +34,32 @@ def psd(A: LinearOperator) -> LinearOperator:
31
34
  A.annotations = set()
32
35
  A.annotations.add(PSD)
33
36
  return A
37
+
38
+
39
+ def add_jitter(matrix: Array, jitter: float | Array = 1e-6) -> Array:
40
+ """Add jitter to the diagonal of a matrix for numerical stability.
41
+
42
+ This function adds a small positive value (jitter) to the diagonal elements
43
+ of a square matrix to improve numerical stability, particularly for
44
+ Cholesky decompositions and matrix inversions.
45
+
46
+ Args:
47
+ matrix: A square matrix to which jitter will be added.
48
+ jitter: The jitter value to add to the diagonal. Defaults to 1e-6.
49
+
50
+ Returns:
51
+ The matrix with jitter added to its diagonal.
52
+
53
+ Examples:
54
+ >>> import jax.numpy as jnp
55
+ >>> from gpjax.linalg.utils import add_jitter
56
+ >>> matrix = jnp.array([[1.0, 0.5], [0.5, 1.0]])
57
+ >>> jittered_matrix = add_jitter(matrix, jitter=0.01)
58
+ """
59
+ if matrix.ndim != 2:
60
+ raise ValueError(f"Expected 2D matrix, got {matrix.ndim}D array")
61
+
62
+ if matrix.shape[0] != matrix.shape[1]:
63
+ raise ValueError(f"Expected square matrix, got shape {matrix.shape}")
64
+
65
+ return matrix + jnp.eye(matrix.shape[0]) * jitter
gpjax/mean_functions.py CHANGED
@@ -27,8 +27,6 @@ from jaxtyping import (
27
27
 
28
28
  from gpjax.parameters import (
29
29
  Parameter,
30
- Real,
31
- Static,
32
30
  )
33
31
  from gpjax.typing import (
34
32
  Array,
@@ -132,12 +130,12 @@ class Constant(AbstractMeanFunction):
132
130
 
133
131
  def __init__(
134
132
  self,
135
- constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter, Static] = 0.0,
133
+ constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter] = 0.0,
136
134
  ):
137
- if isinstance(constant, Parameter) or isinstance(constant, Static):
135
+ if isinstance(constant, Parameter):
138
136
  self.constant = constant
139
137
  else:
140
- self.constant = Real(jnp.array(constant))
138
+ self.constant = jnp.array(constant)
141
139
 
142
140
  def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]:
143
141
  r"""Evaluate the mean function at the given points.
@@ -148,7 +146,10 @@ class Constant(AbstractMeanFunction):
148
146
  Returns:
149
147
  Float[Array, "1"]: The evaluated mean function.
150
148
  """
151
- return jnp.ones((x.shape[0], 1)) * self.constant.value
149
+ if isinstance(self.constant, Parameter):
150
+ return jnp.ones((x.shape[0], 1)) * self.constant.value
151
+ else:
152
+ return jnp.ones((x.shape[0], 1)) * self.constant
152
153
 
153
154
 
154
155
  class Zero(Constant):
@@ -160,7 +161,7 @@ class Zero(Constant):
160
161
  """
161
162
 
162
163
  def __init__(self):
163
- super().__init__(constant=Static(jnp.array(0.0)))
164
+ super().__init__(constant=0.0)
164
165
 
165
166
 
166
167
  class CombinationMeanFunction(AbstractMeanFunction):
gpjax/objectives.py CHANGED
@@ -20,6 +20,7 @@ from gpjax.linalg import (
20
20
  psd,
21
21
  solve,
22
22
  )
23
+ from gpjax.linalg.utils import add_jitter
23
24
  from gpjax.typing import (
24
25
  Array,
25
26
  ScalarFloat,
@@ -97,7 +98,7 @@ def conjugate_mll(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat:
97
98
 
98
99
  # Σ = (Kxx + Io²) = LLᵀ
99
100
  Kxx = posterior.prior.kernel.gram(x)
100
- Kxx_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * posterior.prior.jitter
101
+ Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter)
101
102
  Sigma_dense = Kxx_dense + jnp.eye(Kxx.shape[0]) * obs_noise
102
103
  Sigma = psd(Dense(Sigma_dense))
103
104
 
@@ -213,7 +214,7 @@ def log_posterior_density(
213
214
 
214
215
  # Gram matrix
215
216
  Kxx = posterior.prior.kernel.gram(x)
216
- Kxx_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * posterior.prior.jitter
217
+ Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter)
217
218
  Kxx = psd(Dense(Kxx_dense))
218
219
  Lx = lower_cholesky(Kxx)
219
220
 
@@ -349,7 +350,7 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat:
349
350
  noise = variational_family.posterior.likelihood.obs_stddev.value**2
350
351
  z = variational_family.inducing_inputs.value
351
352
  Kzz = kernel.gram(z)
352
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * variational_family.jitter
353
+ Kzz_dense = add_jitter(Kzz.to_dense(), variational_family.jitter)
353
354
  Kzz = psd(Dense(Kzz_dense))
354
355
  Kzx = kernel.cross_covariance(z, x)
355
356
  Kxx_diag = vmap(kernel, in_axes=(0, 0))(x, x)
gpjax/parameters.py CHANGED
@@ -122,16 +122,6 @@ class SigmoidBounded(Parameter[T]):
122
122
  )
123
123
 
124
124
 
125
- class Static(nnx.Variable[T]):
126
- """Static parameter that is not trainable."""
127
-
128
- def __init__(self, value: T, tag: ParameterTag = "static", **kwargs):
129
- _check_is_arraylike(value)
130
-
131
- super().__init__(value=jnp.asarray(value), tag=tag, **kwargs)
132
- self._tag = tag
133
-
134
-
135
125
  class LowerTriangular(Parameter[T]):
136
126
  """Parameter that is a lower triangular matrix."""
137
127
 
@@ -40,11 +40,11 @@ from gpjax.linalg import (
40
40
  psd,
41
41
  solve,
42
42
  )
43
+ from gpjax.linalg.utils import add_jitter
43
44
  from gpjax.mean_functions import AbstractMeanFunction
44
45
  from gpjax.parameters import (
45
46
  LowerTriangular,
46
47
  Real,
47
- Static,
48
48
  )
49
49
  from gpjax.typing import (
50
50
  Array,
@@ -110,11 +110,10 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
110
110
  inducing_inputs: tp.Union[
111
111
  Float[Array, "N D"],
112
112
  Real,
113
- Static,
114
113
  ],
115
114
  jitter: ScalarFloat = 1e-6,
116
115
  ):
117
- if not isinstance(inducing_inputs, (Real, Static)):
116
+ if not isinstance(inducing_inputs, Real):
118
117
  inducing_inputs = Real(inducing_inputs)
119
118
 
120
119
  self.inducing_inputs = inducing_inputs
@@ -177,25 +176,31 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
177
176
  approximation and the GP prior.
178
177
  """
179
178
  # Unpack variational parameters
180
- mu = self.variational_mean.value
181
- sqrt = self.variational_root_covariance.value
182
- z = self.inducing_inputs.value
179
+ variational_mean = self.variational_mean.value
180
+ variational_sqrt = self.variational_root_covariance.value
181
+ inducing_inputs = self.inducing_inputs.value
183
182
 
184
183
  # Unpack mean function and kernel
185
184
  mean_function = self.posterior.prior.mean_function
186
185
  kernel = self.posterior.prior.kernel
187
186
 
188
- muz = mean_function(z)
189
- Kzz = kernel.gram(z)
190
- Kzz = psd(Dense(Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter))
187
+ inducing_mean = mean_function(inducing_inputs)
188
+ Kzz = kernel.gram(inducing_inputs)
189
+ Kzz = psd(Dense(add_jitter(Kzz.to_dense(), self.jitter)))
191
190
 
192
- sqrt = Triangular(sqrt)
193
- S = sqrt @ sqrt.T
191
+ variational_sqrt_triangular = Triangular(variational_sqrt)
192
+ variational_covariance = (
193
+ variational_sqrt_triangular @ variational_sqrt_triangular.T
194
+ )
194
195
 
195
- qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
196
- pu = GaussianDistribution(loc=jnp.atleast_1d(muz.squeeze()), scale=Kzz)
196
+ q_inducing = GaussianDistribution(
197
+ loc=jnp.atleast_1d(variational_mean.squeeze()), scale=variational_covariance
198
+ )
199
+ p_inducing = GaussianDistribution(
200
+ loc=jnp.atleast_1d(inducing_mean.squeeze()), scale=Kzz
201
+ )
197
202
 
198
- return qu.kl_divergence(pu)
203
+ return q_inducing.kl_divergence(p_inducing)
199
204
 
200
205
  def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
201
206
  r"""Compute the predictive distribution of the GP at the test inputs t.
@@ -215,26 +220,26 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
215
220
  the test inputs.
216
221
  """
217
222
  # Unpack variational parameters
218
- mu = self.variational_mean.value
219
- sqrt = self.variational_root_covariance.value
220
- z = self.inducing_inputs.value
223
+ variational_mean = self.variational_mean.value
224
+ variational_sqrt = self.variational_root_covariance.value
225
+ inducing_inputs = self.inducing_inputs.value
221
226
 
222
227
  # Unpack mean function and kernel
223
228
  mean_function = self.posterior.prior.mean_function
224
229
  kernel = self.posterior.prior.kernel
225
230
 
226
- Kzz = kernel.gram(z)
227
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
231
+ Kzz = kernel.gram(inducing_inputs)
232
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
228
233
  Kzz = psd(Dense(Kzz_dense))
229
234
  Lz = lower_cholesky(Kzz)
230
- muz = mean_function(z)
235
+ inducing_mean = mean_function(inducing_inputs)
231
236
 
232
237
  # Unpack test inputs
233
- t = test_inputs
238
+ test_points = test_inputs
234
239
 
235
- Ktt = kernel.gram(t)
236
- Kzt = kernel.cross_covariance(z, t)
237
- mut = mean_function(t)
240
+ Ktt = kernel.gram(test_points)
241
+ Kzt = kernel.cross_covariance(inducing_inputs, test_points)
242
+ test_mean = mean_function(test_points)
238
243
 
239
244
  # Lz⁻¹ Kzt
240
245
  Lz_inv_Kzt = solve(Lz, Kzt)
@@ -243,10 +248,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
243
248
  Kzz_inv_Kzt = solve(Lz.T, Lz_inv_Kzt)
244
249
 
245
250
  # Ktz Kzz⁻¹ sqrt
246
- Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, sqrt)
251
+ Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, variational_sqrt)
247
252
 
248
253
  # μt + Ktz Kzz⁻¹ (μ - μz)
249
- mean = mut + jnp.matmul(Kzz_inv_Kzt.T, mu - muz)
254
+ mean = test_mean + jnp.matmul(Kzz_inv_Kzt.T, variational_mean - inducing_mean)
250
255
 
251
256
  # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = sqrt sqrtᵀ]
252
257
  covariance = (
@@ -254,7 +259,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
254
259
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
255
260
  + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
256
261
  )
257
- covariance += jnp.eye(covariance.shape[0]) * self.jitter
262
+ if hasattr(covariance, "to_dense"):
263
+ covariance = covariance.to_dense()
264
+ covariance = add_jitter(covariance, self.jitter)
265
+ covariance = Dense(covariance)
258
266
 
259
267
  return GaussianDistribution(
260
268
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -329,7 +337,7 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
329
337
  kernel = self.posterior.prior.kernel
330
338
 
331
339
  Kzz = kernel.gram(z)
332
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
340
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
333
341
  Kzz = psd(Dense(Kzz_dense))
334
342
  Lz = lower_cholesky(Kzz)
335
343
 
@@ -355,7 +363,10 @@ class WhitenedVariationalGaussian(VariationalGaussian[L]):
355
363
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
356
364
  + jnp.matmul(Ktz_Lz_invT_sqrt, Ktz_Lz_invT_sqrt.T)
357
365
  )
358
- covariance += jnp.eye(covariance.shape[0]) * self.jitter
366
+ if hasattr(covariance, "to_dense"):
367
+ covariance = covariance.to_dense()
368
+ covariance = add_jitter(covariance, self.jitter)
369
+ covariance = Dense(covariance)
359
370
 
360
371
  return GaussianDistribution(
361
372
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -390,8 +401,8 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
390
401
  if natural_matrix is None:
391
402
  natural_matrix = -0.5 * jnp.eye(self.num_inducing)
392
403
 
393
- self.natural_vector = Static(natural_vector)
394
- self.natural_matrix = Static(natural_matrix)
404
+ self.natural_vector = Real(natural_vector)
405
+ self.natural_matrix = Real(natural_matrix)
395
406
 
396
407
  def prior_kl(self) -> ScalarFloat:
397
408
  r"""Compute the KL-divergence between our current variational approximation
@@ -422,7 +433,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
422
433
 
423
434
  # S⁻¹ = -2θ₂
424
435
  S_inv = -2 * natural_matrix
425
- S_inv += jnp.eye(m) * self.jitter
436
+ S_inv = add_jitter(S_inv, self.jitter)
426
437
 
427
438
  # Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril:
428
439
  sqrt_inv = jnp.swapaxes(
@@ -441,7 +452,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
441
452
 
442
453
  muz = mean_function(z)
443
454
  Kzz = kernel.gram(z)
444
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
455
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
445
456
  Kzz = psd(Dense(Kzz_dense))
446
457
 
447
458
  qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
@@ -476,7 +487,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
476
487
 
477
488
  # S⁻¹ = -2θ₂
478
489
  S_inv = -2 * natural_matrix
479
- S_inv += jnp.eye(m) * self.jitter
490
+ S_inv = add_jitter(S_inv, self.jitter)
480
491
 
481
492
  # Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril:
482
493
  sqrt_inv = jnp.swapaxes(
@@ -493,7 +504,7 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
493
504
  mu = jnp.matmul(S, natural_vector)
494
505
 
495
506
  Kzz = kernel.gram(z)
496
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
507
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
497
508
  Kzz = psd(Dense(Kzz_dense))
498
509
  Lz = lower_cholesky(Kzz)
499
510
  muz = mean_function(z)
@@ -520,7 +531,10 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian[L]):
520
531
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
521
532
  + jnp.matmul(Ktz_Kzz_inv_L, Ktz_Kzz_inv_L.T)
522
533
  )
523
- covariance += jnp.eye(covariance.shape[0]) * self.jitter
534
+ if hasattr(covariance, "to_dense"):
535
+ covariance = covariance.to_dense()
536
+ covariance = add_jitter(covariance, self.jitter)
537
+ covariance = Dense(covariance)
524
538
 
525
539
  return GaussianDistribution(
526
540
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -556,8 +570,8 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
556
570
  if expectation_matrix is None:
557
571
  expectation_matrix = jnp.eye(self.num_inducing)
558
572
 
559
- self.expectation_vector = Static(expectation_vector)
560
- self.expectation_matrix = Static(expectation_matrix)
573
+ self.expectation_vector = Real(expectation_vector)
574
+ self.expectation_matrix = Real(expectation_matrix)
561
575
 
562
576
  def prior_kl(self) -> ScalarFloat:
563
577
  r"""Evaluate the prior KL-divergence.
@@ -595,12 +609,12 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
595
609
  # S = η₂ - η₁ η₁ᵀ
596
610
  S = expectation_matrix - jnp.outer(mu, mu)
597
611
  S = psd(Dense(S))
598
- S_dense = S.to_dense() + jnp.eye(S.shape[0]) * self.jitter
612
+ S_dense = add_jitter(S.to_dense(), self.jitter)
599
613
  S = psd(Dense(S_dense))
600
614
 
601
615
  muz = mean_function(z)
602
616
  Kzz = kernel.gram(z)
603
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
617
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
604
618
  Kzz = psd(Dense(Kzz_dense))
605
619
 
606
620
  qu = GaussianDistribution(loc=jnp.atleast_1d(mu.squeeze()), scale=S)
@@ -640,14 +654,14 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
640
654
 
641
655
  # S = η₂ - η₁ η₁ᵀ
642
656
  S = expectation_matrix - jnp.matmul(mu, mu.T)
643
- S = Dense(S + jnp.eye(S.shape[0]) * self.jitter)
657
+ S = Dense(add_jitter(S, self.jitter))
644
658
  S = psd(S)
645
659
 
646
660
  # S = sqrt sqrtᵀ
647
661
  sqrt = lower_cholesky(S)
648
662
 
649
663
  Kzz = kernel.gram(z)
650
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
664
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
651
665
  Kzz = psd(Dense(Kzz_dense))
652
666
  Lz = lower_cholesky(Kzz)
653
667
  muz = mean_function(z)
@@ -677,7 +691,10 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian[L]):
677
691
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
678
692
  + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
679
693
  )
680
- covariance += jnp.eye(covariance.shape[0]) * self.jitter
694
+ if hasattr(covariance, "to_dense"):
695
+ covariance = covariance.to_dense()
696
+ covariance = add_jitter(covariance, self.jitter)
697
+ covariance = Dense(covariance)
681
698
 
682
699
  return GaussianDistribution(
683
700
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -734,7 +751,7 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
734
751
 
735
752
  Kzx = kernel.cross_covariance(z, x)
736
753
  Kzz = kernel.gram(z)
737
- Kzz_dense = Kzz.to_dense() + jnp.eye(Kzz.shape[0]) * self.jitter
754
+ Kzz_dense = add_jitter(Kzz.to_dense(), self.jitter)
738
755
  Kzz = psd(Dense(Kzz_dense))
739
756
 
740
757
  # Lz Lzᵀ = Kzz
@@ -780,7 +797,10 @@ class CollapsedVariationalGaussian(AbstractVariationalGaussian[GL]):
780
797
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
781
798
  + jnp.matmul(L_inv_Lz_inv_Kzt.T, L_inv_Lz_inv_Kzt)
782
799
  )
783
- covariance += jnp.eye(covariance.shape[0]) * self.jitter
800
+ if hasattr(covariance, "to_dense"):
801
+ covariance = covariance.to_dense()
802
+ covariance = add_jitter(covariance, self.jitter)
803
+ covariance = Dense(covariance)
784
804
 
785
805
  return GaussianDistribution(
786
806
  loc=jnp.atleast_1d(mean.squeeze()), scale=covariance
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.12.0
3
+ Version: 0.12.2
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -17,7 +17,7 @@ Classifier: Programming Language :: Python :: 3.12
17
17
  Classifier: Programming Language :: Python :: 3.13
18
18
  Classifier: Programming Language :: Python :: Implementation :: CPython
19
19
  Classifier: Programming Language :: Python :: Implementation :: PyPy
20
- Requires-Python: <=3.13,>=3.10
20
+ Requires-Python: >=3.10
21
21
  Requires-Dist: beartype>0.16.1
22
22
  Requires-Dist: flax>=0.10.0
23
23
  Requires-Dist: jax>=0.5.0
@@ -60,7 +60,7 @@ Requires-Dist: mkdocs-jupyter>=0.24.3; extra == 'docs'
60
60
  Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
61
61
  Requires-Dist: mkdocs-material>=9.5.12; extra == 'docs'
62
62
  Requires-Dist: mkdocs>=1.5.3; extra == 'docs'
63
- Requires-Dist: mkdocstrings[python]<0.28.0; extra == 'docs'
63
+ Requires-Dist: mkdocstrings[python]<0.31.0; extra == 'docs'
64
64
  Requires-Dist: nbconvert>=7.16.2; extra == 'docs'
65
65
  Requires-Dist: networkx>=3.0; extra == 'docs'
66
66
  Requires-Dist: pandas>=1.5.3; extra == 'docs'
@@ -126,18 +126,9 @@ Channel](https://join.slack.com/t/gpjax/shared_invite/zt-3cesiykcx-nzajjRdnV3ohw
126
126
  where we can discuss the development of GPJax and broader support for Gaussian
127
127
  process modelling.
128
128
 
129
-
130
- ## Governance
131
-
132
- GPJax was founded by [Thomas Pinder](https://github.com/thomaspinder). Today, the
133
- project's gardeners are [daniel-dodd@](https://github.com/daniel-dodd),
134
- [henrymoss@](https://github.com/henrymoss), [st--@](https://github.com/st--), and
135
- [thomaspinder@](https://github.com/thomaspinder), listed in alphabetical order. The full
136
- governance structure of GPJax is detailed [here](docs/GOVERNANCE.md). We appreciate all
137
- [the contributors to
138
- GPJax](https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors) who have
139
- helped to shape GPJax into the package it is today.
140
-
129
+ We appreciate all [the contributors to
130
+ GPJax](https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors) who have helped to shape
131
+ GPJax into the package it is today.
141
132
 
142
133
  # Supported methods and interfaces
143
134
 
@@ -218,13 +209,14 @@ configuration in development mode.
218
209
  ```bash
219
210
  git clone https://github.com/JaxGaussianProcesses/GPJax.git
220
211
  cd GPJax
221
- uv sync --extra dev
212
+ hatch env create
213
+ hatch shell
222
214
  ```
223
215
 
224
216
  > We recommend you check your installation passes the supplied unit tests:
225
217
  >
226
218
  > ```python
227
- > uv run pytest --beartype-packages='gpjax'
219
+ > hatch run dev:test
228
220
  > ```
229
221
 
230
222
  # Citing GPJax
@@ -1,52 +1,52 @@
1
- gpjax/__init__.py,sha256=FSrKDFSQ7xDqwQGBWwEPqqjvYxEbhUPPestKLoAPjWA,1686
1
+ gpjax/__init__.py,sha256=RzwpixFXn6HNHLVLy4LVXhFUk2c-_ce6n1gjZ2B93F0,1641
2
2
  gpjax/citation.py,sha256=pwFS8h1J-LE5ieRS0zDyuwhmQHNxkFHYE7iSMlVNmQc,3928
3
3
  gpjax/dataset.py,sha256=NsToLKq4lOsHnfLfukrUIRKvhOEuoUk8aHTF0oAqRbU,4079
4
4
  gpjax/distributions.py,sha256=iKmeQ_NN2CIjRiuOeJlwEGASzGROi4ZCerVi1uY7zRM,7758
5
- gpjax/fit.py,sha256=R4TIPvBNHYSg9vBVp6is_QYENldRLIU_FklGE85C-aA,15046
6
- gpjax/gps.py,sha256=-Log0pcU8qmB5fUxfzoNjD0S64gpiypAjFzjGXX6w7I,30301
5
+ gpjax/fit.py,sha256=I2sJVuKZii_d7MEcelHIivfM8ExYGMgdBuKKOT7Dw-A,15326
6
+ gpjax/gps.py,sha256=ipaeYMnPffhKK_JsEHe4fF8GmolQIjXB1YbyfUIL8H4,30118
7
7
  gpjax/integrators.py,sha256=eyJPqWNPKj6pKP5da0fEj4HW7BVyevqeGrurEuy_XPw,5694
8
- gpjax/likelihoods.py,sha256=99oTZoWld1M7vxgGM0pNY5Hnt2Ajd2lQNqawzrLmwtk,9308
9
- gpjax/mean_functions.py,sha256=-sVYO1_LWE8f34rllUOuaT5sgGGAdxo99v5kRo2d4oM,6490
8
+ gpjax/likelihoods.py,sha256=xwnSQpn6Aa-FPpEoDn_3xpBdPQAmHP97jP-9iJmT4G8,9087
9
+ gpjax/mean_functions.py,sha256=KiHQXI-b7o0Vi5KQxGm6RNsUjitJc9jEOCq2GrSx4II,6531
10
10
  gpjax/numpyro_extras.py,sha256=-vWJ7SpZVNhSdCjjrlxIkovMFrM1IzpsMJK3B4LioGE,3411
11
- gpjax/objectives.py,sha256=Tm36h8fz_nWkZPlufMQzZWKK1ytrtT9yvvP8YdxYKNw,15359
12
- gpjax/parameters.py,sha256=qIEqyMKNd2n2Ak15PisCmqhX5qhsoRgng_s4doL96rE,7044
11
+ gpjax/objectives.py,sha256=GvKbDIPqYjsc9FpiTccmZwRdHr2lCykgfxI9BX9I_GA,15362
12
+ gpjax/parameters.py,sha256=hnyIKr6uIzd7Kb3KZC9WowR88ruQwUvdcto3cx2ZDv4,6756
13
13
  gpjax/scan.py,sha256=jStQvwkE9MGttB89frxam1kaeXdWih7cVxkGywyaeHQ,5365
14
14
  gpjax/typing.py,sha256=M3CvWsYtZ3PFUvBvvbRNjpwerNII0w4yGuP0I-sLeYI,1705
15
- gpjax/variational_families.py,sha256=rE3LarwIAkvDvLlWrz8Ww6BUBz88YHdV4ceY97r3IBw,28637
15
+ gpjax/variational_families.py,sha256=TJGGkwkE805X4PQb-C32FxvD9B_OsFLWf6I-ZZvOUWk,29628
16
16
  gpjax/kernels/__init__.py,sha256=WZanH0Tpdkt0f7VfMqnalm_VZAMVwBqeOVaICNj6xQU,1901
17
- gpjax/kernels/base.py,sha256=hOUXwarspDFnuI2_QreyIVPdz2fzRVJj4p3Zdu1touw,11606
17
+ gpjax/kernels/base.py,sha256=4Lx8y3kPX4WqQZGRGAsBkqn_i6FlfoAhSn9Tv415xuQ,11551
18
18
  gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
19
- gpjax/kernels/approximations/rff.py,sha256=VbitjNuahFE5_IvCj1A0SxHhJXU0O0Qq0FMMVq8xA3E,4125
19
+ gpjax/kernels/approximations/rff.py,sha256=GbNUmDPEKEKuMwxUcocxl_9IFR3Q9KEPZXzjy_ZD-2w,4043
20
20
  gpjax/kernels/computations/__init__.py,sha256=uTVkqvnZVesFLDN92h0ZR0jfR69Eo2WyjOlmSYmCPJ8,1379
21
21
  gpjax/kernels/computations/base.py,sha256=L6K0roxZbrYeJKxEw-yaTiK9Mtcv0YtZfWI2Xnau7i8,3616
22
- gpjax/kernels/computations/basis_functions.py,sha256=MPSo40NEx_ngnSLTa9ntVJzma_jugvm5dMpZd5MtG5M,2490
22
+ gpjax/kernels/computations/basis_functions.py,sha256=_SFv4Tiwne40bxr1uVYpEjjZgjIQHKseLmss2Zgl1L4,2484
23
23
  gpjax/kernels/computations/constant_diagonal.py,sha256=JkQhLj7cK48IhOER4ivkALNhD1oQleKe-Rr9BtUJ6es,1984
24
24
  gpjax/kernels/computations/dense.py,sha256=vnW6XKQe4_gzpXRWTctxhgMA9-9TebdtiXzAqh_-j6g,1392
25
25
  gpjax/kernels/computations/diagonal.py,sha256=k1KqW0DwWRIBvbb7jzcKktXRfhXbcos3ncWrFplJ4W0,1768
26
- gpjax/kernels/computations/eigen.py,sha256=w7I7LK42j0ouchHCI1ltXx0lpwqvK1bRb4HclnF3rKs,1936
26
+ gpjax/kernels/computations/eigen.py,sha256=NTHm-cn-RepYuXFrvXo2ih7Gtu1YR_pAg4Jb7IhE_o8,1930
27
27
  gpjax/kernels/non_euclidean/__init__.py,sha256=RT7puRPqCTpyxZ16q596EuOQEQi1LK1v3J9_fWz1NlY,790
28
- gpjax/kernels/non_euclidean/graph.py,sha256=K4WIdX-dx1SsWuNHZnNjHFw8ElKZxGcReUiA3w4aCOI,4204
28
+ gpjax/kernels/non_euclidean/graph.py,sha256=xTrx6ro8ubRXgM7Wgg6NmOyyEjEcGhzydY7KXueknCc,4120
29
29
  gpjax/kernels/non_euclidean/utils.py,sha256=z42aw8ga0zuREzHawemR9okttgrAUPmq-aN5HMt4SuY,1578
30
30
  gpjax/kernels/nonstationary/__init__.py,sha256=YpWQfOy_cqOKc5ezn37vqoK3Z6jznYiJz28BD_8F7AY,930
31
- gpjax/kernels/nonstationary/arccosine.py,sha256=2WV6aM0Z3-xXZnoPw-77n2CW62n-AZuJy-7AQ9xrMco,5858
31
+ gpjax/kernels/nonstationary/arccosine.py,sha256=cqb8sqaNwW3fEbrA7MY9OF2KJFTkxHhqwmQtABE3G8w,5408
32
32
  gpjax/kernels/nonstationary/linear.py,sha256=UIMoCq2hg6dQKr4J5UGiiPqotBleQuYfy00Ia1NaMOo,2571
33
- gpjax/kernels/nonstationary/polynomial.py,sha256=arP8DK0jnBOaayDWcFvHF0pdu9FVhwzXdqjnHUAL2VI,3293
33
+ gpjax/kernels/nonstationary/polynomial.py,sha256=CKc02C7Utgo-hhcOOCcKLdln5lj4vud_8M-JY7SevJ8,3388
34
34
  gpjax/kernels/stationary/__init__.py,sha256=j4BMTaQlIx2kNAT1Dkf4iO2rm-f7_oSVWNrk1bN0tqE,1406
35
35
  gpjax/kernels/stationary/base.py,sha256=25qDqpZP4gNtzbyzDCW-6u7rJfMqkg0dW88XUmTTupU,7078
36
36
  gpjax/kernels/stationary/matern12.py,sha256=DGjqw6VveYsyy0TrufyJJvCei7p9slnm2f0TgRGG7_U,1773
37
37
  gpjax/kernels/stationary/matern32.py,sha256=laLsJWJozJzpYHBzlkPUq0rWxz1eWEwGC36P2nPJuaQ,1966
38
38
  gpjax/kernels/stationary/matern52.py,sha256=VSByD2sb7k-DzRFjaz31P3Rtc4bPPhHvMshrxZNFnns,2019
39
- gpjax/kernels/stationary/periodic.py,sha256=IAbCxURtJEHGdmYzbdrsqRZ3zJ8F8tGQF9O7sggafZk,3598
40
- gpjax/kernels/stationary/powered_exponential.py,sha256=8qT91IWKJK7PpEtFcX4MVu1ahWMOFOZierPko4JCjKA,3776
41
- gpjax/kernels/stationary/rational_quadratic.py,sha256=dYONp3i4rnKj3ET8UyxAKXv6UOl8uOFT3lCutleSvo4,3496
39
+ gpjax/kernels/stationary/periodic.py,sha256=f4PhWhKg-pJsEBGzEMK9pdbylO84GPKhzHlBC83ZVWw,3528
40
+ gpjax/kernels/stationary/powered_exponential.py,sha256=xuFGuIK0mKNMU3iLtZMXZTHXJuMFAMoX7gAtXefCdqU,3679
41
+ gpjax/kernels/stationary/rational_quadratic.py,sha256=zHo2LVW65T52XET4Hx9JaKO0TfxylV8WRUtP7sUUOx0,3418
42
42
  gpjax/kernels/stationary/rbf.py,sha256=euHUs6FdfRICQcabAWE4MX-7GEDr2TxgZWdFQiXr9Bw,1690
43
43
  gpjax/kernels/stationary/utils.py,sha256=6BI9EBcCzeeKx-XH-MfW1ORmtU__tPX5zyvfLhpkBsU,2180
44
44
  gpjax/kernels/stationary/white.py,sha256=TkdXXZCCjDs7JwR_gj5uvn2s1wyfRbe1vyHhUMJ8jjI,2212
45
45
  gpjax/linalg/__init__.py,sha256=F8mxk_9Zc2nFd7Q-unjJ50_6rXEKzZj572WsU_jUKqI,547
46
46
  gpjax/linalg/operations.py,sha256=xvhOy5P4FmUCPWjIVNdg1yDXaoFQ48anFUfR-Tnfr6k,6480
47
47
  gpjax/linalg/operators.py,sha256=arxRGwcoAy_RqUYqBpZ3XG6OXbjShUl7m8sTpg85npE,11608
48
- gpjax/linalg/utils.py,sha256=DGX40TDhmfYn7JBxElpBm_9W0cetm0HZUK7B3j74xxo,895
49
- gpjax-0.12.0.dist-info/METADATA,sha256=8lLQb5SUvWvniry-zBOR3wzm03tXvHe7Lzry_Ho3peE,10562
50
- gpjax-0.12.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
51
- gpjax-0.12.0.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
52
- gpjax-0.12.0.dist-info/RECORD,,
48
+ gpjax/linalg/utils.py,sha256=fKV8G_iKZVhNkNvN20D_dQEi93-8xosGbXBP-v7UEyo,2020
49
+ gpjax-0.12.2.dist-info/METADATA,sha256=eckQKXiBXi8XbBeJFviBAIPdBGVWGFQg7wVZwMfPPxs,10129
50
+ gpjax-0.12.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
51
+ gpjax-0.12.2.dist-info/licenses/LICENSE.txt,sha256=3umwi0h8wmKXOZO8XwRBwSl3vJt2hpWKEqSrSXLR7-I,1084
52
+ gpjax-0.12.2.dist-info/RECORD,,
File without changes